mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Port three dynamo test to Intel GPU (#156575)
For https://github.com/pytorch/pytorch/issues/114850, we will port test cases to Intel GPU. Two dynamo test files were ported in PR [#156056](https://github.com/pytorch/pytorch/pull/156056). In this PR we will port 3 more dynamo test files. We could enable Intel GPU with following methods and try the best to keep the original code styles: - instantiate_device_type_tests() - use "torch.accelerator.current_accelerator()" to determine the accelerator backend - added XPU support in decorators like @requires_gpu - enabled XPU for some test path. Pull Request resolved: https://github.com/pytorch/pytorch/pull/156575 Approved by: https://github.com/guangyey, https://github.com/jansel Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
51853b358e
commit
1155c53e7d
@ -90,6 +90,7 @@ from torch.testing._internal.common_utils import (
|
|||||||
skipIfNNModuleInlined,
|
skipIfNNModuleInlined,
|
||||||
skipIfWindows,
|
skipIfWindows,
|
||||||
TEST_HPU,
|
TEST_HPU,
|
||||||
|
TEST_XPU,
|
||||||
wrapDeterministicFlagAPITest,
|
wrapDeterministicFlagAPITest,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.jit_utils import JitTestCase
|
from torch.testing._internal.jit_utils import JitTestCase
|
||||||
@ -6904,7 +6905,7 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
|||||||
self.assertTrue(guard_failure is not None)
|
self.assertTrue(guard_failure is not None)
|
||||||
self.assertIn("""tensor 'rank' size mismatch at index 0""", guard_failure[0])
|
self.assertIn("""tensor 'rank' size mismatch at index 0""", guard_failure[0])
|
||||||
|
|
||||||
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "Test requires CUDA or XPU.")
|
||||||
def test_symint_as_device_kwarg_non_strict_export(self):
|
def test_symint_as_device_kwarg_non_strict_export(self):
|
||||||
class Mod(torch.nn.Module):
|
class Mod(torch.nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -12771,7 +12772,7 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase):
|
|||||||
|
|
||||||
def test_torch_device_is_available(self, device):
|
def test_torch_device_is_available(self, device):
|
||||||
def fn(x):
|
def fn(x):
|
||||||
if TEST_HPU or TEST_CUDA:
|
if torch.accelerator.is_available():
|
||||||
return x + 1
|
return x + 1
|
||||||
else:
|
else:
|
||||||
return x - 1
|
return x - 1
|
||||||
@ -12874,27 +12875,23 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase):
|
|||||||
def test_cuda_set_device(self, device):
|
def test_cuda_set_device(self, device):
|
||||||
def fn():
|
def fn():
|
||||||
a = torch.ones(2, device=device)
|
a = torch.ones(2, device=device)
|
||||||
torch.cuda.set_device(1)
|
torch.get_device_module(device).set_device(1)
|
||||||
return a + 1
|
return a + 1
|
||||||
|
|
||||||
with torch.cuda.device(0):
|
with torch.get_device_module(device).device(0):
|
||||||
counter = CompileCounter()
|
counter = CompileCounter()
|
||||||
opt_fn = torch.compile(fn, backend=counter)
|
opt_fn = torch.compile(fn, backend=counter)
|
||||||
res = opt_fn()
|
res = opt_fn()
|
||||||
self.assertEqual(res.device.type, "cuda")
|
self.assertEqual(res.device.type, device)
|
||||||
self.assertEqual(res.device.index, 0)
|
self.assertEqual(res.device.index, 0)
|
||||||
self.assertEqual(counter.frame_count, 2)
|
self.assertEqual(counter.frame_count, 2)
|
||||||
|
|
||||||
def test_torch_device_python_type(self):
|
def test_torch_device_python_type(self, device):
|
||||||
|
device_type = torch.device(device).type
|
||||||
for device, device_type, index in [
|
for device, device_type, index in [
|
||||||
("cpu", "cpu", None),
|
("cpu", "cpu", None),
|
||||||
("cuda:0", "cuda", 0),
|
(device, device_type, 0),
|
||||||
("hpu:0", "hpu", 0),
|
|
||||||
]:
|
]:
|
||||||
if (device == "cuda:0" and not TEST_CUDA) or (
|
|
||||||
device == "hpu:0" and not TEST_HPU
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
def fn(target):
|
def fn(target):
|
||||||
target_device = target.device
|
target_device = target.device
|
||||||
@ -12956,8 +12953,10 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase):
|
|||||||
f(x, y)
|
f(x, y)
|
||||||
|
|
||||||
|
|
||||||
devices = ("cuda", "hpu")
|
devices = ("cuda", "hpu", "xpu")
|
||||||
instantiate_device_type_tests(MiscTestsDevice, globals(), only_for=devices)
|
instantiate_device_type_tests(
|
||||||
|
MiscTestsDevice, globals(), only_for=devices, allow_xpu=True
|
||||||
|
)
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._dynamo.test_case import run_tests
|
from torch._dynamo.test_case import run_tests
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ import torch._dynamo.test_case
|
|||||||
import torch._dynamo.testing
|
import torch._dynamo.testing
|
||||||
from torch._dynamo.testing import same
|
from torch._dynamo.testing import same
|
||||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||||
from torch.testing._internal.common_utils import TEST_HPU, TestCase
|
from torch.testing._internal.common_utils import TestCase
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -359,11 +359,11 @@ class TestModelOutputBert(TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
devices = ["cpu", "cuda"]
|
devices = ["cpu", "cuda", "xpu", "hpu"]
|
||||||
if TEST_HPU:
|
|
||||||
devices.append("hpu")
|
|
||||||
|
|
||||||
instantiate_device_type_tests(TestModelOutputBert, globals(), only_for=devices)
|
instantiate_device_type_tests(
|
||||||
|
TestModelOutputBert, globals(), only_for=devices, allow_xpu=True
|
||||||
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._dynamo.test_case import run_tests
|
from torch._dynamo.test_case import run_tests
|
||||||
|
@ -12,11 +12,16 @@ from torch._C import (
|
|||||||
_push_on_torch_function_stack,
|
_push_on_torch_function_stack,
|
||||||
)
|
)
|
||||||
from torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode
|
from torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode
|
||||||
from torch.testing._internal.triton_utils import requires_cuda
|
from torch.testing._internal.triton_utils import requires_gpu
|
||||||
from torch.utils._device import DeviceContext
|
from torch.utils._device import DeviceContext
|
||||||
from torch.utils._python_dispatch import TorchDispatchMode
|
from torch.utils._python_dispatch import TorchDispatchMode
|
||||||
|
|
||||||
|
|
||||||
|
device_type = (
|
||||||
|
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestMode(BaseTorchFunctionMode):
|
class TestMode(BaseTorchFunctionMode):
|
||||||
def __torch_function__(self, func, types, args, kwargs=None):
|
def __torch_function__(self, func, types, args, kwargs=None):
|
||||||
if not kwargs:
|
if not kwargs:
|
||||||
@ -613,12 +618,12 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
func(torch.randn(3))
|
func(torch.randn(3))
|
||||||
|
|
||||||
@requires_cuda
|
@requires_gpu
|
||||||
def test_flex_attention(self):
|
def test_flex_attention(self):
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||||
|
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device(device_type)
|
||||||
|
|
||||||
flex_attention = torch.compile(flex_attention, dynamic=False)
|
flex_attention = torch.compile(flex_attention, dynamic=False)
|
||||||
|
|
||||||
@ -628,7 +633,9 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
|
|||||||
return prefix_lengths[b] >= kv
|
return prefix_lengths[b] >= kv
|
||||||
|
|
||||||
# This runs in fullgraph already
|
# This runs in fullgraph already
|
||||||
create_block_mask(prefix_lm, 8, None, 512, 512, _compile=True)
|
create_block_mask(
|
||||||
|
prefix_lm, 8, None, 512, 512, _compile=True, device=device_type
|
||||||
|
)
|
||||||
|
|
||||||
def test_register_hook(self):
|
def test_register_hook(self):
|
||||||
import functools
|
import functools
|
||||||
|
@ -16,7 +16,7 @@ from torch.testing._internal.common_utils import (
|
|||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
parametrize,
|
parametrize,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_XPU
|
||||||
|
|
||||||
|
|
||||||
@functorch_config.patch("bundled_autograd_cache", True)
|
@functorch_config.patch("bundled_autograd_cache", True)
|
||||||
@ -28,10 +28,13 @@ class TestPackage(torch._inductor.test_case.TestCase):
|
|||||||
return path
|
return path
|
||||||
|
|
||||||
@parametrize("backend", ("eager", "inductor"))
|
@parametrize("backend", ("eager", "inductor"))
|
||||||
@parametrize("device", ("cpu", "cuda"))
|
@parametrize("device", ("cpu", "cuda", "xpu"))
|
||||||
def test_basic_fn(self, backend, device):
|
def test_basic_fn(self, backend, device):
|
||||||
if device == "cuda" and not HAS_CUDA:
|
if device == "cuda" and not HAS_CUDA:
|
||||||
raise unittest.SkipTest("Requires CUDA/Triton")
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
||||||
|
if device == "xpu" and not HAS_XPU:
|
||||||
|
raise unittest.SkipTest("Requires XPU/Triton")
|
||||||
|
|
||||||
ctx = DiskDynamoStore()
|
ctx = DiskDynamoStore()
|
||||||
|
|
||||||
def fn(x):
|
def fn(x):
|
||||||
@ -69,10 +72,12 @@ class TestPackage(torch._inductor.test_case.TestCase):
|
|||||||
self.assertEqual(expected, compiled_fn(*args))
|
self.assertEqual(expected, compiled_fn(*args))
|
||||||
|
|
||||||
@parametrize("backend", ("eager", "inductor"))
|
@parametrize("backend", ("eager", "inductor"))
|
||||||
@parametrize("device", ("cpu", "cuda"))
|
@parametrize("device", ("cpu", "cuda", "xpu"))
|
||||||
def test_graph_break_bomb(self, backend, device):
|
def test_graph_break_bomb(self, backend, device):
|
||||||
if device == "cuda" and not HAS_CUDA:
|
if device == "cuda" and not HAS_CUDA:
|
||||||
raise unittest.SkipTest("Requires CUDA/Triton")
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
||||||
|
if device == "xpu" and not HAS_XPU:
|
||||||
|
raise unittest.SkipTest("Requires XPU/Triton")
|
||||||
|
|
||||||
ctx = DiskDynamoStore()
|
ctx = DiskDynamoStore()
|
||||||
|
|
||||||
@ -131,10 +136,13 @@ class TestPackage(torch._inductor.test_case.TestCase):
|
|||||||
compiled_fn(torch.tensor(N), 0, N - 1)
|
compiled_fn(torch.tensor(N), 0, N - 1)
|
||||||
|
|
||||||
@parametrize("backend", ("eager", "inductor"))
|
@parametrize("backend", ("eager", "inductor"))
|
||||||
@parametrize("device", ("cpu", "cuda"))
|
@parametrize("device", ("cpu", "cuda", "xpu"))
|
||||||
def test_dynamic_shape(self, backend, device):
|
def test_dynamic_shape(self, backend, device):
|
||||||
if device == "cuda" and not HAS_CUDA:
|
if device == "cuda" and not HAS_CUDA:
|
||||||
raise unittest.SkipTest("Requires CUDA/Triton")
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
||||||
|
if device == "xpu" and not HAS_XPU:
|
||||||
|
raise unittest.SkipTest("Requires XPU/Triton")
|
||||||
|
|
||||||
ctx = DiskDynamoStore()
|
ctx = DiskDynamoStore()
|
||||||
|
|
||||||
def fn(x):
|
def fn(x):
|
||||||
|
Reference in New Issue
Block a user