mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix rocm sharding (#102871)
Rocm queries for the number of processes it should use per machine, which might cause it be different across shards, which leads to inconsistencies when distributing tests among shards. My solution is to separate the vars used for shard calculations and the actual number of procs that can be used and to ensure that the var used for shard calculations is consistent across all shards for a test config + job. I believe that the only consequence is that rocm sharding might become unbalanced. Pull Request resolved: https://github.com/pytorch/pytorch/pull/102871 Approved by: https://github.com/huydhn, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
a867e6db85
commit
90fd90dd94
@ -14,14 +14,20 @@ from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
|
||||
|
||||
IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1"
|
||||
|
||||
# NUM_PROCS_FOR_SHARDING_CALC must remain consistent across all shards of a job
|
||||
# to ensure that sharding is consistent, NUM_PROCS is the actual number of procs
|
||||
# used to run tests. If they are not equal, the only consequence should be
|
||||
# unequal shards.
|
||||
IS_ROCM = os.path.exists("/opt/rocm")
|
||||
NUM_PROCS = 1 if IS_MEM_LEAK_CHECK else 2
|
||||
NUM_PROCS_FOR_SHARDING_CALC = NUM_PROCS if not IS_ROCM or IS_MEM_LEAK_CHECK else 2
|
||||
THRESHOLD = 60 * 10 # 10 minutes
|
||||
|
||||
# See Note [ROCm parallel CI testing]
|
||||
# Special logic for ROCm GHA runners to query number of GPUs available.
|
||||
# torch.version.hip was not available to check if this was a ROCm self-hosted runner.
|
||||
# Must check for ROCm runner in another way. We look for /opt/rocm directory.
|
||||
if os.path.exists("/opt/rocm") and not IS_MEM_LEAK_CHECK:
|
||||
if IS_ROCM and not IS_MEM_LEAK_CHECK:
|
||||
try:
|
||||
# This is the same logic used in GHA health check, see .github/templates/common.yml.j2
|
||||
lines = (
|
||||
@ -58,7 +64,7 @@ class ShardJob:
|
||||
self.parallel: List[ShardedTest] = []
|
||||
|
||||
def get_total_time(self) -> float:
|
||||
procs = [0.0 for _ in range(NUM_PROCS)]
|
||||
procs = [0.0 for _ in range(NUM_PROCS_FOR_SHARDING_CALC)]
|
||||
for test in self.parallel:
|
||||
min_index = procs.index(min(procs))
|
||||
procs[min_index] += test.get_time()
|
||||
|
Reference in New Issue
Block a user