@@ -61,7 +61,13 @@ def get_ucx_tls():
6161
6262 Pre-Hopper GPUs need cuda_ipc excluded from UCX transports.
6363 """
64- if get_sm_version () < 90 :
64+ sm = get_sm_version ()
65+ """
66+ ON some gb300 cluster, we need to set `cuda_copy,cuda_ipc,sm,self,tcp` for UCX_TLS
67+ """
68+ if sm == 103 :
69+ return "cuda_copy,cuda_ipc,sm,self,tcp"
70+ if sm < 90 :
6571 return "^cuda_ipc,ib,gdr_copy"
6672 return "^ib,gdr_copy"
6773
@@ -607,10 +613,13 @@ def run_disaggregated_test(example_dir,
607613 "https://nvbugs/5584607 Ray orchestrator is not supported with NIXL(DEFAULT) cache transceiver backend."
608614 )
609615
616+ run_env = env .copy () if env else os .environ .copy ()
617+ run_env ["UCX_TLS" ] = get_ucx_tls ()
618+
610619 config_file = get_test_config (test_desc , example_dir ,
611620 os .path .dirname (__file__ ))
612621 config , ctx_workers , gen_workers , disagg_server , server_port , work_dir = \
613- setup_disagg_cluster (config_file , model_name = model_path , env = env , cwd = cwd ,
622+ setup_disagg_cluster (config_file , model_name = model_path , env = run_env , cwd = cwd ,
614623 schedule_style = disagg_schedule_style )
615624
616625 server_host = config .get ("hostname" , "localhost" )
@@ -637,7 +646,7 @@ def run_disaggregated_test(example_dir,
637646 client_config_file ,
638647 test_desc ,
639648 num_iters ,
640- env ,
649+ run_env ,
641650 300 , # timeout
642651 prompt_file ,
643652 extra_endpoints_test ,
@@ -781,6 +790,7 @@ def test_disaggregated_benchmark_gen_only_insufficient_kv(
781790 env = llm_venv ._new_env .copy ()
782791 env ['TRTLLM_DISAGG_BENCHMARK_GEN_ONLY' ] = '1'
783792 env ['TLLM_BENCHMARK_REQ_QUEUES_SIZE' ] = '64'
793+ env ["UCX_TLS" ] = get_ucx_tls ()
784794
785795 config_file = get_test_config ("gen_only_insufficient_kv" ,
786796 disaggregated_example_root ,
@@ -2356,7 +2366,7 @@ def extract_logprobs(result, api_type):
23562366
23572367 env = llm_venv ._new_env .copy ()
23582368 env ["TRTLLM_USE_UCX_KVCACHE" ] = "1"
2359- env ["UCX_TLS" ] = "^ib,gdr_copy"
2369+ env ["UCX_TLS" ] = get_ucx_tls ()
23602370 ctx_workers , gen_workers , disagg_server , work_dir = [], [], None , None
23612371 config , ctx_workers , gen_workers , disagg_server , server_port , work_dir = \
23622372 setup_disagg_cluster (config_file , env = env ,
@@ -2520,8 +2530,7 @@ def test_disaggregated_mamba_conc_greater_than_mbs(disaggregated_example_root,
25202530 os .path .dirname (__file__ ))
25212531
25222532 env = llm_venv ._new_env .copy ()
2523- # Need to set UCX_TLS to ^ib to avoid hangs on CI B200 cluster.
2524- env ["UCX_TLS" ] = "^ib"
2533+ env ["UCX_TLS" ] = get_ucx_tls ()
25252534 e2el , ttft = run_disaggregated_benchmark (
25262535 disaggregated_example_root ,
25272536 config_file ,
0 commit comments