Skip to content

Commit 558382f

Browse files
Add warm up fix.
Signed-off-by: Wangshanshan <[email protected]>
1 parent 94f9489 commit 558382f

1 file changed

Lines changed: 113 additions & 15 deletions

File tree

tests/integration/defs/stress_test/stress_test.py

Lines changed: 113 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,46 @@ def check_server_health(server_url: str,
349349
return False, f"Unexpected error during health check: {str(e)}"
350350

351351

352+
def warmup_inference(server_url: str,
353+
model_name: str,
354+
timeout: float = 300.0,
355+
num_warmup_requests: int = 2) -> bool:
356+
"""
357+
Send a few lightweight completion requests to the server so the inference
358+
pipeline (JIT, CUDA graphs, memory pools) is fully warmed up before
359+
aiperf begins timing. The /health endpoint only confirms the HTTP server
360+
is up -- the first real inference can be orders of magnitude slower.
361+
362+
Returns True if warmup succeeded, False otherwise.
363+
"""
364+
endpoint = f"{server_url}/v1/completions"
365+
payload = {
366+
"model": model_name,
367+
"prompt": "Hello",
368+
"max_tokens": 8,
369+
"temperature": 0.0,
370+
}
371+
for i in range(num_warmup_requests):
372+
try:
373+
print_info(
374+
f"Sending inference warmup request {i+1}/{num_warmup_requests}..."
375+
)
376+
resp = requests.post(endpoint, json=payload, timeout=timeout)
377+
if resp.status_code == 200:
378+
print_info(
379+
f"Warmup request {i+1}/{num_warmup_requests} succeeded")
380+
else:
381+
print_warning(
382+
f"Warmup request {i+1} returned status {resp.status_code}: "
383+
f"{resp.text[:200]}")
384+
return False
385+
except requests.RequestException as e:
386+
print_warning(f"Warmup request {i+1} failed: {e}")
387+
return False
388+
print_info("Inference warmup complete")
389+
return True
390+
391+
352392
def is_port_available(port: int,
353393
host: str = "localhost") -> Tuple[bool, Optional[str]]:
354394
"""
@@ -753,6 +793,14 @@ def stress_test(config,
753793
print_info(
754794
f"Server is running with model {model_name}. Starting tests...")
755795

796+
# Warm up the inference pipeline before benchmarking.
797+
# The /health endpoint only confirms the HTTP layer is up; the
798+
# first real inference triggers JIT/CUDA-graph compilation that
799+
# can take much longer than aiperf's per-request timeout.
800+
if not warmup_inference(test_server_config.url, model_name):
801+
print_warning("Inference warmup failed -- proceeding anyway, "
802+
"but aiperf may hit startup timeouts")
803+
756804
# Run baseline accuracy test first if enabled
757805
baseline_accuracy_success = True
758806
if stress_config and stress_config.enable_accuracy_test:
@@ -852,7 +900,8 @@ def create_aiperf_command(model_name,
852900
input_len_std=PerformanceParams.input_len_std,
853901
output_len_mean=PerformanceParams.output_len_mean,
854902
output_len_std=PerformanceParams.output_len_std,
855-
warmup_request_count=10):
903+
warmup_request_count=10,
904+
request_timeout_seconds=120.0):
856905
"""
857906
Create a command list for aiperf with standardized parameters.
858907
@@ -867,6 +916,7 @@ def create_aiperf_command(model_name,
867916
output_len_mean: Mean output length
868917
output_len_std: Standard deviation of output length
869918
warmup_request_count: Number of warmup requests
919+
request_timeout_seconds: Per-request timeout in seconds for aiperf
870920
871921
Returns:
872922
List of command-line arguments for aiperf
@@ -898,6 +948,8 @@ def create_aiperf_command(model_name,
898948
str(concurrency),
899949
"--warmup-request-count",
900950
str(warmup_request_count),
951+
"--request-timeout-seconds",
952+
str(request_timeout_seconds),
901953
# "--verbose",
902954
]
903955

@@ -910,6 +962,9 @@ def run_aiperf_process(cmd,
910962
"""
911963
Run a aiperf process and monitor both the process and server health.
912964
965+
Captures stdout/stderr so that aiperf's error output is visible in the
966+
pytest report when it exits with a non-zero code.
967+
913968
Args:
914969
cmd: Command list to execute aiperf
915970
test_start_time: Start time of the test
@@ -920,29 +975,53 @@ def run_aiperf_process(cmd,
920975
Returns:
921976
Boolean indicating whether the process completed successfully
922977
"""
923-
# Start aiperf process with our context manager
924-
with launch_process(cmd,
925-
start_new_session=True,
926-
filter_pattern=None,
927-
request_counter=request_counter) as process:
928-
# Set monitoring parameters
978+
stdout_lines = []
979+
stderr_lines = []
980+
stdout_lock = threading.Lock()
981+
stderr_lock = threading.Lock()
982+
983+
def _capture_and_print(pipe, line_buffer, lock):
984+
try:
985+
for line in iter(pipe.readline, ''):
986+
print(line, end='', flush=True)
987+
with lock:
988+
line_buffer.append(line)
989+
except (BrokenPipeError, IOError, ValueError):
990+
pass
991+
992+
process = Popen(cmd,
993+
start_new_session=True,
994+
stdout=subprocess.PIPE,
995+
stderr=subprocess.PIPE,
996+
bufsize=1,
997+
universal_newlines=True)
998+
print_info(f"Process started with PID: {process.pid}")
999+
1000+
stdout_reader = threading.Thread(target=_capture_and_print,
1001+
args=(process.stdout, stdout_lines,
1002+
stdout_lock),
1003+
daemon=True)
1004+
stderr_reader = threading.Thread(target=_capture_and_print,
1005+
args=(process.stderr, stderr_lines,
1006+
stderr_lock),
1007+
daemon=True)
1008+
stdout_reader.start()
1009+
stderr_reader.start()
1010+
1011+
try:
9291012
last_health_check = time.time()
9301013
process_completed = False
9311014

932-
# Monitor both the server and aiperf process
9331015
while process.poll() is None:
9341016
current_time = time.time()
9351017

936-
# Check if aiperf is still running but exceeded timeout
9371018
elapsed_time = current_time - test_start_time
9381019
if elapsed_time > test_timeout:
9391020
cleanup_process_tree(process, has_session=True)
9401021
raise RuntimeError(
9411022
f"aiperf test timed out after {test_timeout} seconds")
9421023

943-
# Check server health periodically
9441024
if current_time - last_health_check > server_config.health_check_timeout:
945-
9461025
is_healthy, error_msg = check_server_health(
9471026
server_config.url, server_config.health_check_timeout)
9481027

@@ -951,31 +1030,50 @@ def run_aiperf_process(cmd,
9511030
f"Server health check passed after {elapsed_time:.1f} seconds of test"
9521031
)
9531032
else:
954-
# Raise an exception to stop the test
9551033
print_warning(f"Server health check failed: {error_msg}")
9561034
cleanup_process_tree(process, has_session=True)
9571035
raise RuntimeError(
9581036
f"Server health check failed during test: {error_msg}")
9591037

960-
# Update last health check time
9611038
last_health_check = current_time
9621039

9631040
time.sleep(0.5)
9641041

965-
# Check final status of aiperf process
1042+
stdout_reader.join(timeout=5)
1043+
stderr_reader.join(timeout=5)
1044+
9661045
retcode = process.poll()
9671046
if retcode is not None:
9681047
if retcode != 0:
1048+
with stderr_lock:
1049+
captured_stderr = ''.join(stderr_lines[-50:])
1050+
with stdout_lock:
1051+
captured_stdout = ''.join(stdout_lines[-50:])
9691052
cleanup_process_tree(process, has_session=True)
9701053
raise RuntimeError(
971-
f"aiperf exited with non-zero code: {retcode}")
1054+
f"aiperf exited with non-zero code: {retcode}\n"
1055+
f"--- aiperf stdout (last 50 lines) ---\n"
1056+
f"{captured_stdout}\n"
1057+
f"--- aiperf stderr (last 50 lines) ---\n"
1058+
f"{captured_stderr}")
9721059
else:
9731060
print_info("aiperf completed successfully")
9741061
process_completed = True
9751062
else:
9761063
cleanup_process_tree(process, has_session=True)
9771064
raise RuntimeError(
9781065
"aiperf did not complete normally, will terminate")
1066+
finally:
1067+
if process.poll() is None:
1068+
process.terminate()
1069+
try:
1070+
process.wait(timeout=GRACEFUL_TERMINATION_TIMEOUT)
1071+
except subprocess.TimeoutExpired:
1072+
cleanup_process_tree(process, has_session=True)
1073+
if process.stdout:
1074+
process.stdout.close()
1075+
if process.stderr:
1076+
process.stderr.close()
9791077

9801078
return process_completed
9811079

0 commit comments

Comments
 (0)