[Inductor] Add decomposeK as an autotuning choice for mm (#150654)

As a result of adding subgraph as a choice to inductor https://github.com/pytorch/pytorch/pull/149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: https://github.com/pytorch/pytorch/pull/150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />

TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />

We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150654
Approved by: https://github.com/eellison
This commit is contained in:
PaulZhang12
2025-05-02 14:00:23 -07:00
committed by PyTorch MergeBot
parent 5e9682719f
commit 84aa0985fb
12 changed files with 306 additions and 21 deletions

View File

@ -1194,6 +1194,95 @@ class TestMaxAutotune(TestCase):
actual = (opt_f(x, y), x.grad, linear.weight.grad, linear.bias.grad)
assert same(expect, actual, tol=1e-2), f"ref:\n{expect}\nact:\n{actual}"
@skipIfXpu
@unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm")
@unittest.skipIf(
config.cpp_wrapper, "decompose_k not supported for cpp_wrapper yet"
)
@parametrize("dynamic", (True, False))
@parametrize("sizes", ((32, 32, 32768), (64, 128, 200000), (64, 64, 177147)))
@config.patch(
max_autotune=True,
max_autotune_gemm_backends="TRITON",
autotune_fallback_to_aten=False,
)
def test_max_autotune_decompose_k(self, sizes, dynamic):
M, N, K = sizes
a = torch.randn(M, K, dtype=torch.float16, device="cuda", requires_grad=True)
b = torch.randn(K, N, dtype=torch.float16, device="cuda", requires_grad=True)
possible_splits = range(2, min(K // M, K // N) + 1)
divisors = {split for split in possible_splits if K % split == 0}
def check_divisors(code):
for kernel in code:
if "decompose_k" in kernel:
divisor_found = False
for divisor in divisors:
if f"{divisor}_split" in kernel:
divisor_found = True
break
self.assertTrue(
divisor_found,
f"Could not find a split in {divisors} in {kernel}",
)
compiled_func = torch.compile(lambda a, b: a @ b, dynamic=dynamic)
# We assume with the large k dim relative to m, n, decompose_k will be most performant
out, code = run_and_get_code(compiled_func, a, b)
if dynamic:
FileCheck().check_not("extern_kernels.bmm_dtype").check_not(
"decompose_k"
).run(code[0])
else:
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
"triton_.*_fused_0.run"
).check("decompose_k").run(code[0])
check_divisors(code)
torch.testing.assert_close(out, a @ b, atol=1e-2, rtol=1e-2)
# Test adding epilogue also equivalent to eager
compiled_func = torch.compile(lambda a, b: (a @ b).relu(), dynamic=dynamic)
out, code = run_and_get_code(compiled_func, a, b)
if dynamic:
FileCheck().check_not("extern_kernels.bmm_dtype").check_not(
"decompose_k"
).run(code[0])
else:
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
"triton_.*_fused_0.run"
).check("decompose_k").run(code[0])
check_divisors(code)
torch.testing.assert_close(
compiled_func(a, b), (a @ b).relu(), atol=1e-2, rtol=1e-2
)
# Test adding reinterpret view before subgraph
a = a.transpose(0, 1)
compiled_func = torch.compile(
lambda a, b: (a.transpose(0, 1) @ b).relu(), dynamic=dynamic
)
out, code = run_and_get_code(compiled_func, a, b)
if dynamic:
FileCheck().check_not("extern_kernels.bmm_dtype").check_not(
"decompose_k"
).run(code[0])
else:
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
"triton_.*_fused_0.run"
).check("decompose_k").run(code[0])
check_divisors(code)
torch.testing.assert_close(
compiled_func(a, b),
(a.transpose(0, 1) @ b).relu(),
atol=1e-2,
rtol=1e-2,
)
@instantiate_parametrized_tests
class TestMaxAutotuneRemoteCache(TestCase):

View File

