Skip to content

Support Qwen 3.5 #1194

@YSQ-boop

Description

@YSQ-boop

我使用openrlhf-0.9.5训练qwen3.5-4b时出现下列错误:

forward: 0%| | 0/28 [00:00<?, ?it/s]
Traceback (most recent call last):
File "", line 198, in _run_module_as_main
File "", line 88, in _run_code
File "OpenRLHF-main/openrlhf/cli/train_ppo_ray.py", line 597, in
train(args)
File "OpenRLHF-main/openrlhf/cli/train_ppo_ray.py", line 175, in train
ray.get(ppo_trainer.fit.remote())
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/ray/_private/auto_init_hook.py", line 22, in auto_init_wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/ray/_private/client_mode_hook.py", line 104, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/ray/_private/worker.py", line 2858, in get
values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/ray/_private/worker.py", line 958, in get_objects
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(AcceleratorError): ray::PPOTrainer.fit() (pid=6704, ip=100.96.111.166, actor_id=62d5473dbee7d366ad05c96c01000000, repr=<openrlhf.trainer.ppo_trainer.PPOTrainer object at 0x7f6e8d9ee1e0>)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/shared-storage-user/OpenRLHF-main/openrlhf/trainer/ppo_trainer.py", line 324, in fit
status, global_step = self.train_step(rollout_samples, global_step)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/shared-storage-user/OpenRLHF-main/openrlhf/trainer/ppo_trainer.py", line 113, in train_step
experiences = self.experience_maker.make_experience_batch(rollout_samples)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/mnt/shared-storage-user/OpenRLHF-main/openrlhf/trainer/ppo_utils/experience_maker.py", line 544, in make_experience_batch
experiences = self.make_experience(samples_list)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/mnt/shared-storage-user/OpenRLHF-main/openrlhf/trainer/ppo_utils/experience_maker.py", line 597, in make_experience
ray.get(action_log_probs_ref)
^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ray.exceptions.RayTaskError(AcceleratorError): ray::PolicyModelActor.execute_batch() (pid=6055, ip=100.96.111.166, actor_id=3e27d9ae1e44c1f940a826e401000000, repr=<openrlhf.trainer.ray.ppo_actor.PolicyModelActor object at 0x7f4e0cb79c40>)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/shared-storage-user/OpenRLHF-main/openrlhf/trainer/ray/launcher.py", line 98, in execute_batch
result = func(**sample_kwargs)
^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/shared-storage-user/OpenRLHF-main/openrlhf/trainer/ray/ppo_actor.py", line 541, in forward
action_log_probs = self.actor(
^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/shared-storage-user/OpenRLHF-main/openrlhf/models/actor.py", line 174, in forward
output = self.model(sequences, attention_mask=foward_attention_mask, position_ids=position_ids)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
ret_val = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 2316, in forward
loss = self.module(*inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1882, in _call_impl
return inner()
^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1830, in inner
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/transformers/utils/generic.py", line 843, in wrapper
output = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/transformers/models/qwen3_5/modeling_qwen3_5.py", line 1845, in forward
outputs: BaseModelOutputWithPast = self.model(
^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1882, in _call_impl
return inner()
^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1830, in inner
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/transformers/utils/generic.py", line 917, in wrapper
output = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/transformers/utils/output_capturing.py", line 253, in wrapper
outputs = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/transformers/models/qwen3_5/modeling_qwen3_5.py", line 1372, in forward
hidden_states = decoder_layer(
^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/transformers/modeling_layers.py", line 93, in call
return super().call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1882, in _call_impl
return inner()
^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1830, in inner
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/transformers/models/qwen3_5/modeling_qwen3_5.py", line 880, in forward
hidden_states = self.mlp(hidden_states)
^^^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1882, in _call_impl
return inner()
^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1830, in inner
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/transformers/models/qwen3_5/modeling_qwen3_5.py", line 804, in forward
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1882, in _call_impl
return inner()
^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1819, in inner
args_result = hook(self, args)
^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1181, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 300, in _pre_forward_module_hook
self.pre_sub_module_forward_function(module)
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 475, in pre_sub_module_forward_function
param_coordinator.fetch_sub_module(sub_module, forward=True)
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1181, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
ret_val = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 325, in fetch_sub_module
self._fetch_sub_module_impl(current_submodule, forward, is_leaf)
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 380, in _fetch_sub_module_impl
self.__ongoing_fetch_events.popleft().synchronize()
File "/data/conda_envs/openrlhf-9.5/lib/python3.12/site-packages/torch/cuda/streams.py", line 245, in synchronize
super().synchronize()
torch.AcceleratorError: CUDA error: an illegal memory access was encountered
Search for cudaErrorIllegalAddress' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information. CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1 Compile with TORCH_USE_CUDA_DSA` to enable device-side assertions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions