mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[CI] Use job name to index into test times json (#147154)
When the test times are generated, it doesn't know what the build environment is because it's an environment variable. But when we index into the test times, we (previously) didn't know what the job name is. These are usually the same but sometimes they're different and when they're different it ends up using default, which can have unbalanced sharding I think job name was added at some point to most of the CI environments but I didn't realize, so we can now update this code to use the job name instead so the generation and the indexing match also upload stats workflow for mps Checked that inductor_amx doesn't use default Pull Request resolved: https://github.com/pytorch/pytorch/pull/147154 Approved by: https://github.com/huydhn
This commit is contained in:
committed by
PyTorch MergeBot
parent
e8fbc86de0
commit
0d16188c06
2
.github/workflows/upload-test-stats.yml
vendored
2
.github/workflows/upload-test-stats.yml
vendored
@ -2,7 +2,7 @@ name: Upload test stats
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: [pull, trunk, periodic, inductor, unstable, slow, unstable-periodic, inductor-periodic, rocm, inductor-micro-benchmark, inductor-micro-benchmark-x86, inductor-cu124, inductor-rocm]
|
||||
workflows: [pull, trunk, periodic, inductor, unstable, slow, unstable-periodic, inductor-periodic, rocm, inductor-micro-benchmark, inductor-micro-benchmark-x86, inductor-cu124, inductor-rocm, mac-mps]
|
||||
types:
|
||||
- completed
|
||||
|
||||
|
||||
@ -1915,21 +1915,26 @@ def load_test_times_from_file(file: str) -> dict[str, Any]:
|
||||
|
||||
with open(path) as f:
|
||||
test_times_file = cast(dict[str, Any], json.load(f))
|
||||
build_environment = os.environ.get("BUILD_ENVIRONMENT")
|
||||
job_name = os.environ.get("JOB_NAME")
|
||||
if job_name is None or job_name == "":
|
||||
# If job name isn't available, use build environment as a backup
|
||||
job_name = os.environ.get("BUILD_ENVIRONMENT")
|
||||
else:
|
||||
job_name = job_name.split(" / test (")[0]
|
||||
test_config = os.environ.get("TEST_CONFIG")
|
||||
if test_config in test_times_file.get(build_environment, {}):
|
||||
if test_config in test_times_file.get(job_name, {}):
|
||||
print_to_stderr("Found test times from artifacts")
|
||||
return test_times_file[build_environment][test_config]
|
||||
return test_times_file[job_name][test_config]
|
||||
elif test_config in test_times_file["default"]:
|
||||
print_to_stderr(
|
||||
f"::warning:: Gathered no stats from artifacts for {build_environment} build env"
|
||||
f" and {test_config} test config. Using default build env and {test_config} test config instead."
|
||||
f"::warning:: Gathered no stats from artifacts for {job_name} build env"
|
||||
f" and {test_config} test config. Using default job name and {test_config} test config instead."
|
||||
)
|
||||
return test_times_file["default"][test_config]
|
||||
else:
|
||||
print_to_stderr(
|
||||
f"::warning:: Gathered no stats from artifacts for build env {build_environment} build env"
|
||||
f" and {test_config} test config. Using default build env and default test config instead."
|
||||
f"::warning:: Gathered no stats from artifacts for job name {job_name} build env"
|
||||
f" and {test_config} test config. Using default job name and default test config instead."
|
||||
)
|
||||
return test_times_file["default"]["default"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user