@ -521,8 +521,8 @@ class PadMMTest(TestCase):
return x @ y
args = [
torch.randn(2**4, 2**14 - 1, device="cuda", dtype=torch.float16),
torch.randn(2**14 - 1, 2**4, device="cuda", dtype=torch.float16),
torch.randn(2**4, 2**8 - 1, device="cuda", dtype=torch.float16),
torch.randn(2**8 - 1, 2**4, device="cuda", dtype=torch.float16),
]
counters.clear()
@ -534,6 +534,7 @@ class PadMMTest(TestCase):
ret, code = run_and_get_code(opt_fn, *args)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
code = [c for c in code if "decompose_k" not in c]
# The mm kernel should use a template (because we set max_autotune_gemm_backends = TRITON).
# Its name should contain `mm` because `mm` was the original aten op where the mm came from.
FileCheck().check("def triton_tem_fused_mm").run(code[0])

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: inductor"]
import functools
import unittest
import torch
from torch._dispatch.python import enable_python_dispatcher
@ -13,6 +14,7 @@ from torch._inductor.select_algorithm import (
)
from torch._inductor.test_case import run_tests, TestCase
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import skipIfXpu, TEST_WITH_ROCM
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
@ -26,6 +28,8 @@ class TestSubgraphChoice(TestCase):
layout=FixedLayout(torch.device(f"{GPU_TYPE}:0"), dtype=dtype, size=shape),
)
@skipIfXpu
@unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm")
def test_subgraph_decompose_k(self):
from torch._inductor.kernel.mm import aten_mm
from torch._inductor.kernel.mm_common import mm_args
@ -46,7 +50,7 @@ class TestSubgraphChoice(TestCase):
B = k // kPartitions
a_reshaped = torch.permute(a.reshape(m, B, kPartitions), (1, 0, 2))
b_reshaped = b.reshape(B, kPartitions, n)
result = torch.bmm(a_reshaped, b_reshaped)
result = torch.bmm(a_reshaped, b_reshaped, out_dtype=torch.float32)
result_fp32 = result.to(torch.float32)
reduced_buf = torch.sum(result_fp32, 0)
return reduced_buf.to(a.dtype)

View File

