support ConvBinaryInplace in Inductor cpp wrapper (#101394)

This PR has changed the OP schema since `at::Tensor&` should be the FirstArg:
87f9160b67/aten/src/ATen/core/boxing/impl/boxing.h (L305-L341)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101394
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/desertfire
This commit is contained in:
chunyuan
2023-05-31 13:50:37 +00:00
committed by PyTorch MergeBot
parent cdfba6fca7
commit 0d2e7a1888
7 changed files with 69 additions and 19 deletions

View File

@ -545,8 +545,8 @@ Tensor mkldnn_convolution_pointwise_binary(
// op, such as "hardtanh" has scalar parameters "gelu" has algorithm parameters.
Tensor& mkldnn_convolution_pointwise_binary_(
const Tensor& input_t,
Tensor& other_t,
const Tensor& input_t,
const Tensor& weight_t,
const c10::optional<Tensor>& bias_opt,
IntArrayRef padding,

View File

@ -50,7 +50,7 @@ TORCH_LIBRARY(mkldnn, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"mkldnn::_convolution_pointwise.binary(Tensor X, Tensor other, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor Y"));
m.def(TORCH_SELECTIVE_SCHEMA(
"mkldnn::_convolution_pointwise_.binary(Tensor X, Tensor(a!) other, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor(a!) Y"));
"mkldnn::_convolution_pointwise_.binary(Tensor(a!) other, Tensor X, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor(a!) Y"));
m.def(TORCH_SELECTIVE_SCHEMA(
"mkldnn::_convolution_transpose_pointwise(Tensor X, Tensor W, Tensor? B, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, str attr, Scalar?[] scalars, str? algorithm) -> Tensor Y"));
m.def(TORCH_SELECTIVE_SCHEMA(

View File

@ -362,6 +362,7 @@ ALLOW_LIST = [
("aten::_nested_view_from_buffer_copy", datetime.date(2023, 5, 1)),
("aten::_nested_view_from_buffer", datetime.date(2023, 5, 1)),
("aten::_scaled_dot_product_flash_attention_backward", datetime.date(2023, 6, 1)),
("mkldnn::_convolution_pointwise_.binary", datetime.date(2023, 7, 1)),
# These ops were moved to python under the c10d_functional namespace
("aten::wait_tensor", datetime.date(9999, 1, 30)),
("aten::reduce_scatter_tensor", datetime.date(9999, 1, 30)),

View File

@ -72,6 +72,9 @@ test_failures_cpp_wrapper = {
"test_conv2d_binary_inplace_fusion_failed_cpu_dynamic_shapes": test_torchinductor.TestFailure(
("cpp_wrapper",), is_skip=True
),
"test_conv2d_binary_inplace_fusion_pass_cpu_dynamic_shapes": test_torchinductor.TestFailure(
("cpp_wrapper",), is_skip=True
),
}
@ -129,6 +132,16 @@ if RUN_CPU:
["op_convolution_pointwise_binary_.call"],
],
),
BaseTest(
"test_conv2d_binary_inplace_fusion_pass",
"cpu",
test_mkldnn_pattern_matcher.TestPaternMatcher(),
condition=torch._C.has_mkldnn,
func_inputs=[
["op_convolution_pointwise_binary_.call"],
["op_convolution_pointwise_binary.call"],
],
),
BaseTest(
"test_conv2d_unary",
"cpu",

View File

@ -345,7 +345,9 @@ class TestPaternMatcher(TestCase):
v = torch.randn(1, 3, 28, 28)
self._test_common(mod, (v,), 0, 0)
def test_conv2d_binary_inplace_fusion_pass(self):
def test_conv2d_binary_inplace_fusion_pass_cpu(
self, include_ops=None, exclude_ops=None
):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
@ -362,8 +364,12 @@ class TestPaternMatcher(TestCase):
torch.randn(1, 32, 28, 28).to(memory_format=torch.channels_last),
]
mod = Model().to(memory_format=torch.channels_last).eval()
include_ops = ["mkldnn._convolution_pointwise_.binary"]
exclude_ops = ["mkldnn._convolution_pointwise.binary"]
if include_ops is None:
include_ops = ["mkldnn._convolution_pointwise_.binary"]
if exclude_ops is None:
exclude_ops = ["mkldnn._convolution_pointwise.binary"]
self._test_code_common(mod, inputs, include_ops, exclude_ops)
def test_conv2d_binary_inplace_fusion_failed_cpu(

View File

@ -295,7 +295,7 @@ class TestMkldnnFusion(JitTestCase):
# for binary add, we support inplace version.
if attr == "add":
fused_inplace = torch.ops.mkldnn._convolution_pointwise_(
x, other, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation,
other, x, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation,
mod.conv.groups, attr, None, unary_attr, [], None
)
self.assertEqual(ref, other)

View File

@ -3689,21 +3689,50 @@ class ConvolutionBinary(ExternKernelAlloc):
class ConvolutionBinaryInplace(ExternKernelAlloc):
kernel = "torch.ops.mkldnn._convolution_pointwise_.binary"
def __init__(
self,
kernel_layout,
inputs,
constant_args=(),
kernel="torch.ops.mkldnn._convolution_pointwise_.binary",
):
super().__init__(kernel_layout, inputs, constant_args)
self.kernel = kernel
# Due to constrain of op.call, other (Tensor&) should be at input[0]
reordered_inputs = [inputs[1], inputs[0]] + inputs[2:]
super().__init__(
kernel_layout,
reordered_inputs,
constant_args,
None,
kernel="torch.ops.mkldnn._convolution_pointwise_.binary",
cpp_kernel="mkldnn::_convolution_pointwise_",
)
self.cpp_kernel_overlad_name = "binary"
self.cpp_kernel_key = "convolution_pointwise_binary_"
# TODO: op.call: input[0] should be at::Tensor&
self.cpp_op_schema = """
at::Tensor&(
at::Tensor& other_t,
const at::Tensor& input_t,
const at::Tensor& weight_t,
const c10::optional<at::Tensor>& bias_opt,
at::IntArrayRef padding,
at::IntArrayRef stride,
at::IntArrayRef dilation,
int64_t groups,
c10::string_view binary_attr,
c10::optional<at::Scalar> alpha,
c10::optional<c10::string_view> unary_attr,
torch::List<c10::optional<at::Scalar>> unary_scalars,
c10::optional<c10::string_view> unary_algorithm)"""
def codegen(self, wrapper):
wrapper.writeline(
f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
self.get_name(),
self.kernel,
self.codegen_args(),
self.cpp_op_schema,
self.cpp_kernel_key,
self.cpp_kernel_overlad_name,
)
def get_mutation_names(self):
@ -3727,7 +3756,6 @@ class ConvolutionBinaryInplace(ExternKernelAlloc):
unary_scalars: Optional[List[Any]],
unary_algorithm: Optional[str],
):
kernel = "torch.ops.mkldnn._convolution_pointwise_.binary"
(
inputs,
constant_args,
@ -3738,18 +3766,20 @@ class ConvolutionBinaryInplace(ExternKernelAlloc):
)
other = cls.require_stride_order(other, req_stride_order)
inputs.insert(1, other)
optional_scalar = OptionalScalar()
optional_string = OptionalString()
optional_list = OptionalList()
constant_args = constant_args + [
binary_attr,
binary_alpha,
unary_attr,
unary_scalars,
unary_algorithm,
may_convert_to_optional(optional_scalar, binary_alpha),
may_convert_to_optional(optional_string, unary_attr),
may_convert_to_optional(optional_list, unary_scalars),
may_convert_to_optional(optional_string, unary_algorithm),
]
return ConvolutionBinaryInplace(
kernel_layout=MutationLayout(inputs[1]),
inputs=inputs,
constant_args=constant_args,
kernel=kernel,
)