This is the official repository of EMERALD (Efficient MaskEd latent tRAnsformer worLD model).
Read EMERALD paper on OpenReview | Arxiv
Clone GitHub repository and set up environment
git clone https://github.com/burchim/EMERALD && cd EMERALD
./install.sh
Run an experiment:
run_name=crafter python3 main.py
Training logs, replay buffer and checkpoints will be saved to callbacks/run_name.
Overriding model config hyperparameters:
run_name=crafter override_config='{"num_envs": 1, "epochs": 100, "eval_env_params": {"episode_saving_path": "./videos"}}' python3 main.py
tensorboard --logdir ./callbacks
'--mode evaluation' can be used to evaluate agents. The '--load_last' flag will scan the log directory to load the last checkpoint. '--checkpoint' can also be used to load a specific '.ckpt' checkpoint file.
run_name=crafter python3 main.py --load_last --mode evaluation
We provide crafter mean return and achievements score obtained over 20 different seeds with this repository in results/EMERALD.json.
# Args
-c / --config_file type=str default="configs/emerald.py" help="Python configuration file containing model hyperparameters"
-m / --mode type=str default="training" help="Mode: training, evaluation, pass"
-i / --checkpoint type=str default=None help="Load model from checkpoint name"
--cpu action="store_true" help="Load model on cpu"
--load_last action="store_true" help="Load last model checkpoint"
--wandb action="store_true" help="Initialize wandb logging"
--verbose_progress_bar type=int default=1 help="Verbose level of progress bar display"
# Training
--saving_period_epoch type=int default=1 help="Model saving every 'n' epochs"
--log_figure_period_step type=int default=None help="Log figure every 'n' steps"
--log_figure_period_epoch type=int default=1 help="Log figure every 'n' epochs"
--step_log_period type=int default=100 help="Training step log period"
--keep_last_k type=int default=3 help="Keep last k checkpoints"
# Eval
--eval_period_epoch type=int default=5 help="Model evaluation every 'n' epochs"
--eval_period_step type=int default=None help="Model evaluation every 'n' steps"
# Info
--show_dict action="store_true" help="Show model state dict summary"
--show_modules action="store_true" help="Show model named modules"
If this code or paper is helpful in your research, please use the following citation:
@inproceedings{burchiaccurate,
title={Accurate and Efficient World Modeling with Masked Latent Transformers},
author={Burchi, Maxime and Timofte, Radu},
booktitle={Forty-second International Conference on Machine Learning}
}
This project is licensed under the CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) License - see the LICENSE file for details. This project is for academic research and any commercial use requires contact of the owner: Computer Vision Lab, University of Wurzburg.
Official DreamerV3 Implementation: https://github.com/danijar/dreamerv3
Official MaskGIT Implementation: https://github.com/google-research/maskgit
Crafter Repository: https://github.com/danijar/crafter

