Files
pytorch/test/dynamo/test_python_dispatcher.py
Daisy Deng 8088958793 port 4 dynamo test files to Intel GPU (#157779)
For https://github.com/pytorch/pytorch/issues/114850, we will port test cases to Intel GPU. Six dynamo test files were ported in PR [#156056](https://github.com/pytorch/pytorch/pull/156056) and [#156575](https://github.com/pytorch/pytorch/pull/156575.) In this PR we will port 4 more dynamo test files.
We could enable Intel GPU with following methods and try the best to keep the original code styles:

- instantiate_device_type_tests()
- use "torch.accelerator.current_accelerator()" to determine the accelerator backend
- added XPU support in decorators like @requires_gpu
- enabled XPU for some test path
- added xfailIfXPU to skip xpu test when there is a bug.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157779
Approved by: https://github.com/guangyey, https://github.com/jansel
2025-07-11 10:11:49 +00:00

144 lines
4.7 KiB
Python

# Owner(s): ["module: dynamo"]
import unittest
import torch
import torch._dynamo.test_case
from torch._dynamo.testing import CompileCounter, EagerAndRecordGraphs, normalize_gm
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import TEST_XPU
device_type = (
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
)
class PythonDispatcherTests(torch._dynamo.test_case.TestCase):
def test_dispatch_key1(self):
@torch.compile(backend="aot_eager", fullgraph=True)
def fn(x):
x = x + 1
return torch._C._dispatch_keys(x)
x = torch.randn(2, 3)
self.assertTrue(fn(x).raw_repr() == torch._C._dispatch_keys(x + 1).raw_repr())
def test_dispatch_key2(self):
from torch.testing._internal.two_tensor import TwoTensor
@torch.compile(backend="aot_eager", fullgraph=True)
def fn(x):
x = x.sin()
return torch._C._dispatch_keys(x)
x = torch.randn(3)
y = torch.randn(3)
z = TwoTensor(x, y)
self.assertTrue(fn(z).raw_repr() == torch._C._dispatch_keys(z.sin()).raw_repr())
def test_dispatch_key3(self):
@torch.compile(backend="aot_eager", fullgraph=True)
def fn(x):
key_set = torch._C._dispatch_tls_local_include_set()
return torch.sin(x + 1), key_set
x = torch.randn(2, 3)
self.assertEqual(fn(x)[0], torch.sin(x + 1))
self.assertTrue(
fn(x)[1].raw_repr() == torch._C._dispatch_tls_local_include_set().raw_repr()
)
def test_dispatch_key4(self):
eager = EagerAndRecordGraphs()
@torch.compile(backend=eager, fullgraph=True)
def fn(x):
key_set = torch._C._dispatch_tls_local_include_set()
key_set = key_set | torch._C._dispatch_keys(x)
key_set = key_set - torch._C._dispatch_tls_local_exclude_set()
if key_set.highestPriorityTypeId() == torch.DispatchKey.PythonDispatcher:
return torch.sin(x + 1)
else:
return torch.sin(x - 1)
x = torch.randn(2, 3)
self.assertEqual(fn(x), torch.sin(x - 1))
graph = eager.graphs[0]
actual = normalize_gm(graph.print_readable(False))
self.assertExpectedInline(
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[2, 3]"):
l_x_ = L_x_
sub: "f32[2, 3]" = l_x_ - 1; l_x_ = None
sin: "f32[2, 3]" = torch.sin(sub); sub = None
return (sin,)
""", # NOQA: B950
)
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "requires cuda or xpu")
def test_dispatch_key_set_guard(self):
counter = CompileCounter()
@torch.compile(backend=counter, fullgraph=True)
def fn(x, dks):
if dks.has("CPU"):
return torch.sin(x + 1)
else:
return torch.sin(x - 1)
x1 = torch.randn(2, 3)
dks1 = torch._C._dispatch_keys(x1)
self.assertEqual(fn(x1, dks1), torch.sin(x1 + 1))
self.assertEqual(counter.frame_count, 1)
x2 = torch.randn(2, 3)
dks2 = torch._C._dispatch_keys(x2)
self.assertEqual(fn(x2, dks2), torch.sin(x2 + 1))
# No recompile since the dispatch key set is the same though the tensor is different.
self.assertEqual(counter.frame_count, 1)
x3 = torch.randn(2, 3, device=device_type)
dks3 = torch._C._dispatch_keys(x3)
self.assertEqual(fn(x3, dks3), torch.sin(x3 - 1))
# Re-compile since the dispatch key set is different.
self.assertEqual(counter.frame_count, 2)
def test_functorch_interpreter(self):
counter = CompileCounter()
def square_and_add(x, y):
interpreter = (
torch._functorch.pyfunctorch.retrieve_current_functorch_interpreter()
)
level = interpreter.level()
if interpreter.key() == torch._C._functorch.TransformType.Vmap:
return (x**2 + y) * level
else:
return x**2 * level
@torch.compile(backend=counter, fullgraph=True)
def fn(x, y):
return torch.vmap(square_and_add)(x, y)
x = torch.tensor([1, 2, 3, 4])
y = torch.tensor([10, 20, 30, 40])
self.assertEqual(fn(x, y), torch.tensor([11, 24, 39, 56]))
self.assertEqual(counter.frame_count, 1)
x = torch.tensor([1, 2, 3, 1])
y = torch.tensor([10, 20, 30, 10])
self.assertEqual(fn(x, y), torch.tensor([11, 24, 39, 11]))
# No recompile
self.assertEqual(counter.frame_count, 1)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()