Compare commits

...

53 Commits

Author SHA1 Message Date
3467d3e123 Update
[ghstack-poisoned]
2025-10-31 10:18:28 +00:00
f5ee26463a Update (base update)
[ghstack-poisoned]
2025-10-31 10:18:28 +00:00
ab62572b18 Update
[ghstack-poisoned]
2025-10-30 15:49:35 +00:00
f7f0cc0ace Update (base update)
[ghstack-poisoned]
2025-10-30 15:49:35 +00:00
59ca356557 Update
[ghstack-poisoned]
2025-10-30 14:29:17 +00:00
c3e8577183 Update (base update)
[ghstack-poisoned]
2025-10-30 14:29:17 +00:00
35613ea658 Update
[ghstack-poisoned]
2025-10-30 13:05:51 +00:00
5be840ccc2 Update (base update)
[ghstack-poisoned]
2025-10-30 13:05:51 +00:00
57b5d96fcd Update
[ghstack-poisoned]
2025-10-29 15:42:53 +00:00
5d02965b7c Update (base update)
[ghstack-poisoned]
2025-10-29 15:42:53 +00:00
d220390880 Update
[ghstack-poisoned]
2025-10-29 14:33:46 +00:00
c6cfcf49e1 Update (base update)
[ghstack-poisoned]
2025-10-29 14:33:46 +00:00
56c0ca21f0 Update
[ghstack-poisoned]
2025-10-29 13:16:06 +00:00
85b7edb52b Update (base update)
[ghstack-poisoned]
2025-10-29 13:16:06 +00:00
02fa1ad97a Update
[ghstack-poisoned]
2025-10-29 10:50:25 +00:00
c2eb709432 Update (base update)
[ghstack-poisoned]
2025-10-29 10:50:25 +00:00
c1e7268182 Update
[ghstack-poisoned]
2025-10-29 10:45:31 +00:00
acc92f8dc1 Update (base update)
[ghstack-poisoned]
2025-10-29 10:45:31 +00:00
e50c1a04b7 Update
[ghstack-poisoned]
2025-10-28 16:04:05 +00:00
983443cd20 Update
[ghstack-poisoned]
2025-10-28 15:45:48 +00:00
b76d9cfc7f Update (base update)
[ghstack-poisoned]
2025-10-28 15:39:22 +00:00
d8c4903a3e Update
[ghstack-poisoned]
2025-10-28 15:39:22 +00:00
7f855e5590 Update
[ghstack-poisoned]
2025-10-28 15:20:43 +00:00
7ba226eb14 Update
[ghstack-poisoned]
2025-10-28 15:12:09 +00:00
44bac1e070 Update (base update)
[ghstack-poisoned]
2025-10-28 15:08:05 +00:00
24cdf875b8 Update
[ghstack-poisoned]
2025-10-28 15:08:05 +00:00
a9888afe19 Update
[ghstack-poisoned]
2025-10-28 14:49:44 +00:00
10df61b3c2 Update (base update)
[ghstack-poisoned]
2025-10-28 14:07:59 +00:00
f7d934e8a7 Update
[ghstack-poisoned]
2025-10-28 14:07:59 +00:00
7af0937c58 Update (base update)
[ghstack-poisoned]
2025-10-28 13:58:07 +00:00
b41d593878 Update
[ghstack-poisoned]
2025-10-28 13:58:07 +00:00
fcf212b2b7 Update (base update)
[ghstack-poisoned]
2025-10-28 13:44:50 +00:00
8f71493b92 Update
[ghstack-poisoned]
2025-10-28 13:44:50 +00:00
a9117e9028 Update (base update)
[ghstack-poisoned]
2025-10-28 12:02:21 +00:00
d7c68ae739 Update
[ghstack-poisoned]
2025-10-28 12:02:21 +00:00
172ff9f1d3 Update (base update)
[ghstack-poisoned]
2025-10-28 11:48:26 +00:00
09ae386f48 Update
[ghstack-poisoned]
2025-10-28 11:48:26 +00:00
a137f705d2 Update (base update)
[ghstack-poisoned]
2025-10-27 17:15:01 +00:00
9af5881598 Update
[ghstack-poisoned]
2025-10-27 17:15:01 +00:00
5123a3ad68 Update (base update)
[ghstack-poisoned]
2025-10-27 15:12:20 +00:00
37da895a9b Update
[ghstack-poisoned]
2025-10-27 15:12:20 +00:00
885d7b9f8d Update
[ghstack-poisoned]
2025-10-27 12:38:27 +00:00
135a48757d Update
[ghstack-poisoned]
2025-10-27 12:28:17 +00:00
072cef4b11 Update (base update)
[ghstack-poisoned]
2025-10-27 12:04:22 +00:00
eaea290ced Update
[ghstack-poisoned]
2025-10-27 12:04:22 +00:00
df91f285d6 Update
[ghstack-poisoned]
2025-10-27 11:48:10 +00:00
f69fad4130 Update (base update)
[ghstack-poisoned]
2025-10-27 11:23:44 +00:00
3c20f6ba8d Update
[ghstack-poisoned]
2025-10-27 11:23:44 +00:00
ae6f926ede Update (base update)
[ghstack-poisoned]
2025-10-24 16:55:03 +00:00
cec4bcda84 Update
[ghstack-poisoned]
2025-10-24 16:55:03 +00:00
37aa7f9c7e Update
[ghstack-poisoned]
2025-10-24 14:27:27 +00:00
b4fffb32de Update (base update)
[ghstack-poisoned]
2025-10-24 14:08:43 +00:00
9f0c3473b0 Update
[ghstack-poisoned]
2025-10-24 14:08:43 +00:00
6 changed files with 117 additions and 15 deletions

