mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Re-open 90266] [inductor] weight prepack for _convolution_transpose_pointwise (#91955)
Re-open https://github.com/pytorch/pytorch/pull/90266 since earlier pr on that stack got reverted. Depend on internal ideep upgrade. [Update]: internal ideep upgrade issue is resolved in https://github.com/pytorch/pytorch/pull/92239. Pull Request resolved: https://github.com/pytorch/pytorch/pull/91955 Approved by: https://github.com/jgong5, https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
cc49f5abd3
commit
bd4a5b400a
@ -350,8 +350,8 @@ class TestMkldnnFusion(JitTestCase):
|
||||
for pointwise_name, pointwise_info in self._unary_list().items():
|
||||
for dim in [2]:
|
||||
channels_last = torch.channels_last if dim == 2 else torch.channels_last_3d
|
||||
options = itertools.product([True, False], [1, 2], [1, 4], [torch.contiguous_format, channels_last])
|
||||
for bias, dilation, groups, memory_format in options:
|
||||
options = itertools.product([True, False], [1, 2], [1, 4], [torch.contiguous_format, channels_last], [False, True])
|
||||
for bias, dilation, groups, memory_format, prepack_weight in options:
|
||||
oC = 32 * groups
|
||||
iC = 3 * groups
|
||||
x_shape = (1, iC) + input_shapes[dim]
|
||||
@ -363,6 +363,21 @@ class TestMkldnnFusion(JitTestCase):
|
||||
attr = pointwise_info.attr
|
||||
scalars = pointwise_info.scalars
|
||||
algorithm = pointwise_info.algorithm
|
||||
|
||||
if prepack_weight:
|
||||
packed_weight = torch.ops.mkldnn._reorder_convolution_transpose_weight(
|
||||
mod.conv_transpose.weight.to_mkldnn(),
|
||||
mod.conv_transpose.padding,
|
||||
mod.conv_transpose.output_padding,
|
||||
mod.conv_transpose.stride,
|
||||
mod.conv_transpose.dilation,
|
||||
mod.conv_transpose.groups,
|
||||
x.size())
|
||||
mod.conv_transpose.weight = torch.nn.Parameter(
|
||||
packed_weight,
|
||||
requires_grad=mod.conv_transpose.weight.requires_grad,
|
||||
)
|
||||
|
||||
fused = torch.ops.mkldnn._convolution_transpose_pointwise(
|
||||
x,
|
||||
mod.conv_transpose.weight,
|
||||
|
Reference in New Issue
Block a user