[Inductor UT] Generalize inductor UT for intel GPU (Part 2) (#134556)

[Inductor UT] Reuse Inductor test case for Intel GPU.
Reuse `test/inductor/test_torchinductor_opinfo.py`
Reuse `test/inductor/test_minifier_isolate.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134556
Approved by: https://github.com/etaf, https://github.com/eellison
This commit is contained in:
xingyuan li
2024-09-13 05:16:26 +00:00
committed by PyTorch MergeBot
parent e54b559e88
commit b38be727eb
2 changed files with 571 additions and 174 deletions

View File

@ -8,12 +8,11 @@ from torch.testing._internal.common_utils import (
IS_MACOS,
skipIfRocm,
skipIfWindows,
skipIfXpu,
TEST_WITH_ASAN,
)
from torch.testing._internal.inductor_utils import HAS_CUDA
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
from torch.testing._internal.inductor_utils import GPU_TYPE
from torch.testing._internal.triton_utils import requires_gpu
# These minifier tests are slow, because they must be run in separate
@ -41,10 +40,11 @@ inner(torch.randn(2, 2).to("{device}"))
self._test_after_aot_runtime_error("cpu", "")
@skipIfRocm
@requires_cuda
@skipIfXpu
@requires_gpu
@inductor_config.patch("triton.inject_relu_bug_TESTING_ONLY", "runtime_error")
def test_after_aot_cuda_runtime_error(self):
self._test_after_aot_runtime_error("cuda", "device-side assert")
def test_after_aot_gpu_runtime_error(self):
self._test_after_aot_runtime_error(GPU_TYPE, "device-side assert")
if __name__ == "__main__":

View File

@ -26,6 +26,7 @@ from torch.testing._internal.common_device_type import (
ops,
skipCPUIf,
skipCUDAIf,
skipXPUIf,
)
from torch.testing._internal.common_methods_invocations import op_db, skipOps
from torch.testing._internal.common_utils import (
@ -40,7 +41,7 @@ from torch.testing._internal.common_utils import (
TEST_WITH_ASAN,
TEST_WITH_ROCM,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_XPU
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map
@ -207,6 +208,8 @@ if TEST_WITH_ROCM:
inductor_skips["cuda"]["logcumsumexp"] = {f32}
inductor_skips["cuda"]["special.modified_bessel_i1"] = {f64}
inductor_skips["xpu"] = {}
inductor_expected_failures_single_sample = defaultdict(dict)
inductor_expected_failures_single_sample["cpu"] = {
@ -256,6 +259,109 @@ inductor_expected_failures_single_sample["cuda"] = {
}, # NYI: could not find kernel for aten.view.default at dispatch key DispatchKey.SparseCUDA
}
inductor_expected_failures_single_sample["xpu"] = {
"_upsample_bilinear2d_aa": {f16, f32, f64},
"cholesky": {f32, f64},
"multinomial": {f16, f32, f64},
("normal", "in_place"): {f16, f32, f64},
("normal", "number_mean"): {f16, f32, f64},
"normal": {f16, f32, f64},
"sparse.sampled_addmm": {f32, f64},
"tan": {f16},
"torch.ops.aten._flash_attention_forward": {f16},
"torch.ops.aten._efficient_attention_forward": {f16, f32},
"to_sparse": {f16, f32, f64, b8, i32, i64},
"linalg.eig": {f32, f64},
"linalg.eigvals": {f32, f64},
# Double and complex datatype matmul is not supported in oneDNN
"__rmatmul__": {f64},
("addmm", "decomposed"): {f64},
"addr": {f64},
"baddbmm": {f64},
"bmm": {f64},
"byte": {f16, f32},
"cdist": {f64},
"corrcoef": {f64},
"cov": {f64},
"einsum": {f64},
"inner": {f64},
"linalg.cholesky_ex": {f64},
"linalg.cholesky": {f64},
("linalg.det", "singular"): {f64},
"linalg.ldl_factor_ex": {f64},
"linalg.ldl_factor": {f64},
"linalg.ldl_solve": {f64},
"linalg.matrix_power": {f64},
"linalg.multi_dot": {f64},
"matmul": {f64},
"mm": {f64},
"mv": {f64},
"nn.functional.bilinear": {f64},
"nn.functional.linear": {f64},
"pca_lowrank": {f64},
"svd_lowrank": {f64},
"tensordot": {f64},
"triangular_solve": {f64},
"svd": {f64},
"qr": {f64},
"pinverse": {f64},
"ormqr": {f64},
("norm", "nuc"): {f64},
"lu": {f64},
"lu_solve": {f64},
"logdet": {f64},
"linalg.tensorsolve": {f64},
"linalg.tensorinv": {f64},
"linalg.svdvals": {f64},
"linalg.svd": {f64},
"linalg.solve": {f64},
"linalg.solve_triangular": {f64},
"linalg.solve_ex": {f64},
"linalg.slogdet": {f64},
"linalg.qr": {f64},
"linalg.pinv": {f64},
("linalg.pinv", "hermitian"): {f64},
"linalg.norm": {f64},
("linalg.norm", "subgradients_at_zero"): {f64},
"linalg.matrix_rank": {f64},
("linalg.matrix_rank", "hermitian"): {f64},
"linalg.matrix_norm": {f64},
"linalg.lu": {f64},
"linalg.lu_solve": {f64},
"linalg.lu_factor": {f64},
"linalg.lu_factor_ex": {f64},
"linalg.lstsq": {f64},
("linalg.lstsq", "grad_oriented"): {f64},
"linalg.inv": {f64},
"linalg.inv_ex": {f64},
"linalg.householder_product": {f64},
"linalg.eigvalsh": {f64},
"linalg.eigh": {f64},
"linalg.det": {f64},
"linalg.cond": {f64},
"geqrf": {f64},
"cholesky_solve": {f64},
"cholesky_inverse": {f64},
# could not create a primitive
"addbmm": {f16, f32, f64},
"addmm": {f16, f32, f64},
"addmv": {f32, f64},
# could not create a primitive descriptor for
# a deconvolution forward propagation primitive
"nn.functional.conv_transpose2d": {f32, f64},
"nn.functional.conv_transpose3d": {f32, f64},
# rrelu not supported on XPU now
"nn.functional.rrelu": {f16, f32, f64},
"histc": {i32, i64},
# not implemented for 'Half'
"nn.functional.multilabel_margin_loss": {f16},
"nn.functional.multi_margin_loss": {f16},
"nn.functional.avg_pool3d": {f16},
"nn.functional.adaptive_max_pool3d": {f16},
# not implemented for 'Bool'
"nn.functional.unfold": {b8},
}
# intentionally not handled
intentionally_not_handled = {
@ -278,11 +384,13 @@ if not functorch_config.view_replay_for_aliased_outputs:
}
inductor_expected_failures_single_sample["cuda"].update(intentionally_not_handled)
inductor_expected_failures_single_sample["xpu"].update(intentionally_not_handled)
inductor_gradient_expected_failures_single_sample = defaultdict(dict)
inductor_gradient_expected_failures_single_sample["cuda"] = {}
inductor_gradient_expected_failures_single_sample["xpu"] = {}
if not TEST_MKL:
inductor_expected_failures_single_sample["cpu"].update({})
@ -290,6 +398,7 @@ if not TEST_MKL:
inductor_should_fail_with_exception = defaultdict(dict)
inductor_should_fail_with_exception["cpu"] = {}
inductor_should_fail_with_exception["cuda"] = {}
inductor_should_fail_with_exception["xpu"] = {}
def get_skips_and_xfails(from_dict, xfails=True):
@ -324,9 +433,10 @@ torch.testing._internal.common_methods_invocations.wrapper_set_seed = (
wrapper_noop_set_seed
)
# key can be either op_name, or (op_name, dtype)
inductor_override_kwargs = defaultdict(dict)
# key can be either op_name, or (op_name, deivce_type), or (op_name, device_type, dtype)
inductor_override_kwargs = {
inductor_override_kwargs["cpu"] = {
# the return value of empty is undefined
"empty": {"assert_equal": False},
"empty_permuted": {"assert_equal": False},
@ -335,92 +445,222 @@ inductor_override_kwargs = {
"empty_strided": {"assert_equal": False},
"new_empty_strided": {"assert_equal": False},
"randn": {"assert_equal": False},
("cross", "cuda", f16): {"reference_in_float": True},
("linalg.cross", "cuda", f16): {"reference_in_float": True},
("addr", "cuda", f16): {"reference_in_float": True},
("baddbmm", "cuda", f16): {"atol": 2e-3, "rtol": 0.002}, # decomp affects accuracy
("angle", "cuda", f64): {"reference_in_float": True},
("asin", "cuda", f16): {"reference_in_float": True},
("atanh", "cuda", f16): {"reference_in_float": True},
("cauchy", "cuda"): {"reference_in_float": True},
("cummax", "cuda", f16): {"atol": 5e-4, "rtol": 0.002},
("cumsum", "cuda", f16): {"reference_in_float": True},
("cumprod", "cuda"): {"reference_in_float": True, "atol": 7e-5, "rtol": 0.002},
("logcumsumexp", "cuda"): {"grad_atol": 8e-4, "grad_rtol": 0.001},
("exponential", "cuda"): {"reference_in_float": True},
("geometric", "cuda"): {"reference_in_float": True},
("kron", "cuda", f16): {"reference_in_float": True},
("log_normal", "cuda"): {"reference_in_float": True},
("masked.softmin", "cuda", f16): {"atol": 1e-4, "rtol": 0.01},
("nn.functional.batch_norm", "cuda", f16): {"reference_in_float": True},
("nn.functional.batch_norm.without_cudnn", "cuda", f16): {
"reference_in_float": True
},
("nn.functional.cosine_similarity", "cuda", f16): {"reference_in_float": True},
("nn.functional.instance_norm", "cuda", f16): {"reference_in_float": True},
("nn.functional.local_response_norm", "cuda", f16): {"reference_in_float": True},
("nn.functional.normalize", "cuda", f16): {"atol": 1e-3, "rtol": 0.05},
("nn.functional.rms_norm", "cuda", f16): {"reference_in_float": True},
("nn.functional.soft_margin_loss", "cuda", f16): {"reference_in_float": True},
("nn.functional.softmin", "cuda", f16): {"atol": 1e-4, "rtol": 0.01},
("nn.functional.softsign", "cuda", f16): {"reference_in_float": True},
("nn.functional.tanhshrink", "cuda", f16): {"atol": 3e-4, "rtol": 0.001},
("nn.functional.multilabel_soft_margin_loss", "cpu", f16): {
("nn.functional.multilabel_soft_margin_loss", f16): {
"atol": 3e-4,
"rtol": 0.002,
},
("outer", "cuda", f16): {"reference_in_float": True},
("round.decimals_3", "cuda", f16): {"reference_in_float": True},
("nn.functional.triplet_margin_loss", "cuda", f16): {"atol": 1e-4, "rtol": 0.02},
("nn.functional.triplet_margin_with_distance_loss", "cuda", f16): {
"atol": 1e-4,
"rtol": 0.02,
},
("sinc", "cuda", f16): {"atol": 0.008, "rtol": 0.002},
("torch.ops.aten._safe_softmax.default", "cuda", f16): {"atol": 5e-4, "rtol": 0.02},
("softmax", "cpu", f16): {"atol": 1e-4, "rtol": 0.02},
("softmax", "cuda", f16): {"atol": 1e-4, "rtol": 0.02},
("_softmax_backward_data", "cuda", f16): {"atol": 0.008, "rtol": 0.002},
("special.log_ndtr", "cuda", f64): {"atol": 1e-6, "rtol": 1e-5},
("polygamma.polygamma_n_0", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4},
("polygamma.polygamma_n_1", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4},
("polygamma.polygamma_n_2", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4},
("polygamma.polygamma_n_3", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4},
("polygamma.polygamma_n_4", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4},
("special.polygamma.special_polygamma_n_0", "cpu", f32): {
("softmax", f16): {"atol": 1e-4, "rtol": 0.02},
("polygamma.polygamma_n_0", f32): {"atol": 1e-3, "rtol": 1e-4},
("polygamma.polygamma_n_1", f32): {"atol": 1e-3, "rtol": 1e-4},
("polygamma.polygamma_n_2", f32): {"atol": 1e-3, "rtol": 1e-4},
("polygamma.polygamma_n_3", f32): {"atol": 1e-3, "rtol": 1e-4},
("polygamma.polygamma_n_4", f32): {"atol": 1e-3, "rtol": 1e-4},
("special.polygamma.special_polygamma_n_0", f32): {
"atol": 1e-3,
"rtol": 1e-4,
},
("std_mean.unbiased", "cuda", f16): {"reference_in_float": True},
("uniform", "cuda"): {"reference_in_float": True},
("_unsafe_masked_index_put_accumulate", "cuda", f16): {"atol": 1e-4, "rtol": 0.01},
("_unsafe_masked_index_put_accumulate", "cpu", f16): {"atol": 1e-4, "rtol": 0.01},
("_unsafe_masked_index_put_accumulate", f16): {"atol": 1e-4, "rtol": 0.01},
# Following tests are failing with strict comparision but atol=1 is acceptable due roundings errors
("nn.functional.interpolate.bilinear", "cpu", u8): {"atol": 1, "rtol": 0},
("nn.functional.upsample_bilinear", "cpu", u8): {"atol": 1, "rtol": 0},
("nn.functional.interpolate.bicubic", "cpu", u8): {"atol": 1, "rtol": 0},
("nn.functional.interpolate.bilinear", u8): {"atol": 1, "rtol": 0},
("nn.functional.upsample_bilinear", u8): {"atol": 1, "rtol": 0},
("nn.functional.interpolate.bicubic", u8): {"atol": 1, "rtol": 0},
# High atol due to precision loss
("nn.functional.interpolate.bilinear", "cuda", f64): {"atol": 5e-4, "rtol": 0},
("nn.functional.upsample_bilinear", "cuda", f64): {"atol": 5e-4, "rtol": 0},
("nn.functional.interpolate.bicubic", "cpu", f32): {"atol": 5e-3, "rtol": 0},
("nn.functional.interpolate.bicubic", "cuda", f64): {"atol": 1e-3, "rtol": 0},
# Unreasonably high atol requirement:
("index_reduce.mean", "cuda", f16): {"check_gradient": False},
("index_reduce.mean", "cuda", f32): {"check_gradient": False},
("index_reduce.mean", "cuda", f64): {"check_gradient": False},
# Gradient contains non-finite entries:
("index_reduce.amin", "cuda", f64): {"check_gradient": False},
("index_reduce.amin", "cuda", f32): {"check_gradient": False},
("index_reduce.amin", "cuda", f16): {"check_gradient": False},
("index_reduce.amax", "cuda", f64): {"check_gradient": False},
("index_reduce.amax", "cuda", f32): {"check_gradient": False},
("index_reduce.amax", "cuda", f16): {"check_gradient": False},
("tanh", "cuda", f16): {"atol": 1e-4, "rtol": 1e-2},
("nn.functional.interpolate.bicubic", f32): {"atol": 5e-3, "rtol": 0},
}
inductor_override_kwargs["cuda"] = {
# the return value of empty is undefined
"empty": {"assert_equal": False},
"empty_permuted": {"assert_equal": False},
"empty_like": {"assert_equal": False},
"new_empty": {"assert_equal": False},
"empty_strided": {"assert_equal": False},
"new_empty_strided": {"assert_equal": False},
"randn": {"assert_equal": False},
("cross", f16): {"reference_in_float": True},
("linalg.cross", f16): {"reference_in_float": True},
("addr", f16): {"reference_in_float": True},
("baddbmm", f16): {"atol": 2e-3, "rtol": 0.002}, # decomp affects accuracy
("angle", f64): {"reference_in_float": True},
("asin", f16): {"reference_in_float": True},
("atanh", f16): {"reference_in_float": True},
"cauchy": {"reference_in_float": True},
("cummax", f16): {"atol": 5e-4, "rtol": 0.002},
("cumsum", f16): {"reference_in_float": True},
"cumprod": {"reference_in_float": True, "atol": 7e-5, "rtol": 0.002},
"logcumsumexp": {"grad_atol": 8e-4, "grad_rtol": 0.001},
"exponential": {"reference_in_float": True},
"geometric": {"reference_in_float": True},
("kron", f16): {"reference_in_float": True},
"log_normal": {"reference_in_float": True},
("masked.softmin", f16): {"atol": 1e-4, "rtol": 0.01},
("nn.functional.batch_norm", f16): {"reference_in_float": True},
("nn.functional.batch_norm.without_cudnn", f16): {"reference_in_float": True},
("nn.functional.cosine_similarity", f16): {"reference_in_float": True},
("nn.functional.instance_norm", f16): {"reference_in_float": True},
("nn.functional.local_response_norm", f16): {"reference_in_float": True},
("nn.functional.normalize", f16): {"atol": 1e-3, "rtol": 0.05},
("nn.functional.rms_norm", f16): {"reference_in_float": True},
("nn.functional.soft_margin_loss", f16): {"reference_in_float": True},
("nn.functional.softmin", f16): {"atol": 1e-4, "rtol": 0.01},
("nn.functional.softsign", f16): {"reference_in_float": True},
("nn.functional.tanhshrink", f16): {"atol": 3e-4, "rtol": 0.001},
("outer", f16): {"reference_in_float": True},
("round.decimals_3", f16): {"reference_in_float": True},
("nn.functional.triplet_margin_loss", f16): {"atol": 1e-4, "rtol": 0.02},
("nn.functional.triplet_margin_with_distance_loss", f16): {
"atol": 1e-4,
"rtol": 0.02,
},
("sinc", f16): {"atol": 0.008, "rtol": 0.002},
("torch.ops.aten._safe_softmax.default", f16): {"atol": 5e-4, "rtol": 0.02},
("softmax", f16): {"atol": 1e-4, "rtol": 0.02},
("_softmax_backward_data", f16): {"atol": 0.008, "rtol": 0.002},
("special.log_ndtr", f64): {"atol": 1e-6, "rtol": 1e-5},
("std_mean.unbiased", f16): {"reference_in_float": True},
"uniform": {"reference_in_float": True},
("_unsafe_masked_index_put_accumulate", f16): {"atol": 1e-4, "rtol": 0.01},
# High atol due to precision loss
("nn.functional.interpolate.bilinear", f64): {"atol": 5e-4, "rtol": 0},
("nn.functional.upsample_bilinear", f64): {"atol": 5e-4, "rtol": 0},
("nn.functional.interpolate.bicubic", f64): {"atol": 1e-3, "rtol": 0},
# Unreasonably high atol requirement:
("index_reduce.mean", f16): {"check_gradient": False},
("index_reduce.mean", f32): {"check_gradient": False},
("index_reduce.mean", f64): {"check_gradient": False},
# Gradient contains non-finite entries:
("index_reduce.amin", f64): {"check_gradient": False},
("index_reduce.amin", f32): {"check_gradient": False},
("index_reduce.amin", f16): {"check_gradient": False},
("index_reduce.amax", f64): {"check_gradient": False},
("index_reduce.amax", f32): {"check_gradient": False},
("index_reduce.amax", f16): {"check_gradient": False},
("tanh", f16): {"atol": 1e-4, "rtol": 1e-2},
}
inductor_override_kwargs["xpu"] = {
# the return value of empty is undefined
"empty": {"assert_equal": False},
"empty_permuted": {"assert_equal": False},
"empty_like": {"assert_equal": False},
"new_empty": {"assert_equal": False},
"empty_strided": {"assert_equal": False},
"new_empty_strided": {"assert_equal": False},
"randn": {"assert_equal": False},
# XPU
("cross", f16): {"reference_in_float": True},
("addr", f16): {"reference_in_float": True},
("baddbmm", f16): {"atol": 2e-3, "rtol": 0.002}, # decomp affects accuracy
("angle", f64): {"reference_in_float": True},
("asin", f16): {"reference_in_float": True},
("atanh", f16): {"reference_in_float": True},
"cauchy": {"reference_in_float": True},
("cummax", f16): {"atol": 5e-4, "rtol": 0.002},
("cumsum", f16): {"reference_in_float": True},
"cumprod": {"reference_in_float": True, "atol": 7e-5, "rtol": 0.002},
("dot", f16): {"atol": 1e-5, "rtol": 0.002},
"logcumsumexp": {"grad_atol": 8e-4, "grad_rtol": 0.001},
"exponential": {"reference_in_float": True},
"geometric": {"reference_in_float": True},
("kron", f16): {"reference_in_float": True},
("linalg.cross", f16): {"reference_in_float": True},
("linalg.vecdot", f16): {"atol": 1e-5, "rtol": 2e-2},
"log_normal": {"reference_in_float": True},
("logsumexp", f16): {"atol": 1e-5, "rtol": 1e-2},
("masked.cumprod", f16): {"atol": 1e-5, "rtol": 5e-2},
("masked.cumsum", f16): {"atol": 1e-5, "rtol": 5e-3},
("masked.softmin", f16): {"atol": 1e-4, "rtol": 0.01},
("masked.softmax", f16): {"atol": 2e-4, "rtol": 0.01},
("masked.var", f16): {"atol": 2e-5, "rtol": 5e-3},
("native_batch_norm", f64): {"atol": 1e-7, "rtol": 1e-5},
("_native_batch_norm_legit", f64): {"atol": 1e-7, "rtol": 5e-6},
("_batch_norm_with_update", f64): {"atol": 1e-7, "rtol": 1e-6},
("native_layer_norm", f16): {"atol": 5e-3, "rtol": 5e-3},
("native_layer_norm", f32): {"atol": 5e-3, "rtol": 5e-3},
("nn.functional.batch_norm", f16): {"reference_in_float": True},
("nn.functional.batch_norm", f64): {"atol": 1e-6, "rtol": 1e-6},
("nn.functional.batch_norm.without_cudnn", f16): {"reference_in_float": True},
("nn.functional.conv1d", f16): {"atol": 1e-5, "rtol": 6e-3},
("nn.functional.conv3d", f16): {"atol": 1e-5, "rtol": 2e-3},
("nn.functional.conv_transpose2d", f16): {"atol": 1e-5, "rtol": 2e-3},
("nn.functional.conv_transpose3d", f16): {"atol": 1e-5, "rtol": 5e-3},
("nn.functional.cosine_embedding_loss", f16): {"atol": 1e-5, "rtol": 2e-3},
("nn.functional.cosine_similarity", f16): {
"reference_in_float": True,
"atol": 1e-5,
"rtol": 5e-3,
},
("nn.functional.instance_norm", f16): {"reference_in_float": True},
("nn.functional.instance_norm", f64): {"atol": 1e-6, "rtol": 1e-6},
("nn.functional.layer_norm", f16): {"atol": 5e-3, "rtol": 2e-3},
("nn.functional.layer_norm", f32): {"atol": 5e-5, "rtol": 2e-3},
("nn.functional.local_response_norm", f16): {"reference_in_float": True},
("nn.functional.multilabel_soft_margin_loss", f16): {
"atol": 3e-4,
"rtol": 2e-3,
},
("nn.functional.normalize", f16): {"atol": 1e-3, "rtol": 0.05},
("nn.functional.rms_norm", f16): {"reference_in_float": True},
("nn.functional.soft_margin_loss", f16): {"reference_in_float": True},
("nn.functional.softmin", f16): {"atol": 1e-4, "rtol": 0.01},
("nn.functional.softsign", f16): {
"reference_in_float": True,
"atol": 1e-5,
"rtol": 0.005,
},
("nn.functional.tanhshrink", f16): {"atol": 3e-4, "rtol": 0.001},
("outer", f16): {"reference_in_float": True},
("round.decimals_3", f16): {"reference_in_float": True},
("nn.functional.triplet_margin_loss", f16): {"atol": 1e-4, "rtol": 0.02},
("nn.functional.triplet_margin_with_distance_loss", f16): {
"atol": 1e-4,
"rtol": 0.02,
},
("remainder", f16): {"atol": 1e-4, "rtol": 0.005},
("nn.functional.upsample_bilinear", f16): {"atol": 1e-5, "rtol": 0.002},
("sinc", f16): {"atol": 0.008, "rtol": 0.002},
("softmax", f16): {"atol": 1e-4, "rtol": 0.02},
("_softmax_backward_data", f16): {"atol": 0.008, "rtol": 0.002},
("special.log_ndtr", f64): {"atol": 1e-6, "rtol": 1e-5},
("std_mean.unbiased", f16): {
"reference_in_float": True,
"atol": 5e-5,
"rtol": 5e-3,
},
("trapezoid", f16): {"atol": 1e-5, "rtol": 5e-3},
("trapz", f16): {"atol": 1e-5, "rtol": 5e-3},
"uniform": {"reference_in_float": True},
("var_mean", f16): {"atol": 1e-5, "rtol": 2e-3},
("var_mean.unbiased", f16): {"atol": 1e-5, "rtol": 2e-3},
("vdot", f16): {"atol": 1e-5, "rtol": 2e-3},
# Following tests are failing with strict comparision but atol=1 is acceptable due roundings errors
# High atol due to precision loss
("nn.functional.interpolate.bilinear", f64): {"atol": 5e-4, "rtol": 0},
("nn.functional.upsample_bilinear", f64): {"atol": 5e-4, "rtol": 0},
("nn.functional.interpolate.bicubic", f64): {"atol": 1e-3, "rtol": 0},
# Unreasonably high atol requirement:
("index_reduce.mean", f16): {"check_gradient": False},
("index_reduce.mean", f32): {"check_gradient": False},
("index_reduce.mean", f64): {"check_gradient": False},
# Gradient contains non-finite entries:
("index_reduce.amin", f64): {"check_gradient": False},
("index_reduce.amin", f32): {"check_gradient": False},
("index_reduce.amin", f16): {"check_gradient": False},
("index_reduce.amax", f64): {"check_gradient": False},
("index_reduce.amax", f32): {"check_gradient": False},
("index_reduce.amax", f16): {"check_gradient": False},
("tanh", f16): {"atol": 1e-4, "rtol": 1e-2},
("nn.functional.embedding_bag", f16): {"check_gradient": False},
("nn.functional.embedding_bag", f32): {"check_gradient": False},
("nn.functional.embedding_bag", f64): {"check_gradient": False},
("_unsafe_masked_index", f16): {"atol": 1e-5, "rtol": 2e-3},
("_unsafe_masked_index_put_accumulate", f16): {"atol": 1e-5, "rtol": 5e-3},
}
# Test with one sample only for following ops
inductor_one_sample = {
inductor_one_sample = defaultdict(dict)
inductor_one_sample["cpu"] = {
"_segment_reduce.lengths": {f16},
"_segment_reduce.offsets": {f16},
"addmv": {f16},
@ -456,89 +696,244 @@ inductor_one_sample = {
"normal": {f16, f32, f64},
"put": {f16, f32, f64},
"take": {b8, f16, f32, f64, i32, i64},
("__rdiv__", "cuda"): {f16},
("__rmod__", "cuda"): {f16, i64},
("__rmul__", "cuda"): {f16},
("__rpow__", "cuda"): {f16},
("_unsafe_masked_index", "cuda"): {f16},
("_unsafe_masked_index_put_accumulate", "cuda"): {f16},
("addcdiv", "cuda"): {f16},
("addcmul", "cuda"): {f16},
("atan2", "cuda"): {f16},
("cumsum", "cuda"): {f16},
("cumulative_trapezoid", "cuda"): {f16},
("dist", "cuda"): {f16},
("div.no_rounding_mode", "cuda"): {f16},
("fmod", "cuda"): {f16},
("grid_sampler_2d", "cuda"): {f16},
("index_fill", "cuda"): {f16, f32, f64},
("ldexp", "cuda"): {f16},
("lerp", "cuda"): {f16},
("linalg.householder_product", "cuda"): {f32},
("linalg.matrix_norm", "cuda"): {f16},
("linalg.vector_norm", "cuda"): {f16},
("logspace", "cuda"): {i32, i64},
("masked.cumsum", "cuda"): {f16},
("masked.logsumexp", "cuda"): {f16},
("masked.mean", "cuda"): {b8},
("masked.normalize", "cuda"): {f16},
("masked.prod", "cuda"): {f16},
("masked.std", "cuda"): {f16},
("masked.var", "cuda"): {f16},
("mul", "cuda"): {f16},
("nn.functional.alpha_dropout", "cuda"): {f16, f32, f64},
("nn.functional.avg_pool1d", "cuda"): {f16, f32, f64},
("nn.functional.avg_pool2d", "cuda"): {f16, f32, f64},
("nn.functional.avg_pool3d", "cuda"): {f16, f32, f64},
("nn.functional.binary_cross_entropy", "cuda"): {f16},
("nn.functional.binary_cross_entropy_with_logits", "cuda"): {f16},
("nn.functional.conv2d", "cuda"): {f16},
("nn.functional.cosine_embedding_loss", "cuda"): {f16},
("nn.functional.dropout2d", "cuda"): {f16, f32, f64},
("nn.functional.dropout3d", "cuda"): {f16, f32, f64},
("nn.functional.dropout", "cuda"): {f16, f32, f64},
("nn.functional.feature_alpha_dropout.with_train", "cuda"): {f16, f32, f64},
("nn.functional.fractional_max_pool2d", "cuda"): {f16, f32, f64},
("nn.functional.fractional_max_pool3d", "cuda"): {f16, f32, f64},
("nn.functional.grid_sample", "cuda"): {f16},
("nn.functional.group_norm", "cuda"): {f16},
("nn.functional.hinge_embedding_loss", "cuda"): {f16},
}
inductor_one_sample["cuda"] = {
"_segment_reduce.lengths": {f16},
"_segment_reduce.offsets": {f16},
"addmv": {f16},
"as_strided.partial_views": {f16},
"corrcoef": {f16},
"diff": {f16},
"einsum": {f16, i32},
"gradient": {f16},
"histogram": {f32, f64},
"histogramdd": {f32, f64},
"index_put": {f16, f32, f64},
"linalg.eig": {f32, f64},
"linspace": {f16, i32, i64},
"linspace.tensor_overload": {f16, f32, f64, i32, i64},
"logspace": {f16, i32, i64},
"logspace.tensor_overload": {f16, f32, f64, i32, i64},
"masked_logsumexp": {i64},
"max_pool2d_with_indices_backward": {f16, f32, f64},
"new_empty_strided": {f16},
"nn.functional.adaptive_avg_pool3d": {f16},
"nn.functional.adaptive_max_pool1d": {f16, f32},
"nn.functional.adaptive_max_pool2d": {f16, f32},
"nn.functional.bilinear": {f16},
"nn.functional.conv_transpose1d": {f16},
"nn.functional.conv_transpose2d": {f16},
"nn.functional.conv_transpose3d": {f16},
"nn.functional.cosine_similarity": {f16},
"nn.functional.cross_entropy": {f16, f32, f64},
"nn.functional.gaussian_nll_loss": {f16},
"nn.functional.grid_sample": {f16, f32, f64},
"nn.functional.interpolate.area": {f16},
"nn.functional.nll_loss": {f16, f32, f64},
"normal": {f16, f32, f64},
"put": {f16, f32, f64},
"take": {b8, f16, f32, f64, i32, i64},
"__rdiv__": {f16},
"__rmod__": {f16, i64},
"__rmul__": {f16},
"__rpow__": {f16},
"_unsafe_masked_index": {f16},
"_unsafe_masked_index_put_accumulate": {f16},
"addcdiv": {f16},
"addcmul": {f16},
"atan2": {f16},
"cumsum": {f16},
"cumulative_trapezoid": {f16},
"dist": {f16},
"div.no_rounding_mode": {f16},
"fmod": {f16},
"grid_sampler_2d": {f16},
"index_fill": {f16, f32, f64},
"ldexp": {f16},
"lerp": {f16},
"linalg.householder_product": {f32},
"linalg.matrix_norm": {f16},
"linalg.vector_norm": {f16},
"masked.cumsum": {f16},
"masked.logsumexp": {f16},
"masked.mean": {b8},
"masked.normalize": {f16},
"masked.prod": {f16},
"masked.std": {f16},
"masked.var": {f16},
"mul": {f16},
"nn.functional.alpha_dropout": {f16, f32, f64},
"nn.functional.avg_pool1d": {f16, f32, f64},
"nn.functional.avg_pool2d": {f16, f32, f64},
"nn.functional.avg_pool3d": {f16, f32, f64},
"nn.functional.binary_cross_entropy": {f16},
"nn.functional.binary_cross_entropy_with_logits": {f16},
"nn.functional.conv2d": {f16},
"nn.functional.cosine_embedding_loss": {f16},
"nn.functional.dropout2d": {f16, f32, f64},
"nn.functional.dropout3d": {f16, f32, f64},
"nn.functional.dropout": {f16, f32, f64},
"nn.functional.feature_alpha_dropout.with_train": {f16, f32, f64},
"nn.functional.fractional_max_pool2d": {f16, f32, f64},
"nn.functional.fractional_max_pool3d": {f16, f32, f64},
"nn.functional.group_norm": {f16},
"nn.functional.hinge_embedding_loss": {f16},
# Enabling all tests for this test fails randomly
# See https://github.com/pytorch/pytorch/issues/129238
("nn.functional.huber_loss", "cuda"): {f16},
("nn.functional.interpolate.bicubic", "cuda"): {f16},
("nn.functional.interpolate.bilinear", "cuda"): {f16},
("nn.functional.interpolate.trilinear", "cuda"): {f16},
("nn.functional.kl_div", "cuda"): {f16},
("nn.functional.margin_ranking_loss", "cuda"): {f16},
("nn.functional.max_pool1d", "cuda"): {f16, f32, f64},
("nn.functional.max_pool3d", "cuda"): {f16},
("nn.functional.mse_loss", "cuda"): {f16},
("nn.functional.multi_margin_loss", "cuda"): {f16},
("nn.functional.multilabel_margin_loss", "cuda"): {f16},
("nn.functional.multilabel_soft_margin_loss", "cuda"): {f16},
("nn.functional.normalize", "cuda"): {f16},
("nn.functional.pad.replicate", "cuda"): {f16, f32, f64},
("nn.functional.pad.reflect", "cuda"): {f16},
("nn.functional.pairwise_distance", "cuda"): {f16},
("nn.functional.poisson_nll_loss", "cuda"): {f16},
("nn.functional.rms_norm", "cuda"): {f16},
("norm", "cuda"): {f16},
("pow", "cuda"): {f16},
("prod", "cuda"): {f16},
("scatter_reduce.amax", "cuda"): {f16, f32, f64},
("scatter_reduce.amin", "cuda"): {f16, f32, f64},
("scatter_reduce.mean", "cuda"): {f16, f32, f64},
("special.xlog1py", "cuda"): {f16},
("std", "cuda"): {f16},
("std_mean", "cuda"): {f16},
("svd_lowrank", "cuda"): {f32, f64},
("trapezoid", "cuda"): {f16},
("trapz", "cuda"): {f16},
("true_divide", "cuda"): {f16},
("var", "cuda"): {f16},
("var_mean", "cuda"): {f16},
("xlogy", "cuda"): {f16},
"nn.functional.huber_loss": {f16},
"nn.functional.interpolate.bicubic": {f16},
"nn.functional.interpolate.bilinear": {f16},
"nn.functional.interpolate.trilinear": {f16},
"nn.functional.kl_div": {f16},
"nn.functional.margin_ranking_loss": {f16},
"nn.functional.max_pool1d": {f16, f32, f64},
"nn.functional.max_pool3d": {f16},
"nn.functional.mse_loss": {f16},
"nn.functional.multi_margin_loss": {f16},
"nn.functional.multilabel_margin_loss": {f16},
"nn.functional.multilabel_soft_margin_loss": {f16},
"nn.functional.normalize": {f16},
"nn.functional.pad.replicate": {f16, f32, f64},
"nn.functional.pad.reflect": {f16},
"nn.functional.pairwise_distance": {f16},
"nn.functional.poisson_nll_loss": {f16},
"nn.functional.rms_norm": {f16},
"norm": {f16},
"pow": {f16},
"prod": {f16},
"scatter_reduce.amax": {f16, f32, f64},
"scatter_reduce.amin": {f16, f32, f64},
"scatter_reduce.mean": {f16, f32, f64},
"special.xlog1py": {f16},
"std": {f16},
"std_mean": {f16},
"svd_lowrank": {f32, f64},
"trapezoid": {f16},
"trapz": {f16},
"true_divide": {f16},
"var": {f16},
"var_mean": {f16},
"xlogy": {f16},
}
inductor_one_sample["xpu"] = {
"_segment_reduce.lengths": {f16},
"_segment_reduce.offsets": {f16},
"addmv": {f16},
"as_strided.partial_views": {f16},
"corrcoef": {f16},
"diff": {f16},
"einsum": {f16, i32},
"gradient": {f16},
"histogram": {f32, f64},
"histogramdd": {f32, f64},
"index_put": {f16, f32, f64},
"linalg.eig": {f32, f64},
"linspace": {f16, i32, i64},
"linspace.tensor_overload": {f16, f32, f64, i32, i64},
"logspace": {f16, i32, i64},
"logspace.tensor_overload": {f16, f32, f64, i32, i64},
"masked_logsumexp": {i64},
"max_pool2d_with_indices_backward": {f16, f32, f64},
"new_empty_strided": {f16},
"nn.functional.adaptive_avg_pool3d": {f16},
"nn.functional.adaptive_max_pool1d": {f16, f32},
"nn.functional.adaptive_max_pool2d": {f16, f32},
"nn.functional.bilinear": {f16},
"nn.functional.conv_transpose1d": {f16},
"nn.functional.conv_transpose2d": {f16},
"nn.functional.conv_transpose3d": {f16},
"nn.functional.cosine_similarity": {f16},
"nn.functional.cross_entropy": {f16, f32, f64},
"nn.functional.gaussian_nll_loss": {f16},
"nn.functional.grid_sample": {f16, f32, f64},
"nn.functional.interpolate.area": {f16},
"nn.functional.nll_loss": {f16, f32, f64},
"normal": {f16, f32, f64},
"put": {f16, f32, f64},
"take": {b8, f16, f32, f64, i32, i64},
"__rdiv__": {f16},
"__rmod__": {f16, i64},
"__rmul__": {f16},
"__rpow__": {f16},
"_unsafe_masked_index": {f16},
"_unsafe_masked_index_put_accumulate": {f16},
"addcdiv": {f16},
"addcmul": {f16},
"atan2": {f16},
"cumsum": {f16},
"cumulative_trapezoid": {f16},
"dist": {f16},
"div.no_rounding_mode": {f16},
"fmod": {f16},
"grid_sampler_2d": {f16},
"index_fill": {f16, f32, f64},
"ldexp": {f16},
"lerp": {f16},
"linalg.householder_product": {f32},
"linalg.matrix_norm": {f16},
"linalg.vector_norm": {f16},
"masked.cumsum": {f16},
"masked.logsumexp": {f16},
"masked.mean": {b8},
"masked.normalize": {f16},
"masked.prod": {f16},
"masked.std": {f16},
"masked.var": {f16},
"mul": {f16},
"nn.functional.alpha_dropout": {f16, f32, f64},
"nn.functional.avg_pool1d": {f16, f32, f64},
"nn.functional.avg_pool2d": {f16, f32, f64},
"nn.functional.avg_pool3d": {f16, f32, f64},
"nn.functional.binary_cross_entropy": {f16},
"nn.functional.binary_cross_entropy_with_logits": {f16},
"nn.functional.conv2d": {f16},
"nn.functional.cosine_embedding_loss": {f16},
"nn.functional.dropout2d": {f16, f32, f64},
"nn.functional.dropout3d": {f16, f32, f64},
"nn.functional.dropout": {f16, f32, f64},
"nn.functional.feature_alpha_dropout.with_train": {f16, f32, f64},
"nn.functional.fractional_max_pool2d": {f16, f32, f64},
"nn.functional.fractional_max_pool3d": {f16, f32, f64},
"nn.functional.group_norm": {f16},
"nn.functional.hinge_embedding_loss": {f16},
# Enabling all tests for this test fails randomly
# See https://github.com/pytorch/pytorch/issues/129238
"nn.functional.huber_loss": {f16},
"nn.functional.interpolate.bicubic": {f16},
"nn.functional.interpolate.bilinear": {f16},
"nn.functional.interpolate.trilinear": {f16},
"nn.functional.kl_div": {f16},
"nn.functional.margin_ranking_loss": {f16},
"nn.functional.max_pool1d": {f16, f32, f64},
"nn.functional.max_pool3d": {f16},
"nn.functional.mse_loss": {f16},
"nn.functional.multi_margin_loss": {f16},
"nn.functional.multilabel_margin_loss": {f16},
"nn.functional.multilabel_soft_margin_loss": {f16},
"nn.functional.normalize": {f16},
"nn.functional.pad.replicate": {f16, f32, f64},
"nn.functional.pad.reflect": {f16},
"nn.functional.pairwise_distance": {f16},
"nn.functional.poisson_nll_loss": {f16},
"nn.functional.rms_norm": {f16},
"norm": {f16},
"pow": {f16},
"prod": {f16},
"scatter_reduce.amax": {f16, f32, f64},
"scatter_reduce.amin": {f16, f32, f64},
"scatter_reduce.mean": {f16, f32, f64},
"special.xlog1py": {f16},
"std": {f16},
"std_mean": {f16},
"svd_lowrank": {f32, f64},
"trapezoid": {f16},
"trapz": {f16},
"true_divide": {f16},
"var": {f16},
"var_mean": {f16},
"xlogy": {f16},
}
@ -572,6 +967,7 @@ class TestInductorOpInfo(TestCase):
True
) # inductor kernels failing this test intermittently
@skipCUDAIf(not HAS_CUDA, "Skipped! Triton not found")
@skipXPUIf(not HAS_XPU, "Skipped! Supported XPU compiler not found")
@skipCPUIf(not HAS_CPU, "Skipped! Supported CPU compiler not found")
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@skipIfTorchDynamo("Test uses dynamo already")
@ -593,6 +989,8 @@ class TestInductorOpInfo(TestCase):
# TODO: should we move empty_cache to the common device interface
if device_type == "cuda":
torch.cuda.empty_cache()
elif device == "xpu":
torch.xpu.empty_cache()
op_name = op.name
if op.variant_test_name:
op_name += f".{op.variant_test_name}"
@ -630,12 +1028,12 @@ class TestInductorOpInfo(TestCase):
test_expect = ExpectedTestResult.SUCCESS
overridden_kwargs = {}
if op_name in inductor_override_kwargs:
overridden_kwargs = inductor_override_kwargs[op_name]
elif (op_name, device_type) in inductor_override_kwargs:
overridden_kwargs = inductor_override_kwargs[(op_name, device_type)]
elif (op_name, device_type, dtype) in inductor_override_kwargs:
overridden_kwargs = inductor_override_kwargs[(op_name, device_type, dtype)]
overridden_kwargs.update(
inductor_override_kwargs.get(device_type, {}).get(op_name, {})
)
overridden_kwargs.update(
inductor_override_kwargs.get(device_type, {}).get((op_name, dtype), {})
)
func = op.get_op()
def fn(*args, **kwargs):
@ -653,8 +1051,7 @@ class TestInductorOpInfo(TestCase):
samples = op.sample_inputs(device, dtype, requires_grad=requires_grad)
if (
dtype in inductor_one_sample.get(op_name, {})
or dtype in inductor_one_sample.get((op_name, device_type), {})
dtype in inductor_one_sample.get(device_type, {}).get(op_name, {})
) and not ALL_SAMPLES:
if isinstance(samples, (list, tuple)):
samples = [samples[0]]
@ -796,7 +1193,7 @@ class TestInductorOpInfo(TestCase):
# print(f"SUCCEEDED OP {op_name} on {device_type} with {dtype}", flush=True, file=f)
instantiate_device_type_tests(TestInductorOpInfo, globals())
instantiate_device_type_tests(TestInductorOpInfo, globals(), allow_xpu=True)
if __name__ == "__main__":
run_tests()