Skip to content

hardmaru/mdn_jax_tutorial

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 

Repository files navigation

Tutorial: Mixture Density Networks with JAX

April 2020

Tutorial Notebook: mixture_density_networks_jax.ipynb

Reference paper: Mixture Density Networks (Bishop, 1994)

Related posts:

This tutorial is based on the recent PyTorch notebook with many improvements added by kylemcdonald.

Note: This notebook describes a slightly different loss formulation compared to the previous tutorials that is much more numerically stable, and is used in most of my other recent projects that needed MDNs.

JAX is a minimal framework to automatically calculate the gradients of native Python and NumPy / SciPy functions. It is a nice tool in the machine learning research toolbox.

Recommended JAX Tutorials: Getting started with JAX and You don't know JAX.

License

MIT

About

Mixture Density Networks (Bishop, 1994) tutorial in JAX

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors