[ROCm] fix carveout feature (#164303)

Fixes #164271.

Carveout had been applied with an opposite bitmask. Besides being incorrect, this lead to flaky unit test behavior due to carveout being too high.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164303
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
Jeff Daily
2025-10-01 19:25:41 +00:00
committed by PyTorch MergeBot
parent 315ffdc1e4
commit 7304b9e7d2
2 changed files with 26 additions and 11 deletions

View File

@ -191,6 +191,10 @@ uint32_t _getAlignment(uintptr_t address) {
#ifdef USE_ROCM
static c10::cuda::CUDAStream _getCarveoutStream(int32_t value) {
// 0 is default value, meaning full CUs i.e. no mask
if (value == 0) {
return at::cuda::getCurrentCUDAStream();
}
static int32_t last_value = 0;
static hipStream_t stream;
if (last_value == 0) {
@ -209,15 +213,15 @@ static c10::cuda::CUDAStream _getCarveoutStream(int32_t value) {
int32_t CUs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
// how many uint32_t do we need to cover all CUs, fill bitmask with 1
uint32_t mask_size = static_cast<uint32_t>((CUs + 32 - 1) / 32);
std::vector<uint32_t> mask(mask_size, uint32_t{0xffffffff});
std::vector<uint32_t> mask(mask_size, uint32_t{0x00000000});
// starting from lowest order bits, in 32-bit chunks
// set bits to 0 based on how many CUs to carve out
int32_t full_shifts = value / 32;
int32_t remainder = value % 32;
for (int32_t i = 0; i < full_shifts; i++) {
mask[i] = uint32_t{0x00000000};
mask[i] = uint32_t{0xffffffff};
}
mask[full_shifts] = uint32_t{0xffffffff} << remainder;
mask[full_shifts] = uint32_t{0xffffffff} << (32 - remainder);
// finally, create masked stream
AT_CUDA_CHECK(hipExtStreamCreateWithCUMask(&stream, mask_size, &mask[0]));

View File

@ -932,6 +932,9 @@ class TestFP8Matmul(TestCase):
x_fp8 = to_fp8_saturated(x / x_scales, e4m3_type)
y_fp8 = to_fp8_saturated(y / y_scales, e4m3_type)
cu_count = torch.cuda.get_device_properties().multi_processor_count
carveout = 66 if torch.version.cuda else cu_count // 8
with tempfile.NamedTemporaryFile() as f:
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
self.assertIsNone(torch._C._get_sm_carveout_experimental())
@ -952,16 +955,24 @@ class TestFP8Matmul(TestCase):
# events were returned out of order; need to be sorted on "ts" timestamp
events = sorted(events, key=lambda x: x['ts'])
# ROCm carveout is invisible except for kernels running slower on fewer CUs
no_carveout, carveout_0, carveout_66, no_carveout_again = [float(evt.get("dur", "0.0")) for evt in events]
self.assertTrue(no_carveout < carveout_66)
self.assertTrue(carveout_0 < carveout_66)
self.assertTrue(no_carveout_again < carveout_66)
no_carveout, carveout_0, carveout, no_carveout_again = [float(evt.get("dur", "0.0")) for evt in events]
if True or not (no_carveout < carveout and carveout_0 < carveout and no_carveout_again < carveout):
# something went wrong, print more info to help debug flaky test
print("ROCm debug info for test_honor_sm_carveout")
print("cu_count", cu_count)
print("no_carveout", no_carveout)
print("carveout_0", carveout_0)
print("carveout", carveout)
print("no_carveout_again", no_carveout_again)
self.assertTrue(no_carveout < carveout)
self.assertTrue(carveout_0 < carveout)
self.assertTrue(no_carveout_again < carveout)
# ROCm carveout will create new streams when enabled, and go back to the original stream when disabled
no_carveout, carveout_0, carveout_66, no_carveout_again = [int(evt.get("tid", "0")) for evt in events]
no_carveout, carveout_0, carveout, no_carveout_again = [int(evt.get("tid", "0")) for evt in events]
self.assertTrue(no_carveout == no_carveout_again)
self.assertTrue(no_carveout != carveout_0)
self.assertTrue(no_carveout != carveout_66)
self.assertTrue(carveout_0 != carveout_66)
self.assertTrue(no_carveout == carveout_0)
self.assertTrue(no_carveout != carveout)
self.assertTrue(carveout_0 != carveout)
else:
no_carveout, carveout_0, carveout_66, no_carveout_again = [
math.prod(evt.get("args", {}).get("grid", []))