Port three dynamo test to Intel GPU (#156575)

For https://github.com/pytorch/pytorch/issues/114850, we will port test cases to Intel GPU. Two dynamo test files were ported in PR [#156056](https://github.com/pytorch/pytorch/pull/156056). In this PR we will port 3 more dynamo test files.
We could enable Intel GPU with following methods and try the best to keep the original code styles:

- instantiate_device_type_tests()
- use "torch.accelerator.current_accelerator()" to determine the accelerator backend
- added XPU support in decorators like @requires_gpu
- enabled XPU for some test path.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156575
Approved by: https://github.com/guangyey, https://github.com/jansel

Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
This commit is contained in:
Daisy Deng
2025-06-27 05:56:18 +00:00
committed by PyTorch MergeBot
parent 51853b358e
commit 1155c53e7d
4 changed files with 41 additions and 27 deletions

View File

@ -90,6 +90,7 @@ from torch.testing._internal.common_utils import (
skipIfNNModuleInlined, skipIfNNModuleInlined,
skipIfWindows, skipIfWindows,
TEST_HPU, TEST_HPU,
TEST_XPU,
wrapDeterministicFlagAPITest, wrapDeterministicFlagAPITest,
) )
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
@ -6904,7 +6905,7 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
self.assertTrue(guard_failure is not None) self.assertTrue(guard_failure is not None)
self.assertIn("""tensor 'rank' size mismatch at index 0""", guard_failure[0]) self.assertIn("""tensor 'rank' size mismatch at index 0""", guard_failure[0])
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") @unittest.skipIf(not TEST_CUDA and not TEST_XPU, "Test requires CUDA or XPU.")
def test_symint_as_device_kwarg_non_strict_export(self): def test_symint_as_device_kwarg_non_strict_export(self):
class Mod(torch.nn.Module): class Mod(torch.nn.Module):
def forward(self, x): def forward(self, x):
@ -12771,7 +12772,7 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase):
def test_torch_device_is_available(self, device): def test_torch_device_is_available(self, device):
def fn(x): def fn(x):
if TEST_HPU or TEST_CUDA: if torch.accelerator.is_available():
return x + 1 return x + 1
else: else:
return x - 1 return x - 1
@ -12874,27 +12875,23 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase):
def test_cuda_set_device(self, device): def test_cuda_set_device(self, device):
def fn(): def fn():
a = torch.ones(2, device=device) a = torch.ones(2, device=device)
torch.cuda.set_device(1) torch.get_device_module(device).set_device(1)
return a + 1 return a + 1
with torch.cuda.device(0): with torch.get_device_module(device).device(0):
counter = CompileCounter() counter = CompileCounter()
opt_fn = torch.compile(fn, backend=counter) opt_fn = torch.compile(fn, backend=counter)
res = opt_fn() res = opt_fn()
self.assertEqual(res.device.type, "cuda") self.assertEqual(res.device.type, device)
self.assertEqual(res.device.index, 0) self.assertEqual(res.device.index, 0)
self.assertEqual(counter.frame_count, 2) self.assertEqual(counter.frame_count, 2)
def test_torch_device_python_type(self): def test_torch_device_python_type(self, device):
device_type = torch.device(device).type
for device, device_type, index in [ for device, device_type, index in [
("cpu", "cpu", None), ("cpu", "cpu", None),
("cuda:0", "cuda", 0), (device, device_type, 0),
("hpu:0", "hpu", 0),
]: ]:
if (device == "cuda:0" and not TEST_CUDA) or (
device == "hpu:0" and not TEST_HPU
):
continue
def fn(target): def fn(target):
target_device = target.device target_device = target.device
@ -12956,8 +12953,10 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase):
f(x, y) f(x, y)
devices = ("cuda", "hpu") devices = ("cuda", "hpu", "xpu")
instantiate_device_type_tests(MiscTestsDevice, globals(), only_for=devices) instantiate_device_type_tests(
MiscTestsDevice, globals(), only_for=devices, allow_xpu=True
)
if __name__ == "__main__": if __name__ == "__main__":
from torch._dynamo.test_case import run_tests from torch._dynamo.test_case import run_tests

View File

@ -7,7 +7,7 @@ import torch._dynamo.test_case
import torch._dynamo.testing import torch._dynamo.testing
from torch._dynamo.testing import same from torch._dynamo.testing import same
from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import TEST_HPU, TestCase from torch.testing._internal.common_utils import TestCase
try: try:
@ -359,11 +359,11 @@ class TestModelOutputBert(TestCase):
) )
devices = ["cpu", "cuda"] devices = ["cpu", "cuda", "xpu", "hpu"]
if TEST_HPU:
devices.append("hpu")
instantiate_device_type_tests(TestModelOutputBert, globals(), only_for=devices) instantiate_device_type_tests(
TestModelOutputBert, globals(), only_for=devices, allow_xpu=True
)
if __name__ == "__main__": if __name__ == "__main__":
from torch._dynamo.test_case import run_tests from torch._dynamo.test_case import run_tests

View File

@ -12,11 +12,16 @@ from torch._C import (
_push_on_torch_function_stack, _push_on_torch_function_stack,
) )
from torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode from torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode
from torch.testing._internal.triton_utils import requires_cuda from torch.testing._internal.triton_utils import requires_gpu
from torch.utils._device import DeviceContext from torch.utils._device import DeviceContext
from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._python_dispatch import TorchDispatchMode
device_type = (
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
)
class TestMode(BaseTorchFunctionMode): class TestMode(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None): def __torch_function__(self, func, types, args, kwargs=None):
if not kwargs: if not kwargs:
@ -613,12 +618,12 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
func(torch.randn(3)) func(torch.randn(3))
@requires_cuda @requires_gpu
def test_flex_attention(self): def test_flex_attention(self):
import torch import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention from torch.nn.attention.flex_attention import create_block_mask, flex_attention
torch.set_default_device("cuda") torch.set_default_device(device_type)
flex_attention = torch.compile(flex_attention, dynamic=False) flex_attention = torch.compile(flex_attention, dynamic=False)
@ -628,7 +633,9 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
return prefix_lengths[b] >= kv return prefix_lengths[b] >= kv
# This runs in fullgraph already # This runs in fullgraph already
create_block_mask(prefix_lm, 8, None, 512, 512, _compile=True) create_block_mask(
prefix_lm, 8, None, 512, 512, _compile=True, device=device_type
)
def test_register_hook(self): def test_register_hook(self):
import functools import functools

View File

@ -16,7 +16,7 @@ from torch.testing._internal.common_utils import (
instantiate_parametrized_tests, instantiate_parametrized_tests,
parametrize, parametrize,
) )
from torch.testing._internal.inductor_utils import HAS_CUDA from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_XPU
@functorch_config.patch("bundled_autograd_cache", True) @functorch_config.patch("bundled_autograd_cache", True)
@ -28,10 +28,13 @@ class TestPackage(torch._inductor.test_case.TestCase):
return path return path
@parametrize("backend", ("eager", "inductor")) @parametrize("backend", ("eager", "inductor"))
@parametrize("device", ("cpu", "cuda")) @parametrize("device", ("cpu", "cuda", "xpu"))
def test_basic_fn(self, backend, device): def test_basic_fn(self, backend, device):
if device == "cuda" and not HAS_CUDA: if device == "cuda" and not HAS_CUDA:
raise unittest.SkipTest("Requires CUDA/Triton") raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU:
raise unittest.SkipTest("Requires XPU/Triton")
ctx = DiskDynamoStore() ctx = DiskDynamoStore()
def fn(x): def fn(x):
@ -69,10 +72,12 @@ class TestPackage(torch._inductor.test_case.TestCase):
self.assertEqual(expected, compiled_fn(*args)) self.assertEqual(expected, compiled_fn(*args))
@parametrize("backend", ("eager", "inductor")) @parametrize("backend", ("eager", "inductor"))
@parametrize("device", ("cpu", "cuda")) @parametrize("device", ("cpu", "cuda", "xpu"))
def test_graph_break_bomb(self, backend, device): def test_graph_break_bomb(self, backend, device):
if device == "cuda" and not HAS_CUDA: if device == "cuda" and not HAS_CUDA:
raise unittest.SkipTest("Requires CUDA/Triton") raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU:
raise unittest.SkipTest("Requires XPU/Triton")
ctx = DiskDynamoStore() ctx = DiskDynamoStore()
@ -131,10 +136,13 @@ class TestPackage(torch._inductor.test_case.TestCase):
compiled_fn(torch.tensor(N), 0, N - 1) compiled_fn(torch.tensor(N), 0, N - 1)
@parametrize("backend", ("eager", "inductor")) @parametrize("backend", ("eager", "inductor"))
@parametrize("device", ("cpu", "cuda")) @parametrize("device", ("cpu", "cuda", "xpu"))
def test_dynamic_shape(self, backend, device): def test_dynamic_shape(self, backend, device):
if device == "cuda" and not HAS_CUDA: if device == "cuda" and not HAS_CUDA:
raise unittest.SkipTest("Requires CUDA/Triton") raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU:
raise unittest.SkipTest("Requires XPU/Triton")
ctx = DiskDynamoStore() ctx = DiskDynamoStore()
def fn(x): def fn(x):