This repository contains the JAX implementation of the Solaris multiplayer world model for Minecraft. It supports GCP TPU training and inference, and GPU inference. It also contains the source code for the VLM-as-a-judge multiplayer self-consistency metric.
conda env create -f environment.yml
conda activate solaris
pip install -r requirements_gpu.txt
pip install -e .hf download nyu-visionx/solaris --local-dir ./pretrainedSee the nyu-visionx/solaris HF model repo for all available model weights.
hf download nyu-visionx/solaris-eval-datasets --local-dir ./datasets --repo-type datasetSee the nyu-visionx/solaris-eval-datasets for all available evaluation datasets.
For the simplest scenario, run this:
CUDA_VISIBLE_DEVICES=0 python src/inference.py experiment_name=solaris device.eval_num_samples=1It assumes the datasets are in ./datasets and uses the pretrained model weights at ./pretrained/solaris.pt. It will generate 1 video per eval dataset and write generated videos to ./output/. If you want to run on multiple GPUs, adjust the CUDA_VISIBLE_DEVICES env variable, making sure device.eval_num_samples is divisible by it. Inference always uses a per-device batch size of 1, which requires the GPU device to have at least 48GB memory. Refer to the sharding section for details.
GPU warnings
You might see the following GPU log messages:
2026-02-25 08:28:29.343101: E external/xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2026-02-25 08:28:29.472418: E external/xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2026-02-25 08:28:34.231109: W external/xla/xla/tsl/framework/bfc_allocator.cc:310] Allocator (GPU_0_bfc) ran out of memory trying to allocate 36.68GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
These are warnings and you can disregard them.
The code for the VLM evaluation metric lives under vlm_eval/. Refer to vlm_eval/README.md for how to run it and the implementation details.
To get the FID number, check the inference script log file. It outputs FID numbers as log messages by default.
Only TPU training is supported and requires a device with at least 95GB of memory (v5p) with a per-device batch size of 1. Refer to the sharding section for details.
conda env create -f environment.yml
conda activate solaris
pip install -r requirements_tpu.txt
pip install -e .In a multi-host TPU setting, you will need your conda environment on all hosts, which can be achieved by wrapping your installation instruction with gcloud alpha compute tpus tpu-vm ssh --command {COMMAND}.
There are many ways to store data on GCP TPUs, such as Persistent Disks or GCS buckets. Refer to the official guide for how to set it up. Note that your storage option will need to support writing as well to save training checkpoints and generated outputs.
hf download nyu-visionx/solaris --local-dir YOUR_STORAGE_PATH/pretrainedSee the nyu-visionx/solaris HF model repo for all available model weights.
hf download nyu-visionx/solaris-eval-datasets --local-dir YOUR_STORAGE_PATH/datasets --repo-type datasethf download nyu-visionx/solaris-training-dataset --local-dir YOUR_STORAGE_PATH/datasetsThe multiplayer Duet dataset is stored in a sharded form on HuggingFace. The above command will download it into YOUR_STORAGE_PATH/datasets/duet_sharded. Run the below command to unshard it to the original format that this codebase can work with:
python unshard_dataset.py --shards YOUR_STORAGE_PATH/datasets/duet_sharded --out YOUR_STORAGE_PATH/datasets/duetThe full training pipeline requires training on the single-player VPT dataset. Refer to vpt_datasets/README.md for instructions on how to set it up.
The training pipeline consists of four stages, each backed by a dedicated runner:
- Stage 1 — Single-player bidirectional pretraining
- Stage 2 — Multiplayer bidirectional training
- Stage 3 — Multiplayer causal training
- Stage 4 — Multiplayer self-forcing training
Below are the four example commands to run each training stage. Edit the folder paths to where you set them up and run the command as part of gcloud alpha compute tpus tpu-vm ssh --command {COMMAND} in a multi-host setting.
Note that running training automatically runs inference on the test split of the datasets. The training step and inference are JIT compiled functions which can time when running for the first time so the script might appear hanging at the beginning of the training and at the first evaluation.
This stage pretrains the initial Matrix Game 2.0 weights (available as matrix-game-init) on the VPT dataset, extending the action space.
python src/train.py \
runner=trainer_sp \
model=single_player \
dataset=vpt \
+dataset@eval_datasets.vpt=vpt \
~dataset@eval_datasets.duet \
experiment_name=sp_bidirectional_pretrain \
wandb_entity="YOUR_WANDB_ENTITY" \
device.batch_size=64 \
device.eval_num_samples=64 \
device.data_dir="YOUR_DATASETS_DIR" \
device.pretrained_model_dir="YOUR_PRETRAINED_MODEL_DIR" \
device.output_dir="YOUR_OUTPUT_DIR" \
device.checkpoint_dir="YOUR_CHECKPOINT_DIR" \
device.jax_cache_dir="YOUR_JAX_CACHE_DIRIt will train for 120K steps. The final model weights are the initialization for Stage 2. It will save them to YOUR_PRETRAINED_MODEL_DIR/sp_bidirectional_pretrain_120000.pt.
This stage trains the multiplayer bidirectional model on the Duet datasets obtained from SolarisEngine, starting from the pretrained single player model.
python src/train.py \
runner=trainer_mp_bidirectional \
experiment_name=mp_bidirectional \
wandb_entity="YOUR_WANDB_ENTITY" \
device.data_dir="YOUR_DATASETS_DIR" \
device.pretrained_model_dir="YOUR_PRETRAINED_MODEL_DIR" \
device.output_dir="YOUR_OUTPUT_DIR" \
device.checkpoint_dir="YOUR_CHECKPOINT_DIR" \
device.jax_cache_dir="YOUR_JAX_CACHE_DIRIt starts from the model weights at YOUR_PRETRAINED_MODEL_DIR/sp_bidirectional_pretrain_120000.pt and trans for 120k steps. Its final model weights are initialization for Stage 3 and the teacher and critic in Stage 4.
It will save them to YOUR_PRETRAINED_MODEL_DIR/mp_bidirectional_120000.pt.
This stage converts the multiplayer bidirectional model to causal using the Diffusion Forcing objective and a causal attention mask, training on the same Duet dataset.
python src/train.py \
runner=trainer_mp_causal \
experiment_name=mp_causal \
wandb_entity="YOUR_WANDB_ENTITY" \
device.data_dir="YOUR_DATASETS_DIR" \
device.pretrained_model_dir="YOUR_PRETRAINED_MODEL_DIR" \
device.output_dir="YOUR_OUTPUT_DIR" \
device.checkpoint_dir="YOUR_CHECKPOINT_DIR" \
device.jax_cache_dir="YOUR_JAX_CACHE_DIRIt starts from the model weights at YOUR_PRETRAINED_MODEL_DIR/mp_bidirectional_120000.pt and trans for 60k steps. Its final model weights are initialization for the student in Stage 4.
It will save them to YOUR_PRETRAINED_MODEL_DIR/mp_causal_60000.pt.
This stage finetunes the multiplayer causal model (student) on its own rollouts, distilling from the multiplayer bidirectional model (teacher). This stage removes the test time distribution mismatch and makes the final multiplayer causal model a few-step diffusion model.
python src/train.py \
runner=trainer_mp_sf \
experiment_name=mp_sf \
wandb_entity="YOUR_WANDB_ENTITY" \
device.data_dir="YOUR_DATASETS_DIR" \
device.pretrained_model_dir="YOUR_PRETRAINED_MODEL_DIR" \
device.output_dir="YOUR_OUTPUT_DIR" \
device.checkpoint_dir="YOUR_CHECKPOINT_DIR" \
device.jax_cache_dir="YOUR_JAX_CACHE_DIR" \
save_model_state_to="YOUR_PRETRAINED_MODEL_DIR/solaris.pt"It initializes the student from YOUR_PRETRAINED_MODEL_DIR/mp_causal_60000.pt, and the teacher and critic from YOUR_PRETRAINED_MODEL_DIR/mp_bidirectional_120000.pt, and trains for 1.2K steps. It will save the final model weights to YOUR_PRETRAINED_MODEL_DIR/solaris.pt which can be used for inference and evaluation.
TPU Inference requires the same setup as TPU training, except that it doesn't need the training datasets part.
Edit the folder paths to where you've set them up and run the below command as part of gcloud alpha compute tpus tpu-vm ssh --command {COMMAND} in a multi-host TPU setting:
python src/inference.py \
device=tpu \
experiment_name=solaris \
device.data_dir="YOUR_DATASETS_DIR" \
device.pretrained_model_dir="YOUR_PRETRAINED_MODEL_DIR" \
device.output_dir="YOUR_OUTPUT_DIR" \
device.checkpoint_dir="YOUR_CHECKPOINT_DIR" \
device.jax_cache_dir="YOUR_JAX_CACHE_DIR"It will use YOUR_PRETRAINED_MODEL_DIR/solaris.pt model weights for inference.
This project uses Hydra for configuration. The configs live in config/. They customize runners, model architectures, and datasets.
All training stages and inference code are built around runners. They use inheritance to ensure abstraction and code sharing.
| Runner | Description |
|---|---|
| Base Runner | Abstract base for all runners. Handles common evaluation and utility logic. |
| Base Trainer | Base trainer for all training stages. Add a training loop, checkpointing, and logging on top of BaseRunner. |
| Base MP Runner | Base multiplayer runner with evaluation and utilities (batching, rollouts, FID computation). |
| Base SSL Trainer | Base SSL trainer with utilities used by both SP and MP trainers. |
| Trainer SP | Single-player bidirectional training used in Stage 1. |
| Trainer MP | Multiplayer trainer used in Stage 2 and 3. |
| Trainer MP SF | Multiplayer self-forcing trainer used in Stage 4. |
| Inference | Inference-only runner for rollouts and metrics. |
Below is the class inheritance diagram for all runners.
Here is a summary of what runners each training stage uses:
- Stage 1 — Single-player bidirectional pretraining:
Trainer SP - Stage 2 — Multiplayer bidirectional training:
Trainer MPwithbidirectional=True - Stage 3 — Multiplayer causal training:
Trainer MPwithbidirectional=False - Stage 4 — Multiplayer self-forcing training:
Trainer MP SF
The codebase supports three model architectures:
| Model | Description | Config file | JAX Module file |
|---|---|---|---|
Solaris |
Multiplayer world model using the multiplayer attention multiplayer method | config/model/solaris.yaml | src/models/multiplayer/world_model.py |
Single Player |
Single-player world model following Matrix Game 2.0 architecture, with the keyboard action dimension increased to 23 | config/model/single_player.yaml | src/models/singleplayer/world_model.py |
Concat |
Multiplayer model using the channel concatenation multiplayer method | config/model/concat.yaml | src/models/multiplayer/world_model.py |
This repository supports two types of datasets: training and evaluation datasets. The former is used for training and test loss calculation, and the latter for inference and metrics calculation.
vpt and duet are two datasets that are both training and evaluation datasets, where inference for evaluation happens on their test splits. There are 7 evaluation-only datasets: eval_structure (Building), eval_turn_to_look_opposite (Consistency), eval_turn_to_look (Consistency), eval_one_looks_away (Grounding), eval_both_look_away (Memory), eval_rotation (Movement), and eval_translation (Movement). Every dataset has a corresponding config file in config/dataset/, and every dataset that is used for evaluation has a dedicated eval_ids file in src/data/eval_ids/. The eval ids file together with EvalBatchSampler() defined in src/data/batch_sampler.py ensure that evaluation always happens on the same episode segments regardless of the number of GPU/TPU devices used for inference.
Below is a table summarizing all datasets in the codebase:
| Name | Config | Training | Evaluation |
|---|---|---|---|
vpt |
config/dataset/vpt.yaml | ✓ | ✓ |
duet |
config/dataset/duet.yaml | ✓ | ✓ |
eval_structure |
config/dataset/eval_structure.yaml | ✓ | |
eval_turn_to_look |
config/dataset/eval_turn_to_look.yaml | ✓ | |
eval_turn_to_look_opposite |
config/dataset/eval_turn_to_look_opposite.yaml | ✓ | |
eval_one_looks_away |
config/dataset/eval_one_looks_away.yaml | ✓ | |
eval_both_look_away |
config/dataset/eval_both_look_away.yaml | ✓ | |
eval_rotation |
config/dataset/eval_rotation.yaml | ✓ | |
eval_translation |
config/dataset/eval_translation.yaml | ✓ |
This codebase doesn't implement FSDP and fully replicates the optimizer/model states across all devices. We found that this setup is sufficient for inference on a 48GB GPU and for training on a 95GB TPU (v5p) with a per-device batch of 1. However, with this setup, training OOMs on a80GB GPU, thus GPU training is not supported.
The codebase is covered with tests that live under src/tests/. Refer to src/tests/README.md for how to run them and what they cover.