[ROCm] fix carveout feature (#164303)

Fixes #164271.

Carveout had been applied with an opposite bitmask. Besides being incorrect, this lead to flaky unit test behavior due to carveout being too high.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164303
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
Jeff Daily
2025-10-01 19:25:41 +00:00
committed by PyTorch MergeBot
parent 315ffdc1e4
commit 7304b9e7d2
2 changed files with 26 additions and 11 deletions

View File

@ -932,6 +932,9 @@ class TestFP8Matmul(TestCase):
x_fp8 = to_fp8_saturated(x / x_scales, e4m3_type)
y_fp8 = to_fp8_saturated(y / y_scales, e4m3_type)
cu_count = torch.cuda.get_device_properties().multi_processor_count
carveout = 66 if torch.version.cuda else cu_count // 8
with tempfile.NamedTemporaryFile() as f:
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
self.assertIsNone(torch._C._get_sm_carveout_experimental())
@ -952,16 +955,24 @@ class TestFP8Matmul(TestCase):
# events were returned out of order; need to be sorted on "ts" timestamp
events = sorted(events, key=lambda x: x['ts'])
# ROCm carveout is invisible except for kernels running slower on fewer CUs
no_carveout, carveout_0, carveout_66, no_carveout_again = [float(evt.get("dur", "0.0")) for evt in events]
self.assertTrue(no_carveout < carveout_66)
self.assertTrue(carveout_0 < carveout_66)
self.assertTrue(no_carveout_again < carveout_66)
no_carveout, carveout_0, carveout, no_carveout_again = [float(evt.get("dur", "0.0")) for evt in events]
if True or not (no_carveout < carveout and carveout_0 < carveout and no_carveout_again < carveout):
# something went wrong, print more info to help debug flaky test
print("ROCm debug info for test_honor_sm_carveout")
print("cu_count", cu_count)
print("no_carveout", no_carveout)
print("carveout_0", carveout_0)
print("carveout", carveout)
print("no_carveout_again", no_carveout_again)
self.assertTrue(no_carveout < carveout)
self.assertTrue(carveout_0 < carveout)
self.assertTrue(no_carveout_again < carveout)
# ROCm carveout will create new streams when enabled, and go back to the original stream when disabled
no_carveout, carveout_0, carveout_66, no_carveout_again = [int(evt.get("tid", "0")) for evt in events]
no_carveout, carveout_0, carveout, no_carveout_again = [int(evt.get("tid", "0")) for evt in events]
self.assertTrue(no_carveout == no_carveout_again)
self.assertTrue(no_carveout != carveout_0)
self.assertTrue(no_carveout != carveout_66)
self.assertTrue(carveout_0 != carveout_66)
self.assertTrue(no_carveout == carveout_0)
self.assertTrue(no_carveout != carveout)
self.assertTrue(carveout_0 != carveout)
else:
no_carveout, carveout_0, carveout_66, no_carveout_again = [
math.prod(evt.get("args", {}).get("grid", []))