Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions docs/guides/sft.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,50 @@ As long as your custom dataset has the `formatted_ds` and `task_spec` attributes
## Evaluate the Trained Model

Upon completion of the training process, you can refer to our [evaluation guide](eval.md) to assess model capabilities.


## LoRA Configuration

NeMo RL supports LoRA (Low-Rank Adaptation) for parameter-efficient fine-tuning. LoRA reduces trainable parameters by using low-rank matrices for weight updates while keeping the base model frozen.

Notes:
- LoRA is supported with DTensor v2 and Megatron backends. DTensor v1 does not support LoRA (ensure `policy.dtensor_cfg._v2=true` when using DTensor).
- Triton kernels are only used in the DTensor v2 path. For TP > 1, Automodel currently does not support Triton kernels (see note below).

### Configuration Parameters

The LoRA configuration is specified under the `policy.dtensor_cfg.lora_cfg` section:

policy:
dtensor_cfg:
lora_cfg:
enabled: False # Set to True to enable LoRA fine-tuning
target_modules: [] # List of module names to apply LoRA
exclude_modules: [] # List of module names to exclude from LoRA
match_all_linear: true # Apply LoRA to all linear layers
dim: 8 # LoRA rank (r): controls adaptation capacity
alpha: 32 # LoRA scaling factor (effective lr = alpha/dim)
dropout: 0.0 # Dropout probability for LoRA layers
dropout_position: "post" # Dropout position: "pre" or "post"
lora_A_init: "xavier" # Initialization method: "xavier" or "uniform"
use_triton: true # Use Triton-optimized kernels (DTensor v2 path)

