mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
For https://github.com/pytorch/pytorch/issues/114850, we will port test cases to Intel GPU. Six dynamo test files were ported in PR [#156056](https://github.com/pytorch/pytorch/pull/156056) and [#156575](https://github.com/pytorch/pytorch/pull/156575.) In this PR we will port 4 more dynamo test files. We could enable Intel GPU with following methods and try the best to keep the original code styles: - instantiate_device_type_tests() - use "torch.accelerator.current_accelerator()" to determine the accelerator backend - added XPU support in decorators like @requires_gpu - enabled XPU for some test path - added xfailIfXPU to skip xpu test when there is a bug. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157779 Approved by: https://github.com/guangyey, https://github.com/jansel
144 lines
4.7 KiB
Python
144 lines
4.7 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import unittest
|
|
|
|
import torch
|
|
import torch._dynamo.test_case
|
|
from torch._dynamo.testing import CompileCounter, EagerAndRecordGraphs, normalize_gm
|
|
from torch.testing._internal.common_cuda import TEST_CUDA
|
|
from torch.testing._internal.common_utils import TEST_XPU
|
|
|
|
|
|
device_type = (
|
|
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
|
|
)
|
|
|
|
|
|
class PythonDispatcherTests(torch._dynamo.test_case.TestCase):
|
|
def test_dispatch_key1(self):
|
|
@torch.compile(backend="aot_eager", fullgraph=True)
|
|
def fn(x):
|
|
x = x + 1
|
|
return torch._C._dispatch_keys(x)
|
|
|
|
x = torch.randn(2, 3)
|
|
self.assertTrue(fn(x).raw_repr() == torch._C._dispatch_keys(x + 1).raw_repr())
|
|
|
|
def test_dispatch_key2(self):
|
|
from torch.testing._internal.two_tensor import TwoTensor
|
|
|
|
@torch.compile(backend="aot_eager", fullgraph=True)
|
|
def fn(x):
|
|
x = x.sin()
|
|
return torch._C._dispatch_keys(x)
|
|
|
|
x = torch.randn(3)
|
|
y = torch.randn(3)
|
|
z = TwoTensor(x, y)
|
|
self.assertTrue(fn(z).raw_repr() == torch._C._dispatch_keys(z.sin()).raw_repr())
|
|
|
|
def test_dispatch_key3(self):
|
|
@torch.compile(backend="aot_eager", fullgraph=True)
|
|
def fn(x):
|
|
key_set = torch._C._dispatch_tls_local_include_set()
|
|
return torch.sin(x + 1), key_set
|
|
|
|
x = torch.randn(2, 3)
|
|
self.assertEqual(fn(x)[0], torch.sin(x + 1))
|
|
self.assertTrue(
|
|
fn(x)[1].raw_repr() == torch._C._dispatch_tls_local_include_set().raw_repr()
|
|
)
|
|
|
|
def test_dispatch_key4(self):
|
|
eager = EagerAndRecordGraphs()
|
|
|
|
@torch.compile(backend=eager, fullgraph=True)
|
|
def fn(x):
|
|
key_set = torch._C._dispatch_tls_local_include_set()
|
|
key_set = key_set | torch._C._dispatch_keys(x)
|
|
key_set = key_set - torch._C._dispatch_tls_local_exclude_set()
|
|
if key_set.highestPriorityTypeId() == torch.DispatchKey.PythonDispatcher:
|
|
return torch.sin(x + 1)
|
|
else:
|
|
return torch.sin(x - 1)
|
|
|
|
x = torch.randn(2, 3)
|
|
self.assertEqual(fn(x), torch.sin(x - 1))
|
|
|
|
graph = eager.graphs[0]
|
|
actual = normalize_gm(graph.print_readable(False))
|
|
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[2, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
sub: "f32[2, 3]" = l_x_ - 1; l_x_ = None
|
|
sin: "f32[2, 3]" = torch.sin(sub); sub = None
|
|
return (sin,)
|
|
""", # NOQA: B950
|
|
)
|
|
|
|
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "requires cuda or xpu")
|
|
def test_dispatch_key_set_guard(self):
|
|
counter = CompileCounter()
|
|
|
|
@torch.compile(backend=counter, fullgraph=True)
|
|
def fn(x, dks):
|
|
if dks.has("CPU"):
|
|
return torch.sin(x + 1)
|
|
else:
|
|
return torch.sin(x - 1)
|
|
|
|
x1 = torch.randn(2, 3)
|
|
dks1 = torch._C._dispatch_keys(x1)
|
|
self.assertEqual(fn(x1, dks1), torch.sin(x1 + 1))
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
x2 = torch.randn(2, 3)
|
|
dks2 = torch._C._dispatch_keys(x2)
|
|
self.assertEqual(fn(x2, dks2), torch.sin(x2 + 1))
|
|
# No recompile since the dispatch key set is the same though the tensor is different.
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
x3 = torch.randn(2, 3, device=device_type)
|
|
dks3 = torch._C._dispatch_keys(x3)
|
|
self.assertEqual(fn(x3, dks3), torch.sin(x3 - 1))
|
|
# Re-compile since the dispatch key set is different.
|
|
self.assertEqual(counter.frame_count, 2)
|
|
|
|
def test_functorch_interpreter(self):
|
|
counter = CompileCounter()
|
|
|
|
def square_and_add(x, y):
|
|
interpreter = (
|
|
torch._functorch.pyfunctorch.retrieve_current_functorch_interpreter()
|
|
)
|
|
level = interpreter.level()
|
|
if interpreter.key() == torch._C._functorch.TransformType.Vmap:
|
|
return (x**2 + y) * level
|
|
else:
|
|
return x**2 * level
|
|
|
|
@torch.compile(backend=counter, fullgraph=True)
|
|
def fn(x, y):
|
|
return torch.vmap(square_and_add)(x, y)
|
|
|
|
x = torch.tensor([1, 2, 3, 4])
|
|
y = torch.tensor([10, 20, 30, 40])
|
|
self.assertEqual(fn(x, y), torch.tensor([11, 24, 39, 56]))
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
x = torch.tensor([1, 2, 3, 1])
|
|
y = torch.tensor([10, 20, 30, 10])
|
|
self.assertEqual(fn(x, y), torch.tensor([11, 24, 39, 11]))
|
|
# No recompile
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|