Port two dynamo test cases for Intel GPU (#156056)

For https://github.com/pytorch/pytorch/issues/114850, we will port more cases to Intel GPU. This PR is for 2 dynamo cases. We adopted "torch.accelerator.current_accelerator()" to determine the backend, and added XPU support in decorators like @requires_gpu, also enabled XPU for some test path.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156056
Approved by: https://github.com/guangyey, https://github.com/jansel
This commit is contained in:
Daisy Deng
2025-06-19 12:48:58 +00:00
committed by PyTorch MergeBot
parent a8fe982993
commit ccb1f687d6
3 changed files with 46 additions and 34 deletions

View File

@ -3,14 +3,17 @@ import unittest
import torch._dynamo
from torch._dynamo.test_minifier_common import MinifierTestBase
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import skipIfNNModuleInlined
requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda")
requires_gpu = unittest.skipUnless(
torch.cuda.is_available() or torch.xpu.is_available(), "requires cuda or xpu"
)
class MinifierTests(MinifierTestBase):
# Test that compile, runtime, and accuracy errors after dynamo can be repro'd (both CPU and CUDA)
# Test that compile, runtime, and accuracy errors after dynamo can be repro'd (both CPU and CUDA/XPU)
def _test_after_dynamo(self, device, backend, expected_error):
run_code = f"""\
@torch.compile(backend={backend!r})
@ -41,22 +44,22 @@ inner(torch.randn(20, 20).to("{device}"))
"cpu", "relu_accuracy_error_TESTING_ONLY", "AccuracyError"
)
@requires_cuda
def test_after_dynamo_cuda_compile_error(self):
@requires_gpu
def test_after_dynamo_cuda_compile_error(self, device):
self._test_after_dynamo(
"cuda", "relu_compile_error_TESTING_ONLY", "ReluCompileError"
device, "relu_compile_error_TESTING_ONLY", "ReluCompileError"
)
@requires_cuda
def test_after_dynamo_cuda_runtime_error(self):
@requires_gpu
def test_after_dynamo_cuda_runtime_error(self, device):
self._test_after_dynamo(
"cuda", "relu_runtime_error_TESTING_ONLY", "ReluRuntimeError"
device, "relu_runtime_error_TESTING_ONLY", "ReluRuntimeError"
)
@requires_cuda
def test_after_dynamo_cuda_accuracy_error(self):
@requires_gpu
def test_after_dynamo_cuda_accuracy_error(self, device):
self._test_after_dynamo(
"cuda", "relu_accuracy_error_TESTING_ONLY", "AccuracyError"
device, "relu_accuracy_error_TESTING_ONLY", "AccuracyError"
)
def test_after_dynamo_non_leaf_compile_error(self):
@ -94,38 +97,38 @@ inner(torch.randn(20, 20, requires_grad=True) + 1)
"cpu", "relu_accuracy_error_TESTING_ONLY"
)
@requires_cuda
def test_after_dynamo_cuda_compile_backend_passes(self):
@requires_gpu
def test_after_dynamo_cuda_compile_backend_passes(self, device):
self._test_after_dynamo_backend_passes(
"cuda", "relu_compile_error_TESTING_ONLY"
device, "relu_compile_error_TESTING_ONLY"
)
@requires_cuda
def test_after_dynamo_cuda_runtime_backend_passes(self):
@requires_gpu
def test_after_dynamo_cuda_runtime_backend_passes(self, device):
self._test_after_dynamo_backend_passes(
"cuda", "relu_runtime_error_TESTING_ONLY"
device, "relu_runtime_error_TESTING_ONLY"
)
@requires_cuda
def test_after_dynamo_cuda_accuracy_backend_passes(self):
@requires_gpu
def test_after_dynamo_cuda_accuracy_backend_passes(self, device):
self._test_after_dynamo_backend_passes(
"cuda", "relu_accuracy_error_TESTING_ONLY"
device, "relu_accuracy_error_TESTING_ONLY"
)
# Test that a module with mixed cpu/cuda parts with an error after dynamo can be repro'd
# Test that a module with mixed cpu/(cuda|xpu) parts with an error after dynamo can be repro'd
@skipIfNNModuleInlined()
@requires_cuda
def test_cpu_cuda_module_after_dynamo(self):
@requires_gpu
def test_cpu_cuda_module_after_dynamo(self, device):
backend_name = "relu_compile_error_TESTING_ONLY"
run_code = f"""\
class CpuCudaModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.m_x = torch.nn.Linear(20, 20).cuda()
self.m_x = torch.nn.Linear(20, 20).to(device)
self.m_y = torch.nn.Linear(20, 20)
self.p_x = torch.nn.Parameter(torch.randn(20, 20).cuda())
self.p_x = torch.nn.Parameter(torch.randn(20, 20).to(device))
self.p_y = torch.nn.Parameter(torch.randn(20, 20))
self.b_x = torch.nn.Buffer(torch.ones(20, 20).cuda())
self.b_x = torch.nn.Buffer(torch.ones(20, 20).to(device))
self.b_y = torch.nn.Buffer(torch.ones(20, 20))
def forward(self, x, y):
@ -135,12 +138,12 @@ mod = CpuCudaModule()
@torch.compile(backend={backend_name!r})
def inner(x1, y1):
x2 = torch.randn(20, 20).cuda()
x2 = torch.randn(20, 20).to(device)
y2 = torch.randn(20, 20)
x3, y3 = mod(x1 + x2, y1 + y2)
return torch.relu(x3.cpu() + y3)
inner(torch.randn(20, 20).cuda(), torch.randn(20, 20))
inner(torch.randn(20, 20).to(device), torch.randn(20, 20))
"""
res = self._run_full_test(run_code, "dynamo", "ReluCompileError", isolate=False)
@ -151,18 +154,18 @@ inner(torch.randn(20, 20).cuda(), torch.randn(20, 20))
class Repro(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.G__mod___m_x = Linear(in_features=20, out_features=20, bias=True).cuda()
self.G__mod___m_x = Linear(in_features=20, out_features=20, bias=True).to(device)
self.G__mod___m_y = Linear(in_features=20, out_features=20, bias=True)
self.register_buffer('G__mod___b_x', torch.randn([20, 20], dtype=torch.float32).cuda())
self.register_buffer('G__mod___b_x', torch.randn([20, 20], dtype=torch.float32).to(device))
self.register_buffer('G__mod___b_y', torch.randn([20, 20], dtype=torch.float32))
self.G__mod___p_x = torch.nn.Parameter(torch.randn([20, 20], dtype=torch.float32, device="cuda"))
self.G__mod___p_x = torch.nn.Parameter(torch.randn([20, 20], dtype=torch.float32, device=device))
self.G__mod___p_y = torch.nn.Parameter(torch.randn([20, 20], dtype=torch.float32))
def forward(self, L_x1_ : torch.Tensor, L_y1_ : torch.Tensor):
l_x1_ = L_x1_
l_y1_ = L_y1_
randn = torch.randn(20, 20)
x2 = randn.cuda(); randn = None
x2 = randn.to(device); randn = None
y2 = torch.randn(20, 20)
add = l_x1_ + x2; l_x1_ = x2 = None
add_1 = l_y1_ + y2; l_y1_ = y2 = None
@ -213,6 +216,11 @@ class Repro(torch.nn.Module):
)
devices = ["cuda", "xpu", "cpu"]
instantiate_device_type_tests(
MinifierTests, globals(), only_for=devices, allow_xpu=True
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -883,8 +883,10 @@ class UnspecTestsDevice(torch._dynamo.test_case.TestCase):
self.assertEqual(ref.device, res.device)
devices = ["cuda", "hpu"]
instantiate_device_type_tests(UnspecTestsDevice, globals(), only_for=devices)
devices = ["cuda", "hpu", "xpu"]
instantiate_device_type_tests(
UnspecTestsDevice, globals(), only_for=devices, allow_xpu=True
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -73,6 +73,8 @@ class MinifierTestBase(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
if not os.path.exists(cls.DEBUG_DIR):
cls.DEBUG_DIR = tempfile.mkdtemp()
cls._exit_stack.enter_context( # type: ignore[attr-defined]
torch._dynamo.config.patch(debug_dir_root=cls.DEBUG_DIR)
)