mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add print statements to debug sharding error (#102713)
sharding on rocm is broken, i cant replicate on dummy PRs even though it seems to happen pretty often on main, so adding this to increase my sample size. Hopefully this is enough print statements... Pull Request resolved: https://github.com/pytorch/pytorch/pull/102713 Approved by: https://github.com/huydhn
This commit is contained in:
committed by
PyTorch MergeBot
parent
cf0aa38005
commit
c7873522c2
@ -91,9 +91,16 @@ def calculate_shards(
|
||||
tests: List[str],
|
||||
test_file_times: Dict[str, float],
|
||||
must_serial: Optional[Callable[[str], bool]] = None,
|
||||
debug: bool = False,
|
||||
) -> List[Tuple[float, List[ShardedTest]]]:
|
||||
must_serial = must_serial or (lambda x: True)
|
||||
|
||||
if debug:
|
||||
print(test_file_times)
|
||||
print(tests)
|
||||
print(num_shards)
|
||||
print([x for x in tests if must_serial(x)])
|
||||
|
||||
known_tests = [x for x in tests if x in test_file_times]
|
||||
unknown_tests: List[str] = [x for x in tests if x not in known_tests]
|
||||
|
||||
@ -117,6 +124,11 @@ def calculate_shards(
|
||||
for unknown_test in unknown_tests:
|
||||
sharded_jobs[index].serial.append(ShardedTest(unknown_test, 1, 1, None))
|
||||
index = (index + 1) % num_shards
|
||||
|
||||
if debug:
|
||||
for j in sharded_jobs:
|
||||
print(j.convert_to_tuple()[1])
|
||||
|
||||
return [job.convert_to_tuple() for job in sharded_jobs]
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user