Skip to content

SFT always stopped after 1 epoch #254

@songdezhao

Description

@songdezhao

Hello, I am doing SFT with Llama-3.1-8B-Instruct. In the config, I specified it to train for 3 epochs; however, the training always stopped after the first epoch and printed the message "Training finished".

Here is the code and I am wondering whether I did anything wrong here?

Here is how to reproduce:
transformers==4.54.1
deepspeed==0.17.4

Step 1: I installed arctictraining with this command:

git clone https://github.com/snowflakedb/ArcticTraining.git && cd ArcticTraining && pip install -e .

Step 2: I downloaded "meta-llama/Llama-3.1-8B-Instruct" to a local directory: "custom_model".

Step 3: I ran the following training code:

from arctic_training import get_config, SFTTrainer
from arctic_training.config.enums import DType

if __name__ == '__main__':
    config_dict = {
        "type": "sft",
        "seed": 1000,
        "epochs": 3,
        "micro_batch_size": 512,
        "activation_checkpoint_cpu_offload": False,
        "tiled_mlp_compute": True,
        "sequence_parallel_size": 8,
        "deepspeed": {
            "zero_optimization": {
                "stage": 3
            },
            "seq_parallel_communication_data_type": "bf16"
        },
        "optimizer": {
            "type": "fusedadam",
            "learning_rate": 0.000001
        },
        "model": {
            "type": "liger",
            "name_or_path": "./custom_model",
            "dtype": DType.BF16,
            "attn_implementation": "flash_attention_2"
        },
        "data": {
            "type": "sft",
            "dl_num_workers": 8,
            "max_length": 500,
            "pack_samples": False,
            "sources": [
                {
                    "type": "huggingface_instruct",
                    "name_or_path": "HuggingFaceH4/ultrachat_200k",
                    "split": "train_sft",
                    "role_mapping": {
                        "user": "messages.role.user",
                        "assistant": "messages.role.assistant"
                    }
                }
            ]
        },
        "logger": {
            "level": "INFO",
            "output_dir": "logs",
            "print_output_ranks": [0]
        },
        "checkpoint": [
            {
                "type": "huggingface",
                "save_every_n_epochs": 1,
                "save_end_of_training": False,
                "output_dir": "./"
            }
        ]
    }

    config = get_config(config_dict)
    trainer = SFTTrainer(config)
    trainer.train()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions