Add print statements to debug sharding error (#102713)

sharding on rocm is broken, i cant replicate on dummy PRs even though it seems to happen pretty often on main, so adding this to increase my sample size.  Hopefully this is enough print statements...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102713
Approved by: https://github.com/huydhn
This commit is contained in:
Catherine Lee
2023-06-01 22:38:24 +00:00
committed by PyTorch MergeBot
parent cf0aa38005
commit c7873522c2
2 changed files with 17 additions and 1 deletions

View File

@ -1434,7 +1434,11 @@ def get_selected_tests(options) -> List[ShardedTest]:
# Do sharding
test_file_times_config = test_file_times.get(test_config, {})
shards = calculate_shards(
num_shards, selected_tests, test_file_times_config, must_serial=must_serial
num_shards,
selected_tests,
test_file_times_config,
must_serial=must_serial,
debug=TEST_WITH_ROCM,
)
_, tests_from_shard = shards[which_shard - 1]
selected_tests = tests_from_shard