mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor Intel GPU backend Upstream] Reuse inductor test for Intel GPU (PART 1) (#122866)
Reuse Inductor test suite for Intel GPU including: test_torchinductor.py test_triton_wrapper.py test_metrics.py test_codecache.py test_codegen_triton.py test_kernel_benchmark.py test_triton_heuristics.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/122866 Approved by: https://github.com/EikanWang, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
4dd33a1c2b
commit
78a1693266
@ -394,8 +394,8 @@ class TestFxGraphCache(TestCase):
|
||||
|
||||
compiled_fn = torch.compile(fn, fullgraph=True)
|
||||
|
||||
x = torch.randn(4, device="cuda")
|
||||
y = torch.randn(4, device="cuda")
|
||||
x = torch.randn(4, device=GPU_TYPE)
|
||||
y = torch.randn(4, device=GPU_TYPE)
|
||||
compiled_fn(x, y)
|
||||
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
|
||||
|
@ -11,10 +11,13 @@ from torch._inductor.codecache import PyCodeCache
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_device_type import expectedFailureXPU
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
||||
|
||||
|
||||
class TestKernelBenchmark(TestCase):
|
||||
device_type = GPU_TYPE
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.exit_stack = contextlib.ExitStack()
|
||||
@ -95,6 +98,7 @@ class TestKernelBenchmark(TestCase):
|
||||
f(a, b)
|
||||
self.verify_compiled_kernels()
|
||||
|
||||
@expectedFailureXPU
|
||||
@config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
|
||||
@fresh_inductor_cache()
|
||||
def test_mm_triton_kernel_benchmark(self):
|
||||
@ -294,7 +298,14 @@ class TestKernelBenchmark(TestCase):
|
||||
# num_gb = x0 + 2 * size_slice_c + size_out
|
||||
# num_gb = (1000 * 1000 + 2 * 1000 * 1000 + 1000 * 1000) * 2/ 1e9
|
||||
# = 0.008
|
||||
self.check_bandwidth(compiled_module, "0.008")
|
||||
num_gb = "0.008"
|
||||
if GPU_TYPE == "xpu":
|
||||
# In XPU backend, mm + add + add will be fused as admm + add
|
||||
# And CUDA prefer not fuse add + mm, please check in function
|
||||
# `should_prefer_unfused_addmm` in torch/_inductor/fx_passes/post_grad.py
|
||||
num_gb = "0.006"
|
||||
|
||||
self.check_bandwidth(compiled_module, num_gb)
|
||||
|
||||
def test_mm_slice_add_bandwidth_computation_2(self):
|
||||
M, N, K = 1000, 1000, 30
|
||||
@ -322,6 +333,7 @@ class TestKernelBenchmark(TestCase):
|
||||
# have the same index.
|
||||
self.check_bandwidth(compiled_module, "0.006")
|
||||
|
||||
@expectedFailureXPU
|
||||
@config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
|
||||
def test_slice_mm_bandwidth_computation(self):
|
||||
M, N, K = 1000, 2000, 3000
|
||||
|
@ -16,7 +16,7 @@ example_kernel = """
|
||||
triton_meta={
|
||||
'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'},
|
||||
'device': 0,
|
||||
'device_type': 'cuda',
|
||||
'device_type': 'GPU_TYPE',
|
||||
'constants': {},
|
||||
'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(2, 3))]},
|
||||
inductor_meta={
|
||||
@ -50,7 +50,9 @@ def triton_red_fused_add_sum_2(in_out_ptr0, in_ptr0, xnumel, rnumel, XBLOCK : tl
|
||||
tmp5 = tmp4 + tmp2
|
||||
tl.debug_barrier()
|
||||
tl.store(in_out_ptr0 + (x0), tmp5, xmask)
|
||||
"""
|
||||
""".replace(
|
||||
"GPU_TYPE", GPU_TYPE
|
||||
)
|
||||
|
||||
|
||||
class TestMetrics(TestCase):
|
||||
|
@ -58,7 +58,7 @@ from torch.testing._internal.common_cuda import (
|
||||
|
||||
from torch.testing._internal.common_device_type import (
|
||||
_has_sufficient_memory,
|
||||
get_desired_device_type_test_bases,
|
||||
expectedFailureXPU,
|
||||
)
|
||||
from torch.testing._internal.common_dtype import all_types, get_all_dtypes
|
||||
from torch.testing._internal.common_utils import (
|
||||
@ -72,6 +72,7 @@ from torch.testing._internal.common_utils import (
|
||||
parametrize,
|
||||
serialTest,
|
||||
skipIfRocm,
|
||||
skipIfXpu,
|
||||
subtest,
|
||||
TEST_WITH_ASAN,
|
||||
TEST_WITH_ROCM,
|
||||
@ -115,9 +116,6 @@ from torch.testing._internal.inductor_utils import (
|
||||
)
|
||||
|
||||
HAS_AVX2 = "fbgemm" in torch.backends.quantized.supported_engines
|
||||
_desired_test_bases = get_desired_device_type_test_bases()
|
||||
RUN_CPU = any(getattr(x, "device_type", "") == "cpu" for x in _desired_test_bases)
|
||||
RUN_GPU = any(getattr(x, "device_type", "") == GPU_TYPE for x in _desired_test_bases)
|
||||
|
||||
aten = torch.ops.aten
|
||||
requires_gpu = functools.partial(unittest.skipIf, not HAS_GPU, "requires gpu")
|
||||
@ -176,18 +174,19 @@ def _large_cumprod_input(shape, dim, dtype, device):
|
||||
return (t * sign).to(dtype)
|
||||
|
||||
|
||||
def define_custom_op_for_test(id_, fn_cpu, fn_cuda, fn_meta, tags=()):
|
||||
def define_custom_op_for_test(id_, fn_cpu, fn_cuda, fn_xpu, fn_meta, tags=()):
|
||||
global libtest
|
||||
global ids
|
||||
if id_ not in ids:
|
||||
libtest.define(f"{id_}(Tensor self) -> Tensor", tags=tags)
|
||||
libtest.impl(id_, fn_cpu, "CPU")
|
||||
libtest.impl(id_, fn_cuda, "CUDA")
|
||||
libtest.impl(id_, fn_xpu, "XPU")
|
||||
libtest.impl(id_, fn_meta, "Meta")
|
||||
ids.add(id_)
|
||||
|
||||
|
||||
def define_custom_op_2_for_test(id_, fn_cpu, fn_cuda, fn_meta, tags=()):
|
||||
def define_custom_op_2_for_test(id_, fn_cpu, fn_cuda, fn_xpu, fn_meta, tags=()):
|
||||
global libtest
|
||||
global ids
|
||||
if id_ not in ids:
|
||||
@ -196,17 +195,19 @@ def define_custom_op_2_for_test(id_, fn_cpu, fn_cuda, fn_meta, tags=()):
|
||||
)
|
||||
libtest.impl(id_, fn_cpu, "CPU")
|
||||
libtest.impl(id_, fn_cuda, "CUDA")
|
||||
libtest.impl(id_, fn_xpu, "XPU")
|
||||
libtest.impl(id_, fn_meta, "Meta")
|
||||
ids.add(id_)
|
||||
|
||||
|
||||
def define_custom_op_3_for_test(id_, fn_cpu, fn_cuda, fn_meta, tags=()):
|
||||
def define_custom_op_3_for_test(id_, fn_cpu, fn_cuda, fn_xpu, fn_meta, tags=()):
|
||||
global libtest
|
||||
global ids
|
||||
if id_ not in ids:
|
||||
libtest.define(f"{id_}(Tensor[] x) -> Tensor", tags=tags)
|
||||
libtest.impl(id_, fn_cpu, "CPU")
|
||||
libtest.impl(id_, fn_cuda, "CUDA")
|
||||
libtest.impl(id_, fn_xpu, "XPU")
|
||||
libtest.impl(id_, fn_meta, "Meta")
|
||||
ids.add(id_)
|
||||
|
||||
@ -2258,8 +2259,9 @@ class CommonTemplate:
|
||||
|
||||
# Can't use assertEqual as it expands broadcasted inputs
|
||||
del t
|
||||
if torch.device(self.device).type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
if torch.device(self.device).type == GPU_TYPE:
|
||||
getattr(torch, GPU_TYPE).empty_cache()
|
||||
|
||||
self.assertTrue((actual == 2).all())
|
||||
|
||||
def test_large_offset_pointwise(self):
|
||||
@ -2430,6 +2432,7 @@ class CommonTemplate:
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/98979
|
||||
@skipCUDAIf(True, "cuda failed for float64 linear")
|
||||
@skipIfXpu(msg="Double and complex datatype matmul is not supported in oneDNN")
|
||||
def test_linear_float64(self):
|
||||
mod = torch.nn.Sequential(torch.nn.Linear(8, 16).to(torch.float64)).eval()
|
||||
with torch.no_grad():
|
||||
@ -2549,6 +2552,7 @@ class CommonTemplate:
|
||||
check_lowp=True,
|
||||
)
|
||||
|
||||
@expectedFailureXPU
|
||||
def test_mm_mixed_dtype(self):
|
||||
def fn(a, b):
|
||||
return torch.mm(a, b)
|
||||
@ -2562,6 +2566,7 @@ class CommonTemplate:
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
fn(t1, t2)
|
||||
|
||||
@expectedFailureXPU
|
||||
def test_linear_mixed_dtype(self):
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
@ -4233,7 +4238,7 @@ class CommonTemplate:
|
||||
a[:, 20:40] = a[:, 20:40] + 1
|
||||
a[:, 2:900025] = a[:, 1:900024] + 2
|
||||
|
||||
a = torch.rand((1, 1000000), device="cuda")
|
||||
a = torch.rand((1, 1000000), device=GPU_TYPE)
|
||||
self.common(f, (a,))
|
||||
|
||||
def test_gather_scatter(self):
|
||||
@ -4672,6 +4677,13 @@ class CommonTemplate:
|
||||
c = torch.cat((x, x1), 1)
|
||||
return (c,)
|
||||
|
||||
if self.device == "xpu":
|
||||
atol = 3e-4
|
||||
rtol = 1e-4
|
||||
else:
|
||||
# use default
|
||||
atol = None
|
||||
rtol = None
|
||||
self.common(
|
||||
fn,
|
||||
(
|
||||
@ -4680,6 +4692,8 @@ class CommonTemplate:
|
||||
torch.randn(1024, 1600),
|
||||
torch.randn(100, 256),
|
||||
),
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
check_lowp=False, # accuracy issues with relatively large matmuls
|
||||
)
|
||||
|
||||
@ -5173,6 +5187,7 @@ class CommonTemplate:
|
||||
)
|
||||
|
||||
@skipCUDAIf(not TEST_CUDNN, "CUDNN not available")
|
||||
@skipIfXpu
|
||||
@skipIfRocm
|
||||
def test_cudnn_rnn(self):
|
||||
if self.device == "cpu":
|
||||
@ -6423,6 +6438,10 @@ class CommonTemplate:
|
||||
if self.device == "cuda":
|
||||
raise unittest.SkipTest("unstable on sm86")
|
||||
|
||||
check_lowp = True
|
||||
if self.device == "xpu":
|
||||
check_lowp = False
|
||||
|
||||
def fn(a, dim, index, b):
|
||||
return aten.scatter.reduce(a, dim, index, b, reduce="add")
|
||||
|
||||
@ -6434,12 +6453,17 @@ class CommonTemplate:
|
||||
torch.zeros((64, 512), dtype=torch.int64),
|
||||
torch.ones(64, 512),
|
||||
],
|
||||
check_lowp=check_lowp,
|
||||
)
|
||||
|
||||
def test_scatter3(self):
|
||||
def fn(a, dim, index, b):
|
||||
return aten.scatter(a, dim, index, b, reduce="add")
|
||||
|
||||
check_lowp = True
|
||||
if self.device == "xpu":
|
||||
check_lowp = False
|
||||
|
||||
self.common(
|
||||
fn,
|
||||
[
|
||||
@ -6453,12 +6477,17 @@ class CommonTemplate:
|
||||
# Greatest relative difference: 0.0022371364653243847 at index (0, 0, 3) (up to 0.001 allowed)
|
||||
atol=2e-4,
|
||||
rtol=1e-3,
|
||||
check_lowp=check_lowp,
|
||||
)
|
||||
|
||||
def test_scatter4(self):
|
||||
def fn(x, ind, src):
|
||||
return torch.scatter(x, 0, ind, src)
|
||||
|
||||
check_lowp = True
|
||||
if self.device == "xpu":
|
||||
check_lowp = False
|
||||
|
||||
for deterministic in [False, True]:
|
||||
with DeterministicGuard(deterministic):
|
||||
self.common(
|
||||
@ -6468,6 +6497,7 @@ class CommonTemplate:
|
||||
torch.randint(196, (1, 992)),
|
||||
torch.randn(1, 992),
|
||||
],
|
||||
check_lowp=check_lowp,
|
||||
)
|
||||
|
||||
def test_scatter5(self):
|
||||
@ -6478,6 +6508,10 @@ class CommonTemplate:
|
||||
a1.scatter_(dim, index, b, reduce=reduce)
|
||||
return (a, a1)
|
||||
|
||||
check_lowp = True
|
||||
if self.device == "xpu":
|
||||
check_lowp = False
|
||||
|
||||
for reduce in ["add", "multiply"]:
|
||||
self.common(
|
||||
fn,
|
||||
@ -6488,12 +6522,17 @@ class CommonTemplate:
|
||||
torch.randn(4, 5),
|
||||
reduce,
|
||||
],
|
||||
check_lowp=check_lowp,
|
||||
)
|
||||
|
||||
def test_scatter6(self):
|
||||
def fn(a, dim, index, b):
|
||||
return aten.scatter(a, dim, index, b)
|
||||
|
||||
check_lowp = True
|
||||
if self.device == "xpu":
|
||||
check_lowp = False
|
||||
|
||||
for deterministic in [False, True]:
|
||||
with DeterministicGuard(deterministic):
|
||||
self.common(
|
||||
@ -6504,6 +6543,7 @@ class CommonTemplate:
|
||||
torch.tensor([[[3, 5, 7, 9]]]),
|
||||
0.8, # src can be a scalar
|
||||
],
|
||||
check_lowp=check_lowp,
|
||||
)
|
||||
|
||||
@unittest.skip("Flaky test, needs debugging")
|
||||
@ -6511,6 +6551,10 @@ class CommonTemplate:
|
||||
def fn(a, dim, index, b):
|
||||
return aten.scatter_add(a, dim, index, b)
|
||||
|
||||
check_lowp = True
|
||||
if self.device == "xpu":
|
||||
check_lowp = False
|
||||
|
||||
self.common(
|
||||
fn,
|
||||
[
|
||||
@ -6519,12 +6563,17 @@ class CommonTemplate:
|
||||
torch.tensor([[0]]),
|
||||
torch.randn(2, 3),
|
||||
],
|
||||
check_lowp=check_lowp,
|
||||
)
|
||||
|
||||
def test_scatter_add2(self):
|
||||
def fn(a, dim, index, b):
|
||||
return aten.scatter_add(a, dim, index, b)
|
||||
|
||||
check_lowp = True
|
||||
if self.device == "xpu":
|
||||
check_lowp = False
|
||||
|
||||
self.common(
|
||||
fn,
|
||||
[
|
||||
@ -6533,12 +6582,17 @@ class CommonTemplate:
|
||||
torch.tensor([[0, 0, 0], [1, 1, 1]]),
|
||||
torch.randn(2, 3),
|
||||
],
|
||||
check_lowp=check_lowp,
|
||||
)
|
||||
|
||||
def test_scatter_add3(self):
|
||||
def fn(a, dim, index, b):
|
||||
return aten.scatter_add(a, dim, index, b)
|
||||
|
||||
check_lowp = True
|
||||
if self.device == "xpu":
|
||||
check_lowp = False
|
||||
|
||||
for deterministic in [False, True]:
|
||||
with DeterministicGuard(deterministic):
|
||||
self.common(
|
||||
@ -6549,12 +6603,17 @@ class CommonTemplate:
|
||||
torch.tensor([[[3, 5, 7, 9]]]),
|
||||
torch.randn(1, 1, 10),
|
||||
],
|
||||
check_lowp=check_lowp,
|
||||
)
|
||||
|
||||
def test_scatter_reduce1(self):
|
||||
def fn(a, dim, index, b):
|
||||
return aten.scatter_reduce(a, dim, index, b, "sum")
|
||||
|
||||
check_lowp = True
|
||||
if self.device == "xpu":
|
||||
check_lowp = False
|
||||
|
||||
self.common(
|
||||
fn,
|
||||
[
|
||||
@ -6563,12 +6622,17 @@ class CommonTemplate:
|
||||
torch.tensor([[[3, 5, 7, 9]]]),
|
||||
torch.randn(1, 1, 10),
|
||||
],
|
||||
check_lowp=check_lowp,
|
||||
)
|
||||
|
||||
def test_scatter_reduce2(self):
|
||||
def fn(a, dim, index, b, reduce):
|
||||
return aten.scatter_reduce(a, dim, index, b, reduce, include_self=False)
|
||||
|
||||
check_lowp = True
|
||||
if self.device == "xpu":
|
||||
check_lowp = False
|
||||
|
||||
for reduce in ["sum", "amax"]:
|
||||
self.common(
|
||||
fn,
|
||||
@ -6579,6 +6643,7 @@ class CommonTemplate:
|
||||
torch.randn(2, 3),
|
||||
reduce,
|
||||
],
|
||||
check_lowp=check_lowp,
|
||||
)
|
||||
|
||||
def test_scatter_reduce3(self):
|
||||
@ -6589,6 +6654,10 @@ class CommonTemplate:
|
||||
a1.scatter_reduce_(dim, index, b, reduce=reduce)
|
||||
return (a, a1)
|
||||
|
||||
check_lowp = True
|
||||
if self.device == "xpu":
|
||||
check_lowp = False
|
||||
|
||||
for reduce in ["sum", "prod"]:
|
||||
self.common(
|
||||
fn,
|
||||
@ -6599,6 +6668,7 @@ class CommonTemplate:
|
||||
torch.randn(4, 5),
|
||||
reduce,
|
||||
],
|
||||
check_lowp=check_lowp,
|
||||
)
|
||||
|
||||
def test_dense_mask_index(self):
|
||||
@ -6890,6 +6960,7 @@ class CommonTemplate:
|
||||
):
|
||||
compiled_f = compile_fx_inner(mod, cloned_args)
|
||||
|
||||
@expectedFailureXPU
|
||||
def test_functionalize_rng_wrappers(self):
|
||||
# Ideally, we would like to use torch.compile for these operators. But
|
||||
# currently the plan is to introduce these operators at the partitioner
|
||||
@ -6935,6 +7006,7 @@ class CommonTemplate:
|
||||
self.assertEqual(a2, b2)
|
||||
|
||||
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
|
||||
@expectedFailureXPU
|
||||
def test_philox_rand(self):
|
||||
if self.device == "cpu":
|
||||
raise unittest.SkipTest(
|
||||
@ -7185,6 +7257,7 @@ class CommonTemplate:
|
||||
)
|
||||
assertGeneratedKernelCountEqual(self, 1)
|
||||
|
||||
@expectedFailureXPU
|
||||
def test_max_pool2d_with_indices_backward5(self):
|
||||
# Window size is too big. Should fallback
|
||||
def fn(a, b, c):
|
||||
@ -7768,7 +7841,7 @@ class CommonTemplate:
|
||||
1,
|
||||
dtype=torch.int32,
|
||||
layout=torch.strided,
|
||||
device=device(type="cuda", index=0),
|
||||
device=device(type=GPU_TYPE, index=0),
|
||||
pin_memory=False,
|
||||
)
|
||||
|
||||
@ -7777,7 +7850,7 @@ class CommonTemplate:
|
||||
start=0,
|
||||
step=1,
|
||||
dtype=torch.int32,
|
||||
device=device(type="cuda"),
|
||||
device=device(type=GPU_TYPE),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
@ -7787,7 +7860,7 @@ class CommonTemplate:
|
||||
start=0,
|
||||
step=1001,
|
||||
dtype=torch.int32,
|
||||
device=device(type="cuda", index=0),
|
||||
device=device(type=GPU_TYPE, index=0),
|
||||
requires_grad=False,
|
||||
)
|
||||
view: "i32[6150144]" = torch.ops.aten.reshape.default(mul, [-1])
|
||||
@ -7834,7 +7907,7 @@ class CommonTemplate:
|
||||
permute_1,
|
||||
]
|
||||
|
||||
kwargs = aot_graph_input_parser(forward, device="cuda")
|
||||
kwargs = aot_graph_input_parser(forward, device=GPU_TYPE)
|
||||
self.common(forward, [], kwargs=kwargs)
|
||||
|
||||
def test_misaligned_address_issue1(self):
|
||||
@ -8684,8 +8757,8 @@ class CommonTemplate:
|
||||
for dtype in [torch.int32, torch.int64]:
|
||||
self.common(fn, (torch.ones(1, 1, 13, dtype=dtype),))
|
||||
|
||||
@unittest.skipIf(not HAS_CPU or not RUN_CPU, "requires C++ compiler")
|
||||
def test_data_type_propagation(self):
|
||||
@unittest.skipIf(not HAS_CPU, "requires C++ compiler")
|
||||
def test_data_type_propogation(self):
|
||||
from torch._dynamo.utils import detect_fake_mode
|
||||
from torch._inductor.codegen.common import boolean_ops
|
||||
from torch._inductor.compile_fx import _shape_env_from_inputs
|
||||
@ -9033,6 +9106,7 @@ class CommonTemplate:
|
||||
)
|
||||
|
||||
@skipIfRocm
|
||||
@expectedFailureXPU
|
||||
def test_scaled_dot_product_efficient_attention(self):
|
||||
if self.device == "cpu":
|
||||
raise unittest.SkipTest(f"requires {GPU_TYPE}")
|
||||
@ -9245,10 +9319,13 @@ class CommonTemplate:
|
||||
def foo_cuda(x):
|
||||
return 3 * x
|
||||
|
||||
def foo_xpu(x):
|
||||
return 3 * x
|
||||
|
||||
def foo_meta(x):
|
||||
return torch.empty_like(x)
|
||||
|
||||
define_custom_op_for_test("foo", foo_cpu, foo_cuda, foo_meta)
|
||||
define_custom_op_for_test("foo", foo_cpu, foo_cuda, foo_xpu, foo_meta)
|
||||
|
||||
def fn(x):
|
||||
a = torch.nn.functional.relu(x)
|
||||
@ -9268,10 +9345,13 @@ class CommonTemplate:
|
||||
def foo_cuda(x, scale: float):
|
||||
return scale * x, torch.cos(x)
|
||||
|
||||
def foo_xpu(x, scale: float):
|
||||
return scale * x, torch.cos(x)
|
||||
|
||||
def foo_meta(x, scale: float):
|
||||
return torch.empty_like(x), torch.empty_like(x)
|
||||
|
||||
define_custom_op_2_for_test("foo2", foo_cpu, foo_cuda, foo_meta)
|
||||
define_custom_op_2_for_test("foo2", foo_cpu, foo_cuda, foo_xpu, foo_meta)
|
||||
|
||||
def fn(x, scale: float):
|
||||
a = torch.nn.functional.relu(x)
|
||||
@ -9295,10 +9375,16 @@ class CommonTemplate:
|
||||
result += t
|
||||
return result
|
||||
|
||||
def foo_xpu(x):
|
||||
result = torch.zeros_like(x[0])
|
||||
for t in x:
|
||||
result += t
|
||||
return result
|
||||
|
||||
def foo_meta(x):
|
||||
return torch.empty_like(x[0])
|
||||
|
||||
define_custom_op_3_for_test("foo3", foo_cpu, foo_cuda, foo_meta)
|
||||
define_custom_op_3_for_test("foo3", foo_cpu, foo_cuda, foo_xpu, foo_meta)
|
||||
|
||||
def fn(x):
|
||||
return torch.ops.test.foo3(x)
|
||||
@ -9316,8 +9402,8 @@ class CommonTemplate:
|
||||
def test_custom_op_fixed_layout_sequential(self):
|
||||
import torch.library
|
||||
|
||||
mod = nn.Conv2d(3, 128, 1, stride=1, bias=False).cuda()
|
||||
inp = torch.rand(2, 3, 128, 128, device="cuda")
|
||||
mod = nn.Conv2d(3, 128, 1, stride=1, bias=False).to(device=GPU_TYPE)
|
||||
inp = torch.rand(2, 3, 128, 128, device=GPU_TYPE)
|
||||
expected_stride = mod(inp).stride()
|
||||
|
||||
def bar_cpu(x):
|
||||
@ -9328,6 +9414,10 @@ class CommonTemplate:
|
||||
self.assertEqual(x.stride(), expected_stride)
|
||||
return x.clone()
|
||||
|
||||
def bar_xpu(x):
|
||||
self.assertEqual(x.stride(), expected_stride)
|
||||
return x.clone()
|
||||
|
||||
def bar_meta(x):
|
||||
return torch.empty_like(x)
|
||||
|
||||
@ -9335,6 +9425,7 @@ class CommonTemplate:
|
||||
"bar",
|
||||
bar_cpu,
|
||||
bar_cuda,
|
||||
bar_xpu,
|
||||
bar_meta,
|
||||
tags=[torch._C.Tag.needs_fixed_stride_order],
|
||||
)
|
||||
@ -9373,8 +9464,8 @@ class CommonTemplate:
|
||||
return out
|
||||
|
||||
model = Block()
|
||||
model = model.to("cuda").to(memory_format=torch.channels_last)
|
||||
input_t = torch.randn([1, 320, 128, 128], dtype=torch.float32, device="cuda")
|
||||
model = model.to(GPU_TYPE).to(memory_format=torch.channels_last)
|
||||
input_t = torch.randn([1, 320, 128, 128], dtype=torch.float32, device=GPU_TYPE)
|
||||
input_t = input_t.to(memory_format=torch.channels_last)
|
||||
expected_strides = model.helper(input_t).stride()
|
||||
|
||||
@ -9386,6 +9477,10 @@ class CommonTemplate:
|
||||
self.assertEqual(expected_strides, x.stride())
|
||||
return x.clone()
|
||||
|
||||
def baz_xpu(x):
|
||||
self.assertEqual(expected_strides, x.stride())
|
||||
return x.clone()
|
||||
|
||||
def baz_meta(x):
|
||||
return torch.empty_like(x)
|
||||
|
||||
@ -9393,6 +9488,7 @@ class CommonTemplate:
|
||||
"baz",
|
||||
baz_cpu,
|
||||
baz_cuda,
|
||||
baz_xpu,
|
||||
baz_meta,
|
||||
tags=[torch._C.Tag.needs_fixed_stride_order],
|
||||
)
|
||||
@ -9664,7 +9760,7 @@ class CommonTemplate:
|
||||
def test_pointwise(self, name, op):
|
||||
dtype = torch.float32
|
||||
check_lowp = True
|
||||
if self.device == "cuda" and name in {
|
||||
if self.device == GPU_TYPE and name in {
|
||||
"airy_ai",
|
||||
"bessel_i0",
|
||||
"bessel_i1",
|
||||
@ -9863,7 +9959,7 @@ def copy_tests(
|
||||
setattr(other_cls, f"{name}_{suffix}", new_test)
|
||||
|
||||
|
||||
if HAS_CPU and RUN_CPU:
|
||||
if HAS_CPU:
|
||||
|
||||
class SweepInputsCpuTest(SweepInputs2, TestCase):
|
||||
gen = InputGen(10, "cpu")
|
||||
@ -9876,7 +9972,7 @@ if HAS_CPU and RUN_CPU:
|
||||
|
||||
copy_tests(CommonTemplate, CpuTests, "cpu")
|
||||
|
||||
if HAS_GPU and RUN_GPU and not TEST_WITH_ASAN:
|
||||
if HAS_GPU and not TEST_WITH_ASAN:
|
||||
|
||||
class SweepInputsGPUTest(SweepInputs2, TestCase):
|
||||
gen = InputGen(10, GPU_TYPE)
|
||||
@ -9892,6 +9988,8 @@ if HAS_GPU and RUN_GPU and not TEST_WITH_ASAN:
|
||||
class TritonCodeGenTests(TestCase):
|
||||
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
|
||||
|
||||
device_type = GPU_TYPE
|
||||
|
||||
class NoOpCompilerBackend:
|
||||
def __init__(self):
|
||||
self.example_args = None
|
||||
@ -9948,6 +10046,7 @@ if HAS_GPU and RUN_GPU and not TEST_WITH_ASAN:
|
||||
|
||||
return kernels
|
||||
|
||||
@expectedFailureXPU
|
||||
def test_divisible_by_16_covers_numel_args(self):
|
||||
torch._dynamo.reset()
|
||||
|
||||
@ -9981,6 +10080,7 @@ if HAS_GPU and RUN_GPU and not TEST_WITH_ASAN:
|
||||
self.assertEqual(arguments_that_are_divisible_by_16_in_kernel1, (0, 1))
|
||||
torch._dynamo.reset()
|
||||
|
||||
@expectedFailureXPU
|
||||
@config.patch(assume_aligned_inputs=False)
|
||||
def test_codegen_config_option_dont_assume_alignment(self):
|
||||
def fn(x: torch.Tensor) -> torch.Tensor:
|
||||
@ -10605,8 +10705,9 @@ if HAS_GPU and RUN_GPU and not TEST_WITH_ASAN:
|
||||
}
|
||||
)
|
||||
@skipIfRocm
|
||||
@expectedFailureXPU
|
||||
@unittest.skipIf(
|
||||
torch.cuda.get_device_capability() < (9, 0),
|
||||
torch.cuda.is_available() and torch.cuda.get_device_capability() < (9, 0),
|
||||
"Triton does not support fp8 on A100",
|
||||
)
|
||||
def test_red_followed_by_transposed_pointwise(self):
|
||||
@ -10661,6 +10762,8 @@ if HAS_GPU and RUN_GPU and not TEST_WITH_ASAN:
|
||||
print(p.key_averages().table(max_name_column_width=200))
|
||||
|
||||
class RNNTest(TestCase):
|
||||
device_type = GPU_TYPE
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -10669,6 +10772,7 @@ if HAS_GPU and RUN_GPU and not TEST_WITH_ASAN:
|
||||
def forward(self, x):
|
||||
return self.gru(x)
|
||||
|
||||
@expectedFailureXPU
|
||||
def test_rnn_compile_safe(self):
|
||||
device = torch.device(GPU_TYPE)
|
||||
model = RNNTest.Model().to(device)
|
||||
@ -10707,7 +10811,7 @@ if HAS_GPU and RUN_GPU and not TEST_WITH_ASAN:
|
||||
torch.compile(f)(x)
|
||||
|
||||
|
||||
if HAS_CPU and RUN_CPU:
|
||||
if HAS_CPU:
|
||||
|
||||
class TestFull(TestCase):
|
||||
def test_full_dtype(self):
|
||||
|
@ -17,7 +17,7 @@ from torch.testing._internal.inductor_utils import (
|
||||
_check_has_dynamic_shape,
|
||||
GPU_TYPE,
|
||||
HAS_CPU,
|
||||
HAS_GPU,
|
||||
HAS_CUDA,
|
||||
)
|
||||
|
||||
if IS_WINDOWS and IS_CI:
|
||||
@ -372,7 +372,7 @@ if HAS_CPU:
|
||||
)
|
||||
|
||||
|
||||
if HAS_GPU and not TEST_WITH_ASAN:
|
||||
if HAS_CUDA and not TEST_WITH_ASAN:
|
||||
|
||||
class DynamicShapesCodegenGPUTests(TestCase):
|
||||
maxDiff = None
|
||||
@ -398,5 +398,5 @@ if HAS_GPU and not TEST_WITH_ASAN:
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
if HAS_CPU or HAS_GPU:
|
||||
if HAS_CPU or HAS_CUDA:
|
||||
run_tests(needs="filelock")
|
||||
|
@ -30,7 +30,7 @@ from torch.testing._internal.common_utils import (
|
||||
TEST_WITH_ASAN,
|
||||
TEST_WITH_ROCM,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_GPU
|
||||
|
||||
if IS_WINDOWS and IS_CI:
|
||||
sys.stderr.write(
|
||||
@ -96,7 +96,7 @@ if HAS_CPU:
|
||||
copy_tests(DynamicShapesCommonTemplate, DynamicShapesCpuTests, "cpu", test_failures)
|
||||
|
||||
|
||||
if HAS_GPU and not TEST_WITH_ASAN:
|
||||
if HAS_CUDA and not TEST_WITH_ASAN:
|
||||
|
||||
class DynamicShapesGPUTests(TestCase):
|
||||
common = check_model_gpu
|
||||
@ -794,5 +794,5 @@ if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
# Slow on ASAN after https://github.com/pytorch/pytorch/pull/94068
|
||||
if (HAS_CPU or HAS_GPU) and not TEST_WITH_ASAN:
|
||||
if (HAS_CPU or HAS_CUDA) and not TEST_WITH_ASAN:
|
||||
run_tests(needs="filelock")
|
||||
|
@ -4,9 +4,10 @@ import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_device_type import expectedFailureXPU
|
||||
|
||||
from torch.testing._internal.common_utils import IS_LINUX
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
||||
|
||||
try:
|
||||
import triton # noqa: F401
|
||||
@ -22,6 +23,8 @@ from torch._inductor.test_case import run_tests, TestCase
|
||||
|
||||
|
||||
class TestTritonHeuristics(TestCase):
|
||||
device_type = GPU_TYPE
|
||||
|
||||
def test_triton_config(self):
|
||||
"""
|
||||
Make sure block size does not exceed the maximum defined in inductor config.
|
||||
@ -54,9 +57,9 @@ class TestTritonHeuristics(TestCase):
|
||||
s1 = 512
|
||||
|
||||
args = [
|
||||
torch.rand([2, 4], device="cuda"),
|
||||
torch.rand([2], device="cuda"),
|
||||
torch.rand([s0, s1], device="cuda"),
|
||||
torch.rand([2, 4], device=GPU_TYPE),
|
||||
torch.rand([2], device=GPU_TYPE),
|
||||
torch.rand([s0, s1], device=GPU_TYPE),
|
||||
]
|
||||
torch._dynamo.mark_dynamic(args[-1], 0)
|
||||
foo_c = torch.compile(forward)
|
||||
@ -64,17 +67,18 @@ class TestTritonHeuristics(TestCase):
|
||||
self.assertEqual(forward(*args), foo_c(*args))
|
||||
|
||||
args = [
|
||||
torch.rand([2, 4], device="cuda"),
|
||||
torch.rand([2], device="cuda"),
|
||||
torch.rand([s0, s1], device="cuda"),
|
||||
torch.rand([2, 4], device=GPU_TYPE),
|
||||
torch.rand([2], device=GPU_TYPE),
|
||||
torch.rand([s0, s1], device=GPU_TYPE),
|
||||
]
|
||||
self.assertEqual(forward(*args), foo_c(*args))
|
||||
|
||||
@unittest.skip("https://github.com/pytorch/pytorch/issues/123210")
|
||||
@expectedFailureXPU
|
||||
def test_artificial_zgrid(self):
|
||||
self._test_artificial_zgrid()
|
||||
|
||||
@unittest.skip("https://github.com/pytorch/pytorch/issues/123210")
|
||||
@expectedFailureXPU
|
||||
@config.patch("cpp_wrapper", True)
|
||||
def test_artificial_grid_cpp_wrapper(self):
|
||||
self._test_artificial_zgrid()
|
||||
|
@ -15,6 +15,7 @@ from torch.testing._internal.common_utils import (
|
||||
skipIfTorchDynamo,
|
||||
TEST_CUDA,
|
||||
TEST_PRIVATEUSE1,
|
||||
TEST_XPU,
|
||||
)
|
||||
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
|
||||
|
||||
@ -34,7 +35,7 @@ def remove_build_path():
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
IS_ARM64 or not IS_LINUX or TEST_CUDA or TEST_PRIVATEUSE1 or TEST_ROCM,
|
||||
IS_ARM64 or not IS_LINUX or TEST_CUDA or TEST_PRIVATEUSE1 or TEST_ROCM or TEST_XPU,
|
||||
"Only on linux platform and mutual exclusive to other backends",
|
||||
)
|
||||
@torch.testing._internal.common_utils.markDynamoStrictTest
|
||||
|
@ -633,6 +633,7 @@ def get_device_type_test_bases():
|
||||
test_bases.append(CPUTestBase)
|
||||
if torch.cuda.is_available():
|
||||
test_bases.append(CUDATestBase)
|
||||
|
||||
device_type = torch._C._get_privateuse1_backend_name()
|
||||
device_mod = getattr(torch, device_type, None)
|
||||
if hasattr(device_mod, "is_available") and device_mod.is_available():
|
||||
@ -1138,7 +1139,12 @@ class expectedFailure:
|
||||
|
||||
@wraps(fn)
|
||||
def efail_fn(slf, *args, **kwargs):
|
||||
if self.device_type is None or self.device_type == slf.device_type:
|
||||
if not hasattr(slf, "device_type") and hasattr(slf, "device") and isinstance(slf.device, str):
|
||||
target_device_type = slf.device
|
||||
else:
|
||||
target_device_type = slf.device_type
|
||||
|
||||
if self.device_type is None or self.device_type == target_device_type:
|
||||
try:
|
||||
fn(slf, *args, **kwargs)
|
||||
except Exception:
|
||||
@ -1386,6 +1392,9 @@ def expectedFailureCPU(fn):
|
||||
def expectedFailureCUDA(fn):
|
||||
return expectedFailure('cuda')(fn)
|
||||
|
||||
def expectedFailureXPU(fn):
|
||||
return expectedFailure('xpu')(fn)
|
||||
|
||||
def expectedFailureMeta(fn):
|
||||
return skipIfTorchDynamo()(expectedFailure('meta')(fn))
|
||||
|
||||
|
@ -34,16 +34,16 @@ HAS_CUDA = torch.cuda.is_available() and has_triton()
|
||||
|
||||
HAS_XPU = torch.xpu.is_available() and has_triton()
|
||||
|
||||
HAS_GPU = HAS_CUDA
|
||||
HAS_GPU = HAS_CUDA or HAS_XPU
|
||||
|
||||
GPUS = ["cuda"]
|
||||
GPUS = ["cuda", "xpu"]
|
||||
|
||||
HAS_MULTIGPU = any(
|
||||
getattr(torch, gpu).is_available() and getattr(torch, gpu).device_count() >= 2
|
||||
for gpu in GPUS
|
||||
)
|
||||
|
||||
tmp_gpus = [x for x in ["cuda", "xpu"] if getattr(torch, x).is_available()]
|
||||
tmp_gpus = [x for x in GPUS if getattr(torch, x).is_available()]
|
||||
assert len(tmp_gpus) <= 1
|
||||
GPU_TYPE = "cuda" if len(tmp_gpus) == 0 else tmp_gpus.pop()
|
||||
del tmp_gpus
|
||||
|
Reference in New Issue
Block a user