CERTAIN (Context Uncertainty-aware One-Shot Adaptation) implements the framework described in “Context Uncertainty-aware One-Shot Adaptation for COMRL” . It provides:
- A heteroscedastic uncertainty estimation module to identify ambiguous or OOD transitions during task representation learning.
- An uncertainty-aware context-collecting policy that prioritizes low-uncertainty transitions when gathering a one-shot adaptation trajectory.
- An uncertainty-weighted fusion mechanism to produce more reliable task representations from a single collected trajectory.
- Plug-and-play compatibility with most existing COMRL pipelines (classifier-based, reconstruction-based, or contrastive-based). By jointly training an uncertainty network alongside the context encoder, CERTAIN significantly improves both one-shot and zero-shot adaptation performance over standard baselines.
To install locally, you will need to first install MuJoCo. For task distributions in which the reward function varies (Cheetah, Ant, Humanoid), install MuJoCo200. Set LD_LIBRARY_PATH to point to both the MuJoCo binaries (/$HOME/.mujoco/mujoco200/bin).
For the remaining dependencies, create conda environment by
conda env create -f environment.yaml
For Walker and Hopper environments, MuJoCo131 is required. Simply install it the same way as MuJoCo200. To switch between different MuJoCo versions:
export MUJOCO_PY_MJPRO_PATH=~/.mujoco/mjpro131
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/.mujoco/mjpro131/bin
export MUJOCO_PY_MJPRO_PATH=~/.mujoco/mujoco200
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/.mujoco/mujoco200/bin
The environments make use of the module rand_param_envs which is submoduled in this repository https://github.com/dennisl88/rand_param_envs. We modify some parameters of environment in random_param_envs.
The whole pipeline consists of tree stages: Data Generation Task Representation Training and Policy Training & One-Shot Adaptation:
CERTAIN requires fixed offline datasets for meta-training and meta-testing, which are generated by pre-trained SAC behavior policies. Experiments at this stage are configured via train.yaml and train_point.yaml located in ./rlkit/torch/sac/pytorch_sac/config/.
The following is to divide the all environments into 8 parts. All the environments in the 0 part are trained on gpu 0:
python policy_train.py --config ./configs/[ENV].json --split 8 --split_idx 0
Generated data will be saved in ./offline_dataset/
Experiments are configured via json configuration files located in ./configs. Basic settings are defined and described in ./configs/default.py. To train task encoder and uncertainty network, run:
python launch_experiment.py configs/point-robot.json --exp_name classifier_mix_z0_hvar_pre --gpu 0,1,2,3 --seed 0,1,2,3 --algo_type CLASSIFIER --pretrain true --z_strategy mean --train_z0_policy true --use_hvar true --mujoco_version 200
This script will:
- Load the offline dataset specified in configs/point-robot.json.
- rain—simultaneously—the following modules:
- Context encoder
$q_\phi(z \mid c)$ - Any networks associated with the chosen task-representation loss (classifier, reconstruction, or contrastive)
- Uncertainty network
$h_\psi(\sigma \mid z)$
- Context encoder
Output files will be written to ./logs/[ENV]/[EXP NAME]/seed[seed]. The file progress.csv contains statistics logged over the course of training. We recommend tensorboard for visualizing learning curves.
Before starting this stage, edit your JSON config (e.g., configs/point-robot.json) so that the field [algo_params] [ALGO TYPE] [pretrained_agent_path] points to the models you just trained under Task Representation Training ./logs/[ENV]/[EXP NAME]
Then launch:
python launch_experiment.py configs/point-robot.json --exp_name classifier_mix_z0_hvar_weighted --gpu 0,1,2,3 --seed 0,1,2,3 --algo_type CLASSIFIER --pretrain false --z_strategy weighted --train_z0_policy true --use_hvar true --mujoco_version 200
What this script does:
- Loads pretrained models (encoder.pth, uncertainty.pth, etc.) from the previous stage.
- Trains the meta-policy
$\pi_\theta(a \mid s, z)$ and the context-collecting policy$\pi_\theta(a \mid s, z_0)$ using the weighted-uncertainty strategy (z_strategy = weighted). - Periodically evaluates one-shot online adaptation by having the collector policy sample a single trajectory in a held-out test environment, computing the average return after task inference.
TensorBoard logs are recorded under:
./logs/[ENV]/classifier_mix_z0_hvar_weighted/seed[seed]/
If you use this code or our CERTAIN algorithm in your research, please cite our paper:
This project is licensed under the MIT License. See the LICENSE file for details.