mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
[AOTI][reland] Fix assert_function call in cpu autotune template (#135920)
Summary: Reland https://github.com/pytorch/pytorch/pull/135086. In the ABI-compatible mode, assert_function should be AOTI_TORCH_CHECK. Test Plan: CI Differential Revision: D62500592 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135920 Approved by: https://github.com/chenyang78
This commit is contained in:
committed by
PyTorch MergeBot
parent
2f53d570fe
commit
ea2ecab15b
@ -3270,7 +3270,9 @@ class AOTInductorTestsTemplate:
|
|||||||
model, example_inputs_list, dynamic_shapes=dynamic_shapes
|
model, example_inputs_list, dynamic_shapes=dynamic_shapes
|
||||||
)
|
)
|
||||||
|
|
||||||
@common_utils.parametrize("max_autotune", [False, True])
|
# max_autotune is disabled due to https://github.com/pytorch/pytorch/issues/135106
|
||||||
|
# @common_utils.parametrize("max_autotune", [False, True])
|
||||||
|
@common_utils.parametrize("max_autotune", [False])
|
||||||
def test_misc_1(self, max_autotune):
|
def test_misc_1(self, max_autotune):
|
||||||
if self.device == "cpu" and IS_MACOS and max_autotune:
|
if self.device == "cpu" and IS_MACOS and max_autotune:
|
||||||
raise unittest.SkipTest("max_autotune not supported on macos")
|
raise unittest.SkipTest("max_autotune not supported on macos")
|
||||||
|
@ -87,13 +87,7 @@ if TEST_WITH_ROCM:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
if config.abi_compatible:
|
if config.abi_compatible:
|
||||||
xfail_list = [
|
xfail_list = []
|
||||||
*[
|
|
||||||
func
|
|
||||||
for func in dir(test_cpu_select_algorithm.TestSelectAlgorithmCPU())
|
|
||||||
if func.startswith("test_linear_with_pointwise")
|
|
||||||
],
|
|
||||||
]
|
|
||||||
for test_name in xfail_list:
|
for test_name in xfail_list:
|
||||||
test_failures_cpp_wrapper[test_name] = test_torchinductor.TestFailure(
|
test_failures_cpp_wrapper[test_name] = test_torchinductor.TestFailure(
|
||||||
("cpp_wrapper",), is_skip=False
|
("cpp_wrapper",), is_skip=False
|
||||||
@ -103,6 +97,11 @@ if config.abi_compatible:
|
|||||||
] = test_torchinductor.TestFailure(("cpp_wrapper",), is_skip=False)
|
] = test_torchinductor.TestFailure(("cpp_wrapper",), is_skip=False)
|
||||||
skip_list = [
|
skip_list = [
|
||||||
"test_multihead_attention_cpu",
|
"test_multihead_attention_cpu",
|
||||||
|
*[
|
||||||
|
func
|
||||||
|
for func in dir(test_cpu_select_algorithm.TestSelectAlgorithmCPU())
|
||||||
|
if func.startswith("test_linear_with_pointwise")
|
||||||
|
],
|
||||||
]
|
]
|
||||||
for test_name in skip_list:
|
for test_name in skip_list:
|
||||||
test_failures_cpp_wrapper[test_name] = test_torchinductor.TestFailure(
|
test_failures_cpp_wrapper[test_name] = test_torchinductor.TestFailure(
|
||||||
|
@ -2146,9 +2146,7 @@ class CppKernel(Kernel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def assert_function(self) -> str:
|
def assert_function(self) -> str:
|
||||||
if V.graph.aot_mode:
|
if config.abi_compatible:
|
||||||
# TODO: Using AOTI_TORCH_CHECK is causing performance drop for some models
|
|
||||||
# compared with JIT Inductor which uses TORCH_CHECK
|
|
||||||
return "AOTI_TORCH_CHECK"
|
return "AOTI_TORCH_CHECK"
|
||||||
else:
|
else:
|
||||||
return "TORCH_CHECK"
|
return "TORCH_CHECK"
|
||||||
|
@ -332,8 +332,8 @@ class CppMicroGemmFP32Vec(CppMicroGemm):
|
|||||||
|
|
||||||
TEMPLATE_ENTRY = r"""
|
TEMPLATE_ENTRY = r"""
|
||||||
{{declare_kernel}} {
|
{{declare_kernel}} {
|
||||||
TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
|
{{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
|
||||||
TORCH_CHECK(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}");
|
{{kernel.assert_function}}(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}");
|
||||||
// TODO(jgong5): loop unroll for M and N
|
// TODO(jgong5): loop unroll for M and N
|
||||||
for (int64_t m = 0; m < M; m += {{block_m}}) {
|
for (int64_t m = 0; m < M; m += {{block_m}}) {
|
||||||
int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
|
int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
|
||||||
@ -364,7 +364,7 @@ class CppMicroGemmFP32Vec(CppMicroGemm):
|
|||||||
break;
|
break;
|
||||||
{%- endfor %}
|
{%- endfor %}
|
||||||
default:
|
default:
|
||||||
{{kernel.assert_function}}(false, "Unsupported block_m: ", block_m);
|
{{kernel.assert_function}}(false, "Unsupported block_m: {{block_m}}");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -509,8 +509,8 @@ class CppMicroGemmAMX(CppMicroGemm):
|
|||||||
|
|
||||||
TEMPLATE_ENTRY = r"""
|
TEMPLATE_ENTRY = r"""
|
||||||
{{declare_kernel}} {
|
{{declare_kernel}} {
|
||||||
TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
|
{{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
|
||||||
TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2");
|
{{kernel.assert_function}}(K % 2 == 0, "K dimension must be multiple of 2");
|
||||||
// TODO(jgong5): loop unroll for M and N
|
// TODO(jgong5): loop unroll for M and N
|
||||||
for (int64_t m = 0; m < M; m += {{block_m}}) {
|
for (int64_t m = 0; m < M; m += {{block_m}}) {
|
||||||
int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
|
int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
|
||||||
|
@ -26,6 +26,7 @@
|
|||||||
#include <c10/util/generic_math.h>
|
#include <c10/util/generic_math.h>
|
||||||
#include <c10/util/Half.h>
|
#include <c10/util/Half.h>
|
||||||
#include <c10/util/TypeCast.h>
|
#include <c10/util/TypeCast.h>
|
||||||
|
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||||
|
|
||||||
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX)
|
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX)
|
||||||
#define INDUCTOR_USE_VECTOR_TYPES() 1
|
#define INDUCTOR_USE_VECTOR_TYPES() 1
|
||||||
|
@ -111,11 +111,10 @@ class CppTemplate(KernelTemplate):
|
|||||||
def header(self) -> IndentedBuffer:
|
def header(self) -> IndentedBuffer:
|
||||||
res = IndentedBuffer()
|
res = IndentedBuffer()
|
||||||
res.writeline(codecache.cpp_prefix())
|
res.writeline(codecache.cpp_prefix())
|
||||||
res.splice(
|
# TODO: add c10::ForcedUnroll test to test_aoti_abi_check
|
||||||
"""
|
res.splice("""#include <c10/util/Unroll.h>""")
|
||||||
#include "c10/util/Unroll.h"
|
if config.abi_compatible:
|
||||||
"""
|
res.splice("""#include <torch/csrc/inductor/aoti_torch/c/shim.h>""")
|
||||||
)
|
|
||||||
enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [
|
enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [
|
||||||
"linux",
|
"linux",
|
||||||
"win32",
|
"win32",
|
||||||
|
@ -854,7 +854,6 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
def allocate_non_dup_const_name(
|
def allocate_non_dup_const_name(
|
||||||
self, name: Optional[str], data: Union[Tensor]
|
self, name: Optional[str], data: Union[Tensor]
|
||||||
) -> str:
|
) -> str:
|
||||||
orig_name = name
|
|
||||||
if not config.aot_inductor.use_runtime_constant_folding:
|
if not config.aot_inductor.use_runtime_constant_folding:
|
||||||
for constant_name, value in self.constants.items():
|
for constant_name, value in self.constants.items():
|
||||||
if (
|
if (
|
||||||
@ -871,7 +870,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
|
|
||||||
if name is None:
|
if name is None:
|
||||||
name = f"constant{len(self.constants)}"
|
name = f"constant{len(self.constants)}"
|
||||||
assert name is not None
|
orig_name = name
|
||||||
if name[0].isdigit():
|
if name[0].isdigit():
|
||||||
name = f"constant_{name}"
|
name = f"constant_{name}"
|
||||||
name = self.qualify_name(name)
|
name = self.qualify_name(name)
|
||||||
|
Reference in New Issue
Block a user