Add print statements to debug sharding error (#102713)

sharding on rocm is broken, i cant replicate on dummy PRs even though it seems to happen pretty often on main, so adding this to increase my sample size.  Hopefully this is enough print statements...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102713
Approved by: https://github.com/huydhn
This commit is contained in:
Catherine Lee
2023-06-01 22:38:24 +00:00
committed by PyTorch MergeBot
parent cf0aa38005
commit c7873522c2
2 changed files with 17 additions and 1 deletions

View File

@ -91,9 +91,16 @@ def calculate_shards(
tests: List[str],
test_file_times: Dict[str, float],
must_serial: Optional[Callable[[str], bool]] = None,
debug: bool = False,
) -> List[Tuple[float, List[ShardedTest]]]:
must_serial = must_serial or (lambda x: True)
if debug:
print(test_file_times)
print(tests)
print(num_shards)
print([x for x in tests if must_serial(x)])
known_tests = [x for x in tests if x in test_file_times]
unknown_tests: List[str] = [x for x in tests if x not in known_tests]
@ -117,6 +124,11 @@ def calculate_shards(
for unknown_test in unknown_tests:
sharded_jobs[index].serial.append(ShardedTest(unknown_test, 1, 1, None))
index = (index + 1) % num_shards
if debug:
for j in sharded_jobs:
print(j.convert_to_tuple()[1])
return [job.convert_to_tuple() for job in sharded_jobs]