View File

@ -170,10 +170,14 @@ static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const
#if defined(CUDA_VERSION) || defined(USE_ROCM)
const auto scalar_type = mat1.scalar_type();
return (beta.toComplexDouble() == 1.0
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
// is to use lt interface only when self is bias.
&& self.dim() == 1 && self.sizes()[0] == mat2_sizes[1] && self.is_contiguous()
&& result.dim() == 2 && result.is_contiguous()
// Conditions for bias to be fusable
&& (
self.is_contiguous() &&
// NOTE: fine to have 1-len dims to the left from the right-most one
(self.dim() == 1 || self.squeeze().dim() == 1) &&
self.sizes().back() == mat2_sizes[1]
)
&& ( // some dtype restrictions
#ifndef USE_ROCM
scalar_type == at::ScalarType::Double ||

View File

@ -15270,7 +15270,7 @@ if RUN_GPU:
),
(
fn3,
"triton_poi_fused_native_layer_norm_relu",
"triton_poi_fused_addmm_native_layer_norm",
(torch.randn(4, 4, device=GPU_TYPE),),
),
]
@ -15283,7 +15283,7 @@ if RUN_GPU:
),
(
fn3,
"triton_poi_fused_LayerNorm_ReLU",
"triton_poi_fused_LayerNorm_Linear_ReLU",
(torch.randn(4, 4, device=GPU_TYPE),),
),
]

View File

@ -7328,9 +7328,11 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
m2 = torch.randn(50, 25, device=device).to(dtype)
self._test_addmm_addmv(func, M, m1, m2, activation=activation)
# vector-shaped bias and beta=1 result in epilogue fusion in CUDA
# vector-shaped bias (or with 1-len dims on the left from the leading dim)
# and beta=1 result in epilogue fusion in CUDA
V = torch.randn(25, device=device).to(dtype)
self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation)
self._test_addmm_addmv(func, V.unsqueeze(0), m1, m2, beta=1, activation=activation)
# Test 0-strided
M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25)
@ -7357,8 +7359,9 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
self._test_addmm_addmv(func, M, m1, m2, transpose_out=t4, activation=activation)
if t1:
# use vector V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1)
# use vector/(1 by k)-shaped V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1)
self._test_addmm_addmv(func, V, m1, m2, beta=1, transpose_out=t4, activation=activation,)
self._test_addmm_addmv(func, V.unsqueeze(0), m1, m2, beta=1, transpose_out=t4, activation=activation,)
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})

View File

