mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] fix carveout feature (#164303)
Fixes #164271. Carveout had been applied with an opposite bitmask. Besides being incorrect, this lead to flaky unit test behavior due to carveout being too high. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164303 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
315ffdc1e4
commit
7304b9e7d2
@ -191,6 +191,10 @@ uint32_t _getAlignment(uintptr_t address) {
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static c10::cuda::CUDAStream _getCarveoutStream(int32_t value) {
|
||||
// 0 is default value, meaning full CUs i.e. no mask
|
||||
if (value == 0) {
|
||||
return at::cuda::getCurrentCUDAStream();
|
||||
}
|
||||
static int32_t last_value = 0;
|
||||
static hipStream_t stream;
|
||||
if (last_value == 0) {
|
||||
@ -209,15 +213,15 @@ static c10::cuda::CUDAStream _getCarveoutStream(int32_t value) {
|
||||
int32_t CUs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
|
||||
// how many uint32_t do we need to cover all CUs, fill bitmask with 1
|
||||
uint32_t mask_size = static_cast<uint32_t>((CUs + 32 - 1) / 32);
|
||||
std::vector<uint32_t> mask(mask_size, uint32_t{0xffffffff});
|
||||
std::vector<uint32_t> mask(mask_size, uint32_t{0x00000000});
|
||||
// starting from lowest order bits, in 32-bit chunks
|
||||
// set bits to 0 based on how many CUs to carve out
|
||||
int32_t full_shifts = value / 32;
|
||||
int32_t remainder = value % 32;
|
||||
for (int32_t i = 0; i < full_shifts; i++) {
|
||||
mask[i] = uint32_t{0x00000000};
|
||||
mask[i] = uint32_t{0xffffffff};
|
||||
}
|
||||
mask[full_shifts] = uint32_t{0xffffffff} << remainder;
|
||||
mask[full_shifts] = uint32_t{0xffffffff} << (32 - remainder);
|
||||
|
||||
// finally, create masked stream
|
||||
AT_CUDA_CHECK(hipExtStreamCreateWithCUMask(&stream, mask_size, &mask[0]));
|
||||
|
||||
@ -932,6 +932,9 @@ class TestFP8Matmul(TestCase):
|
||||
x_fp8 = to_fp8_saturated(x / x_scales, e4m3_type)
|
||||
y_fp8 = to_fp8_saturated(y / y_scales, e4m3_type)
|
||||
|
||||
cu_count = torch.cuda.get_device_properties().multi_processor_count
|
||||
carveout = 66 if torch.version.cuda else cu_count // 8
|
||||
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
|
||||
self.assertIsNone(torch._C._get_sm_carveout_experimental())
|
||||
@ -952,16 +955,24 @@ class TestFP8Matmul(TestCase):
|
||||
# events were returned out of order; need to be sorted on "ts" timestamp
|
||||
events = sorted(events, key=lambda x: x['ts'])
|
||||
# ROCm carveout is invisible except for kernels running slower on fewer CUs
|
||||
no_carveout, carveout_0, carveout_66, no_carveout_again = [float(evt.get("dur", "0.0")) for evt in events]
|
||||
self.assertTrue(no_carveout < carveout_66)
|
||||
self.assertTrue(carveout_0 < carveout_66)
|
||||
self.assertTrue(no_carveout_again < carveout_66)
|
||||
no_carveout, carveout_0, carveout, no_carveout_again = [float(evt.get("dur", "0.0")) for evt in events]
|
||||
if True or not (no_carveout < carveout and carveout_0 < carveout and no_carveout_again < carveout):
|
||||
# something went wrong, print more info to help debug flaky test
|
||||
print("ROCm debug info for test_honor_sm_carveout")
|
||||
print("cu_count", cu_count)
|
||||
print("no_carveout", no_carveout)
|
||||
print("carveout_0", carveout_0)
|
||||
print("carveout", carveout)
|
||||
print("no_carveout_again", no_carveout_again)
|
||||
self.assertTrue(no_carveout < carveout)
|
||||
self.assertTrue(carveout_0 < carveout)
|
||||
self.assertTrue(no_carveout_again < carveout)
|
||||
# ROCm carveout will create new streams when enabled, and go back to the original stream when disabled
|
||||
no_carveout, carveout_0, carveout_66, no_carveout_again = [int(evt.get("tid", "0")) for evt in events]
|
||||
no_carveout, carveout_0, carveout, no_carveout_again = [int(evt.get("tid", "0")) for evt in events]
|
||||
self.assertTrue(no_carveout == no_carveout_again)
|
||||
self.assertTrue(no_carveout != carveout_0)
|
||||
self.assertTrue(no_carveout != carveout_66)
|
||||
self.assertTrue(carveout_0 != carveout_66)
|
||||
self.assertTrue(no_carveout == carveout_0)
|
||||
self.assertTrue(no_carveout != carveout)
|
||||
self.assertTrue(carveout_0 != carveout)
|
||||
else:
|
||||
no_carveout, carveout_0, carveout_66, no_carveout_again = [
|
||||
math.prod(evt.get("args", {}).get("grid", []))
|
||||
|
||||
Reference in New Issue
Block a user