### Parameter Details
- **`enabled`** (bool): Whether to enable LoRA training
- **`target_modules`** (list): Specific module names to apply LoRA. Empty with `match_all_linear=true` applies to all linear layers
- **`exclude_modules`** (list): Module names to exclude from LoRA
- **`match_all_linear`** (bool): When `true`, applies LoRA to all linear layers (overrides `target_modules`)
- **`dim`** (int): LoRA rank (r). Lower values = fewer parameters but less capacity. Typical: 4, 8, 16, 32, 64
- **`alpha`** (int): LoRA scaling factor. Effective learning rate multiplier = `alpha/dim`. Typical: 16, 32, 64
- **`dropout`** (float): Dropout probability for regularization
- **`dropout_position`** (str): Apply dropout before ("pre") or after ("post") LoRA
- **`lora_A_init`** (str): Initialization method for LoRA A matrix
- **`use_triton`** (bool): Use Triton-optimized kernels for better performance. Used for DTensor v2 only. **Note**: [Automodel does not support Triton for TP > 1](https://github.com/NVIDIA-NeMo/Automodel/blob/b2db55eee98dfe81a8bfe5e23ac4e57afd8ab261/nemo_automodel/recipes/llm/train_ft.py#L199). Set to `false` when `tensor_parallel_size > 1` to avoid compatibility issues

### Example Usage

```bash
uv run examples/run_sft.py policy.dtensor_cfg.lora_cfg.enabled=true
```

For more details on LoRA, see [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685).
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
defaults: ../../sft.yaml
sft:
max_num_steps: 350
val_period: 20
val_global_batch_size: 128
val_micro_batch_size: 2
checkpointing:
checkpoint_dir: results/sft-tmblog-llama3.1-8b
save_period: 20
policy:
model_name: meta-llama/Llama-3.1-8B
tokenizer:
name: meta-llama/Llama-3.1-8B-Instruct
chat_template: default
dtensor_cfg:
lora_cfg:
enabled: true
dim: 128
alpha: 128
train_global_batch_size: 128
max_total_sequence_length: 4096
make_sequence_length_divisible_by: 2
optimizer:
kwargs:
lr: 2.0e-05
weight_decay: 0.01
eps: 1.0e-08
data:
dataset_name: tulu3
add_generation_prompt: true
seed: 42
logger:
log_dir: logs/sft-tmblog-llama3.1-8b
tensorboard_enabled: false
wandb:
project: nemo-rl
name: sft-tmblog-llama3.1-8b
tensorboard:
log_dir: tb_logs-sft-dev-tulu3
cluster:
gpus_per_node: 8
14 changes: 14 additions & 0 deletions examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ policy:
offload_optimizer_for_logprob: false

dtensor_cfg:
_v2: true
enabled: true
env_vars: {}
cpu_offload: False
Expand All @@ -44,6 +45,19 @@ policy:
tensor_parallel_size: 1
context_parallel_size: 1
custom_parallel_plan: null

# LoRA (Low-Rank Adaptation) Configuration
lora_cfg:
enabled: False # Set to True to enable LoRA fine-tuning
target_modules: [] # List of module names to apply LoRA (empty list with match_all_linear=true applies to all linear layers)
exclude_modules: [] # List of module names to exclude from LoRA
match_all_linear: true # If True, applies LoRA to all linear layers (overrides target_modules)
dim: 8 # LoRA rank (r): lower rank = fewer parameters but less capacity. Typical values: 4, 8, 16, 32, 64
alpha: 32 # LoRA scaling factor: effective learning rate multiplier = alpha/dim. Typical values: 16, 32, 64
dropout: 0.0 # Dropout probability applied to LoRA layers (0.0 = no dropout)
dropout_position: "post" # Where to apply dropout: "pre" (before LoRA) or "post" (after LoRA)
lora_A_init: "xavier" # Initialization method for LoRA A matrix: "xavier" or "uniform"
use_triton: true # Use Triton-optimized kernels for LoRA (faster but requires flash-attn). Disable when tensor_parallel_size > 1

dynamic_batching:
enabled: false
Expand Down
18 changes: 18 additions & 0 deletions nemo_rl/models/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,23 @@
from nemo_rl.models.generation.interfaces import GenerationConfig


class LoRAConfigDisabled(TypedDict):
enabled: Literal[False]


class LoRAConfig(TypedDict):
enabled: Literal[True]
target_modules: list[str]
exclude_modules: list[str]
match_all_linear: NotRequired[bool]
dim: int
alpha: int
dropout: float
dropout_position: Literal["pre", "post"]
lora_A_init: str
use_triton: NotRequired[bool]


class DTensorConfigDisabled(TypedDict):
enabled: Literal[False]

Expand All @@ -32,6 +49,7 @@ class DTensorConfig(TypedDict):
context_parallel_size: int
custom_parallel_plan: str | None
clear_cache_every_n_steps: NotRequired[int | None]
lora_cfg: NotRequired[LoRAConfig | LoRAConfigDisabled]


class SequencePackingConfigDisabled(TypedDict):
Expand Down
4 changes: 4 additions & 0 deletions nemo_rl/models/policy/lm_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ def __init__(
if use_v2:
worker_builder_cls = "nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2"
else:
assert (
config["dtensor_cfg"].get("lora_cfg", {}).get("enabled", False)
is False
), "LoRA is not supported for DTensorPolicyWorker V1"
worker_builder_cls = "nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker"

tp_size = config["dtensor_cfg"]["tensor_parallel_size"]
Expand Down
40 changes: 40 additions & 0 deletions nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,24 @@

import gc
import itertools
import math
import os
import warnings
from collections import defaultdict
from contextlib import AbstractContextManager, contextmanager, nullcontext
from typing import Any, Generator, Optional, cast

import nemo_automodel.components._peft.lora as _lora_mod
import ray
import torch
from accelerate import init_empty_weights
from nemo_automodel import (
NeMoAutoModelForSequenceClassification,
)
from nemo_automodel.components._peft.lora import (
PeftConfig,
apply_lora_to_linear_modules,
)
from nemo_automodel.components.distributed.cp_utils import (
create_context_parallel_ctx,
get_train_context,
Expand Down Expand Up @@ -93,6 +99,15 @@
from nemo_rl.utils.packed_tensor import packed_broadcast_producer


# TODO: @ruit remove this once the bump Automodel to 2d20e33a19d5e53a271b1403b507475e68ad14dc (https://github.com/NVIDIA-NeMo/RL/issues/1586)
def _patched_init_lora_weights(self, init_method: str):
Comment thread
RayenTian marked this conversation as resolved.
if init_method == "xavier":
nn.init.xavier_normal_(self.lora_A.weight.data)
else:
nn.init.kaiming_uniform_(self.lora_A.weight.data, a=math.sqrt(5))
self.lora_B.weight.data.zero_()


@ray.remote(
runtime_env=get_runtime_env_for_policy_worker("dtensor_policy_worker_v2")
) # pragma: no cover
Expand Down Expand Up @@ -222,6 +237,23 @@ def __init__(

full_state_dict = None
model_state_dict_keys = None

# lora config
lora_cfg = self.cfg["dtensor_cfg"].get("lora_cfg", None)
self.peft_config = None
self.lora_enabled = lora_cfg is not None and lora_cfg["enabled"]
# patch the init_lora_weights method to use the xavier initialization
_lora_mod.LinearLoRA.init_lora_weights = _patched_init_lora_weights
Comment thread
terrykong marked this conversation as resolved.
if self.lora_enabled:
if self.cfg["dtensor_cfg"]["tensor_parallel_size"] > 1:
assert not lora_cfg["use_triton"], (
"Triton is not supported when tensor_parallel_size > 1"
)
# Always use float32 since FSDP requires all parameters to be in the same dtype.
# autocast should cast the weights to the correct dtype during the forward pass.
cfg_dict_with_dtype = {**lora_cfg, "lora_dtype": "torch.float32"}
self.peft_config = PeftConfig.from_dict(cfg_dict_with_dtype)

if self.rank == 0:
print(f"[Rank {self.rank}] Loading model {model_name} on CPU...")
model = model_class.from_pretrained(
Expand All @@ -233,6 +265,9 @@ def __init__(
torch_dtype=str(model_config.torch_dtype),
)

if self.lora_enabled:
apply_lora_to_linear_modules(model, self.peft_config)
Comment thread
hemildesai marked this conversation as resolved.

full_state_dict = model.state_dict()
# Store the original model state dict keys before any parallelization
model_state_dict_keys = list(full_state_dict.keys())
Expand All @@ -255,6 +290,8 @@ def __init__(
trust_remote_code=True,
torch_dtype=str(model_config.torch_dtype),
)
if self.lora_enabled:
apply_lora_to_linear_modules(self.model, self.peft_config)

if self.model.config.pad_token_id is None:
self.model.config.pad_token_id = tokenizer.pad_token_id
Expand Down Expand Up @@ -1857,6 +1894,9 @@ def save_checkpoint(
"peft_config",
}
}
if self.lora_enabled:
checkpoint_kwargs["is_peft"] = True
checkpoint_kwargs["peft_config"] = self.peft_config

save_checkpoint(
model=self.model,
Expand Down
46 changes: 46 additions & 0 deletions tests/functional/test_automodel_lora_sft.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/bin/bash

# clean up checkpoint directory on exit
trap "rm -rf /tmp/lora_sft_checkpoints" EXIT

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..)
# Mark the current repo as safe, since wandb fetches metadata about the repo
git config --global --add safe.directory $PROJECT_ROOT

set -eou pipefail

EXP_NAME=$(basename $0 .sh)
EXP_DIR=$SCRIPT_DIR/$EXP_NAME
LOG_DIR=$EXP_DIR/logs
JSON_METRICS=$EXP_DIR/metrics.json
RUN_LOG=$EXP_DIR/run.log
export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-}

rm -rf $EXP_DIR $LOG_DIR
mkdir -p $EXP_DIR $LOG_DIR

cd $PROJECT_ROOT
uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \
$PROJECT_ROOT/examples/run_sft.py \
policy.model_name=Qwen/Qwen3-0.6B \
cluster.gpus_per_node=2 \
sft.max_num_steps=3 \
sft.val_batches=1 \
sft.val_period=3 \
policy.dtensor_cfg.lora.enabled=true \
logger.tensorboard_enabled=true \
logger.log_dir=$LOG_DIR \
logger.wandb_enabled=false \
logger.monitor_gpus=true \
checkpointing.enabled=true \
checkpointing.save_period=3 \
checkpointing.checkpoint_dir=/tmp/lora_sft_checkpoints \
"$@" \
2>&1 | tee $RUN_LOG

uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS

uv run tests/check_metrics.py $JSON_METRICS \
'data["train/loss"]["3"] < 5.9'

Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/bin/bash
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
source $SCRIPT_DIR/common.env

# ===== BEGIN CONFIG =====
NUM_NODES=1
STEPS_PER_RUN=50
MAX_STEPS=50
NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up
NUM_MINUTES=30
# ===== END CONFIG =====

exit_if_max_steps_reached

# Run the experiment
cd $PROJECT_ROOT
uv run examples/run_sft.py \
--config $CONFIG_PATH \
sft.max_num_steps=$MAX_STEPS \
logger.log_dir=$LOG_DIR \
logger.wandb_enabled=True \
logger.wandb.project=ruit_personal_debug \
logger.wandb.name=$EXP_NAME \
logger.monitor_gpus=True \
logger.tensorboard_enabled=True \
checkpointing.enabled=True \
checkpointing.checkpoint_dir=$CKPT_DIR \
$@ \
2>&1 | tee $RUN_LOG

# Convert tensorboard logs to json
uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS

# TODO: memory check will fail due to OOM tracked here https://github.com/NVIDIA-NeMo/RL/issues/263

# Only run metrics if the target step is reached
if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then
uv run tests/check_metrics.py $JSON_METRICS \
'data["train/loss"]["1"] < 1.0' \
'data["train/loss"]["50"] < 0.8' \
'max(data["ray/node.0.gpu.0.mem_gb"]) < 50' \
'mean(data["timing/train/total_step_time"], 2) < 10'
fi
3 changes: 3 additions & 0 deletions tests/test_suites/nightly.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ tests/test_suites/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.sh
tests/test_suites/llm/sft-llama3.1-8b-1n8g-fsdp2tp2.sh
# dynamic batching
tests/test_suites/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-dynamicbatch.sh
# lora
# Tulu3 dataset is not supported yet. Re-enable this test once PR https://github.com/NVIDIA-NeMo/RL/pull/1506 is merged.
# tests/test_suites/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-lora.sh

# Functional 32b test
tests/test_suites/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v3.sh
Expand Down
Loading
Loading