@ -1,3 +1,4 @@
import itertools
import logging
from typing import Any, Callable
@ -63,6 +64,7 @@ class SubgraphChoiceCaller(ir.ChoiceCaller):
bm_graph_lowering.run(*self.example_inputs)
mod = bm_graph_lowering.compile_to_module()
bm_func = mod.call
bm_func([*args])
return benchmarker.benchmark_gpu(lambda: bm_func([*args]))
@ -76,6 +78,11 @@ class SubgraphChoiceCaller(ir.ChoiceCaller):
for arg in self.example_inputs
if isinstance(arg, torch.Tensor)
],
*[
str(arg.stride())
for arg in self.example_inputs
if isinstance(arg, torch.Tensor)
],
str(self.gm.graph),
]
)
@ -111,6 +118,8 @@ class SubgraphTemplate(KernelTemplate):
optimal implementations for complex operations.
"""
index_counter = itertools.count()
def __init__(
self,
name: str,
@ -123,7 +132,7 @@ class SubgraphTemplate(KernelTemplate):
name: The name of this template
graph: The FX graph
"""
self.name = name
self.name = f"{name}_{next(SubgraphTemplate.index_counter)}"
self.make_fx_graph = make_fx_graph
def generate( # type: ignore[override]

View File

@ -2644,7 +2644,7 @@ class PythonWrapperCodegen(CodeGen):
if (
name in V.graph.removed_buffers
or name in self.allocated
or isinstance(buffer, ir.DonatedBuffer)
or isinstance(buffer, (ir.DonatedBuffer, ir.SubgraphBuffer))
):
return
self.allocated.add(name)

View File

@ -265,6 +265,7 @@ def round_dec(x: torch.Tensor, decimals: int = 0) -> torch.Tensor:
def bmm(
self: torch.Tensor,
batch2: torch.Tensor,
out_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
# TODO: Re-enable for mps once our reductions are performant enough
# (https://github.com/pytorch/pytorch/issues/150121)
@ -291,6 +292,7 @@ def addmm(
self: torch.Tensor,
mat1: torch.Tensor,
mat2: torch.Tensor,
out_dtype: Optional[torch.dtype] = None,
beta: torch.types.Number = 1,
alpha: torch.types.Number = 1,
) -> torch.Tensor:
@ -319,6 +321,7 @@ def addmm(
def mm(
self: torch.Tensor,
input2: torch.Tensor,
out_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
# Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning.
# todo: Look into why and fix it (hopefully)

View File

@ -5720,6 +5720,7 @@ class ExternKernel(InputsKernel):
return
size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size())
stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride())
wrapper.writeline(
f"assert_size_stride({self.get_name()}, {size}, {stride})"
)
@ -6015,7 +6016,6 @@ class SubgraphBuffer(ExternKernel):
self.subgraph = V.graph.make_subgraph(
self.gm, self.example_inputs, subgraph_name
)
import torch._inductor.config as inductor_config
with V.set_graph_handler(self.subgraph):
@ -6033,9 +6033,11 @@ class SubgraphBuffer(ExternKernel):
self.graph = graph
self.name = graph.name
outer_inputs = [t.codegen_reference() for t in self.inputs]
wrapper.codegen_subgraph_with_flattened_outputs(
CodegenGraph(self.subgraph),
[*[buffer.get_name() for buffer in self.inputs]],
outer_inputs,
[self.name],
)

View File

@ -121,13 +121,22 @@ bmm_template = TritonTemplate(
)
aten_bmm = ExternKernelChoice(torch.bmm, "at::bmm_out")
aten_bmm_dtype = ExternKernelChoice(
torch.bmm,
"at::_bmm_out_dtype_cuda",
name="bmm_dtype",
op_overload=aten.bmm.dtype_out,
)
aten_baddbmm = ExternKernelChoice(
torch.baddbmm, "at::baddbmm_out", op_overload=aten.baddbmm.out
)
@L.register_lowering(aten.bmm)
def tuned_bmm(mat1, mat2, *, layout=None):
def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
"""
Lowering for autotuning aten.bmm with different backends (Aten, Triton, CUTLASS, etc.)
"""
if all(x.get_device().type == "cpu" for x in [mat1, mat2]):
# decompose to small ops when memory bound
if mat1.get_size()[1] == 1 or mat2.get_size()[2] == 1:
@ -165,7 +174,9 @@ def tuned_bmm(mat1, mat2, *, layout=None):
meta_mat2 = V.graph.current_node.args[1]
mat2 = may_require_contiguous(mat2, meta_mat2)
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
m, n, k, layout, mat1, mat2 = mm_args(
mat1, mat2, layout=layout, out_dtype=out_dtype
)
# below is for getting an overview logging info of inductor mms
counters["aten_mm_info"][f"aten.bmm_{m}_{n}_{k}"] += 1
@ -179,13 +190,21 @@ def tuned_bmm(mat1, mat2, *, layout=None):
layout,
)
if out_dtype:
assert mat1.get_device().type == "cuda", "out_dtype is only supported for CUDA"
aten_func = aten_bmm_dtype.bind((mat1, mat2), layout, out_dtype=out_dtype)
else:
aten_func = aten_bmm.bind((mat1, mat2), layout)
# options to tune from
choices = [aten_bmm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
choices = [aten_func] if use_aten_gemm_kernels() else []
device_type = ir.get_device_type(mat1)
bmm_configs = V.choices.get_base_mm_configs(device_type)
if use_triton_template(layout):
# TODO: add out_dtype support for Triton Template
assert out_dtype is None, "out_dtype is not supported for Triton"
for config in bmm_configs(
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
):
@ -200,7 +219,7 @@ def tuned_bmm(mat1, mat2, *, layout=None):
if batch_stride_largest and is_nonzero and use_cutlass_template(layout, m, n, k):
from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) # type: ignore[arg-type]
if use_cpp_bmm_template(layout, mat1, mat2):
from ..codegen.cpp_bmm_template import CppBmmTemplate

View File

@ -3,6 +3,8 @@ import functools
import logging
from typing import Any, Optional
import sympy
import torch
from torch._dynamo.utils import counters
from torch._inductor.autoheuristic.autoheuristic import AutoHeuristicSelectAlgorithm
@ -14,12 +16,14 @@ from torch._inductor.autoheuristic.autoheuristic_utils import (
)
from torch._inductor.codegen.cpp_gemm_template import CppGemmTemplate
from torch._inductor.virtualized import V
from torch.fx.experimental.proxy_tensor import make_fx
from torch.torch_version import TorchVersion
from .. import config as inductor_config, ir
from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate
from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
from ..ir import FlexibleLayout, is_triton
from ..codegen.subgraph import SubgraphTemplate
from ..ir import FlexibleLayout, ir_node_to_tensor, is_triton
from ..lowering import (
add_layout_constraint,
constrain_to_fx_strides,
@ -33,11 +37,13 @@ from ..select_algorithm import (
TritonTemplate,
)
from ..utils import (
get_k_splits,
get_tma_workspace_arg,
use_aten_gemm_kernels,
use_ck_gemm_template,
use_cpp_gemm_template,
use_cutlass_template,
use_decompose_k_choice,
use_max_autotune,
use_triton_template,
use_triton_tma_template,
@ -68,6 +74,7 @@ except ImportError:
log = logging.getLogger(__name__)
aten = torch.ops.aten
prims = torch.ops.prims
mm_template = TritonTemplate(
name="mm",
@ -584,8 +591,25 @@ def check_supported_striding(mat_a, mat_b) -> None:
aten_bias_addmm = ExternKernelChoice(bias_addmm, None)
def decomposeK(a, b, k_splits):
m = a.shape[0]
n = b.shape[1]
k = a.shape[1]
k_parts = k // k_splits
B = k_splits
a_reshaped = torch.permute(a.reshape(m, B, k_parts), (1, 0, 2))
b_reshaped = b.reshape(B, k_parts, n)
result = torch.bmm(a_reshaped, b_reshaped, out_dtype=torch.float32)
reduced_buf = torch.sum(result, 0)
return reduced_buf.to(a.dtype)
@register_lowering(aten.mm, type_promotion_kind=None)
def tuned_mm(mat1, mat2, *, layout=None):
"""
Lowering for autotuning aten.mm with different backends (Aten, Triton, CUTLASS, etc.)
"""
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
device_type = ir.get_device_type(mat1)
name = "mm"
@ -620,7 +644,10 @@ def tuned_mm(mat1, mat2, *, layout=None):
if is_nonzero and use_triton_template(layout):
for config in mm_configs(
m, n, k, *mm_config_kwargs(device_type, _is_large_block_for_cpu)
m,
n,
k,
**mm_config_kwargs(device_type, _is_large_block_for_cpu),
):
mm_template.maybe_append_choice(
choices,
@ -628,9 +655,13 @@ def tuned_mm(mat1, mat2, *, layout=None):
layout=layout,
**mm_options(config, m, n, k, layout),
)
if use_triton_tma_template(mat1, mat2):
for config in persistent_mm_configs(
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
m,
n,
k,
**mm_config_kwargs(device_type, _is_large_block_for_cpu),
):
persistent_tma_mm_template.maybe_append_choice(
choices,
@ -643,6 +674,40 @@ def tuned_mm(mat1, mat2, *, layout=None):
**mm_options(config, m, n, k, layout),
**persistent_mm_options(mat1, mat2),
)
# Only do split-k optimization if K is much larger than m, n and m, n are small
if use_decompose_k_choice(m, n, k):
from torch._dispatch.python import enable_python_dispatcher
from ..decomposition import select_decomp_table
k_splits = get_k_splits(m, n, k)
for k_split in k_splits:
if not V.graph.sizevars.is_expr_static_and_true(
sympy.Eq(sympy.Mod(k, k_split), 0)
):
continue
with enable_python_dispatcher():
decompositions = select_decomp_table()
decompose_k_subgraph_template = SubgraphTemplate(
name=f"decompose_k_mm_{k_split}_split",
make_fx_graph=make_fx(
functools.partial(decomposeK, k_splits=k_split),
decompositions,
),
)
with V.fake_mode:
mat1_tensor = ir_node_to_tensor(mat1)
mat2_tensor = ir_node_to_tensor(mat2)
decompose_k_subgraph_template.maybe_append_choice(
choices,
input_nodes=(mat1, mat2),
layout=layout,
example_inputs=[mat1_tensor, mat2_tensor],
)
if is_nonzero and use_cutlass_template(layout, m, n, k):
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])
@ -706,7 +771,6 @@ def tuned_mm(mat1, mat2, *, layout=None):
for k in inductor_config.external_matmul:
choices.append(lazy_register_extern_choice(k).bind((mat1, mat2), layout))
if should_fallback_to_aten(choices):
return aten_mm.bind((mat1, mat2), aten_layout).output_node()

View File

@ -50,6 +50,7 @@ from .codegen.common import (
WorkspaceZeroMode,
)
from .codegen.simd_kernel_features import SIMDKernelFeatures
from .codegen.subgraph import SubgraphChoiceCaller
from .codegen.triton import (
gen_common_triton_imports,
texpr,
@ -2195,7 +2196,7 @@ class AlgorithmSelectorCache(PersistentCache):
def benchmark_choice_in_current_process(
choice: ChoiceCaller, autotune_args: AutotuneArgs
) -> float:
is_extern = isinstance(choice, ExternKernelCaller)
is_extern = isinstance(choice, (ExternKernelCaller, SubgraphChoiceCaller))
benchmark_tensors = autotune_args.get_benchmark_tensors(is_extern)
inpts, output = benchmark_tensors.unpack()
output.zero_()

View File

@ -43,6 +43,7 @@ from typing_extensions import (
dataclass_transform,
ParamSpec,
Self,
TypeAlias,
TypeGuard,
)
from unittest import mock
@ -1560,6 +1561,85 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
return res
decompose_k_threshold = 32
# To limit compile time
k_splits_limit = 5
# Hand-tuned
default_k_splits = [16, 32, 64, 128, 256]
_IntLike: TypeAlias = Union[int, sympy.Expr]
def use_decompose_k_choice(m: _IntLike, n: _IntLike, k: _IntLike) -> bool:
from torch._inductor.virtualized import V
return (
V.graph.sizevars.is_expr_static_and_true(
sympy.And(
sympy.Ge(k, decompose_k_threshold * m),
sympy.Ge(k, decompose_k_threshold * n),
)
)
and not V.graph.aot_mode # TODO: Support AOTI for decomposeK
and not V.graph.cpp_wrapper
)
@functools.lru_cache(None)
def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]:
# If k is a sympy expression, we can't do any splitting
if isinstance(k, sympy.Expr) and not k.is_number:
return default_k_splits
if (isinstance(m, sympy.Expr) and not m.is_number) or (
isinstance(n, sympy.Expr) and not n.is_number
):
max_k_split = 256
else:
max_k_split = min(k // m, k // n)
min_k_split = 2
# Get all divisors of k, k has to be divisible by kPart
divisors = sympy.divisors(k)
divisors = [
divisor
for divisor in divisors
if divisor <= max_k_split and divisor >= min_k_split
]
pow_of_2_divisors, mul_of_32_divisors, rest_of_splits = [], [], []
for d in divisors:
kPart = k // d
# Smaller than 128 might not even fit in a single tile, BLOCK_K can be 128
if kPart < 128:
continue
# Power of 2 divisors are best performing, conform to hardware
if (kPart & kPart - 1) == 0 and kPart >= 128:
pow_of_2_divisors.append(d)
# Else check if creates a multiple of 32
elif kPart % 32 == 0:
mul_of_32_divisors.append(d)
# otherwise, take the smallest values
else:
rest_of_splits.append(d)
# If the # of power of 2 divisors are greater than k_splits_limit, return all
# This should be ok for compile time, all perfect squares between 128 and min(k / m, k / n)
# should never be a massive amount
if len(pow_of_2_divisors) >= k_splits_limit:
return pow_of_2_divisors
else:
best_splits = pow_of_2_divisors + mul_of_32_divisors + rest_of_splits
# Otherwise, conform results to k_splits_limit
return best_splits[:k_splits_limit]
@functools.lru_cache(None)
def _rocm_native_device_arch_name(device: str) -> str:
return torch.cuda.get_device_properties(device).gcnArchName

View File

@ -4298,7 +4298,7 @@ def meta_alias(self):
return self.view(self.shape)
def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None):
def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None, out_dtype=None):
torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
@ -4316,10 +4316,18 @@ def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None):
lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}"
f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].",
)
# TODO: handle out
output = batch2.new_empty(output_size)
if out_dtype:
supported_out_dtype = (
batch1.dtype == torch.float16 or batch1.dtype == torch.bfloat16
) and out_dtype == torch.float32
torch._check(
out_dtype == batch1.dtype or supported_out_dtype,
lambda: "out_dtype only supported for torch.float32 output with float16/bfloat16 inputs or same as input dtypes",
)
output = batch2.new_empty(output_size).to(out_dtype)
else:
# TODO: handle out
output = batch2.new_empty(output_size)
if not is_bmm and self_baddbmm is not None:
torch._check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor")
@ -4336,6 +4344,11 @@ def meta_bmm(self, mat2):
return common_meta_baddbmm_bmm(self, mat2, True)
@register_meta(aten.bmm.dtype)
def meta_bmm_dtype(self, mat2, out_dtype):
return common_meta_baddbmm_bmm(self, mat2, True, out_dtype=out_dtype)
def div_rtn(x, y):
q = x // y
r = x % y