Compare commits

...

7 Commits

Author SHA1 Message Date
381da56181 up 2025-07-25 13:01:09 -07:00
8e3a250ea7 up 2025-07-25 13:01:09 -07:00
93af908459 up 2025-07-25 13:01:09 -07:00
a0198671a4 try fix AOTI again 2025-07-25 13:01:09 -07:00
242a92605e try fix again 2025-07-25 13:01:09 -07:00
367d694153 try fix AOTI 2025-07-25 13:01:09 -07:00
c62d3eb32e Add torch.compile support for torch.mm(out_dtype=...) 2025-07-25 13:01:09 -07:00
3 changed files with 84 additions and 6 deletions

View File

@ -3778,6 +3778,57 @@ class CommonTemplate:
check_lowp=False,
)
@skip_if_cpp_wrapper(
"Unable to make AOTI fallback_ops.py generate aoti_torch_cuda__mm_dtype_out_cuda"
)
def test_mm_out_dtype(self):
if self.device != "cuda":
self.skipTest("out_dtype is only supported on CUDA")
# Test bf16 -> fp32 (upcast to higher precision)
def fn_bf16_to_fp32(a, b, bias):
out = torch.mm(a, b, out_dtype=torch.float32)
out = out + bias
return out
self.common(
fn_bf16_to_fp32,
(
torch.randn(64, 128, dtype=torch.bfloat16),
torch.randn(128, 64, dtype=torch.bfloat16),
torch.randn(64, dtype=torch.float32),
),
check_lowp=False,
)
# Test with different shapes
self.common(
fn_bf16_to_fp32,
(
torch.randn(1, 256, dtype=torch.bfloat16),
torch.randn(256, 128, dtype=torch.bfloat16),
torch.randn(128, dtype=torch.float32),
),
check_lowp=False,
)
# Test float16 -> float32
def fn_fp16_to_fp32(a, b, bias):
out = torch.mm(a, b, out_dtype=torch.float32)
out = out + bias
return out
if torch.cuda.get_device_capability()[0] >= 7:
self.common(
fn_fp16_to_fp32,
(
torch.randn(32, 64, dtype=torch.float16),
torch.randn(64, 32, dtype=torch.float16),
torch.randn(32, dtype=torch.float32),
),
check_lowp=False,
)
@skipIfPy312 # segfaults
@skipCUDAIf(not SM80OrLater, "Requires sm80")
def test_mixed_mm(self):

View File

@ -570,6 +570,13 @@ def lazy_register_extern_choice(fn):
aten_mm = ExternKernelChoice(torch.mm, "at::mm_out")
aten_mm_dtype = ExternKernelChoice(
torch.mm,
"at::_mm_dtype_out_cuda",
name="mm_dtype",
op_overload=aten.mm.dtype_out,
)
aten_addmm = ExternKernelChoice(
torch.addmm, "at::addmm_out", op_overload=aten.addmm.default
)
@ -661,11 +668,14 @@ def decomposeK(a, b, k_splits):
@register_lowering(aten.mm, type_promotion_kind=None)
def tuned_mm(mat1, mat2, *, layout=None):
@register_lowering(aten.mm.dtype, type_promotion_kind=None)
def tuned_mm(mat1, mat2, out_dtype=None, *, layout=None):
"""
Lowering for autotuning aten.mm with different backends (Aten, Triton, CUTLASS, etc.)
"""
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
m, n, k, layout, mat1, mat2 = mm_args(
mat1, mat2, layout=layout, out_dtype=out_dtype
)
device_type = ir.get_device_type(mat1)
name = "mm"
@ -688,9 +698,13 @@ def tuned_mm(mat1, mat2, *, layout=None):
)
# options to tune from
choices = (
[aten_mm.bind((mat1, mat2), aten_layout)] if use_aten_gemm_kernels() else []
)
if out_dtype:
assert mat1.get_device().type == "cuda", "out_dtype is only supported for CUDA"
aten_func = aten_mm_dtype.bind((mat1, mat2), aten_layout, out_dtype=out_dtype)
else:
aten_func = aten_mm.bind((mat1, mat2), aten_layout)
choices = [aten_func] if use_aten_gemm_kernels() else []
static_shape, is_nonzero = _is_static_problem(layout)
mm_configs = V.choices.get_base_mm_configs(device_type)

View File

@ -2238,7 +2238,7 @@ def meta__fused_moving_avg_obs_fq_helper(
return (torch.empty_like(self), mask)
@register_meta(aten.mm)
@register_meta([aten.mm.default, aten.mm.out])
@out_wrapper(exact_dtype=True)
def meta_mm(a, b):
torch._check(a.dim() == 2, lambda: "a must be 2D")
@ -2252,6 +2252,19 @@ def meta_mm(a, b):
return a.new_empty(N, P)
@register_meta(aten.mm.dtype)
def meta_mm_dtype(a, b, out_dtype):
torch._check(a.dim() == 2, lambda: "a must be 2D")
torch._check(b.dim() == 2, lambda: "b must be 2D")
N, M1 = a.shape
M2, P = b.shape
torch._check(
M1 == M2,
lambda: f"a and b must have same reduction dim, but got [{N}, {M1}] X [{M2}, {P}].",
)
return torch.empty(N, P, dtype=out_dtype, device=a.device)
def _compute_reduction_shape(self, dims, keepdim):
if keepdim:
return tuple(self.shape[i] if i not in dims else 1 for i in range(self.ndim))