Compare commits

...

7 Commits

Author SHA1 Message Date
eebd57f4fd Update
[ghstack-poisoned]
2025-11-07 17:04:23 +00:00
280d77bd86 Update
[ghstack-poisoned]
2025-11-07 16:30:18 +00:00
54c90ae440 Update (base update)
[ghstack-poisoned]
2025-11-07 16:30:18 +00:00
53dc8a0875 Update
[ghstack-poisoned]
2025-10-10 17:00:28 +00:00
9e119dd8c4 Update
[ghstack-poisoned]
2025-09-29 17:16:55 +00:00
93a6e99edc Update (base update)
[ghstack-poisoned]
2025-09-26 07:34:19 +00:00
d0892c7792 Update
[ghstack-poisoned]
2025-09-26 07:34:19 +00:00
2 changed files with 73 additions and 5 deletions

View File

@ -856,6 +856,51 @@ class TestPatternMatcher(TestPatternMatcherBase):
)
torch.testing.assert_close(actual, expected, atol=1e-2, rtol=1e-2)
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
def test_linear_input_non_contiguous_3D(self, device="cpu"):
self.device = device
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear = torch.nn.Linear(4, 3, bias=bias)
def forward(self, x):
x = torch.reshape(torch.permute(x, (0, 2, 3, 1)), (2, 12, 4))
return self.linear(x)
dtypes = [torch.float]
if is_mkldnn_bf16_supported(self.device):
dtypes.append(torch.bfloat16)
if is_mkldnn_fp16_supported(self.device):
dtypes.append(torch.float16)
for dtype, bias in itertools.product(dtypes, [True, False]):
mod = M(bias).eval()
v = torch.randn(2, 4, 3, 4)
torch._dynamo.reset()
autocast_enabled = dtype in [torch.bfloat16, torch.float16]
with (
torch.no_grad(),
torch.autocast(
device_type="cpu",
enabled=autocast_enabled,
dtype=dtype,
),
):
expected = mod(v)
actual, (source_code,) = run_and_get_code(
torch.compile(mod, fullgraph=True),
v,
)
self.assertIn(
"torch.ops.mkldnn._linear_pointwise.default"
if autocast_enabled
else "torch.ops.mkl._mkl_linear.default",
source_code,
)
torch.testing.assert_close(actual, expected, atol=1e-2, rtol=1e-2)
@skipIfXpu(
msg="Different with CPU, two linears will be concat on XPU for better performance"
)

View File

@ -1251,11 +1251,14 @@ if torch._C._has_mkldnn:
and not torch._C.has_mkl
):
return False
if not is_lp_weight and linear_node.target == aten.bmm.default:
return False
for meta_value in [input_meta_value, weight_meta_value]:
if (
meta_value is None
or meta_value.device.type != "cpu"
or meta_value.dim() != 2
or meta_value.dim()
!= (3 if linear_node.target == aten.bmm.default else 2)
):
return False
if weight_idx == 2:
@ -1434,27 +1437,47 @@ if torch._C._has_mkldnn:
extra_check=_is_packable_linear,
pass_number=1,
)
@register_freezing_graph_pattern(
CallFunction(aten.bmm.default, Arg(), Arg()),
extra_check=_is_packable_linear,
pass_number=1,
)
def linear(match, *args, **kwargs):
graph = match.graph
linear_node = match.output_node()
input = args[0] if linear_node.target is aten.mm.default else args[1]
input = (
args[0]
if linear_node.target in [aten.mm.default, aten.bmm.default]
else args[1]
)
bias = (
None
if linear_node.target is aten.mm.default
if linear_node.target in [aten.mm.default, aten.bmm.default]
or (
linear_node.target is aten.addmm.default
and linear_node.kwargs.get("beta", 1.0) == 0.0
)
else args[0]
)
weight = args[1] if linear_node.target is aten.mm.default else args[2]
if linear_node.target is aten.mm.default:
weight = args[1]
weight_dtype = weight.meta.get("val").dtype
elif linear_node.target is aten.addmm.default:
weight = args[2]
weight_dtype = weight.meta.get("val").dtype
else:
assert linear_node.target is aten.bmm.default
wgt_expand_node = args[1]
weight = graph.create_node(
"call_function", aten.select.int, (wgt_expand_node, 0, 0)
)
weight_dtype = wgt_expand_node.meta.get("val").dtype
device_type = input.meta.get("val").device.type
mkldnn_device_op = _get_mkldnn_device_op(device_type)
with graph.inserting_before(linear_node):
transpose_weight_node = graph.create_node(
"call_function", aten.permute.default, (weight, (1, 0))
)
weight_dtype = weight.meta.get("val").dtype
is_lp_weight = weight_dtype in (
torch.bfloat16,
torch.float16,