mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Port three dynamo test to Intel GPU (#156575)
For https://github.com/pytorch/pytorch/issues/114850, we will port test cases to Intel GPU. Two dynamo test files were ported in PR [#156056](https://github.com/pytorch/pytorch/pull/156056). In this PR we will port 3 more dynamo test files. We could enable Intel GPU with following methods and try the best to keep the original code styles: - instantiate_device_type_tests() - use "torch.accelerator.current_accelerator()" to determine the accelerator backend - added XPU support in decorators like @requires_gpu - enabled XPU for some test path. Pull Request resolved: https://github.com/pytorch/pytorch/pull/156575 Approved by: https://github.com/guangyey, https://github.com/jansel Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
51853b358e
commit
1155c53e7d
@ -90,6 +90,7 @@ from torch.testing._internal.common_utils import (
|
||||
skipIfNNModuleInlined,
|
||||
skipIfWindows,
|
||||
TEST_HPU,
|
||||
TEST_XPU,
|
||||
wrapDeterministicFlagAPITest,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
@ -6904,7 +6905,7 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
||||
self.assertTrue(guard_failure is not None)
|
||||
self.assertIn("""tensor 'rank' size mismatch at index 0""", guard_failure[0])
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
||||
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "Test requires CUDA or XPU.")
|
||||
def test_symint_as_device_kwarg_non_strict_export(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -12771,7 +12772,7 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase):
|
||||
|
||||
def test_torch_device_is_available(self, device):
|
||||
def fn(x):
|
||||
if TEST_HPU or TEST_CUDA:
|
||||
if torch.accelerator.is_available():
|
||||
return x + 1
|
||||
else:
|
||||
return x - 1
|
||||
@ -12874,27 +12875,23 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase):
|
||||
def test_cuda_set_device(self, device):
|
||||
def fn():
|
||||
a = torch.ones(2, device=device)
|
||||
torch.cuda.set_device(1)
|
||||
torch.get_device_module(device).set_device(1)
|
||||
return a + 1
|
||||
|
||||
with torch.cuda.device(0):
|
||||
with torch.get_device_module(device).device(0):
|
||||
counter = CompileCounter()
|
||||
opt_fn = torch.compile(fn, backend=counter)
|
||||
res = opt_fn()
|
||||
self.assertEqual(res.device.type, "cuda")
|
||||
self.assertEqual(res.device.type, device)
|
||||
self.assertEqual(res.device.index, 0)
|
||||
self.assertEqual(counter.frame_count, 2)
|
||||
|
||||
def test_torch_device_python_type(self):
|
||||
def test_torch_device_python_type(self, device):
|
||||
device_type = torch.device(device).type
|
||||
for device, device_type, index in [
|
||||
("cpu", "cpu", None),
|
||||
("cuda:0", "cuda", 0),
|
||||
("hpu:0", "hpu", 0),
|
||||
(device, device_type, 0),
|
||||
]:
|
||||
if (device == "cuda:0" and not TEST_CUDA) or (
|
||||
device == "hpu:0" and not TEST_HPU
|
||||
):
|
||||
continue
|
||||
|
||||
def fn(target):
|
||||
target_device = target.device
|
||||
@ -12956,8 +12953,10 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase):
|
||||
f(x, y)
|
||||
|
||||
|
||||
devices = ("cuda", "hpu")
|
||||
instantiate_device_type_tests(MiscTestsDevice, globals(), only_for=devices)
|
||||
devices = ("cuda", "hpu", "xpu")
|
||||
instantiate_device_type_tests(
|
||||
MiscTestsDevice, globals(), only_for=devices, allow_xpu=True
|
||||
)
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
|
@ -7,7 +7,7 @@ import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo.testing import same
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_utils import TEST_HPU, TestCase
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
|
||||
try:
|
||||
@ -359,11 +359,11 @@ class TestModelOutputBert(TestCase):
|
||||
)
|
||||
|
||||
|
||||
devices = ["cpu", "cuda"]
|
||||
if TEST_HPU:
|
||||
devices.append("hpu")
|
||||
devices = ["cpu", "cuda", "xpu", "hpu"]
|
||||
|
||||
instantiate_device_type_tests(TestModelOutputBert, globals(), only_for=devices)
|
||||
instantiate_device_type_tests(
|
||||
TestModelOutputBert, globals(), only_for=devices, allow_xpu=True
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -12,11 +12,16 @@ from torch._C import (
|
||||
_push_on_torch_function_stack,
|
||||
)
|
||||
from torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode
|
||||
from torch.testing._internal.triton_utils import requires_cuda
|
||||
from torch.testing._internal.triton_utils import requires_gpu
|
||||
from torch.utils._device import DeviceContext
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
|
||||
device_type = (
|
||||
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
|
||||
)
|
||||
|
||||
|
||||
class TestMode(BaseTorchFunctionMode):
|
||||
def __torch_function__(self, func, types, args, kwargs=None):
|
||||
if not kwargs:
|
||||
@ -613,12 +618,12 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
func(torch.randn(3))
|
||||
|
||||
@requires_cuda
|
||||
@requires_gpu
|
||||
def test_flex_attention(self):
|
||||
import torch
|
||||
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_device(device_type)
|
||||
|
||||
flex_attention = torch.compile(flex_attention, dynamic=False)
|
||||
|
||||
@ -628,7 +633,9 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
|
||||
return prefix_lengths[b] >= kv
|
||||
|
||||
# This runs in fullgraph already
|
||||
create_block_mask(prefix_lm, 8, None, 512, 512, _compile=True)
|
||||
create_block_mask(
|
||||
prefix_lm, 8, None, 512, 512, _compile=True, device=device_type
|
||||
)
|
||||
|
||||
def test_register_hook(self):
|
||||
import functools
|
||||
|
@ -16,7 +16,7 @@ from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_XPU
|
||||
|
||||
|
||||
@functorch_config.patch("bundled_autograd_cache", True)
|
||||
@ -28,10 +28,13 @@ class TestPackage(torch._inductor.test_case.TestCase):
|
||||
return path
|
||||
|
||||
@parametrize("backend", ("eager", "inductor"))
|
||||
@parametrize("device", ("cpu", "cuda"))
|
||||
@parametrize("device", ("cpu", "cuda", "xpu"))
|
||||
def test_basic_fn(self, backend, device):
|
||||
if device == "cuda" and not HAS_CUDA:
|
||||
raise unittest.SkipTest("Requires CUDA/Triton")
|
||||
if device == "xpu" and not HAS_XPU:
|
||||
raise unittest.SkipTest("Requires XPU/Triton")
|
||||
|
||||
ctx = DiskDynamoStore()
|
||||
|
||||
def fn(x):
|
||||
@ -69,10 +72,12 @@ class TestPackage(torch._inductor.test_case.TestCase):
|
||||
self.assertEqual(expected, compiled_fn(*args))
|
||||
|
||||
@parametrize("backend", ("eager", "inductor"))
|
||||
@parametrize("device", ("cpu", "cuda"))
|
||||
@parametrize("device", ("cpu", "cuda", "xpu"))
|
||||
def test_graph_break_bomb(self, backend, device):
|
||||
if device == "cuda" and not HAS_CUDA:
|
||||
raise unittest.SkipTest("Requires CUDA/Triton")
|
||||
if device == "xpu" and not HAS_XPU:
|
||||
raise unittest.SkipTest("Requires XPU/Triton")
|
||||
|
||||
ctx = DiskDynamoStore()
|
||||
|
||||
@ -131,10 +136,13 @@ class TestPackage(torch._inductor.test_case.TestCase):
|
||||
compiled_fn(torch.tensor(N), 0, N - 1)
|
||||
|
||||
@parametrize("backend", ("eager", "inductor"))
|
||||
@parametrize("device", ("cpu", "cuda"))
|
||||
@parametrize("device", ("cpu", "cuda", "xpu"))
|
||||
def test_dynamic_shape(self, backend, device):
|
||||
if device == "cuda" and not HAS_CUDA:
|
||||
raise unittest.SkipTest("Requires CUDA/Triton")
|
||||
if device == "xpu" and not HAS_XPU:
|
||||
raise unittest.SkipTest("Requires XPU/Triton")
|
||||
|
||||
ctx = DiskDynamoStore()
|
||||
|
||||
def fn(x):
|
||||
|
Reference in New Issue
Block a user