Files
pytorch/test/inductor/test_gpu_cpp_wrapper.py
xinan.lin b75bb64eb4 [AOTI XPU] Rename test_cuda_cpp_wrapper.py to test_gpu_cpp_wrapper.py, (#135320)
[Inductor] Rename test_cuda_cpp_wrapper.py to test_gpu_cpp_wrapper.py, since the test suite is shared by cuda and xpu.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135320
Approved by: https://github.com/jansel, https://github.com/EikanWang, https://github.com/desertfire
ghstack dependencies: #135318
2024-11-27 14:08:06 +00:00

303 lines
10 KiB
Python

# Owner(s): ["module: inductor"]
import itertools
import sys
import unittest
from typing import NamedTuple
import torch
from torch._inductor import config
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.utils import is_gpu
from torch.testing._internal.common_device_type import (
get_desired_device_type_test_bases,
)
from torch.testing._internal.common_utils import slowTest, TEST_WITH_ASAN
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
try:
try:
from . import (
test_combo_kernels,
test_foreach,
test_pattern_matcher,
test_select_algorithm,
test_torchinductor,
test_torchinductor_dynamic_shapes,
)
except ImportError:
import test_combo_kernels # @manual=fbcode//caffe2/test/inductor:combo_kernels-library
import test_foreach # @manual=fbcode//caffe2/test/inductor:foreach-library
import test_pattern_matcher # @manual=fbcode//caffe2/test/inductor:pattern_matcher-library
import test_select_algorithm # @manual=fbcode//caffe2/test/inductor:select_algorithm-library
import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library
import test_torchinductor_dynamic_shapes # @manual=fbcode//caffe2/test/inductor:test_inductor-library_dynamic_shapes
except unittest.SkipTest:
if __name__ == "__main__":
sys.exit(0)
raise
_desired_test_bases = get_desired_device_type_test_bases(allow_xpu=True)
RUN_GPU = (
HAS_GPU
and any(is_gpu(getattr(x, "device_type", "")) for x in _desired_test_bases)
and not TEST_WITH_ASAN
)
class GpuWrapperTemplate:
pass
class TestGpuWrapper(InductorTestCase):
device = GPU_TYPE
class DynamicShapesGpuWrapperGpuTests(InductorTestCase):
device = GPU_TYPE
test_failures_gpu_wrapper = {
"test_mm_plus_mm2_cuda_dynamic_shapes": test_torchinductor.TestFailure(
("gpu_wrapper",), is_skip=True
),
"test_randint_xpu": test_torchinductor.TestFailure(("gpu_wrapper",), is_skip=False),
"test_randint_xpu_dynamic_shapes": test_torchinductor.TestFailure(
("gpu_wrapper",), is_skip=False
),
# ATen ops: scaled_dot_product_efficient_attention not implemented on XPU.
"test_scaled_dot_product_efficient_attention_xpu": test_torchinductor.TestFailure(
("gpu_wrapper",), is_skip=False
),
"test_scaled_dot_product_efficient_attention_xpu_dynamic_shapes": test_torchinductor.TestFailure(
("gpu_wrapper",), is_skip=False
),
}
def make_test_case(
name,
device,
tests,
condition=True,
slow=False,
func_inputs=None,
code_string_count=None,
check_code=True,
):
test_name = f"{name}_{device}" if device else name
if code_string_count is None:
code_string_count = {}
func = getattr(tests, test_name)
assert callable(func), "not a callable"
func = slowTest(func) if slow else func
@config.patch(cpp_wrapper=True, search_autotune_cache=False)
def fn(self):
tests.setUpClass()
tests.setUp()
try:
with torch._C._PreserveDispatchKeyGuard():
torch._C._dispatch_tls_set_dispatch_key_included(
torch._C.DispatchKey.Dense, True
)
_, code = test_torchinductor.run_and_get_cpp_code(
func, *func_inputs if func_inputs else []
)
if check_code:
self.assertEqual("CppWrapperCodeCache" in code, True)
self.assertTrue(
all(
code.count(string) == code_string_count[string]
for string in code_string_count
)
)
finally:
tests.tearDown()
tests.tearDownClass()
fn.__name__ = test_name
import copy
fn.__dict__ = copy.deepcopy(func.__dict__)
if condition:
setattr(
GpuWrapperTemplate,
test_name,
fn,
)
if RUN_GPU:
class BaseTest(NamedTuple):
name: str
device: str = GPU_TYPE
tests: InductorTestCase = test_torchinductor.GPUTests()
check_code: bool = True
# XPU Not implemented yet
XPU_BASE_TEST_SKIP = [
"test_foreach_cpp_wrapper",
"test_enable_dynamic_shapes_cpp_wrapper",
"test_dynamic_shapes_persistent_reduction_mixed_x_dim",
"test_cat_slice_cat",
"test_mm_plus_mm2",
"test_mm_plus_mm3",
"test_addmm",
"test_linear_relu",
"test_fft_real_input",
"test_fft_real_input_real_output",
]
# Maintain two separate test lists for cuda and cpp for now
for item in [
BaseTest("test_add_complex"),
BaseTest("test_add_complex4"),
BaseTest("test_as_strided"), # buffer reuse
BaseTest("test_batch_norm_2d_2"),
BaseTest("test_bernoulli1"),
BaseTest("test_bitwise"), # int32
BaseTest("test_bmm1"),
BaseTest("test_bmm2"),
BaseTest("test_buffer_use_after_remove"),
BaseTest("test_cat"), # alias
BaseTest("test_convolution1"),
BaseTest("test_conv_backward"),
BaseTest("test_custom_op_1"),
BaseTest("test_custom_op_2"),
BaseTest("test_custom_op_3"),
BaseTest("test_embedding_bag"), # test default FallbackKernel
BaseTest("test_index_put_deterministic_fallback"),
BaseTest("test_adding_tensor_offsets"),
BaseTest("test_index_tensor"),
BaseTest("test_inductor_layout_optimization_input_mutations"),
BaseTest("test_insignificant_strides"),
BaseTest("test_layer_norm"),
BaseTest("test_linear1"),
BaseTest("test_linear2"),
BaseTest("test_mm_views"),
BaseTest("test_multi_device"),
BaseTest("test_multi_threading"),
BaseTest("test_pow3"),
BaseTest("test_profiler_mark_wrapper_call"),
BaseTest("test_randint"),
BaseTest("test_reduction1"), # Reduction
BaseTest("test_relu"), # multiple inputs
BaseTest("test_repeat_interleave_2"),
BaseTest("test_roi_align"),
BaseTest("test_scalar_input"),
BaseTest("test_scaled_dot_product_attention"),
BaseTest("test_scaled_dot_product_efficient_attention"),
BaseTest("test_sort"),
BaseTest("test_silu"), # single input, single output
BaseTest("test_sum_dtype"), # float64
BaseTest("test_sum_int"), # bool, int64, int8, uint8
BaseTest("test_transpose"), # multiple outputs, buffer clear
*[
BaseTest(f"test_unspec_inputs_{str(dtype)[6:]}")
for dtype in test_torchinductor.test_dtypes
],
BaseTest("test_consecutive_split_cumprod"),
BaseTest("test_pointwise_hermite_polynomial_he"),
BaseTest("test_pointwise_hermite_polynomial_h"),
BaseTest(
"test_foreach_cpp_wrapper",
tests=test_foreach.ForeachTests(),
), # test foreach
BaseTest(
"test_enable_dynamic_shapes_cpp_wrapper",
tests=test_foreach.ForeachTests(),
),
BaseTest(
"test_dynamic_shapes_persistent_reduction_mixed_x_dim",
tests=test_combo_kernels.ComboKernelDynamicShapesTests(),
),
BaseTest(
"test_cat_slice_cat",
tests=test_pattern_matcher.TestPatternMatcher(),
),
# TODO: Re-enable this test after fixing cuda wrapper for conv Triton templates with dynamic shapes.
# This test is unstable: it succeeds when an ATEN kernel is used, and fails when a Triton kernel is used.
# Currently it passes on CI (an ATEN kernel is chosen) and fails locally (a Triton kernel is chosen).
# Ideally, it should succeed for whatever kernels.
# BaseTest(
# "test_convolution1",
# device=None,
# tests=test_select_algorithm.TestSelectAlgorithm(),
# ),
BaseTest(
"test_mm_plus_mm2",
tests=test_select_algorithm.TestSelectAlgorithm(),
),
BaseTest(
"test_mm_plus_mm3",
tests=test_select_algorithm.TestSelectAlgorithm(),
),
BaseTest("test_fft_real_input"),
BaseTest("test_fft_real_input_real_output"),
*[
# some dtypes may raise exception and be skipped in test_dtypeview, so set check_code to False here
BaseTest(
f"test_dtypeview_{str(dtype_x)[6:]}_{str(dtype_y)[6:]}",
check_code=False,
)
for dtype_x, dtype_y in itertools.product(
test_torchinductor.test_dtypes, test_torchinductor.test_dtypes
)
],
BaseTest("test_dtypeview_fusion"),
# skip if not enough SMs
BaseTest(
"test_addmm",
tests=test_select_algorithm.TestSelectAlgorithm(),
),
# skip if not enough SMs
BaseTest(
"test_linear_relu",
tests=test_select_algorithm.TestSelectAlgorithm(),
),
]:
if item.device == "xpu" and item.name in XPU_BASE_TEST_SKIP:
continue
make_test_case(item.name, item.device, item.tests, check_code=item.check_code)
from torch._inductor.utils import is_big_gpu
if GPU_TYPE == "cuda" and is_big_gpu(0):
skip_list = ["test_addmm", "test_linear_relu"]
# need to skip instead of omit, otherwise fbcode ci can be flaky
for test_name in skip_list:
test_failures_gpu_wrapper[
f"{test_name}_cuda"
] = test_torchinductor.TestFailure(("gpu_wrapper",), is_skip=True)
test_failures_gpu_wrapper[
f"{test_name}_gpu_dynamic_shapes"
] = test_torchinductor.TestFailure(("gpu_wrapper",), is_skip=True)
test_torchinductor.copy_tests(
GpuWrapperTemplate, TestGpuWrapper, "gpu_wrapper", test_failures_gpu_wrapper
)
DynamicShapesGpuWrapperTemplate = (
test_torchinductor_dynamic_shapes.make_dynamic_cls(GpuWrapperTemplate)
)
test_torchinductor.copy_tests(
DynamicShapesGpuWrapperTemplate,
DynamicShapesGpuWrapperGpuTests,
"gpu_wrapper",
test_failures_gpu_wrapper,
xfail_prop="_expected_failure_dynamic_wrapper",
)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if RUN_GPU:
run_tests(needs="filelock")