mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
Add support for capturing tensors with score_mod (#124444)
```
import torch
from torch import nn
import torch.nn.functional as F
import torch._inductor.config as config
# torch.set_default_device('cuda')
import torch
from torch.nn.attention._templated_attention import _templated_attention as templated_attention
from triton.testing import do_bench
from torch.nn.attention import SDPBackend, sdpa_kernel
index = torch.ops.aten
torch.manual_seed(0)
B = 16
H = 16
S = 2048
D = 64
head_scale = torch.randn(H, device='cuda')
def alibi(score, batch, head, token_q, token_kv):
return score + torch.ops.aten.index(head_scale, [head]) * (token_q - token_kv)
bias = torch.randn(H, S, S, dtype=torch.float16, device='cuda')
query = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)
key = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)
value = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)
compiled = torch.compile(templated_attention)
out = compiled(query, key, value, score_mod=alibi)
out2 = templated_attention(query, key, value,score_mod=alibi)
print((out - out2).abs().mean())
assert (out - out2).abs().mean() < 1e-3
print("Flash (no mask): ", do_bench(lambda: F.scaled_dot_product_attention(query, key, value)))
print("Flash (mask): ", do_bench(lambda: F.scaled_dot_product_attention(query, key, value, attn_mask=bias)))
print("flexattention: ", do_bench(lambda: compiled(query, key, value, score_mod=alibi)))
```
<img width="324" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/18c175d0-2720-4dfd-8747-85b8a8f609f5">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124444
Approved by: https://github.com/jansel, https://github.com/drisspg
This commit is contained in:
@ -4,7 +4,7 @@ import functools
|
||||
from collections import namedtuple
|
||||
from typing import Callable
|
||||
|
||||
from unittest import expectedFailure, skipUnless
|
||||
from unittest import skip, skipUnless
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -28,6 +28,8 @@ supported_platform = skipUnless(
|
||||
Tolerances = namedtuple("Tolerances", ["atol", "rtol"])
|
||||
torch.set_float32_matmul_precision("high")
|
||||
|
||||
index = torch.ops.aten.index
|
||||
|
||||
|
||||
def create_attention(score_mod):
|
||||
return functools.partial(_templated_attention, score_mod=score_mod)
|
||||
@ -39,6 +41,8 @@ test_dtypes = (
|
||||
else [torch.float16, torch.float32]
|
||||
)
|
||||
|
||||
test_dtypes_fast = [torch.float16]
|
||||
|
||||
# TODO float16 was causing ERRORs for tests on ROCm
|
||||
# See https://github.com/pytorch/pytorch/issues/123531
|
||||
if common_utils.TEST_WITH_ROCM:
|
||||
@ -53,13 +57,19 @@ def _causal_mod(score, b, h, token_q, token_kv):
|
||||
return torch.where(token_q >= token_kv, score, float("-inf"))
|
||||
|
||||
|
||||
B = 4
|
||||
H = 8
|
||||
S = 2048
|
||||
D = 64
|
||||
|
||||
|
||||
class TestTemplatedSDPA(InductorTestCase):
|
||||
def run_test(self, score_mod: Callable, dtype: torch.dtype = torch.float16):
|
||||
sdpa_partial = create_attention(score_mod)
|
||||
compiled_sdpa = torch.compile(sdpa_partial)
|
||||
q = torch.randn((4, 8, 2048, 64), dtype=dtype, device="cuda")
|
||||
k = torch.randn((4, 8, 2048, 64), dtype=dtype, device="cuda")
|
||||
v = torch.randn((4, 8, 2048, 64), dtype=dtype, device="cuda")
|
||||
q = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
k = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
v = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
|
||||
golden_out = sdpa_partial(
|
||||
q.to(torch.float64), k.to(torch.float64), v.to(torch.float64)
|
||||
)
|
||||
@ -147,23 +157,116 @@ class TestTemplatedSDPA(InductorTestCase):
|
||||
|
||||
self.run_test(composed_score_mod, dtype)
|
||||
|
||||
# TODO We are currently not capturing free variables in the closure correctly
|
||||
@expectedFailure
|
||||
@supported_platform
|
||||
@common_utils.parametrize("dtype", test_dtypes)
|
||||
def test_captured_buffers(self, dtype: torch.dtype):
|
||||
head_offset = torch.rand(8, device="cuda", dtype=dtype)
|
||||
head_offset = torch.rand(H, device="cuda", dtype=dtype)
|
||||
|
||||
def score_mod(score, b, h, m, n):
|
||||
return score + head_offset[h]
|
||||
return score + index(head_offset, [h])
|
||||
|
||||
self.run_test(score_mod, dtype)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("dtype", test_dtypes_fast)
|
||||
def test_seq_masking(self, dtype):
|
||||
seq_idx = torch.zeros(S, device="cuda", dtype=torch.bool)
|
||||
seq_idx[S // 2 :] = 1
|
||||
|
||||
def seq_mask_mod(score, b, h, q, kv):
|
||||
return torch.where(
|
||||
index(seq_idx, [q]) == index(seq_idx, [kv]), score, float("-inf")
|
||||
)
|
||||
|
||||
self.run_test(seq_mask_mod, dtype)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("dtype", test_dtypes_fast)
|
||||
def test_load_from_bias_seq_only(self, dtype):
|
||||
bias = torch.randn(S, S, device="cuda", dtype=dtype)
|
||||
|
||||
def bias_mod(score, b, h, q, kv):
|
||||
return score + index(bias, [q, kv])
|
||||
|
||||
self.run_test(bias_mod, dtype)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("dtype", test_dtypes_fast)
|
||||
def test_load_from_bias_seq_batch(self, dtype):
|
||||
bias = torch.randn(B, S, S, device="cuda", dtype=dtype)
|
||||
|
||||
def bias_mod(score, b, h, q, kv):
|
||||
return score + index(bias, [b, q, kv])
|
||||
|
||||
self.run_test(bias_mod, dtype)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("dtype", test_dtypes_fast)
|
||||
def test_load_from_bias_head_seq_batch(self, dtype):
|
||||
bias = torch.randn(B, H, S, S, device="cuda", dtype=dtype)
|
||||
|
||||
def bias_mod(score, b, h, q, kv):
|
||||
return score + index(bias, [b, h, q, kv])
|
||||
|
||||
self.run_test(bias_mod, dtype)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("dtype", test_dtypes_fast)
|
||||
def test_load_rel_bias(self, dtype):
|
||||
rel_bias = torch.randn(2 * S, device="cuda", dtype=dtype)
|
||||
|
||||
def bias_mod(score, b, h, q, kv):
|
||||
return score + index(rel_bias, [(q - kv) + S])
|
||||
|
||||
self.run_test(bias_mod, dtype)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("dtype", test_dtypes_fast)
|
||||
def test_dependent_causal_bidirectional(self, dtype):
|
||||
num_bidirectional = torch.randint(0, S, (B,), device="cuda", dtype=torch.int32)
|
||||
|
||||
def bias_mod(score, b, h, q, kv):
|
||||
causal_attention = q >= kv
|
||||
cur_num_bidirectional = index(num_bidirectional, (b,))
|
||||
bidirectional_attention_on_video = (q <= cur_num_bidirectional) & (
|
||||
kv <= cur_num_bidirectional
|
||||
)
|
||||
return torch.where(
|
||||
bidirectional_attention_on_video | causal_attention,
|
||||
score,
|
||||
-float("inf"),
|
||||
)
|
||||
|
||||
self.run_test(bias_mod, dtype)
|
||||
|
||||
@supported_platform
|
||||
@skip("Triton bug ") # https://github.com/pytorch/pytorch/issues/124571
|
||||
@common_utils.parametrize("dtype", test_dtypes)
|
||||
def test_njt_causal(self, dtype):
|
||||
offsets = torch.tensor(
|
||||
[0, 1024, 1024 + 512, S], device="cuda", dtype=torch.int32
|
||||
)
|
||||
seq_idx = torch.zeros(S, device="cuda", dtype=torch.int32)
|
||||
for idx in range(len(offsets) - 1):
|
||||
seq_idx[offsets[idx] : offsets[idx + 1]] = idx
|
||||
|
||||
def create_njt_wrapper(orig_score_mod, offsets, seq_idx):
|
||||
def njt_score_mod(qk, b, h, q, kv):
|
||||
q_nested = q - index(offsets, [index(seq_idx, [q])])
|
||||
kv_nested = kv - index(offsets, [index(seq_idx, [kv])])
|
||||
return orig_score_mod(qk, b, h, q_nested, kv_nested)
|
||||
|
||||
return njt_score_mod
|
||||
|
||||
causal_njt = create_njt_wrapper(_causal_mod, offsets, seq_idx)
|
||||
|
||||
self.run_test(causal_njt, dtype)
|
||||
|
||||
@supported_platform
|
||||
def test_backwards_fails(self):
|
||||
make_tensor = functools.partial(
|
||||
torch.randn,
|
||||
(4, 8, 2048, 64),
|
||||
(B, H, S, D),
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
requires_grad=True,
|
||||
@ -177,9 +280,9 @@ class TestTemplatedSDPA(InductorTestCase):
|
||||
|
||||
@supported_platform
|
||||
def test_mixed_dtypes_fails(self):
|
||||
query = torch.randn((1, 1, 2048, 64), dtype=torch.float32, device="cuda")
|
||||
key = torch.randn((1, 1, 2048, 64), dtype=torch.float16, device="cuda")
|
||||
value = torch.randn((1, 1, 2048, 64), dtype=torch.float16, device="cuda")
|
||||
query = torch.randn((1, 1, 1024, 64), dtype=torch.float32, device="cuda")
|
||||
key = torch.randn((1, 1, 1024, 64), dtype=torch.float16, device="cuda")
|
||||
value = torch.randn((1, 1, 1024, 64), dtype=torch.float16, device="cuda")
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Expected query, key, and value to have the same dtype"
|
||||
):
|
||||
@ -201,6 +304,21 @@ class TestTemplatedSDPA(InductorTestCase):
|
||||
|
||||
self.run_test(score_mod)
|
||||
|
||||
@supported_platform
|
||||
@patch.object(torch._inductor.config, "max_autotune", True)
|
||||
def test_max_autotune_with_captured(self):
|
||||
head_scale = torch.randn(H, device="cuda")
|
||||
batch_scale = torch.randn(B, device="cuda")
|
||||
tok_scale = torch.randn(S, device="cuda")
|
||||
|
||||
def bias_mod(score, batch, head, token_q, token_kv):
|
||||
score = score + index(tok_scale, [token_q])
|
||||
score = score + index(batch_scale, [batch])
|
||||
score = score + index(head_scale, [head])
|
||||
return score
|
||||
|
||||
self.run_test(bias_mod)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("dtype", test_dtypes)
|
||||
@common_utils.parametrize("score_mod", [_identity_mod, _causal_mod])
|
||||
@ -211,7 +329,7 @@ class TestTemplatedSDPA(InductorTestCase):
|
||||
|
||||
make_tensor = functools.partial(
|
||||
torch.randn,
|
||||
(4, 8, 2048, 64),
|
||||
(B, H, S, D),
|
||||
dtype=dtype,
|
||||
device="cuda",
|
||||
requires_grad=True,
|
||||
@ -253,7 +371,7 @@ class TestTemplatedSDPA(InductorTestCase):
|
||||
def test_logsumexp_only_return(self):
|
||||
make_tensor = functools.partial(
|
||||
torch.randn,
|
||||
(4, 8, 2048, 64),
|
||||
(B, H, S, D),
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
requires_grad=True,
|
||||
@ -274,7 +392,7 @@ class TestTemplatedSDPA(InductorTestCase):
|
||||
def test_logsumexp_is_not_fused(self):
|
||||
make_tensor = functools.partial(
|
||||
torch.randn,
|
||||
(4, 8, 2048, 64),
|
||||
(B, H, S, D),
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
requires_grad=True,
|
||||
|
||||
@ -1535,12 +1535,10 @@ class TemplatedAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
) -> "VariableTracker":
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
query, key, value, score_mod, *other_buffers = self.normalize_to_args(
|
||||
args, kwargs
|
||||
)
|
||||
query, key, value, score_mod = self.normalize_to_args(args, kwargs)
|
||||
|
||||
p_args, p_kwargs = self.create_wrapped_node(tx, query, score_mod)
|
||||
proxied_args = [query, key, value, *other_buffers]
|
||||
proxied_args = [query, key, value]
|
||||
|
||||
# Store the invocation as a call
|
||||
# Norm_kwargs contains the score_function and we dont want to proxy this because
|
||||
|
||||
@ -60,7 +60,7 @@ def math_attention(
|
||||
"""
|
||||
assert len(other_buffers) == 0, "Other buffers are not yet supported."
|
||||
|
||||
scores = query @ key.transpose(-2, -1)
|
||||
scores = (query @ key.transpose(-2, -1)).to(dtype=torch.float32)
|
||||
|
||||
b = torch.arange(0, scores.size(0), device=scores.device)
|
||||
h = torch.arange(0, scores.size(1), device=scores.device)
|
||||
@ -179,9 +179,11 @@ def templated_attention_functionalize(
|
||||
assert isinstance(other_buffers_unwrapped, tuple)
|
||||
assert all(isinstance(item, torch.Tensor) for item in other_buffers_unwrapped)
|
||||
|
||||
example_vals = [torch.zeros((), dtype=query.dtype)] + [
|
||||
torch.zeros((), dtype=torch.int) for _ in range(4)
|
||||
]
|
||||
example_vals = (
|
||||
[torch.zeros((), dtype=query.dtype)]
|
||||
+ [torch.zeros((), dtype=torch.int) for _ in range(4)]
|
||||
+ list(other_buffers_unwrapped)
|
||||
)
|
||||
with ctx.redispatch_to_next() as m:
|
||||
functional_score_mod = ctx.functionalize(score_mod)
|
||||
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
||||
|
||||
@ -3412,22 +3412,14 @@ class TritonScheduling(BaseScheduling):
|
||||
buffer_names.update(node.used_buffer_names())
|
||||
|
||||
# Get buffers objects
|
||||
def _get_buffer(name: str) -> Union[ir.Buffer, ir.TensorBox]:
|
||||
if name in V.graph.name_to_buffer:
|
||||
return V.graph.name_to_buffer[name]
|
||||
elif name in V.graph.graph_inputs:
|
||||
return V.graph.graph_inputs[name]
|
||||
elif name in V.graph.constants:
|
||||
data = V.graph.constants[name]
|
||||
return ir.ConstantBuffer(
|
||||
name,
|
||||
ir.FixedLayout(
|
||||
data.device, data.dtype, *V.graph.static_sizes_strides(data)
|
||||
),
|
||||
)
|
||||
raise RuntimeError(f"Failed to find buffer matching name {name}")
|
||||
|
||||
buffers = [_get_buffer(name) for name in buffer_names]
|
||||
def _get_buffer(name: str) -> Union[ir.Buffer, ir.TensorBox]:
|
||||
buf = V.graph.get_buffer(name)
|
||||
if buf is None:
|
||||
raise RuntimeError(f"Failed to find buffer matching name {name}")
|
||||
return buf
|
||||
|
||||
buffers = [V.graph.get_buffer(name) for name in buffer_names]
|
||||
|
||||
# In theory we can separately check xnumel and rnumel are <= int_max
|
||||
# but some indexers do use the full linear index so we need to be
|
||||
|
||||
@ -660,6 +660,14 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
return self.name_to_buffer[buffer_name]
|
||||
if buffer_name in self.graph_inputs:
|
||||
return self.graph_inputs[buffer_name]
|
||||
if buffer_name in self.constants:
|
||||
data = V.graph.constants[buffer_name]
|
||||
return ir.ConstantBuffer(
|
||||
buffer_name,
|
||||
ir.FixedLayout(
|
||||
data.device, data.dtype, *V.graph.static_sizes_strides(data)
|
||||
),
|
||||
)
|
||||
return None
|
||||
|
||||
def get_dtype(self, buffer_name: str):
|
||||
|
||||
@ -3,6 +3,7 @@ import logging
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
from .. import config
|
||||
from ..lowering import empty_strided, lowerings, register_lowering
|
||||
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
|
||||
|
||||
@ -114,12 +115,14 @@ sdpa_template = TritonTemplate(
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk = tl.dot(q, k.to(MATMUL_PRECISION), acc=qk)
|
||||
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
||||
m = offs_m[:, None]
|
||||
n = start_n + offs_n[None, :]
|
||||
{{ modification(
|
||||
score="qk",
|
||||
b="off_hz // H",
|
||||
h="off_hz % H",
|
||||
m="offs_m[:, None]",
|
||||
n="start_n + offs_n[None, :]",
|
||||
m="m",
|
||||
n="n",
|
||||
out="qk"
|
||||
) | indent_except_first(2) }}
|
||||
# TODO: In the case that score_mod is linear, this can be LICMed
|
||||
@ -170,7 +173,8 @@ sdpa_template = TritonTemplate(
|
||||
)
|
||||
|
||||
|
||||
@register_lowering(torch.ops.higher_order.templated_attention)
|
||||
# TODO: We probably also need a layout constraint?
|
||||
@register_lowering(torch.ops.higher_order.templated_attention, type_promotion_kind=None)
|
||||
def templated_attention(*args, **kwargs):
|
||||
from torch._prims_common import make_contiguous_strides_for
|
||||
from ..ir import (
|
||||
@ -182,7 +186,7 @@ def templated_attention(*args, **kwargs):
|
||||
TensorBox,
|
||||
)
|
||||
|
||||
query, key, value, subgraph = args
|
||||
query, key, value, subgraph, *other_buffers = args
|
||||
|
||||
def create_placeholder(name: str, dtype: torch.dtype) -> InputBuffer:
|
||||
return TensorBox.create(
|
||||
@ -272,17 +276,22 @@ def templated_attention(*args, **kwargs):
|
||||
configs: List[Any] = []
|
||||
if query.get_dtype() == torch.float32:
|
||||
configs.append((64, 64, 4, 3))
|
||||
configs += [
|
||||
(128, 64, 4, 3),
|
||||
(128, 128, 4, 3),
|
||||
(128, 128, 8, 2),
|
||||
(64, 128, 4, 3),
|
||||
]
|
||||
|
||||
else:
|
||||
configs.append((128, 64, 4, 3))
|
||||
if config.max_autotune:
|
||||
configs += [
|
||||
(128, 64, 4, 3),
|
||||
(128, 128, 4, 3),
|
||||
(128, 128, 8, 2),
|
||||
(64, 128, 4, 3),
|
||||
]
|
||||
# Note, we don't need to pass in the captured buffers explicitly
|
||||
# because they're implicitly added by the score_mod function
|
||||
# We do need to explicitly pass it in for autotuning though.
|
||||
for BLOCK_M, BLOCK_N, num_warps, num_stages in configs:
|
||||
sdpa_template.maybe_append_choice(
|
||||
choices=choices,
|
||||
input_nodes=(query, key, value, logsumexp),
|
||||
input_nodes=[query, key, value, logsumexp],
|
||||
layout=layout,
|
||||
subgraphs=subgraph_buffer,
|
||||
mutated_inputs=[
|
||||
@ -298,9 +307,10 @@ def templated_attention(*args, **kwargs):
|
||||
ROWS_GUARANTEED_SAFE=False,
|
||||
OUTPUT_LOGSUMEXP=True,
|
||||
)
|
||||
inputs_for_autotuning = [query, key, value, logsumexp] + list(other_buffers)
|
||||
return (
|
||||
autotune_select_algorithm(
|
||||
"sdpa", choices, [query, key, value, logsumexp], layout
|
||||
"sdpa", choices, inputs_for_autotuning, layout
|
||||
),
|
||||
logsumexp,
|
||||
)
|
||||
|
||||
@ -37,7 +37,14 @@ from .exc import CUDACompileError
|
||||
from .ir import ChoiceCaller, PrimitiveInfoType
|
||||
from .runtime.hints import DeviceProperties
|
||||
from .runtime.runtime_utils import do_bench
|
||||
from .utils import get_dtype_size, Placeholder, sympy_dot, sympy_product, unique
|
||||
from .utils import (
|
||||
get_dtype_size,
|
||||
Placeholder,
|
||||
sympy_dot,
|
||||
sympy_index_symbol,
|
||||
sympy_product,
|
||||
unique,
|
||||
)
|
||||
from .virtualized import V
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -269,20 +276,23 @@ class TritonTemplateKernel(TritonKernel):
|
||||
potential multiple modifications
|
||||
"""
|
||||
|
||||
def add_input(name):
|
||||
return self.args.input(name)
|
||||
|
||||
class PlaceholderSubstitution(V.WrapperHandler): # type: ignore[name-defined]
|
||||
self.name = "PlaceholderSubstitution"
|
||||
|
||||
def load(self, name: str, index: sympy.Expr):
|
||||
if name not in fixed_inputs:
|
||||
raise AssertionError(
|
||||
f"All loads should be coming from fixed inputs - {name}"
|
||||
)
|
||||
# If it's not a fixed input, it's a load from a captured
|
||||
# tensor
|
||||
var = add_input(name)
|
||||
return f"tl.load({var} + {index})"
|
||||
|
||||
return f"({fixed_inputs[name]})"
|
||||
|
||||
# TODO Doesn't work yet
|
||||
def indirect_indexing(self, index_var, size, check):
|
||||
return self._inner.indirect_indexing(index_var, size, False)
|
||||
# return sympy_symbol(str(index_var))
|
||||
return sympy_index_symbol(str(index_var))
|
||||
|
||||
# if self.modification_cache is None:
|
||||
with V.set_ops_handler(PlaceholderSubstitution(V.ops)):
|
||||
@ -589,16 +599,25 @@ class TritonTemplate(KernelTemplate):
|
||||
+ "-"
|
||||
)
|
||||
mod = PyCodeCache.load(code, extra)
|
||||
_, call_args, _ = kernel.args.python_argdefs()
|
||||
|
||||
expected_args = list(unique(x.get_name() for x in input_nodes))
|
||||
expected_args.extend([fake_out.get_name()])
|
||||
assert list(call_args)[: len(expected_args)] == expected_args, (
|
||||
call_args,
|
||||
expected_args,
|
||||
input_call_args = tuple(kernel.args.input_buffers.keys())
|
||||
output_call_args = tuple(kernel.args.output_buffers.keys())
|
||||
|
||||
# We expect the input_buffer order to be [*input_nodes, *captured_buffers]
|
||||
expected_input_args = tuple(unique(x.get_name() for x in input_nodes))
|
||||
expected_output_args = (fake_out.get_name(),)
|
||||
assert input_call_args[: len(expected_input_args)] == expected_input_args, (
|
||||
input_call_args,
|
||||
expected_input_args,
|
||||
)
|
||||
assert output_call_args == expected_output_args, (
|
||||
output_call_args,
|
||||
expected_output_args,
|
||||
)
|
||||
|
||||
full_input_nodes = tuple([V.graph.get_buffer(k) for k in input_call_args])
|
||||
extra_args = V.graph.sizevars.size_hints(
|
||||
map(sympy.expand, call_args[len(expected_args) :]),
|
||||
map(sympy.expand, tuple(kernel.args.sizevars.keys())),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
)
|
||||
|
||||
@ -636,13 +655,13 @@ class TritonTemplate(KernelTemplate):
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0),
|
||||
input_tensor_meta=TensorMeta.from_irnodes(input_nodes),
|
||||
input_tensor_meta=TensorMeta.from_irnodes(full_input_nodes),
|
||||
output_tensor_meta=TensorMeta.from_irnodes(layout),
|
||||
)
|
||||
|
||||
return TritonTemplateCaller(
|
||||
kernel_hash_name,
|
||||
input_nodes,
|
||||
full_input_nodes,
|
||||
layout,
|
||||
make_kernel_render,
|
||||
extra.strip("-").replace("-", ", "),
|
||||
|
||||
Reference in New Issue
Block a user