-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Description
Describe the bug
I am trying to test the DeepSpeed framework. The following code main.py includes a model I generated for testing, training data (with the random seed fixed at 42 to avoid nondeterminism), and the DeepSpeed configuration.
import torch
import torch.nn as nn
import torch.nn.functional as F
import deepspeed
import os
import torch.distributed as dist
from torch.utils.data import Dataset
import random
import numpy as np
import torch
def set_seed(seed: int):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
set_seed(42)
def get_device():
if torch.cuda.is_available():
return torch.device("cuda")
else:
return torch.device("cpu")
class RandomDataset(Dataset):
def __init__(self, num_samples=3200, input_dim=32, num_classes=10):
self.x = torch.randn(num_samples, input_dim)
self.y = torch.randint(0, num_classes, (num_samples,))
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
return self.x[idx], self.y[idx]
class RandomNet(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
layers = []
dim = input_dim
for _ in range(10):
next_dim = 64
layers.append(nn.Linear(dim, next_dim))
layers.append(nn.ReLU())
dim = next_dim
layers.append(nn.Linear(dim, output_dim))
self.net = nn.Sequential(*layers)
self.criterion = nn.CrossEntropyLoss()
def forward(self, x, labels=None):
logits = self.net(x)
if labels is not None:
return self.criterion(logits, labels)
return logits
def generate_fake_data(batch_size, input_dim, output_dim):
x = torch.randn(batch_size, input_dim)
y = torch.randint(0, output_dim, (batch_size,))
return x, y
ds_config = {
"train_batch_size": 32,
"train_micro_batch_size_per_gpu": 4,
"optimizer": {
"type": "AdamW"
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0.10873162,
"warmup_max_lr": 0.146191
}
},
"fp16": {
"enabled": True,
"auto_cast": True
},
"zero_optimization": {
"stage": 1
}
}
def main():
input_dim = 32
output_dim = 10
model = RandomNet(input_dim, output_dim)
trainset = RandomDataset(num_samples=32 * 100, input_dim=32, num_classes=10)
model_engine, _, trainloader, _ = deepspeed.initialize(
model=model,
model_parameters=model.parameters(),
config=ds_config,
training_data=trainset
)
rank = int(os.getenv("LOCAL_RANK", "0"))
try:
for step, batch in enumerate(trainloader):
x, y = batch
x = x.to(model_engine.device)
y = y.to(model_engine.device)
loss = model_engine(x, labels=y)
model_engine.backward(loss)
model_engine.step()
print(f"[Rank {rank}] Step {step} | Loss = {loss.item():.4f}")
finally:
if dist.is_initialized():
dist.destroy_process_group()
if __name__ == "__main__":
main()
When I run the command deepspeed --num_gpus=4 main.py, the training hangs after 16 steps until NCCL times out. However, if I remove any one of the following from the configuration — scheduler, fp16, or zero_optimization — the training completes successfully.
To Reproduce
Steps to reproduce the behavior:
- Run
deepspeed --num_gpus=4 main.py - The training hangs after 16 steps until NCCL times out
Expected behavior
The training process should not hang. If there is an error, it should at least raise an exception instead of stalling.
ds_report output
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
runtime if needed. Op compatibility means that your system
meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
async_io ............... [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
dc ..................... [NO] ....... [OKAY]
[WARNING] Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fp_quantizer ........... [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
gds .................... [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.5
[WARNING] using untested triton version (3.1.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/yanzhen/miniconda3/envs/deepspeed/lib/python3.10/site-packages/torch']
torch version .................... 2.5.1
deepspeed install path ........... ['/home/yanzhen/miniconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed']
deepspeed info ................... 0.17.6+533e834b, 533e834b, master
torch cuda version ............... 11.8
torch hip version ................ None
nvcc version ..................... 12.4
deepspeed wheel compiled w. ...... torch 2.5, cuda 11.8
shared memory (/dev/shm) size .... 503.83 GB
Screenshots
System info (please complete the following information):
- Ubuntu 22.04
- one machines with x4 RTX 4090s
- Python 3.10.18
Launcher context
deepspeed
Docker context
No