Compare commits

...

30 Commits

Author SHA1 Message Date
002b64bc31 Update
[ghstack-poisoned]
2025-08-12 16:13:04 +00:00
2de0642e3b Update
[ghstack-poisoned]
2025-08-12 10:56:44 +00:00
7023af4cb8 Update
[ghstack-poisoned]
2025-08-11 06:01:16 +00:00
ffdd74d656 Update (base update)
[ghstack-poisoned]
2025-08-11 06:01:16 +00:00
eb0fb88520 Update
[ghstack-poisoned]
2025-05-09 06:13:55 +00:00
40f9208084 Update (base update)
[ghstack-poisoned]
2025-05-09 06:13:55 +00:00
a401654df8 Update
[ghstack-poisoned]
2025-04-15 22:54:47 -07:00
b43bcf040c Update (base update)
[ghstack-poisoned]
2025-04-15 22:54:47 -07:00
25e10dbd87 Update
[ghstack-poisoned]
2025-04-07 18:55:56 -07:00
8204178267 Update (base update)
[ghstack-poisoned]
2025-04-07 18:55:56 -07:00
e310e1658e Update
[ghstack-poisoned]
2025-04-07 00:54:36 -07:00
7c9e4c7732 Update (base update)
[ghstack-poisoned]
2025-04-07 00:54:36 -07:00
0020cb0d70 Update
[ghstack-poisoned]
2025-04-01 02:59:47 -07:00
b39492179d Update
[ghstack-poisoned]
2025-03-30 05:00:38 +00:00
a842a41788 Update (base update)
[ghstack-poisoned]
2025-03-30 02:47:47 +00:00
9bd7a06242 Update
[ghstack-poisoned]
2025-03-30 02:47:47 +00:00
c0cb65034a Update (base update)
[ghstack-poisoned]
2025-03-30 02:04:03 +00:00
3d3adf27fe Update
[ghstack-poisoned]
2025-03-30 02:04:03 +00:00
81cc05ee09 Update (base update)
[ghstack-poisoned]
2025-03-18 19:28:27 -07:00
d44375a722 Update
[ghstack-poisoned]
2025-03-18 19:28:27 -07:00
0700d70fff Update (base update)
[ghstack-poisoned]
2025-03-18 02:13:50 +00:00
1694a1c816 Update
[ghstack-poisoned]
2025-03-18 02:13:50 +00:00
c2dacffa5e Update (base update)
[ghstack-poisoned]
2025-02-24 23:12:01 -08:00
aca33d6ef8 Update
[ghstack-poisoned]
2025-02-24 23:12:01 -08:00
c97db2b49a Update
[ghstack-poisoned]
2025-02-19 10:13:20 +00:00
726b0e2e3b Update
[ghstack-poisoned]
2025-02-19 06:12:38 +00:00
8378acfea4 Update
[ghstack-poisoned]
2025-02-18 02:46:29 -08:00
d921f396e8 Update
[ghstack-poisoned]
2025-02-18 01:01:15 -08:00
eacc1d96db Update (base update)
[ghstack-poisoned]
2025-02-18 00:07:35 -08:00
ea0a35d4ce Update
[ghstack-poisoned]
2025-02-18 00:07:35 -08:00
11 changed files with 73 additions and 93 deletions

View File

