mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
port 4 dynamo test files to Intel GPU (#157779)
For https://github.com/pytorch/pytorch/issues/114850, we will port test cases to Intel GPU. Six dynamo test files were ported in PR [#156056](https://github.com/pytorch/pytorch/pull/156056) and [#156575](https://github.com/pytorch/pytorch/pull/156575.) In this PR we will port 4 more dynamo test files. We could enable Intel GPU with following methods and try the best to keep the original code styles: - instantiate_device_type_tests() - use "torch.accelerator.current_accelerator()" to determine the accelerator backend - added XPU support in decorators like @requires_gpu - enabled XPU for some test path - added xfailIfXPU to skip xpu test when there is a bug. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157779 Approved by: https://github.com/guangyey, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
e1a20988f3
commit
8088958793
@ -23,8 +23,10 @@ from torch.testing._internal.common_utils import (
|
||||
find_free_port,
|
||||
munge_exc,
|
||||
skipIfTorchDynamo,
|
||||
TEST_XPU,
|
||||
xfailIf,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_XPU
|
||||
from torch.testing._internal.logging_utils import (
|
||||
LoggingTestCase,
|
||||
make_logging_test,
|
||||
@ -33,10 +35,15 @@ from torch.testing._internal.logging_utils import (
|
||||
|
||||
|
||||
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
|
||||
requires_gpu = unittest.skipUnless(HAS_CUDA or HAS_XPU, "requires cuda or xpu")
|
||||
requires_distributed = functools.partial(
|
||||
unittest.skipIf, not dist.is_available(), "requires distributed"
|
||||
)
|
||||
|
||||
device_type = (
|
||||
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
|
||||
)
|
||||
|
||||
|
||||
def munge_shape_guards(s: str) -> str:
|
||||
SHAPE_GUARD_REGEX = (
|
||||
@ -71,7 +78,7 @@ def inductor_error_fn(a):
|
||||
|
||||
|
||||
def inductor_schedule_fn(a):
|
||||
output = a.add(torch.ones(1000, 1000, device="cuda"))
|
||||
output = a.add(torch.ones(1000, 1000, device=device_type))
|
||||
return output
|
||||
|
||||
|
||||
@ -108,19 +115,19 @@ class LoggingTests(LoggingTestCase):
|
||||
test_output_code = multi_record_test(3, output_code=True)
|
||||
test_aot_graphs = multi_record_test(3, aot_graphs=True)
|
||||
|
||||
@requires_cuda
|
||||
@requires_gpu
|
||||
@make_logging_test(schedule=True)
|
||||
def test_schedule(self, records):
|
||||
fn_opt = torch.compile(inductor_schedule_fn, backend="inductor")
|
||||
fn_opt(torch.ones(1000, 1000, device="cuda"))
|
||||
fn_opt(torch.ones(1000, 1000, device=device_type))
|
||||
self.assertGreater(len(records), 0)
|
||||
self.assertLess(len(records), 5)
|
||||
|
||||
@requires_cuda
|
||||
@requires_gpu
|
||||
@make_logging_test(fusion=True)
|
||||
def test_fusion(self, records):
|
||||
fn_opt = torch.compile(inductor_schedule_fn, backend="inductor")
|
||||
fn_opt(torch.ones(1000, 1000, device="cuda"))
|
||||
fn_opt(torch.ones(1000, 1000, device=device_type))
|
||||
self.assertGreater(len(records), 0)
|
||||
self.assertLess(len(records), 8)
|
||||
|
||||
@ -128,7 +135,7 @@ class LoggingTests(LoggingTestCase):
|
||||
@make_logging_test(cudagraphs=True)
|
||||
def test_cudagraphs(self, records):
|
||||
fn_opt = torch.compile(mode="reduce-overhead")(inductor_schedule_fn)
|
||||
fn_opt(torch.ones(1000, 1000, device="cuda"))
|
||||
fn_opt(torch.ones(1000, 1000, device=device_type))
|
||||
self.assertGreater(len(records), 0)
|
||||
self.assertLess(len(records), 8)
|
||||
|
||||
@ -763,10 +770,11 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
|
||||
self.assertGreater(len(records), 0)
|
||||
self.assertLess(len(records), 4)
|
||||
|
||||
@xfailIf(TEST_XPU) # https://github.com/pytorch/pytorch/issues/157778
|
||||
@make_logging_test(perf_hints=True)
|
||||
@requires_cuda
|
||||
@requires_gpu
|
||||
def test_optimizer_non_static_param(self, records):
|
||||
params = [torch.randn(10, 10, device="cuda") for _ in range(2)]
|
||||
params = [torch.randn(10, 10, device=device_type) for _ in range(2)]
|
||||
for param in params:
|
||||
param.grad = torch.zeros_like(param)
|
||||
opt = torch.optim.Adam(params)
|
||||
@ -776,7 +784,7 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
|
||||
self.assertLess(len(records), 3)
|
||||
|
||||
@make_logging_test(autotuning=True)
|
||||
@requires_cuda
|
||||
@requires_gpu
|
||||
@unittest.skipIf(not SM90OrLater, "requires H100+ GPU")
|
||||
def test_autotuning(self, records):
|
||||
with torch._inductor.utils.fresh_cache():
|
||||
@ -785,7 +793,10 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
|
||||
return torch.mm(a, b)
|
||||
|
||||
f = torch.compile(f, mode="max-autotune-no-cudagraphs")
|
||||
f(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda"))
|
||||
f(
|
||||
torch.randn(10, 10, device=device_type),
|
||||
torch.randn(10, 10, device=device_type),
|
||||
)
|
||||
self.assertGreater(len(records), 0)
|
||||
self.assertLess(len(records), 40)
|
||||
|
||||
|
@ -3386,8 +3386,10 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
||||
compiled_mod(x)
|
||||
|
||||
|
||||
devices = ["cuda", "hpu"]
|
||||
instantiate_device_type_tests(NNModuleTestsDevice, globals(), only_for=devices)
|
||||
devices = ["cuda", "hpu", "xpu"]
|
||||
instantiate_device_type_tests(
|
||||
NNModuleTestsDevice, globals(), only_for=devices, allow_xpu=True
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -5,6 +5,12 @@ import torch
|
||||
import torch._dynamo.test_case
|
||||
from torch._dynamo.testing import CompileCounter, EagerAndRecordGraphs, normalize_gm
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_utils import TEST_XPU
|
||||
|
||||
|
||||
device_type = (
|
||||
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
|
||||
)
|
||||
|
||||
|
||||
class PythonDispatcherTests(torch._dynamo.test_case.TestCase):
|
||||
@ -74,7 +80,7 @@ class GraphModule(torch.nn.Module):
|
||||
""", # NOQA: B950
|
||||
)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "requires cuda")
|
||||
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "requires cuda or xpu")
|
||||
def test_dispatch_key_set_guard(self):
|
||||
counter = CompileCounter()
|
||||
|
||||
@ -96,7 +102,7 @@ class GraphModule(torch.nn.Module):
|
||||
# No recompile since the dispatch key set is the same though the tensor is different.
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
|
||||
x3 = torch.randn(2, 3, device="cuda")
|
||||
x3 = torch.randn(2, 3, device=device_type)
|
||||
dks3 = torch._C._dispatch_keys(x3)
|
||||
self.assertEqual(fn(x3, dks3), torch.sin(x3 - 1))
|
||||
# Re-compile since the dispatch key set is different.
|
||||
|
@ -12,6 +12,11 @@ from torch._dynamo.exc import FailOnRecompileLimitHit
|
||||
from torch.testing._internal.logging_utils import kwargs_to_settings, log_settings
|
||||
|
||||
|
||||
device_type = (
|
||||
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
|
||||
)
|
||||
|
||||
|
||||
class RecompileUxTests(torch._dynamo.test_case.TestCase):
|
||||
# TODO(whc) dynamo actually recompiles one more time than the cache limit
|
||||
cache_limit = 1
|
||||
@ -101,7 +106,10 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase):
|
||||
.startswith("torch._dynamo hit config.recompile_limit")
|
||||
)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
||||
@unittest.skipIf(
|
||||
not torch.cuda.is_available() and not torch.xpu.is_available(),
|
||||
"requires cuda or xpu",
|
||||
)
|
||||
def test_nvfuser_guards(self):
|
||||
# we may want to model dynamo's guards sufficiently after nvfuser's ProfilingExecutor guards
|
||||
# such that we ensure dynamo is in charge of all the recompilations at the top level,
|
||||
@ -109,11 +117,11 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase):
|
||||
def func(a, b, c):
|
||||
return a + b * c
|
||||
|
||||
a = torch.rand(3, 4, 5, device="cuda")
|
||||
b = torch.rand(3, 4, 5, device="cuda")
|
||||
b_v = torch.rand(3, 5, 4, device="cuda").view(3, 4, 5)
|
||||
b_p = torch.rand(3, 5, 4, device="cuda").permute(0, 2, 1)
|
||||
c = torch.rand(3, 4, 5, device="cuda")
|
||||
a = torch.rand(3, 4, 5, device=device_type)
|
||||
b = torch.rand(3, 4, 5, device=device_type)
|
||||
b_v = torch.rand(3, 5, 4, device=device_type).view(3, 4, 5)
|
||||
b_p = torch.rand(3, 5, 4, device=device_type).permute(0, 2, 1)
|
||||
c = torch.rand(3, 4, 5, device=device_type)
|
||||
compile_counter = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
with torch._dynamo.config.patch("recompile_limit", 2):
|
||||
|
@ -23,7 +23,7 @@ from torch.nn import Buffer, Parameter
|
||||
from torch.testing._internal import opinfo
|
||||
from torch.testing._internal.common_utils import \
|
||||
(gradcheck, gradgradcheck, parametrize, run_tests, TestCase, download_file, MACOS_VERSION, IS_CI,
|
||||
NoTest, skipIfSlowGradcheckEnv, suppress_warnings, serialTest, instantiate_parametrized_tests)
|
||||
NoTest, skipIfSlowGradcheckEnv, suppress_warnings, serialTest, instantiate_parametrized_tests, xfailIf)
|
||||
from torch.testing._internal.common_mps import mps_ops_modifier, mps_ops_grad_modifier, mps_ops_error_inputs_modifier
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_dtype import get_all_dtypes, integral_types
|
||||
@ -71,14 +71,6 @@ _ref_test_ops = tuple(
|
||||
)
|
||||
)
|
||||
|
||||
def xfailIf(condition):
|
||||
def wrapper(func):
|
||||
if condition:
|
||||
return unittest.expectedFailure(func)
|
||||
else:
|
||||
return func
|
||||
return wrapper
|
||||
|
||||
# Same logic as test_cuda.py
|
||||
if not torch.backends.mps.is_available():
|
||||
print('MPS not available, skipping tests', file=sys.stderr)
|
||||
|
@ -1957,6 +1957,14 @@ def runOnRocmArch(arch: tuple[str, ...]):
|
||||
def xfailIfS390X(func):
|
||||
return unittest.expectedFailure(func) if IS_S390X else func
|
||||
|
||||
def xfailIf(condition):
|
||||
def wrapper(func):
|
||||
if condition:
|
||||
return unittest.expectedFailure(func)
|
||||
else:
|
||||
return func
|
||||
return wrapper
|
||||
|
||||
def skipIfXpu(func=None, *, msg="test doesn't currently work on the XPU stack"):
|
||||
def dec_fn(fn):
|
||||
reason = f"skipIfXpu: {msg}"
|
||||
|
Reference in New Issue
Block a user