1- import json
21import os
32import subprocess
43
5- from tools .stats .s3_stat_parser import (
6- get_previous_reports_for_branch ,
7- Report ,
8- Version2Report ,
9- HAVE_BOTO3 ,
10- )
114from tools .stats .import_test_stats import get_disabled_tests , get_slow_tests
125
13- from typing import Any , Dict , List , Optional , Tuple , cast
14- from typing_extensions import TypedDict
15-
16-
17- class JobTimeJSON (TypedDict ):
18- commit : str
19- JOB_BASE_NAME : str
20- job_times : Dict [str , float ]
21-
22-
23- def _get_stripped_CI_job () -> str :
24- return os .environ .get ("BUILD_ENVIRONMENT" , "" )
25-
26-
27- def _get_job_times_json (job_times : Dict [str , float ]) -> JobTimeJSON :
28- return {
29- "commit" : subprocess .check_output (
30- ["git" , "rev-parse" , "HEAD" ], encoding = "ascii"
31- ).strip (),
32- "JOB_BASE_NAME" : _get_stripped_CI_job (),
33- "job_times" : job_times ,
34- }
35-
36-
37- def _calculate_job_times (reports : List ["Report" ]) -> Dict [str , float ]:
38- """Compute test runtime by filename: ("test_file_name" -> (current_avg, # values))"""
39- jobs_to_times : Dict [str , Tuple [float , int ]] = dict ()
40- for report in reports :
41- v_report = cast (Version2Report , report )
42- assert (
43- "format_version" in v_report .keys () and v_report .get ("format_version" ) == 2
44- ), "S3 format currently handled is version 2 only"
45- files : Dict [str , Any ] = v_report ["files" ]
46- for name , test_file in files .items ():
47- if name not in jobs_to_times :
48- jobs_to_times [name ] = (test_file ["total_seconds" ], 1 )
49- else :
50- curr_avg , curr_count = jobs_to_times [name ]
51- new_count = curr_count + 1
52- new_avg = (
53- curr_avg * curr_count + test_file ["total_seconds" ]
54- ) / new_count
55- jobs_to_times [name ] = (new_avg , new_count )
56-
57- return {job : time for job , (time , _ ) in jobs_to_times .items ()}
6+ from typing import Dict , List , Tuple
587
598
609def calculate_shards (
@@ -91,63 +40,6 @@ def calculate_shards(
9140 return sharded_jobs
9241
9342
94- def _pull_job_times_from_S3 () -> Dict [str , float ]:
95- if HAVE_BOTO3 :
96- ci_job_prefix = _get_stripped_CI_job ()
97- s3_reports : List ["Report" ] = get_previous_reports_for_branch (
98- "origin/viable/strict" , ci_job_prefix
99- )
100- else :
101- print (
102- "Uh oh, boto3 is not found. Either it is not installed or we failed to import s3_stat_parser."
103- )
104- print (
105- "If not installed, please install boto3 for automatic sharding and test categorization."
106- )
107- s3_reports = []
108-
109- if len (s3_reports ) == 0 :
110- print ("::warning:: Gathered no reports from S3. Please proceed without them." )
111- return dict ()
112-
113- return _calculate_job_times (s3_reports )
114-
115-
116- def _query_past_job_times (test_times_file : Optional [str ] = None ) -> Dict [str , float ]:
117- """Read historic test job times from a file.
118-
119- If the file doesn't exist or isn't matching current commit. It will download data from S3 and exported it.
120- """
121- if test_times_file and os .path .exists (test_times_file ):
122- with open (test_times_file ) as file :
123- test_times_json : JobTimeJSON = json .load (file )
124-
125- curr_commit = subprocess .check_output (
126- ["git" , "rev-parse" , "HEAD" ], encoding = "ascii"
127- ).strip ()
128- file_commit = test_times_json .get ("commit" , "" )
129- curr_ci_job = _get_stripped_CI_job ()
130- file_ci_job = test_times_json .get ("JOB_BASE_NAME" , "N/A" )
131- if curr_commit != file_commit :
132- print (f"Current test times file is from different commit { file_commit } ." )
133- elif curr_ci_job != file_ci_job :
134- print (f"Current test times file is for different CI job { file_ci_job } ." )
135- else :
136- print (
137- f"Found stats for current commit: { curr_commit } and job: { curr_ci_job } . Proceeding with those values."
138- )
139- return test_times_json .get ("job_times" , {})
140-
141- # Found file, but commit or CI job in JSON doesn't match
142- print (
143- f"Overwriting current file with stats based on current commit: { curr_commit } and CI job: { curr_ci_job } "
144- )
145-
146- job_times = export_S3_test_times (test_times_file )
147-
148- return job_times
149-
150-
15143def _query_changed_test_files () -> List [str ]:
15244 default_branch = f"origin/{ os .environ .get ('GIT_DEFAULT_BRANCH' , 'master' )} "
15345 cmd = ["git" , "diff" , "--name-only" , default_branch , "HEAD" ]
@@ -161,47 +53,6 @@ def _query_changed_test_files() -> List[str]:
16153 return lines
16254
16355
164- # Get sharded test allocation based on historic S3 data.
165- def get_shard_based_on_S3 (
166- which_shard : int , num_shards : int , tests : List [str ], test_times_file : str
167- ) -> List [str ]:
168- # Short circuit and don't do any work if there's only 1 shard
169- if num_shards == 1 :
170- return tests
171-
172- jobs_to_times = _query_past_job_times (test_times_file )
173-
174- # Got no stats from S3, returning early to save runtime
175- if len (jobs_to_times ) == 0 :
176- print (
177- "::warning:: Gathered no stats from S3. Proceeding with default sharding plan."
178- )
179- return tests [which_shard - 1 :: num_shards ]
180-
181- shards = calculate_shards (num_shards , tests , jobs_to_times )
182- _ , tests_from_shard = shards [which_shard - 1 ]
183- return tests_from_shard
184-
185-
186- def get_slow_tests_based_on_S3 (
187- test_list : List [str ], td_list : List [str ], slow_test_threshold : int
188- ) -> List [str ]:
189- """Get list of slow tests based on historic S3 data."""
190- jobs_to_times : Dict [str , float ] = _query_past_job_times ()
191-
192- # Got no stats from S3, returning early to save runtime
193- if len (jobs_to_times ) == 0 :
194- print ("::warning:: Gathered no stats from S3. No new slow tests calculated." )
195- return []
196-
197- slow_tests : List [str ] = []
198- for test in test_list :
199- if test in jobs_to_times and test not in td_list :
200- if jobs_to_times [test ] > slow_test_threshold :
201- slow_tests .append (test )
202- return slow_tests
203-
204-
20556def get_reordered_tests (tests : List [str ]) -> List [str ]:
20657 """Get the reordered test filename list based on github PR history or git changed file."""
20758 prioritized_tests : List [str ] = []
@@ -242,20 +93,6 @@ def get_reordered_tests(tests: List[str]) -> List[str]:
24293 return tests
24394
24495
245- # TODO Refactor this and unify with tools.stats.export_slow_tests
246- def export_S3_test_times (test_times_filename : Optional [str ] = None ) -> Dict [str , float ]:
247- test_times : Dict [str , float ] = _pull_job_times_from_S3 ()
248- if test_times_filename is not None :
249- print (f"Exporting S3 test stats to { test_times_filename } ." )
250- if os .path .exists (test_times_filename ):
251- print (f"Overwriting existent file: { test_times_filename } " )
252- with open (test_times_filename , "w+" ) as file :
253- job_times_json = _get_job_times_json (test_times )
254- json .dump (job_times_json , file , indent = " " , separators = ("," , ": " ))
255- file .write ("\n " )
256- return test_times
257-
258-
25996def get_test_case_configs (dirpath : str ) -> None :
26097 get_slow_tests (dirpath = dirpath )
26198 get_disabled_tests (dirpath = dirpath )
0 commit comments