mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] fix carveout feature (#164303)
Fixes #164271. Carveout had been applied with an opposite bitmask. Besides being incorrect, this lead to flaky unit test behavior due to carveout being too high. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164303 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
315ffdc1e4
commit
7304b9e7d2
@ -932,6 +932,9 @@ class TestFP8Matmul(TestCase):
|
||||
x_fp8 = to_fp8_saturated(x / x_scales, e4m3_type)
|
||||
y_fp8 = to_fp8_saturated(y / y_scales, e4m3_type)
|
||||
|
||||
cu_count = torch.cuda.get_device_properties().multi_processor_count
|
||||
carveout = 66 if torch.version.cuda else cu_count // 8
|
||||
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
|
||||
self.assertIsNone(torch._C._get_sm_carveout_experimental())
|
||||
@ -952,16 +955,24 @@ class TestFP8Matmul(TestCase):
|
||||
# events were returned out of order; need to be sorted on "ts" timestamp
|
||||
events = sorted(events, key=lambda x: x['ts'])
|
||||
# ROCm carveout is invisible except for kernels running slower on fewer CUs
|
||||
no_carveout, carveout_0, carveout_66, no_carveout_again = [float(evt.get("dur", "0.0")) for evt in events]
|
||||
self.assertTrue(no_carveout < carveout_66)
|
||||
self.assertTrue(carveout_0 < carveout_66)
|
||||
self.assertTrue(no_carveout_again < carveout_66)
|
||||
no_carveout, carveout_0, carveout, no_carveout_again = [float(evt.get("dur", "0.0")) for evt in events]
|
||||
if True or not (no_carveout < carveout and carveout_0 < carveout and no_carveout_again < carveout):
|
||||
# something went wrong, print more info to help debug flaky test
|
||||
print("ROCm debug info for test_honor_sm_carveout")
|
||||
print("cu_count", cu_count)
|
||||
print("no_carveout", no_carveout)
|
||||
print("carveout_0", carveout_0)
|
||||
print("carveout", carveout)
|
||||
print("no_carveout_again", no_carveout_again)
|
||||
self.assertTrue(no_carveout < carveout)
|
||||
self.assertTrue(carveout_0 < carveout)
|
||||
self.assertTrue(no_carveout_again < carveout)
|
||||
# ROCm carveout will create new streams when enabled, and go back to the original stream when disabled
|
||||
no_carveout, carveout_0, carveout_66, no_carveout_again = [int(evt.get("tid", "0")) for evt in events]
|
||||
no_carveout, carveout_0, carveout, no_carveout_again = [int(evt.get("tid", "0")) for evt in events]
|
||||
self.assertTrue(no_carveout == no_carveout_again)
|
||||
self.assertTrue(no_carveout != carveout_0)
|
||||
self.assertTrue(no_carveout != carveout_66)
|
||||
self.assertTrue(carveout_0 != carveout_66)
|
||||
self.assertTrue(no_carveout == carveout_0)
|
||||
self.assertTrue(no_carveout != carveout)
|
||||
self.assertTrue(carveout_0 != carveout)
|
||||
else:
|
||||
no_carveout, carveout_0, carveout_66, no_carveout_again = [
|
||||
math.prod(evt.get("args", {}).get("grid", []))
|
||||
|
Reference in New Issue
Block a user