April 2020
Tutorial Notebook: mixture_density_networks_jax.ipynb
Reference paper: Mixture Density Networks (Bishop, 1994)
Related posts:
This tutorial is based on the recent PyTorch notebook with many improvements added by kylemcdonald.
Note: This notebook describes a slightly different loss formulation compared to the previous tutorials that is much more numerically stable, and is used in most of my other recent projects that needed MDNs.
JAX is a minimal framework to automatically calculate the gradients of native Python and NumPy / SciPy functions. It is a nice tool in the machine learning research toolbox.
Recommended JAX Tutorials: Getting started with JAX and You don't know JAX.
MIT
