Skip to content

Megatron inference crash due to flash infer from latest mcore bump #1633

@terrykong

Description

@terrykong
uv run --group test bash tests/run_unit.sh -k test_megatron_policy_generation

error:

_______________________________________________________________________________________________________________________________ test_megatron_policy_generation[2gpu_dp2_megatron] ________________________________________________________________________________________________________________________________

generation_setup = (<nemo_rl.models.policy.lm_policy.Policy object at 0x774db1a87dd0>, <nemo_rl.distributed.virtual_cluster.RayVirtualClu...o, how are you?', 'The capital of France is', 'Write a short story about', 'Explain quantum physics in simple terms:'])

    @pytest.mark.timeout(240)
    @pytest.mark.parametrize(
        "generation_setup",
        [
            # (num_gpus, tp, pp, generation_backend)
            (2, 1, 1, "megatron"),
            (2, 2, 1, "megatron"),
        ],
        indirect=True,
        ids=["2gpu_dp2_megatron", "2gpu_tp2_megatron"],
    )
    def test_megatron_policy_generation(generation_setup):
        """Test Megatron policy generation with different backends."""
        policy, cluster, data, prompts = generation_setup
    
        # Verify resources were created properly
        assert policy is not None, "Generation policy was not created properly"
        assert cluster is not None, "Generation cluster was not created properly"
        assert data is not None, "Test data was not created properly"
    
        # Call prepare_for_generation
        print("Preparing for generation...")
        policy.prepare_for_generation()
    
        # Generate text
        print("Generating text...")
>       results = policy.generate(data, greedy=True)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

unit/models/policy/test_megatron_worker.py:576: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../nemo_rl/models/policy/lm_policy.py:578: in generate
    self.worker_group.get_all_worker_results(futures),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../nemo_rl/distributed/worker_groups.py:953: in get_all_worker_results
    return future_bundle.get_results(
../nemo_rl/distributed/worker_groups.py:103: in get_results
    all_results = ray.get(object_refs)
                  ^^^^^^^^^^^^^^^^^^^^
/opt/nemo_rl_venv/lib/python3.12/site-packages/ray/_private/auto_init_hook.py:22: in auto_init_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
/opt/nemo_rl_venv/lib/python3.12/site-packages/ray/_private/client_mode_hook.py:104: in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
/opt/nemo_rl_venv/lib/python3.12/site-packages/ray/_private/worker.py:2882: in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <ray._private.worker.Worker object at 0x77501d8029c0>, object_refs = [ObjectRef(1e8ff6d2361327847db1c9eac7fcddf5dacf325a0100000001000000), ObjectRef(85748392bcd969ccf9c68f78d8701fd67c1f0e8d0100000001000000)], timeout = None, return_exceptions = False, skip_deserialization = False

    def get_objects(
        self,
        object_refs: list,
        timeout: Optional[float] = None,
        return_exceptions: bool = False,
        skip_deserialization: bool = False,
    ) -> Tuple[List[serialization.SerializedRayObject], bytes]:
        """Get the values in the object store associated with the IDs.
    
        Return the values from the local object store for object_refs. This
        will block until all the values for object_refs have been written to
        the local object store.
    
        Args:
            object_refs: A list of the object refs
                whose values should be retrieved.
            timeout: The maximum amount of time in
                seconds to wait before returning.
            return_exceptions: If any of the objects deserialize to an
                Exception object, whether to return them as values in the
                returned list. If False, then the first found exception will be
                raised.
            skip_deserialization: If true, only the buffer will be released and
                the object associated with the buffer will not be deserialized.
        Returns:
            list: List of deserialized objects or None if skip_deserialization is True.
            bytes: UUID of the debugger breakpoint we should drop
                into or b"" if there is no breakpoint.
        """
        # Make sure that the values are object refs.
        for object_ref in object_refs:
            if not isinstance(object_ref, ObjectRef):
                raise TypeError(
                    f"Attempting to call `get` on the value {object_ref}, "
                    "which is not an ray.ObjectRef."
                )
    
        timeout_ms = (
            int(timeout * 1000) if timeout is not None and timeout != -1 else -1
        )
        serialized_objects: List[
            serialization.SerializedRayObject
        ] = self.core_worker.get_objects(
            object_refs,
            timeout_ms,
        )
    
        debugger_breakpoint = b""
        for data, metadata, _ in serialized_objects:
            if metadata:
                metadata_fields = metadata.split(b",")
                if len(metadata_fields) >= 2 and metadata_fields[1].startswith(
                    ray_constants.OBJECT_METADATA_DEBUG_PREFIX
                ):
                    debugger_breakpoint = metadata_fields[1][
                        len(ray_constants.OBJECT_METADATA_DEBUG_PREFIX) :
                    ]
        if skip_deserialization:
            return None, debugger_breakpoint
    
        values = self.deserialize_objects(serialized_objects, object_refs)
        if not return_exceptions:
            # Raise exceptions instead of returning them to the user.
            for i, value in enumerate(values):
                if isinstance(value, RayError):
                    if isinstance(value, ray.exceptions.ObjectLostError):
                        global_worker.core_worker.log_plasma_usage()
                    if isinstance(value, RayTaskError):
>                       raise value.as_instanceof_cause()
E                       ray.exceptions.RayTaskError: ray::MegatronPolicyWorker.generate() (pid=2680520, ip=172.17.0.2, actor_id=7db1c9eac7fcddf5dacf325a01000000, repr=MegatronPolicyWorker[rank=0])
E                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                         File "/workspaces/nemo-rl/nemo_rl/utils/nsys.py", line 88, in wrapper
E                           ret = func(*args, **kwargs)
E                                 ^^^^^^^^^^^^^^^^^^^^^
E                         File "/workspaces/nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker.py", line 1879, in generate
E                           dynamic_engine = DynamicInferenceEngine(
E                                            ^^^^^^^^^^^^^^^^^^^^^^^
E                         File "/workspaces/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/engines/dynamic_engine.py", line 186, in __init__
E                           self.create_cuda_graphs()
E                         File "/workspaces/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/engines/dynamic_engine.py", line 260, in create_cuda_graphs
E                           controller._dynamic_step_forward_logits(input_ids, position_ids)
E                         File "/workspaces/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/text_generation_controllers/text_generation_controller.py", line 553, in _dynamic_step_forward_logits
E                           logits = self.inference_wrapped_model.run_one_forward_step(
E                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                         File "/opt/ray_venvs/nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
E                           return func(*args, **kwargs)
E                                  ^^^^^^^^^^^^^^^^^^^^^
E                         File "/workspaces/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py", line 389, in run_one_forward_step
E                           return self.forward_pass_without_pipeline_parallel(inference_input)
E                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                         File "/workspaces/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py", line 213, in forward_pass_without_pipeline_parallel
E                           logits = self._forward(inference_input)
E                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                         File "/workspaces/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py", line 161, in _forward
E                           return self.model(
E                                  ^^^^^^^^^^^
E                         File "/opt/ray_venvs/nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
E                           return self._call_impl(*args, **kwargs)
E                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                         File "/opt/ray_venvs/nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
E                           return forward_call(*args, **kwargs)
E                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                         File "/workspaces/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/distributed/data_parallel_base.py", line 22, in forward
E                           return self.module(*inputs, **kwargs)
E                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                         File "/opt/ray_venvs/nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
E                           return self._call_impl(*args, **kwargs)
E                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                         File "/opt/ray_venvs/nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
E                           return forward_call(*args, **kwargs)
E                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                         File "/workspaces/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/transformer/module.py", line 443, in forward
E                           outputs = self.module(*inputs, **kwargs)
E                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                         File "/opt/ray_venvs/nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
E                           return self._call_impl(*args, **kwargs)
E                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                         File "/opt/ray_venvs/nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
E                           return forward_call(*args, **kwargs)
E                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                         File "/workspaces/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/models/gpt/gpt_model.py", line 478, in forward
E                           hidden_states = self.decoder(
E                                           ^^^^^^^^^^^^^
E                         File "/workspaces/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/transformer/transformer_block.py", line 586, in __call__
E                           return super().__call__(*args, **kwargs)
E                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                         File "/workspaces/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/transformer/module.py", line 319, in __call__
E                           return super().__call__(*args, **kwargs)
E                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                         File "/opt/ray_venvs/nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
E                           return self._call_impl(*args, **kwargs)
E                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                         File "/opt/ray_venvs/nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
E                           return forward_call(*args, **kwargs)
E                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                         File "/workspaces/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/transformer/transformer_block.py", line 735, in forward
E                           hidden_states, context = layer(
E                                                    ^^^^^^
E                         File "/workspaces/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 1015, in __call__
E                           return super().__call__(*args, **kwargs)
E                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                         File "/workspaces/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/transformer/module.py", line 319, in __call__
E                           return super().__call__(*args, **kwargs)
E                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                         File "/opt/ray_venvs/nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
E                           return self._call_impl(*args, **kwargs)
E                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                         File "/opt/ray_venvs/nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
E                           return forward_call(*args, **kwargs)
E                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                         File "/workspaces/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 456, in forward
E                           hidden_states, context = self._forward_attention(*args, **kwargs)
E                                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                         File "/workspaces/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 530, in _forward_attention
E                           attention_output_with_bias = self.self_attention(
E                                                        ^^^^^^^^^^^^^^^^^^^^
E                         File "/opt/ray_venvs/nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
E                           return self._call_impl(*args, **kwargs)
E                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                         File "/opt/ray_venvs/nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
E                           return forward_call(*args, **kwargs)
E                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                         File "/workspaces/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/transformer/attention.py", line 838, in forward
E                           self._adjust_key_value_for_inference(
E                         File "/workspaces/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/transformer/attention.py", line 449, in _adjust_key_value_for_inference
E                           query, key = inference_context.apply_fused_qk_rotary_emb(
E                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                         File "/workspaces/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/contexts/dynamic_context.py", line 881, in apply_fused_qk_rotary_emb
E                           query_rope, key_rope = flashinfer.rope.apply_rope_with_cos_sin_cache(
E                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                         File "/opt/ray_venvs/nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/flashinfer/rope.py", line 1184, in apply_rope_with_cos_sin_cache
E                           _apply_rope_pos_ids_cos_sin_cache(
E                         File "/opt/ray_venvs/nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/flashinfer/rope.py", line 340, in _apply_rope_pos_ids_cos_sin_cache
E                           get_rope_module().apply_rope_pos_ids_cos_sin_cache(
E                         File "python/tvm_ffi/cython/function.pxi", line 904, in tvm_ffi.core.Function.__call__
E                       RuntimeError: Error in function 'BatchQKApplyRotaryPosIdsCosSinCache' at /opt/ray_venvs/nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/flashinfer/data/include/flashinfer/pos_enc.cuh:1262: Unsupported head_dim: 32

Metadata

Metadata

Labels

bugSomething isn't workingqa_rcca_donewhen RCCA finished for the issue, the qa will mark with this label .

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions