-
Conda Environment:
- create an environment with
conda create -n serl python=3.10
- create an environment with
-
Recommended:
-
Assume the machines have the lastest Nvdia drivers and CUDA Versions (either 12.1 or 11.x)
-
Run
pip install --upgrade pip pip install -e .# CUDA 12 installation # Note: wheels only available on linux. pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # CUDA 11 installation # Note: wheels only available on linux. pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
-
-
Check here for JAX installation with local CUDA and CUDNN installations,
- This way can be more complicated.
-
For running experiments from vision, please also
git cloneandpip install -e .this library https://github.com/Leo428/efficientnet-jax. It is forked from https://github.com/rwightman/efficientnet-jax to support learning with pre-trained visual encoders (EfficientNet and MobileNets) in JAX and Flax.
This folder contains example usages of serl as in the paper.