mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
support ConvBinaryInplace in Inductor cpp wrapper (#101394)
This PR has changed the OP schema since `at::Tensor&` should be the FirstArg:
87f9160b67/aten/src/ATen/core/boxing/impl/boxing.h (L305-L341)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101394
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
cdfba6fca7
commit
0d2e7a1888
@ -545,8 +545,8 @@ Tensor mkldnn_convolution_pointwise_binary(
|
||||
// op, such as "hardtanh" has scalar parameters "gelu" has algorithm parameters.
|
||||
|
||||
Tensor& mkldnn_convolution_pointwise_binary_(
|
||||
const Tensor& input_t,
|
||||
Tensor& other_t,
|
||||
const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
const c10::optional<Tensor>& bias_opt,
|
||||
IntArrayRef padding,
|
||||
|
@ -50,7 +50,7 @@ TORCH_LIBRARY(mkldnn, m) {
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"mkldnn::_convolution_pointwise.binary(Tensor X, Tensor other, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor Y"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"mkldnn::_convolution_pointwise_.binary(Tensor X, Tensor(a!) other, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor(a!) Y"));
|
||||
"mkldnn::_convolution_pointwise_.binary(Tensor(a!) other, Tensor X, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor(a!) Y"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
"mkldnn::_convolution_transpose_pointwise(Tensor X, Tensor W, Tensor? B, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, str attr, Scalar?[] scalars, str? algorithm) -> Tensor Y"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA(
|
||||
|
@ -362,6 +362,7 @@ ALLOW_LIST = [
|
||||
("aten::_nested_view_from_buffer_copy", datetime.date(2023, 5, 1)),
|
||||
("aten::_nested_view_from_buffer", datetime.date(2023, 5, 1)),
|
||||
("aten::_scaled_dot_product_flash_attention_backward", datetime.date(2023, 6, 1)),
|
||||
("mkldnn::_convolution_pointwise_.binary", datetime.date(2023, 7, 1)),
|
||||
# These ops were moved to python under the c10d_functional namespace
|
||||
("aten::wait_tensor", datetime.date(9999, 1, 30)),
|
||||
("aten::reduce_scatter_tensor", datetime.date(9999, 1, 30)),
|
||||
|
@ -72,6 +72,9 @@ test_failures_cpp_wrapper = {
|
||||
"test_conv2d_binary_inplace_fusion_failed_cpu_dynamic_shapes": test_torchinductor.TestFailure(
|
||||
("cpp_wrapper",), is_skip=True
|
||||
),
|
||||
"test_conv2d_binary_inplace_fusion_pass_cpu_dynamic_shapes": test_torchinductor.TestFailure(
|
||||
("cpp_wrapper",), is_skip=True
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@ -129,6 +132,16 @@ if RUN_CPU:
|
||||
["op_convolution_pointwise_binary_.call"],
|
||||
],
|
||||
),
|
||||
BaseTest(
|
||||
"test_conv2d_binary_inplace_fusion_pass",
|
||||
"cpu",
|
||||
test_mkldnn_pattern_matcher.TestPaternMatcher(),
|
||||
condition=torch._C.has_mkldnn,
|
||||
func_inputs=[
|
||||
["op_convolution_pointwise_binary_.call"],
|
||||
["op_convolution_pointwise_binary.call"],
|
||||
],
|
||||
),
|
||||
BaseTest(
|
||||
"test_conv2d_unary",
|
||||
"cpu",
|
||||
|
@ -345,7 +345,9 @@ class TestPaternMatcher(TestCase):
|
||||
v = torch.randn(1, 3, 28, 28)
|
||||
self._test_common(mod, (v,), 0, 0)
|
||||
|
||||
def test_conv2d_binary_inplace_fusion_pass(self):
|
||||
def test_conv2d_binary_inplace_fusion_pass_cpu(
|
||||
self, include_ops=None, exclude_ops=None
|
||||
):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -362,8 +364,12 @@ class TestPaternMatcher(TestCase):
|
||||
torch.randn(1, 32, 28, 28).to(memory_format=torch.channels_last),
|
||||
]
|
||||
mod = Model().to(memory_format=torch.channels_last).eval()
|
||||
include_ops = ["mkldnn._convolution_pointwise_.binary"]
|
||||
exclude_ops = ["mkldnn._convolution_pointwise.binary"]
|
||||
|
||||
if include_ops is None:
|
||||
include_ops = ["mkldnn._convolution_pointwise_.binary"]
|
||||
if exclude_ops is None:
|
||||
exclude_ops = ["mkldnn._convolution_pointwise.binary"]
|
||||
|
||||
self._test_code_common(mod, inputs, include_ops, exclude_ops)
|
||||
|
||||
def test_conv2d_binary_inplace_fusion_failed_cpu(
|
||||
|
@ -295,7 +295,7 @@ class TestMkldnnFusion(JitTestCase):
|
||||
# for binary add, we support inplace version.
|
||||
if attr == "add":
|
||||
fused_inplace = torch.ops.mkldnn._convolution_pointwise_(
|
||||
x, other, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation,
|
||||
other, x, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation,
|
||||
mod.conv.groups, attr, None, unary_attr, [], None
|
||||
)
|
||||
self.assertEqual(ref, other)
|
||||
|
@ -3689,21 +3689,50 @@ class ConvolutionBinary(ExternKernelAlloc):
|
||||
|
||||
|
||||
class ConvolutionBinaryInplace(ExternKernelAlloc):
|
||||
kernel = "torch.ops.mkldnn._convolution_pointwise_.binary"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_layout,
|
||||
inputs,
|
||||
constant_args=(),
|
||||
kernel="torch.ops.mkldnn._convolution_pointwise_.binary",
|
||||
):
|
||||
super().__init__(kernel_layout, inputs, constant_args)
|
||||
self.kernel = kernel
|
||||
# Due to constrain of op.call, other (Tensor&) should be at input[0]
|
||||
reordered_inputs = [inputs[1], inputs[0]] + inputs[2:]
|
||||
|
||||
super().__init__(
|
||||
kernel_layout,
|
||||
reordered_inputs,
|
||||
constant_args,
|
||||
None,
|
||||
kernel="torch.ops.mkldnn._convolution_pointwise_.binary",
|
||||
cpp_kernel="mkldnn::_convolution_pointwise_",
|
||||
)
|
||||
self.cpp_kernel_overlad_name = "binary"
|
||||
self.cpp_kernel_key = "convolution_pointwise_binary_"
|
||||
# TODO: op.call: input[0] should be at::Tensor&
|
||||
self.cpp_op_schema = """
|
||||
at::Tensor&(
|
||||
at::Tensor& other_t,
|
||||
const at::Tensor& input_t,
|
||||
const at::Tensor& weight_t,
|
||||
const c10::optional<at::Tensor>& bias_opt,
|
||||
at::IntArrayRef padding,
|
||||
at::IntArrayRef stride,
|
||||
at::IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
c10::string_view binary_attr,
|
||||
c10::optional<at::Scalar> alpha,
|
||||
c10::optional<c10::string_view> unary_attr,
|
||||
torch::List<c10::optional<at::Scalar>> unary_scalars,
|
||||
c10::optional<c10::string_view> unary_algorithm)"""
|
||||
|
||||
def codegen(self, wrapper):
|
||||
wrapper.writeline(
|
||||
f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
|
||||
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
|
||||
self.get_name(),
|
||||
self.kernel,
|
||||
self.codegen_args(),
|
||||
self.cpp_op_schema,
|
||||
self.cpp_kernel_key,
|
||||
self.cpp_kernel_overlad_name,
|
||||
)
|
||||
|
||||
def get_mutation_names(self):
|
||||
@ -3727,7 +3756,6 @@ class ConvolutionBinaryInplace(ExternKernelAlloc):
|
||||
unary_scalars: Optional[List[Any]],
|
||||
unary_algorithm: Optional[str],
|
||||
):
|
||||
kernel = "torch.ops.mkldnn._convolution_pointwise_.binary"
|
||||
(
|
||||
inputs,
|
||||
constant_args,
|
||||
@ -3738,18 +3766,20 @@ class ConvolutionBinaryInplace(ExternKernelAlloc):
|
||||
)
|
||||
other = cls.require_stride_order(other, req_stride_order)
|
||||
inputs.insert(1, other)
|
||||
optional_scalar = OptionalScalar()
|
||||
optional_string = OptionalString()
|
||||
optional_list = OptionalList()
|
||||
constant_args = constant_args + [
|
||||
binary_attr,
|
||||
binary_alpha,
|
||||
unary_attr,
|
||||
unary_scalars,
|
||||
unary_algorithm,
|
||||
may_convert_to_optional(optional_scalar, binary_alpha),
|
||||
may_convert_to_optional(optional_string, unary_attr),
|
||||
may_convert_to_optional(optional_list, unary_scalars),
|
||||
may_convert_to_optional(optional_string, unary_algorithm),
|
||||
]
|
||||
return ConvolutionBinaryInplace(
|
||||
kernel_layout=MutationLayout(inputs[1]),
|
||||
inputs=inputs,
|
||||
constant_args=constant_args,
|
||||
kernel=kernel,
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user