@ -5,6 +5,7 @@ import time
import unittest
from itertools import product
from functools import partial
from typing import Callable
import torch
@ -90,14 +91,21 @@ class TestMatmulCuda(InductorTestCase):
torch.backends.cuda.matmul.allow_tf32 = True
super().tearDown()
def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = False, fp16_accumulate: bool = False):
def cublas_addmm(
self,
size: int,
dtype: torch.dtype,
reduced_precision: bool = False,
fp16_accumulate: bool = False,
bias_shape_modifier: Callable | None = None,
):
#
# Check for catastrophic cuBLAS inaccuracy by measuring the deviation between
# results from the CUDA invocation of torch.addmm and the CPU invocation
# (which does not use CUDA backend).
#
# Get dims
n, m, p = (size + 1, size, size + 2)
m, k, n = (size + 1, size, size + 2)
# Disable reduced precision reductions in BFloat16 to bypass some kernels
# which fail the threshold check
orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
@ -109,10 +117,12 @@ class TestMatmulCuda(InductorTestCase):
# Make random tensors on CPU (seed set on common_utils.py import)
# (Not using numpy because it does not support bfloat16)
make_arg = partial(make_tensor, dtype=dtype, device="cpu")
bias_shape_modifier = (lambda shape: shape) if bias_shape_modifier is None else bias_shape_modifier
m_input = make_arg(bias_shape_modifier((m, n)))
m_1 = make_arg((m, k))
m_2 = make_arg((k, n))
m_beta = make_arg(1)
m_input = make_arg((n, p))
m_1 = make_arg((n, m))
m_2 = make_arg((m, p))
# scale to abate overflows in fp16 accum
if fp16_accumulate:
m_1 = m_1 / 100
@ -179,6 +189,25 @@ class TestMatmulCuda(InductorTestCase):
with blas_library_context(backend):
self.cublas_addmm(size, dtype, True)
@onlyCUDA
# imported 'tol' as 'xtol' to avoid aliasing in code above
@toleranceOverride({torch.float16: xtol(atol=1e-3, rtol=1e-4),
torch.bfloat16: xtol(atol=1e-3, rtol=1e-4),
torch.float32: xtol(atol=1e-3, rtol=1e-4)})
@dtypes(torch.bfloat16, torch.float16, torch.float32)
@parametrize("size", [128])
@parametrize("backend", ["cublas", "cublaslt"])
def test_cublas_addmm_bias_shapes(self, size: int, dtype: torch.dtype, backend):
with blas_library_context(backend):
# 2D bias
self.cublas_addmm(size, dtype, bias_shape_modifier=lambda shape: shape)
# 1D bias which is row-broadcast to 2D
self.cublas_addmm(size, dtype, bias_shape_modifier=lambda shape: (1, shape[-1]))
# 1D bias which row-broadcasts
self.cublas_addmm(size, dtype, bias_shape_modifier=lambda shape: (shape[-1],))
@onlyCUDA
@dtypes(torch.float16)
# m == 4 chooses OUTPUT_TYPE reduction on H200

View File

@ -51,8 +51,8 @@ from ..utils import (
decode_device,
get_all_devices,
get_gpu_type,
has_uses_tagged_as,
is_gpu,
is_pointwise_use,
OPTIMUS_EXCLUDE_POST_GRAD,
)
from ..virtualized import V
@ -1510,8 +1510,10 @@ def should_prefer_unfused_addmm(match):
if not is_gpu(inp.meta["val"].device.type):
return False
output = match.output_node()
return all(is_pointwise_use(use) for use in output.users)
return has_uses_tagged_as(
match.output_node(),
(torch.Tag.pointwise, torch.Tag.reduction),
)
@register_graph_pattern(

View File

@ -553,6 +553,70 @@ def is_pointwise_use(
return torch.Tag.pointwise in target.tags or is_pointwise_fn(target)
class LogicalConnective(enum.Enum):
OR = enum.auto()
AND = enum.auto()
def has_uses(
target: Node,
use_selector_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
) -> bool:
"""
Given a target, explore the uses of `target` by applying `use_selector_fn`
on them, and then aggregate these booleans with the `use_aggregate_type`
logical connective.
Uses in view ops will follow the views uses.
"""
def get_use_aggregate_fn(
use_aggregate_type: LogicalConnective,
) -> Callable[[Iterator[Any]], bool]:
match use_aggregate_type:
case LogicalConnective.AND:
return all
case LogicalConnective.OR:
return any
case _:
return any
use_aggregate_fn = get_use_aggregate_fn(use_aggregate_type)
def has_uses_impl(use: Node) -> bool:
if use.op != "call_function":
return False
if not (
isinstance(use.target, torch._ops.OpOverload)
or use.target is operator.getitem
):
return False
target = cast(torch._ops.OpOverload, use.target)
# Process getitem and view
if target is operator.getitem or is_view(target):
return use_aggregate_fn(has_uses_impl(user) for user in use.users)
return use_selector_fn(target)
return use_aggregate_fn(has_uses_impl(user) for user in target.users)
def has_uses_tagged_as(
target: Node,
use_tags: Collection[torch.Tag],
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
) -> bool:
"""
Is there a use with given tags?
"""
return has_uses(
target, lambda use: any(tag in use_tags for tag in use.tags), use_aggregate_type
)
def gen_gm_and_inputs(
target: Any, args: list[Any], kwargs: dict[str, Any]
) -> tuple[GraphModule, list[torch.Tensor]]: