Skip to content

lillian039/ELF

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PyTorch ELF

PyTorch version of ELF: Embedded Language Flows.

Installation

Create a conda environment named elf and install the dependencies:

conda create -n elf python=3.10 -y
conda activate elf
pip install -r requirements.txt

Then log in to WandB to track your experiments if needed:

wandb login YOUR_WANDB_API_KEY

Converted Checkpoints

We provide PyTorch-converted versions of the official JAX checkpoints on HuggingFace:

Model Task Params HuggingFace Repo
ELF-B OpenWebText (unconditional) 105M embedded-language-flows/ELF-B-owt-torch
ELF-M OpenWebText (unconditional) 342M embedded-language-flows/ELF-M-owt-torch
ELF-L OpenWebText (unconditional) 652M embedded-language-flows/ELF-L-owt-torch
ELF-B XSum (summarization) 105M embedded-language-flows/ELF-B-xsum-torch
ELF-B WMT14 De-En (translation) 105M embedded-language-flows/ELF-B-de-en-torch

These are pulled automatically via --checkpoint_path <hf-repo-id> — no manual download needed.

Reference Results

The PyTorch port targets parity with the JAX reference numbers from the paper. Small differences (≲1 PPL, ≲0.5 BLEU/ROUGE) are expected due to bf16 vs. JAX TPU numerics and sampling stochasticity.

Unconditional generation (OpenWebText), expected:

Model Sampling Gen. PPL ↓ Entropy ↑
ELF-B (105M) 32-step SDE 24.1 5.15
ELF-M (342M) 64-step SDE 21.7 5.18
ELF-L (652M) 64-step SDE 23.3 5.28

Gen. PPL is computed under a frozen GPT-2 Large; entropy is unigram entropy over the generated tokens. Default sampling configs (src/configs/sampling_configs/uncond_sampling_configs.yml) use SC-CFG=3 and γ=1.5 (32-step) or γ=1.0 (64-step).

Conditional generation (ELF-B), expected on the validation set:

Task Metric Reference (paper, test) Validation
WMT14 De-En BLEU ↑ 26.4 ≈ 26.7
XSum ROUGE-1 ↑ 36.0 ≈ 36.3
XSum ROUGE-2 ↑ 12.2 ≈ 12.5
XSum ROUGE-L ↑ 27.8 ≈ 28.1

Default conditional sampling (src/configs/sampling_configs/cond_sampling_configs.yml): 64-step ODE, CFG=2, SC-CFG=1.

The paper numbers were computed on TPU v5p-64; numbers from this PyTorch port on 8× L40S / H200 should land within sampling noise (typically <1 PPL or <0.5 metric points).

Training

Launch single-GPU training:

bash scripts/launch.sh train src/configs/training_configs/train_owt_ELF-B.yml

Launch multi-GPU (single-host) training:

NGPU=8 bash scripts/launch.sh train src/configs/training_configs/train_owt_ELF-B.yml

Available training configs:

  • src/configs/training_configs/train_owt_ELF-B.yml — ELF-B on OpenWebText
  • src/configs/training_configs/train_owt_ELF-M.yml — ELF-M on OpenWebText
  • src/configs/training_configs/train_owt_ELF-L.yml — ELF-L on OpenWebText
  • src/configs/training_configs/train_de-en_ELF-B.yml — WMT14 De-En machine translation
  • src/configs/training_configs/train_xsum_ELF-B.yml — XSum abstractive summarization

Estimated wall-clock: ~4 h per epoch on 8× H200 (OpenWebText, ELF-B, global batch size 512, bf16). The default ELF-B OWT run is 5 epochs.

Evaluation

Run evaluation against the converted checkpoints on HuggingFace. We recommend passing use_bf16=true (matches the bf16 autocast used at training time) and use_compile=true (wraps the eval model in torch.compile) for a ~3–4× speedup on consumer GPUs:

Unconditional generation (OpenWebText):

# ELF-B (105M)
NGPU=8 bash scripts/launch.sh eval src/configs/training_configs/train_owt_ELF-B.yml \
    --checkpoint_path embedded-language-flows/ELF-B-owt-torch \
    --config_override use_bf16=true --config_override use_compile=true

# ELF-M (342M)
NGPU=8 bash scripts/launch.sh eval src/configs/training_configs/train_owt_ELF-M.yml \
    --checkpoint_path embedded-language-flows/ELF-M-owt-torch \
    --config_override use_bf16=true --config_override use_compile=true

# ELF-L (652M)
NGPU=8 bash scripts/launch.sh eval src/configs/training_configs/train_owt_ELF-L.yml \
    --checkpoint_path embedded-language-flows/ELF-L-owt-torch \
    --config_override use_bf16=true --config_override use_compile=true

Conditional generation (XSum / WMT14 De-En):

# XSum (ROUGE)
NGPU=8 bash scripts/launch.sh eval src/configs/training_configs/train_xsum_ELF-B.yml \
    --checkpoint_path embedded-language-flows/ELF-B-xsum-torch \
    --config_override use_bf16=true --config_override use_compile=true

# WMT14 De-En (BLEU)
NGPU=8 bash scripts/launch.sh eval src/configs/training_configs/train_de-en_ELF-B.yml \
    --checkpoint_path embedded-language-flows/ELF-B-de-en-torch \
    --config_override use_bf16=true --config_override use_compile=true

Eval config flags

Flag Default What it does
use_bf16 true Wraps the sampling forward in torch.amp.autocast('cuda', dtype=bfloat16). Mirrors the training-time precision; output heads stay fp32.
use_compile false Wraps the eval model in torch.compile. First batch is slower due to tracing; subsequent batches run materially faster.

Both flags are also editable in the YAML config under the same names. You can also run the standalone PPL script afterwards:

python scripts/eval_ppl.py \
    --input outputs/<run>/<sampling_dir>/all_generated_*.jsonl \
    --batch_size 16

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages