Compare commits

...

63 Commits

Author SHA1 Message Date
70518444a8 Update
[ghstack-poisoned]
2025-10-31 10:18:29 +00:00
55a42d3a3e Update (base update)
[ghstack-poisoned]
2025-10-31 10:18:29 +00:00
793a5e2f86 Update
[ghstack-poisoned]
2025-10-30 15:49:36 +00:00
e7a108c32b Update (base update)
[ghstack-poisoned]
2025-10-30 15:49:36 +00:00
dd5a8d3fc8 Update
[ghstack-poisoned]
2025-10-30 15:44:01 +00:00
1d82e429d7 Update (base update)
[ghstack-poisoned]
2025-10-30 15:44:01 +00:00
03e312c85e Update
[ghstack-poisoned]
2025-10-30 14:29:34 +00:00
2540eaff4d Update (base update)
[ghstack-poisoned]
2025-10-30 14:29:34 +00:00
b1e44a9ff1 Update
[ghstack-poisoned]
2025-10-30 13:05:51 +00:00
7c585b11f9 Update (base update)
[ghstack-poisoned]
2025-10-30 13:05:51 +00:00
77330f39e4 Update
[ghstack-poisoned]
2025-10-29 14:33:47 +00:00
eafb84aebd Update (base update)
[ghstack-poisoned]
2025-10-29 14:33:47 +00:00
bf5e9cc835 Update
[ghstack-poisoned]
2025-10-29 13:16:06 +00:00
83370ee71f Update (base update)
[ghstack-poisoned]
2025-10-29 13:16:06 +00:00
8e2e74b12a Update
[ghstack-poisoned]
2025-10-29 10:50:26 +00:00
dbb5565e17 Update (base update)
[ghstack-poisoned]
2025-10-29 10:50:26 +00:00
3a41807da9 Update
[ghstack-poisoned]
2025-10-29 10:45:32 +00:00
670828a5bb Update (base update)
[ghstack-poisoned]
2025-10-29 10:45:32 +00:00
b0c4e5ce92 Update
[ghstack-poisoned]
2025-10-28 16:04:05 +00:00
23d74eb617 Update (base update)
[ghstack-poisoned]
2025-10-28 16:04:05 +00:00
3cc8b64300 Update
[ghstack-poisoned]
2025-10-28 15:45:49 +00:00
5282147127 Update (base update)
[ghstack-poisoned]
2025-10-28 15:45:49 +00:00
367f40a7e0 Update
[ghstack-poisoned]
2025-10-28 15:39:23 +00:00
c67c516653 Update (base update)
[ghstack-poisoned]
2025-10-28 15:39:23 +00:00
a451258d9c Update
[ghstack-poisoned]
2025-10-28 15:20:44 +00:00
ef90141cbf Update (base update)
[ghstack-poisoned]
2025-10-28 15:20:44 +00:00
5d84e13851 Update
[ghstack-poisoned]
2025-10-28 15:12:10 +00:00
ba63727c2e Update (base update)
[ghstack-poisoned]
2025-10-28 15:12:10 +00:00
02a255ae8e Update
[ghstack-poisoned]
2025-10-28 15:08:06 +00:00
329d47c055 Update (base update)
[ghstack-poisoned]
2025-10-28 15:08:06 +00:00
295a042e39 Update
[ghstack-poisoned]
2025-10-28 14:49:48 +00:00
82e7131068 Update (base update)
[ghstack-poisoned]
2025-10-28 14:49:48 +00:00
b61a11e8d9 Update
[ghstack-poisoned]
2025-10-28 14:07:59 +00:00
00a615d1e2 Update (base update)
[ghstack-poisoned]
2025-10-28 14:07:59 +00:00
9f582f55af Update
[ghstack-poisoned]
2025-10-28 13:58:07 +00:00
c4fc3c53e1 Update (base update)
[ghstack-poisoned]
2025-10-28 13:58:07 +00:00
cd82b0f7d9 Update
[ghstack-poisoned]
2025-10-28 13:44:50 +00:00
139222da06 Update (base update)
[ghstack-poisoned]
2025-10-28 13:44:50 +00:00
2450d02e97 Update
[ghstack-poisoned]
2025-10-28 12:02:22 +00:00
0820b97e78 Update (base update)
[ghstack-poisoned]
2025-10-28 12:02:22 +00:00
b421538f59 Update
[ghstack-poisoned]
2025-10-28 11:48:27 +00:00
3efbfb3f6f Update (base update)
[ghstack-poisoned]
2025-10-28 11:48:27 +00:00
2325197448 Update
[ghstack-poisoned]
2025-10-27 17:15:01 +00:00
a849ab3e44 Update (base update)
[ghstack-poisoned]
2025-10-27 17:15:01 +00:00
994fe49902 Update
[ghstack-poisoned]
2025-10-27 15:12:21 +00:00
bc85bf7ed1 Update (base update)
[ghstack-poisoned]
2025-10-27 15:12:21 +00:00
ba1fe373be Update
[ghstack-poisoned]
2025-10-27 15:00:12 +00:00
28a754f37b Update (base update)
[ghstack-poisoned]
2025-10-27 15:00:12 +00:00
cc7c1c81f6 Update
[ghstack-poisoned]
2025-10-27 12:38:28 +00:00
f78c4dee42 Update (base update)
[ghstack-poisoned]
2025-10-27 12:38:28 +00:00
2514f9d62f Update
[ghstack-poisoned]
2025-10-27 12:28:17 +00:00
51122c815f Update (base update)
[ghstack-poisoned]
2025-10-27 12:28:17 +00:00
d7fd08839f Update
[ghstack-poisoned]
2025-10-27 12:04:23 +00:00
8f73b7cb35 Update (base update)
[ghstack-poisoned]
2025-10-27 12:04:23 +00:00
c587a960fb Update
[ghstack-poisoned]
2025-10-27 11:48:10 +00:00
25d8411fb5 Update (base update)
[ghstack-poisoned]
2025-10-27 11:48:10 +00:00
eabecd05c5 Update
[ghstack-poisoned]
2025-10-27 11:41:21 +00:00
edca2c8698 Update (base update)
[ghstack-poisoned]
2025-10-27 11:28:48 +00:00
f2de9313f4 Update
[ghstack-poisoned]
2025-10-27 11:28:48 +00:00
3ac256e289 Update (base update)
[ghstack-poisoned]
2025-10-24 16:55:04 +00:00
94055d73d4 Update
[ghstack-poisoned]
2025-10-24 16:55:04 +00:00
dc09f97271 Update (base update)
[ghstack-poisoned]
2025-10-24 16:48:45 +00:00
841b8a27b8 Update
[ghstack-poisoned]
2025-10-24 16:48:45 +00:00
9 changed files with 182 additions and 45 deletions

View File

@ -170,10 +170,14 @@ static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const
#if defined(CUDA_VERSION) || defined(USE_ROCM)
const auto scalar_type = mat1.scalar_type();
return (beta.toComplexDouble() == 1.0
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
// is to use lt interface only when self is bias.
&& self.dim() == 1 && self.sizes()[0] == mat2_sizes[1] && self.is_contiguous()
&& result.dim() == 2 && result.is_contiguous()
// Conditions for bias to be fusable
&& (
self.is_contiguous() &&
// NOTE: fine to have 1-len dims to the left from the right-most one
(self.dim() == 1 || self.squeeze().dim() == 1) &&
self.sizes().back() == mat2_sizes[1]
)
&& ( // some dtype restrictions
#ifndef USE_ROCM
scalar_type == at::ScalarType::Double ||

View File

@ -553,7 +553,7 @@ class TestPatternMatcher(TestCase):
torch.randn(16, 16, device=GPU_TYPE),
torch.randn(16, 16, device=GPU_TYPE),
torch.randn(16, 16, device=GPU_TYPE),
True,
False,
),
(
torch.randn(8, device=GPU_TYPE),
@ -687,17 +687,20 @@ class TestPatternMatcher(TestCase):
FileCheck().check("call").check_not(".run").run(code[0])
def test_cat_addmm(self):
def fn(a, b, c):
def fn(b1, b2, b3, mat1, mat2, mat3):
return torch.cat(
[
torch.addmm(a, b, c),
torch.addmm(b, c, a),
torch.addmm(c, a, b),
torch.addmm(b1, mat1, mat2),
torch.addmm(b2, mat1, mat3),
torch.addmm(b3, mat2, mat3),
],
1,
)
args = [
torch.randn(16, device=GPU_TYPE),
torch.randn(16, device=GPU_TYPE),
torch.randn(16, device=GPU_TYPE),
torch.randn(16, 16, device=GPU_TYPE),
torch.randn(16, 16, device=GPU_TYPE),
torch.randn(16, 16, device=GPU_TYPE),

View File

@ -15270,7 +15270,7 @@ if RUN_GPU:
),
(
fn3,
"triton_poi_fused_native_layer_norm_relu",
"triton_poi_fused_addmm_native_layer_norm",
(torch.randn(4, 4, device=GPU_TYPE),),
),
]
@ -15283,7 +15283,7 @@ if RUN_GPU:
),
(
fn3,
"triton_poi_fused_LayerNorm_ReLU",
"triton_poi_fused_LayerNorm_Linear_ReLU",
(torch.randn(4, 4, device=GPU_TYPE),),
),
]

View File

@ -7328,9 +7328,11 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
m2 = torch.randn(50, 25, device=device).to(dtype)
self._test_addmm_addmv(func, M, m1, m2, activation=activation)
# vector-shaped bias and beta=1 result in epilogue fusion in CUDA
# vector-shaped bias (or with 1-len dims on the left from the leading dim)
# and beta=1 result in epilogue fusion in CUDA
V = torch.randn(25, device=device).to(dtype)
self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation)
self._test_addmm_addmv(func, V.unsqueeze(0), m1, m2, beta=1, activation=activation)
# Test 0-strided
M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25)
@ -7357,8 +7359,9 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
self._test_addmm_addmv(func, M, m1, m2, transpose_out=t4, activation=activation)
if t1:
# use vector V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1)
# use vector/(1 by k)-shaped V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1)
self._test_addmm_addmv(func, V, m1, m2, beta=1, transpose_out=t4, activation=activation,)
self._test_addmm_addmv(func, V.unsqueeze(0), m1, m2, beta=1, transpose_out=t4, activation=activation,)
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})

View File

@ -5,6 +5,7 @@ import time
import unittest
from itertools import product
from functools import partial
from typing import Callable
import torch
@ -90,14 +91,21 @@ class TestMatmulCuda(InductorTestCase):
torch.backends.cuda.matmul.allow_tf32 = True
super().tearDown()
def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = False, fp16_accumulate: bool = False):
def cublas_addmm(
self,
size: int,
dtype: torch.dtype,
reduced_precision: bool = False,
fp16_accumulate: bool = False,
bias_shape_modifier: Callable | None = None,
):
#
# Check for catastrophic cuBLAS inaccuracy by measuring the deviation between
# results from the CUDA invocation of torch.addmm and the CPU invocation
# (which does not use CUDA backend).
#
# Get dims
n, m, p = (size + 1, size, size + 2)
m, k, n = (size + 1, size, size + 2)
# Disable reduced precision reductions in BFloat16 to bypass some kernels
# which fail the threshold check
orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
@ -109,10 +117,12 @@ class TestMatmulCuda(InductorTestCase):
# Make random tensors on CPU (seed set on common_utils.py import)
# (Not using numpy because it does not support bfloat16)
make_arg = partial(make_tensor, dtype=dtype, device="cpu")
bias_shape_modifier = (lambda shape: shape) if bias_shape_modifier is None else bias_shape_modifier
m_input = make_arg(bias_shape_modifier((m, n)))
m_1 = make_arg((m, k))
m_2 = make_arg((k, n))
m_beta = make_arg(1)
m_input = make_arg((n, p))
m_1 = make_arg((n, m))
m_2 = make_arg((m, p))
# scale to abate overflows in fp16 accum
if fp16_accumulate:
m_1 = m_1 / 100
@ -179,6 +189,25 @@ class TestMatmulCuda(InductorTestCase):
with blas_library_context(backend):
self.cublas_addmm(size, dtype, True)
@onlyCUDA
# imported 'tol' as 'xtol' to avoid aliasing in code above
@toleranceOverride({torch.float16: xtol(atol=1e-3, rtol=1e-4),
torch.bfloat16: xtol(atol=1e-3, rtol=1e-4),
torch.float32: xtol(atol=1e-3, rtol=1e-4)})
@dtypes(torch.bfloat16, torch.float16, torch.float32)
@parametrize("size", [128])
@parametrize("backend", ["cublas", "cublaslt"])
def test_cublas_addmm_bias_shapes(self, size: int, dtype: torch.dtype, backend):
with blas_library_context(backend):
# 2D bias
self.cublas_addmm(size, dtype, bias_shape_modifier=lambda shape: shape)
# 1D bias which is row-broadcast to 2D
self.cublas_addmm(size, dtype, bias_shape_modifier=lambda shape: (1, shape[-1]))
# 1D bias which row-broadcasts
self.cublas_addmm(size, dtype, bias_shape_modifier=lambda shape: (shape[-1],))
@onlyCUDA
@dtypes(torch.float16)
# m == 4 chooses OUTPUT_TYPE reduction on H200

View File

@ -1693,7 +1693,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
)
from torch._higher_order_ops.utils import _maybe_fake_tracing
from torch._inductor.utils import is_pointwise_use
from torch._inductor.utils import has_only_pointwise_uses
with tx.fake_mode:
sub_args_fake = [
@ -1712,9 +1712,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
for node in fx.graph.nodes:
# Check that the combine_fn is pointwise, if combine_mode='pointwise'
if not all(
is_pointwise_use(use) or use.op == "output" for use in node.users
):
if not has_only_pointwise_uses(node, select_output=True):
raise RuntimeError(
"For combine_mode='pointwise', the combine_fn needs to be pointwise"
)

View File

@ -51,8 +51,8 @@ from ..utils import (
decode_device,
get_all_devices,
get_gpu_type,
has_uses_tagged_as,
is_gpu,
is_pointwise_use,
OPTIMUS_EXCLUDE_POST_GRAD,
)
from ..virtualized import V
@ -1505,13 +1505,39 @@ def view_to_reshape(gm):
nd.target = torch.ops.aten.reshape.default
# Relevant for addmm and (add + mm)/(mm + add)
# Follows the dispatch logic for cuBLASLt at
# aten/src/ATen/native/cuda/Blas.cpp::isInputCompliesAddmmCudaLt
def _cublaslt_can_fuse_bias_epilogue(inp, mat1, mat2):
if config.max_autotune_gemm:
return False
# match the dispatch logic for cuBLASLT at aten/src/ATen/native/cuda/Blas.cpp
if not (inp.is_cuda and inp.dim() == 1 and inp.is_contiguous()):
return False
if not (mat1.dim() == 2 and mat2.dim() == 2):
return False
if inp.size(0) != mat2.size(1):
return False
if inp.dtype != mat1.dtype or inp.dtype != mat2.dtype:
return False
return True
def should_prefer_unfused_addmm(match):
inp = match.kwargs["inp"]
if not is_gpu(inp.meta["val"].device.type):
return False
output = match.output_node()
return all(is_pointwise_use(use) for use in output.users)
if has_uses_tagged_as(match.output_node(), (torch.Tag.pointwise, torch.Tag.reduction)):
return True
else:
args_val = (arg.meta["val"] for arg in (inp, *match.args))
return not _cublaslt_can_fuse_bias_epilogue(*args_val)
@register_graph_pattern(

View File

@ -80,9 +80,9 @@ from .ir import (
from .utils import (
ceildiv,
decode_device,
has_only_pointwise_uses,
is_dynamic,
is_gpu,
is_pointwise_use,
is_view,
needs_fallback_due_to_atomic_add_limitations,
pad_listlike,
@ -1850,10 +1850,7 @@ def cat(inputs, dim=0):
(len(inputs) <= config.max_pointwise_cat_inputs)
and all(op_count(t) <= MAX_SIMPLE_OP_COUNT for t in inputs)
):
pointwise_uses = all(
is_pointwise_use(use, additional_pointwise_ops)
for use in V.current_node.users
)
pointwise_uses = has_only_pointwise_uses(V.current_node)
# fuse in case we will be used in a pointwise node, and there are any inputs we
# we can prevent materialization of.
fuse_pointwise_use = (

View File

@ -529,28 +529,105 @@ def is_view(op: torch._ops.OpOverload) -> bool:
return any(a.alias_info is not None for a in op._schema.arguments)
def is_pointwise_use(
use: Node,
is_pointwise_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
class LogicalConnective(enum.Enum):
OR = enum.auto()
AND = enum.auto()
def has_uses(
target: Node,
use_selector_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
*,
select_output: bool = False,
) -> bool:
"""
Do all uses of this op have torch.Tag.pointwise or return True for optional `is_pointwise_fn`
Given a target, explore the uses of `target` by applying `use_selector_fn`
on them, and then aggregate these booleans with the `use_aggregate_type`
logical connective.
Uses in view ops will follow the views uses.
"""
def get_use_aggregate_fn(
use_aggregate_type: LogicalConnective,
) -> Callable[[Iterator[Any]], bool]:
match use_aggregate_type:
case LogicalConnective.AND:
return all
case LogicalConnective.OR:
return any
case _:
return any
use_aggregate_fn = get_use_aggregate_fn(use_aggregate_type)
def has_uses_impl(use: Node) -> bool:
if select_output and use.op == "output":
return True
if use.op != "call_function":
return False
if not (
isinstance(use.target, torch._ops.OpOverload)
or use.target is operator.getitem
):
return False
target = cast(torch._ops.OpOverload, use.target)
# Process getitem and view
if target is operator.getitem or is_view(target):
return use_aggregate_fn(has_uses_impl(user) for user in use.users)
return use_selector_fn(target)
return use_aggregate_fn(has_uses_impl(user) for user in target.users)
def has_only_uses(
target: Node,
use_selector_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
*,
select_output: bool = False,
) -> bool:
return has_uses(target, use_selector_fn, LogicalConnective.AND, select_output=select_output)
def has_uses_tagged_as(
target: Node,
use_tags: Collection[torch.Tag],
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
*,
select_output: bool = False,
) -> bool:
"""
Is there a use with given tags?
"""
return has_uses(
target,
lambda use: any(tag in use_tags for tag in use.tags),
use_aggregate_type,
select_output=select_output,
)
def has_only_pointwise_uses(
target: Node,
*,
select_output: bool = False,
) -> bool:
"""
Do all uses of target have torch.Tag.pointwise?
Uses in views ops will follow the views uses
"""
if use.op != "call_function":
return False
if not (
isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem
):
return False
target = cast(torch._ops.OpOverload, use.target)
if target is operator.getitem or is_view(target):
return all(is_pointwise_use(u, is_pointwise_fn) for u in use.users)
return torch.Tag.pointwise in target.tags or is_pointwise_fn(target)
return has_uses_tagged_as(
target,
use_tags=(torch.Tag.pointwise,),
use_aggregate_type=LogicalConnective.AND,
select_output=select_output,
)
def gen_gm_and_inputs(