@ -209,14 +209,6 @@ Tensor mkldnn_linear_pointwise(
std::string_view attr,
c10::List<std::optional<at::Scalar>> scalars,
std::optional<std::string_view> algorithm) {
auto aprop_kind = ideep::prop_kind::forward;
bool maybe_backward = GradMode::is_enabled() &&
(input_t.requires_grad() || weight_t.requires_grad() ||
(bias_opt.has_value() && bias_opt->defined() &&
bias_opt->requires_grad()));
if (!maybe_backward) {
aprop_kind = ideep::prop_kind::forward_inference;
}
auto input = input_t.contiguous();
auto input_size = input.sizes();
@ -228,14 +220,14 @@ Tensor mkldnn_linear_pointwise(
dim == 2 ? input : input.reshape({-1, input.size(input.dim() - 1)});
std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
output_size.push_back(weight_t.size(0));
output_size.push_back(weight_t.size(1));
auto output = at::empty(output_size, input.options());
if (output.sym_numel() == 0) {
return output;
}
if (dim != 2) {
std::vector<int64_t> output_size_reshaped = {input_reshaped.size(0),
weight_t.size(0)};
weight_t.size(1)};
output = output.reshape(output_size_reshaped);
}
@ -250,7 +242,7 @@ Tensor mkldnn_linear_pointwise(
std::optional<ideep::tensor> mkldnn_bias{std::nullopt};
if (bias.defined()) {
mkldnn_bias = itensor_from_tensor(bias);
mkldnn_bias = itensor_from_tensor(bias.reshape({1, weight_t.size(1)}));
}
const ideep::tensor w = itensor_from_tensor(weight_t);
@ -268,20 +260,22 @@ Tensor mkldnn_linear_pointwise(
op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32);
}
if (mkldnn_bias.has_value()) {
ideep::inner_product_forward::compute</*reorder_src=*/false, /*reorder_weight=*/false>(
ideep::matmul_forward::compute</*reorder_src=*/false, /*reorder_weight=*/false>(
mkldnn_input,
w,
mkldnn_bias.value(),
mkldnn_output,
op_attr,
aprop_kind);
1.0f,
1.0f,
op_attr);
} else {
ideep::inner_product_forward::compute</*reorder_src=*/false, /*reorder_weight=*/false>(
ideep::matmul_forward::compute</*reorder_src=*/false, /*reorder_weight=*/false>(
mkldnn_input,
w,
mkldnn_output,
op_attr,
aprop_kind);
1.0f,
1.0f,
op_attr);
}
if (dim != 2) {
@ -319,7 +313,7 @@ Tensor mkldnn_linear_pointwise_binary(
dim == 2 ? input : input.reshape({-1, input.size(input.dim() - 1)});
std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
output_size.push_back(weight_t.size(0));
output_size.push_back(weight_t.size(1));
auto output = at::empty(output_size, input.options());
if (output.sym_numel() == 0) {
return output;
@ -329,7 +323,7 @@ Tensor mkldnn_linear_pointwise_binary(
if (dim != 2) {
std::vector<int64_t> output_size_reshaped = {
input_reshaped.size(0), weight_t.size(0)};
input_reshaped.size(0), weight_t.size(1)};
output = output.reshape(output_size_reshaped);
other_reshaped = other_reshaped.reshape(output_size_reshaped);
TORCH_CHECK(
@ -348,13 +342,12 @@ Tensor mkldnn_linear_pointwise_binary(
std::optional<ideep::tensor> mkldnn_bias{std::nullopt};
if (bias.defined()) {
mkldnn_bias = itensor_from_tensor(bias);
mkldnn_bias = itensor_from_tensor(bias.reshape({1, weight_t.size(1)}));
}
const ideep::tensor w = itensor_from_tensor(weight_t);
auto other_desc = mkldnn_other.get_desc();
auto op_attr = ideep::attr_t::fuse_binary(it_binary->second, other_desc);
auto aprop_kind = ideep::prop_kind::forward_inference;
if (use_mkldnn_bf32_linear() && input_t.scalar_type() == at::kFloat){
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
@ -365,17 +358,17 @@ Tensor mkldnn_linear_pointwise_binary(
}
if (mkldnn_bias.has_value()) {
ideep::inner_product_forward::compute_binary</*reorder_src=*/false, /*reorder_weight=*/false>(
ideep::matmul_forward::compute_binary</*reorder_src=*/false, /*reorder_weight=*/false>(
mkldnn_input,
mkldnn_other,
w,
mkldnn_bias.value(),
mkldnn_output,
op_attr,
aprop_kind);
1.0f,
op_attr);
} else {
ideep::inner_product_forward::compute_binary</*reorder_src=*/false, /*reorder_weight=*/false>(
mkldnn_input, mkldnn_other, w, mkldnn_output, op_attr, aprop_kind);
ideep::matmul_forward::compute_binary</*reorder_src=*/false, /*reorder_weight=*/false>(
mkldnn_input, mkldnn_other, w, mkldnn_output, 1.0f, op_attr);
}
if (dim != 2) {

View File

@ -260,8 +260,8 @@ static Tensor mkldnn_reorder_linear_weight(
const Tensor& self,
std::optional<int64_t> batch_size_opt) {
mkldnn_check_low_precision(self.scalar_type(), "mkldnn_reorder_linear_weight");
auto out_features = self.size(0);
auto in_features = self.size(1);
auto in_features = self.size(0);
auto out_features = self.size(1);
auto self_ = self.contiguous();
auto w = itensor_from_tensor(self_);
ideep::dims input_size;
@ -269,12 +269,11 @@ static Tensor mkldnn_reorder_linear_weight(
if (batch_size_opt.has_value()) {
input_size = {batch_size_opt.value(), in_features};
}
auto packed_desc = ideep::inner_product_forward::expected_weights_desc(
{out_features, in_features},
auto packed_desc = ideep::matmul_forward::expected_weights_desc(
{in_features, out_features},
input_size,
/* weight dtype */ dtype,
/* src dtype */ dtype,
ideep::prop_kind::forward_inference);
/* src dtype */ dtype);
ideep::tensor result;
result.init(packed_desc);
result.feed_from(w);

View File

@ -220,7 +220,7 @@ class TestMkldnnFusion(JitTestCase):
scalars = pointwise_info.scalars
algorithm = pointwise_info.algorithm
fused = torch.ops.mkldnn._linear_pointwise(
v, mod.linear.weight, mod.linear.bias, attr, scalars, algorithm
v, mod.linear.weight.t(), mod.linear.bias, attr, scalars, algorithm
)
self.assertEqual(ref, fused)
@ -338,7 +338,7 @@ class TestMkldnnFusion(JitTestCase):
ref = mod(v, other)
attr = pointwise_name
fused = torch.ops.mkldnn._linear_pointwise(
v, other, mod.linear.weight, mod.linear.bias, attr
v, other, mod.linear.weight.t(), mod.linear.bias, attr
)
self.assertEqual(ref, fused)

View File

@ -15,9 +15,10 @@ _DNNL_RUNTIME_OMP = {
"#cmakedefine DNNL_ENABLE_STACK_CHECKER": "#undef DNNL_ENABLE_STACK_CHECKER",
"#cmakedefine DNNL_EXPERIMENTAL_UKERNEL": "/* undef DNNL_EXPERIMENTAL_UKERNEL */",
"#cmakedefine DNNL_EXPERIMENTAL": "#undef DNNL_EXPERIMENTAL",
"#cmakedefine DNNL_EXPERIMENTAL_SPARSE": "#undef DNNL_EXPERIMENTAL_SPARSE",
"#cmakedefine ONEDNN_BUILD_GRAPH": "#undef ONEDNN_BUILD_GRAPH",
"#cmakedefine DNNL_EXPERIMENTAL_PROFILING": "#undef DNNL_EXPERIMENTAL_PROFILING",
"#cmakedefine DNNL_EXPERIMENTAL_LOGGING": "#undef DNNL_EXPERIMENTAL_LOGGING",
"#cmakedefine DNNL_EXPERIMENTAL_SYCL_KERNEL_COMPILER": "#undef DNNL_EXPERIMENTAL_SYCL_KERNEL_COMPILER",
"#cmakedefine DNNL_DISABLE_GPU_REF_KERNELS": "#undef DNNL_DISABLE_GPU_REF_KERNELS",
"#cmakedefine01 BUILD_TRAINING": "#define BUILD_TRAINING 1",
"#cmakedefine01 BUILD_INFERENCE": "#define BUILD_INFERENCE 0",
@ -49,8 +50,6 @@ _DNNL_RUNTIME_OMP = {
"#cmakedefine01 BUILD_AVX512": "#define BUILD_AVX512 0",
"#cmakedefine01 BUILD_AMX": "#define BUILD_AMX 0",
"#cmakedefine01 BUILD_PRIMITIVE_GPU_ISA_ALL": "#define BUILD_PRIMITIVE_GPU_ISA_ALL 1",
"#cmakedefine01 BUILD_GEN9": "#define BUILD_GEN9 0",
"#cmakedefine01 BUILD_GEN11": "#define BUILD_GEN11 0",
"#cmakedefine01 BUILD_XELP": "#define BUILD_XELP 0",
"#cmakedefine01 BUILD_XEHPG": "#define BUILD_XEHPG 0",
"#cmakedefine01 BUILD_XEHPC": "#define BUILD_XEHPC 0",
@ -70,8 +69,8 @@ template_rule(
out = "include/oneapi/dnnl/dnnl_version.h",
substitutions = {
"@DNNL_VERSION_MAJOR@": "3",
"@DNNL_VERSION_MINOR@": "7",
"@DNNL_VERSION_PATCH@": "1",
"@DNNL_VERSION_MINOR@": "9",
"@DNNL_VERSION_PATCH@": "0",
},
)
@ -86,7 +85,7 @@ template_rule(
name = "include_dnnl_version_hash",
src = "include/oneapi/dnnl/dnnl_version_hash.h.in",
out = "include/oneapi/dnnl/dnnl_version_hash.h",
substitutions = {"@DNNL_VERSION_HASH@": "8d263e693366ef8db40acc569cc7d8edf644556d",}
substitutions = {"@DNNL_VERSION_HASH@": "56e10537b8d046f9a3a7c971e48d394948150b4a",}
)
cc_library(
@ -99,6 +98,7 @@ cc_library(
"src/cpu/aarch64/**/*.cpp",
"src/cpu/rv64/**/*.cpp",
"src/cpu/sycl/**/*.cpp",
"src/cpu/ppc64/**/*.cpp",
]),
hdrs = glob([
"include/oneapi/dnnl/*.h",
@ -110,13 +110,16 @@ cc_library(
"src/cpu/**/**/*.h",
"src/common/*.hpp",
"src/common/**/**/*.h",
"src/common/ittnotify/jitprofiling.h",
"third_party/xbyak/*.h",
"third_party/ittnotify/jitprofiling.h",
"third_party/spdlog/**/*.h",
], exclude=[
"src/cpu/aarch64/**/*.hpp",
"src/cpu/aarch64/**/*.h",
"src/cpu/rv64/**/*.hpp",
"src/cpu/rv64/**/*.h",
"src/cpu/sycl/**/*.hpp",
"src/cpu/ppc64/**/*.hpp",
]) + [
"include/oneapi/dnnl/dnnl_config.h",
"include/oneapi/dnnl/dnnl_version.h",
@ -141,7 +144,7 @@ cc_library(
"src/",
"src/common/",
"src/cpu/",
"src/cpu/x64/xbyak/",
"third_party/",
],
visibility = ["//visibility:public"],
linkopts = [

View File

@ -414,11 +414,13 @@ def expand_bias(B: Optional[_T], X: _T) -> Optional[_T]:
if not isinstance(B, ir.TensorBox):
B = ir.TensorBox(B)
assert hasattr(X, "get_size")
B = L.expand(B, (X.get_size()[0], B.get_size()[-1]))
if len(B.get_size()) == 1:
B = L.expand(B, (X.get_size()[0], B.get_size()[-1]))
else:
assert isinstance(B, torch.Tensor)
assert isinstance(X, torch.Tensor)
B = B.expand(X.shape[0], B.shape[-1])
if len(B.shape) == 1:
B = B.expand(X.shape[0], B.shape[-1])
return B
@ -978,8 +980,6 @@ class CppGemmTemplate(CppTemplate):
view_size, view_stride, view_offset
)
if not trans_w:
return new_inputs, layout_or_out
X = new_inputs[0]
W = new_inputs[1]
B = new_inputs[2] if has_bias else None

View File

@ -239,8 +239,6 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
layout_or_out: _U,
) -> tuple[list[_T], _U]:
new_inputs: list[_T] = list(inputs)
if not trans_w:
return new_inputs, layout_or_out
X = new_inputs[0]
for wgt_idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num):
new_input = new_inputs[wgt_idx]

View File

@ -45,9 +45,6 @@ if torch._C._has_mkldnn:
_conv_transpose_args = [Arg() for _ in range(11)]
class MkldnnDeviceOpBase:
def get_linear_transpose_weight(self, weight_node):
raise NotImplementedError
def pack_conv_weight(
self,
graph,
@ -58,9 +55,7 @@ if torch._C._has_mkldnn:
):
raise NotImplementedError
def pack_linear_weight(
self, graph, is_lp_weight, transpose_weight_node, batch_size
):
def pack_linear_weight(self, graph, is_lp_weight, weight_node, batch_size):
raise NotImplementedError
def pack_linear(
@ -69,13 +64,6 @@ if torch._C._has_mkldnn:
raise NotImplementedError
class CpuMkldnnDeviceOp(MkldnnDeviceOpBase):
def get_linear_transpose_weight(self, weight_node):
packed_weight_node = weight_node
assert packed_weight_node.target == mkldnn._reorder_linear_weight
transpose_weight_node = packed_weight_node.args[0]
assert transpose_weight_node.target == aten.permute.default
return transpose_weight_node
def pack_conv_weight(
self,
graph,
@ -94,12 +82,10 @@ if torch._C._has_mkldnn:
"call_function", packed_weight_op, args=packed_weight_inputs
)
def pack_linear_weight(
self, graph, is_lp_weight, transpose_weight_node, batch_size
):
def pack_linear_weight(self, graph, is_lp_weight, weight_node, batch_size):
# For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance.
packed_weight_inputs = (
transpose_weight_node,
weight_node,
batch_size.node.shape_env.size_hint(batch_size.node.expr)
if has_free_symbols(batch_size)
else batch_size,
@ -127,12 +113,12 @@ if torch._C._has_mkldnn:
self, graph, is_lp_weight, batch_size, input, packed_weight_node, bias
):
packed_linear_inputs: tuple[Any, ...] = (input, packed_weight_node)
transpose_weight_node = packed_weight_node.args[0]
weight_node = packed_weight_node.args[0]
if is_lp_weight or mkldnn._is_mkldnn_acl_supported() or V.aot_compilation:
packed_linear_inputs += (bias, "none", [], "")
packed_linear_op: Callable[..., Any] = mkldnn._linear_pointwise.default
else:
packed_linear_inputs += (transpose_weight_node, bias, batch_size)
packed_linear_inputs += (weight_node, bias, batch_size)
packed_linear_op = torch.ops.mkl._mkl_linear
return graph.create_node(
@ -1049,12 +1035,9 @@ if torch._C._has_mkldnn:
def is_linear_add_bias(match):
add_node = match.output_node()
linear_node = add_node.args[0]
device_type = add_node.meta.get("val").device.type
mkldnn_device_op = _get_mkldnn_device_op(device_type)
transpose_weight_node = mkldnn_device_op.get_linear_transpose_weight(
linear_node.args[1]
)
weight_meta = transpose_weight_node.args[0].meta.get("val")
packed_weight_node = linear_node.args[1]
assert packed_weight_node.target == mkldnn._reorder_linear_weight
weight_meta = packed_weight_node.args[0].meta.get("val")
bias_node = add_node.args[1]
if isinstance(bias_node, int):
# we only folding bias if it is a constant
@ -1444,9 +1427,6 @@ if torch._C._has_mkldnn:
device_type = input.meta.get("val").device.type
mkldnn_device_op = _get_mkldnn_device_op(device_type)
with graph.inserting_before(linear_node):
transpose_weight_node = graph.create_node(
"call_function", aten.permute.default, (weight, (1, 0))
)
weight_dtype = weight.meta.get("val").dtype
is_lp_weight = weight_dtype in (
torch.bfloat16,
@ -1464,8 +1444,19 @@ if torch._C._has_mkldnn:
assert compute_with_lp or mkldnn._is_mkldnn_acl_supported(), (
f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}"
)
weight_node = (
weight
if (
compute_with_lp
or mkldnn._is_mkldnn_acl_supported()
or V.aot_compilation
)
else graph.create_node(
"call_function", aten.permute.default, (weight, (1, 0))
)
)
packed_weight_node = mkldnn_device_op.pack_linear_weight(
graph, compute_with_lp, transpose_weight_node, batch_size
graph, compute_with_lp, weight_node, batch_size
)
packed_linear_node = mkldnn_device_op.pack_linear(
graph, compute_with_lp, batch_size, input, packed_weight_node, bias

View File

@ -869,7 +869,7 @@ class LinearUnary(ExternKernelAlloc):
w = cls.require_contiguous(cls.realize_input(w))
*m, _ic = x.get_size()
oc, _ic = w.get_size()
_ic, oc = w.get_size()
output_size = list(m) + [oc]
inputs = [x, w]
constant_args = [attr, scalars if scalars else [-1], algorithm]
@ -926,7 +926,7 @@ class LinearBinary(ExternKernelAlloc):
w = cls.require_contiguous(cls.realize_input(w))
*m, _ic = x.get_size()
oc, _ic = w.get_size()
_ic, oc = w.get_size()
output_size = list(m) + [oc]
inputs = [x, y, w]
constant_args = [attr]

View File

@ -145,11 +145,11 @@ def grouped_gemm_lowering(
b = [bias if bias is None else ir.ExternKernel.realize_input(bias) for bias in b]
choices: list[ChoiceCaller] = []
*_, layout, x, _ = mm_args(x, permute(w[0], [1, 0]), layout=layout)
*_, layout, x, _ = mm_args(x, w[0], layout=layout)
kwargs = {
"has_bias": [bias is not None for bias in b],
"trans_w": True,
"trans_w": False,
"epilogue_creator": None,
"act_mapping": dict.fromkeys(range(num_gemm), x),
}
@ -344,9 +344,8 @@ def register_onednn_fusion_ops():
b = ir.ExternKernel.realize_input(b) # type: ignore[assignment]
choices: list[ChoiceCaller] = []
if config.max_autotune or config.max_autotune_gemm:
transposed_w = permute(w, [1, 0])
*_, layout, x, transposed_w = mm_args(x, transposed_w, layout=layout)
if use_cpp_gemm_template(layout, x, transposed_w):
*_, layout, x, w = mm_args(x, w, layout=layout)
if use_cpp_gemm_template(layout, x, w):
def epilogue_creator(buf):
return create_epilogue_with_attr(
@ -355,7 +354,7 @@ def register_onednn_fusion_ops():
kwargs = {
"has_bias": b is not None,
"trans_w": True,
"trans_w": False,
"epilogue_creator": (
None if attr == "none" else epilogue_creator
),
@ -409,18 +408,15 @@ def register_onednn_fusion_ops():
b = ir.ExternKernel.realize_input(b) # type: ignore[assignment]
choices: list[ChoiceCaller] = []
if config.max_autotune or config.max_autotune_gemm:
transposed_w = permute(w, [1, 0])
*_, layout, x, transposed_w, y = mm_args(
x, transposed_w, y, layout=layout
)
if use_cpp_gemm_template(layout, x, transposed_w):
*_, layout, x, w, y = mm_args(x, w, y, layout=layout)
if use_cpp_gemm_template(layout, x, w):
def epilogue_creator(buf):
return create_epilogue_with_attr(buf, attr, other=y)
kwargs = {
"has_bias": b is not None,
"trans_w": True,
"trans_w": False,
"epilogue_creator": epilogue_creator,
}

View File

@ -2503,7 +2503,7 @@ if torch._C._has_mkldnn:
def meta_linear_pointwise_default(
input_tensor, weight, bias, attr, scalars, algorithm
):
return input_tensor.new_empty((*input_tensor.shape[:-1], weight.shape[0]))
return input_tensor.new_empty((*input_tensor.shape[:-1], weight.shape[1]))
if torch._C.has_mkl:
_meta_lib_dont_use_me_use_register_meta_for_mkl = torch.library.Library(