mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Made FlexAttention rewrite getitem calls to use aten.index in score_mod (#124799)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124799 Approved by: https://github.com/drisspg ghstack dependencies: #124444
This commit is contained in:
@ -12,7 +12,7 @@ const char* get_cuda_check_suffix() noexcept {
|
||||
} else {
|
||||
return "\nCUDA kernel errors might be asynchronously reported at some"
|
||||
" other API call, so the stacktrace below might be incorrect."
|
||||
"\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.";
|
||||
"\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1";
|
||||
}
|
||||
}
|
||||
std::mutex* getFreeMutex() {
|
||||
|
||||
@ -4,7 +4,7 @@ import functools
|
||||
from collections import namedtuple
|
||||
from typing import Callable
|
||||
|
||||
from unittest import skip, skipUnless
|
||||
from unittest import expectedFailure, skip, skipUnless
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -125,7 +125,7 @@ class TestTemplatedSDPA(InductorTestCase):
|
||||
head_offset = torch.rand(H, device="cuda", dtype=dtype)
|
||||
|
||||
def score_mod(score, b, h, m, n):
|
||||
return score + index(head_offset, [h])
|
||||
return score + head_offset[h]
|
||||
|
||||
self.run_test(score_mod, dtype)
|
||||
|
||||
@ -136,9 +136,7 @@ class TestTemplatedSDPA(InductorTestCase):
|
||||
seq_idx[S // 2 :] = 1
|
||||
|
||||
def seq_mask_mod(score, b, h, q, kv):
|
||||
return torch.where(
|
||||
index(seq_idx, [q]) == index(seq_idx, [kv]), score, float("-inf")
|
||||
)
|
||||
return torch.where(seq_idx[q] == seq_idx[kv], score, float("-inf"))
|
||||
|
||||
self.run_test(seq_mask_mod, dtype)
|
||||
|
||||
@ -148,7 +146,7 @@ class TestTemplatedSDPA(InductorTestCase):
|
||||
bias = torch.randn(S, S, device="cuda", dtype=dtype)
|
||||
|
||||
def bias_mod(score, b, h, q, kv):
|
||||
return score + index(bias, [q, kv])
|
||||
return score + bias[q, kv]
|
||||
|
||||
self.run_test(bias_mod, dtype)
|
||||
|
||||
@ -158,7 +156,7 @@ class TestTemplatedSDPA(InductorTestCase):
|
||||
bias = torch.randn(B, S, S, device="cuda", dtype=dtype)
|
||||
|
||||
def bias_mod(score, b, h, q, kv):
|
||||
return score + index(bias, [b, q, kv])
|
||||
return score + bias[b, q, kv]
|
||||
|
||||
self.run_test(bias_mod, dtype)
|
||||
|
||||
@ -168,7 +166,7 @@ class TestTemplatedSDPA(InductorTestCase):
|
||||
bias = torch.randn(B, H, S, S, device="cuda", dtype=dtype)
|
||||
|
||||
def bias_mod(score, b, h, q, kv):
|
||||
return score + index(bias, [b, h, q, kv])
|
||||
return score + bias[b, h, q, kv]
|
||||
|
||||
self.run_test(bias_mod, dtype)
|
||||
|
||||
@ -178,7 +176,7 @@ class TestTemplatedSDPA(InductorTestCase):
|
||||
rel_bias = torch.randn(2 * S, device="cuda", dtype=dtype)
|
||||
|
||||
def bias_mod(score, b, h, q, kv):
|
||||
return score + index(rel_bias, [(q - kv) + S])
|
||||
return score + rel_bias[(q - kv) + S]
|
||||
|
||||
self.run_test(bias_mod, dtype)
|
||||
|
||||
@ -189,7 +187,7 @@ class TestTemplatedSDPA(InductorTestCase):
|
||||
|
||||
def bias_mod(score, b, h, q, kv):
|
||||
causal_attention = q >= kv
|
||||
cur_num_bidirectional = index(num_bidirectional, (b,))
|
||||
cur_num_bidirectional = num_bidirectional[b]
|
||||
bidirectional_attention_on_video = (q <= cur_num_bidirectional) & (
|
||||
kv <= cur_num_bidirectional
|
||||
)
|
||||
@ -201,6 +199,38 @@ class TestTemplatedSDPA(InductorTestCase):
|
||||
|
||||
self.run_test(bias_mod, dtype)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("dtype", test_dtypes_fast)
|
||||
def test_natten_2d(self, dtype):
|
||||
H = 32
|
||||
W = S // H
|
||||
WINDOW = 3
|
||||
assert W * H == S
|
||||
|
||||
def get_x_y(idx):
|
||||
# This should be a floor divide, but we don't support that properly
|
||||
return idx / W, idx % W
|
||||
|
||||
def natten_mask(score, b, h, q, kv):
|
||||
q_x, q_y = get_x_y(q)
|
||||
kv_x, kv_y = get_x_y(kv)
|
||||
return torch.where(
|
||||
((q_x - kv_x).abs() <= WINDOW) | ((q_y - kv_y).abs() <= WINDOW),
|
||||
score,
|
||||
float("-inf"),
|
||||
)
|
||||
|
||||
self.run_test(natten_mask, dtype)
|
||||
|
||||
@supported_platform
|
||||
@expectedFailure
|
||||
@common_utils.parametrize("dtype", test_dtypes_fast)
|
||||
def test_silu_on_score(self, dtype):
|
||||
def silu_score(score, b, h, q, kv):
|
||||
return torch.nn.functional.silu(score)
|
||||
|
||||
self.run_test(silu_score, dtype)
|
||||
|
||||
@supported_platform
|
||||
@skip("Triton bug ") # https://github.com/pytorch/pytorch/issues/124571
|
||||
@common_utils.parametrize("dtype", test_dtypes)
|
||||
@ -214,8 +244,8 @@ class TestTemplatedSDPA(InductorTestCase):
|
||||
|
||||
def create_njt_wrapper(orig_score_mod, offsets, seq_idx):
|
||||
def njt_score_mod(qk, b, h, q, kv):
|
||||
q_nested = q - index(offsets, [index(seq_idx, [q])])
|
||||
kv_nested = kv - index(offsets, [index(seq_idx, [kv])])
|
||||
q_nested = q - offsets[seq_idx[q]]
|
||||
kv_nested = kv - offsets[seq_idx[kv]]
|
||||
return orig_score_mod(qk, b, h, q_nested, kv_nested)
|
||||
|
||||
return njt_score_mod
|
||||
@ -274,9 +304,9 @@ class TestTemplatedSDPA(InductorTestCase):
|
||||
tok_scale = torch.randn(S, device="cuda")
|
||||
|
||||
def bias_mod(score, batch, head, token_q, token_kv):
|
||||
score = score + index(tok_scale, [token_q])
|
||||
score = score + index(batch_scale, [batch])
|
||||
score = score + index(head_scale, [head])
|
||||
score = score + tok_scale[token_q]
|
||||
score = score + batch_scale[batch]
|
||||
score = score + head_scale[head]
|
||||
return score
|
||||
|
||||
self.run_test(bias_mod)
|
||||
|
||||
@ -1387,6 +1387,28 @@ class TestTorchFunctionMode(TestCase):
|
||||
|
||||
self.assertTrue(called)
|
||||
|
||||
def test_getitem_call(self):
|
||||
# This failed because the parser thinks the function is called to()
|
||||
# but it's actually called _parse_to()
|
||||
|
||||
called = False
|
||||
|
||||
class A(TorchFunctionMode):
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
nonlocal called
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
called = True
|
||||
return func(*args, **kwargs)
|
||||
|
||||
a = torch.zeros(5)
|
||||
b = torch.tensor(0)
|
||||
with A():
|
||||
a[b]
|
||||
|
||||
self.assertTrue(called)
|
||||
|
||||
|
||||
def test_distributions_bernoulli(self):
|
||||
# This failed because improper use of has_torch_function when
|
||||
# is_tensor_like should have been used instead, inside the
|
||||
|
||||
@ -1475,6 +1475,7 @@ class TemplatedAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
self, tx, query: "VariableTracker", score_function: "VariableTracker"
|
||||
):
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
from torch._higher_order_ops.templated_attention import TransformGetItemToIndex
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
tx: InstructionTranslator = tx
|
||||
@ -1499,19 +1500,21 @@ class TemplatedAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
|
||||
bhmn = [create_scalar() for _ in range(4)]
|
||||
new_args = [score, *bhmn]
|
||||
(
|
||||
(body_output, body_treespec),
|
||||
body_graph,
|
||||
body_lifted_freevars,
|
||||
) = speculate_subgraph(
|
||||
tx,
|
||||
score_function,
|
||||
new_args,
|
||||
{}, # expect only args no kwargs for now
|
||||
description="templated_attention",
|
||||
source_target=self.value,
|
||||
set_subgraph_inputs="flatten_manual",
|
||||
)
|
||||
|
||||
with TransformGetItemToIndex():
|
||||
(
|
||||
(body_output, body_treespec),
|
||||
body_graph,
|
||||
body_lifted_freevars,
|
||||
) = speculate_subgraph(
|
||||
tx,
|
||||
score_function,
|
||||
new_args,
|
||||
{}, # expect only args no kwargs for now
|
||||
description="templated_attention",
|
||||
source_target=self.value,
|
||||
set_subgraph_inputs="flatten_manual",
|
||||
)
|
||||
|
||||
body_name = add_subgraph(
|
||||
tx,
|
||||
|
||||
@ -178,7 +178,7 @@ def _maybe_remove_batch_dim(name, batched_output, vmap_level, batch_size, out_di
|
||||
raise ValueError(
|
||||
f"vmap({name}, ...): `{name}` must only return "
|
||||
f"Tensors, got type {type(batched_output)}. "
|
||||
"Did you mean to set out_dim= to None for output?"
|
||||
"Did you mean to set out_dims= to None for output?"
|
||||
)
|
||||
|
||||
return _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Callable, Tuple
|
||||
from typing import Any, Callable, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -16,6 +16,29 @@ from torch.fx.experimental.proxy_tensor import (
|
||||
track_tensor_tree,
|
||||
)
|
||||
|
||||
from torch.overrides import TorchFunctionMode
|
||||
|
||||
|
||||
def transform_getitem_args(x: torch.Tensor, index_args) -> Tuple[Any, ...]:
|
||||
if isinstance(index_args, tuple):
|
||||
return (x, list(index_args))
|
||||
elif not isinstance(index_args, (list, tuple)):
|
||||
return (x, [index_args])
|
||||
return (x, index_args)
|
||||
|
||||
|
||||
class TransformGetItemToIndex(TorchFunctionMode):
|
||||
# This is needed since we want to support calling
|
||||
# A[q_idx], where q_idx is a scalar tensor in score_mod.
|
||||
# Today, when q_idx is a scalar tensor, we implicitly convert it to a python
|
||||
# scalar and create a view. We do not want that behavior in this case, so we
|
||||
# use this torchfunctionmode to override that behavior for score_mod
|
||||
# wherever we're running it.
|
||||
def __torch_function__(self, func, types, args, kwargs=None):
|
||||
if func == torch.Tensor.__getitem__:
|
||||
return torch.ops.aten.index(*transform_getitem_args(*args))
|
||||
return func(*args, **(kwargs or {}))
|
||||
|
||||
|
||||
class TemplatedAttentionHOP(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
@ -73,7 +96,10 @@ def math_attention(
|
||||
score_mod = torch.vmap(score_mod, in_dims=(0, None, 0, None, None) + in_dim_buffers)
|
||||
score_mod = torch.vmap(score_mod, in_dims=(0, 0, None, None, None) + in_dim_buffers)
|
||||
|
||||
scores = score_mod(scores, b, h, m, n, *other_buffers).to(torch.float32)
|
||||
# todo: We wouldn't need these overrides in this file if Dynamo always did the
|
||||
# rewriting.
|
||||
with TransformGetItemToIndex():
|
||||
scores = score_mod(scores, b, h, m, n, *other_buffers).to(torch.float32)
|
||||
|
||||
# TODO Unconditionally return logsumexp for backwards
|
||||
# if any(t.requires_grad for t in (query, key, value)):
|
||||
@ -122,7 +148,8 @@ def trace_templated_attention(
|
||||
example_vals = [
|
||||
torch.zeros((), dtype=query.dtype, requires_grad=query.requires_grad)
|
||||
] + [torch.zeros((), dtype=torch.int) for _ in range(4)]
|
||||
score_graph = make_fx(score_mod)(*example_vals, *other_buffers)
|
||||
with TransformGetItemToIndex():
|
||||
score_graph = make_fx(score_mod)(*example_vals, *other_buffers)
|
||||
proxy_mode.tracer.root.register_module("sdpa_score", score_graph)
|
||||
node_args = (query, key, value, score_graph, *other_buffers)
|
||||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
|
||||
@ -187,9 +214,10 @@ def templated_attention_functionalize(
|
||||
with ctx.redispatch_to_next() as m:
|
||||
functional_score_mod = ctx.functionalize(score_mod)
|
||||
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
||||
mutates = _has_potential_branch_input_mutation(
|
||||
functional_score_mod, example_vals, pre_dispatch
|
||||
)
|
||||
with TransformGetItemToIndex():
|
||||
mutates = _has_potential_branch_input_mutation(
|
||||
functional_score_mod, example_vals, pre_dispatch
|
||||
)
|
||||
# The only care about mutations of existing buffers since we can't replay these.
|
||||
# However, we can just error if anything is detected
|
||||
if mutates:
|
||||
|
||||
@ -85,6 +85,7 @@ torch_function_passthrough = {
|
||||
torch.Tensor.__format__,
|
||||
torch.Tensor.__repr__,
|
||||
torch.Tensor.requires_grad.__get__, # type: ignore[attr-defined]
|
||||
torch.Tensor.__getitem__,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -32,8 +32,7 @@
|
||||
using namespace at;
|
||||
using namespace torch::autograd::utils;
|
||||
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
namespace torch::autograd {
|
||||
|
||||
Py_ssize_t THPVariable_length(PyObject* self) {
|
||||
HANDLE_TH_ERRORS
|
||||
@ -69,7 +68,7 @@ static inline int64_t count_specified_dimensions(PyObject* index) {
|
||||
for (Py_ssize_t i = 0; i < size; i++) {
|
||||
PyObject* obj = PyTuple_GET_ITEM(
|
||||
index, i); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
|
||||
if (!THPVariable_CheckExact(obj) && check_has_torch_function(obj))
|
||||
if (check_has_torch_function(obj))
|
||||
return -1;
|
||||
if (THPVariable_Check(obj)) {
|
||||
const auto& var = THPVariable_Unpack(obj);
|
||||
@ -341,7 +340,7 @@ static inline THPObjectPtr wrapTuple(PyObject* index) {
|
||||
// indexing is needed, it calls C++ `at::indexing::dispatch_index`.
|
||||
PyObject* THPVariable_getitem(PyObject* self, PyObject* index) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (!THPVariable_CheckExact(self) && check_has_torch_function(self)) {
|
||||
if (check_has_torch_function(self)) {
|
||||
return handle_torch_function_indexing(self, index);
|
||||
}
|
||||
const auto& self_ = THPVariable_Unpack(self);
|
||||
@ -438,9 +437,8 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) {
|
||||
if (py_value == nullptr) {
|
||||
throw TypeError("Tensor does not support deleting items");
|
||||
}
|
||||
if ((!THPVariable_CheckExact(self) && check_has_torch_function(self)) ||
|
||||
(!THPVariable_CheckExact(py_value) &&
|
||||
check_has_torch_function(py_value))) {
|
||||
if ((check_has_torch_function(self)) ||
|
||||
(check_has_torch_function(py_value))) {
|
||||
py::object ret = py::reinterpret_steal<py::object>(
|
||||
handle_torch_function_indexing(self, index, py_value));
|
||||
return 0;
|
||||
@ -553,5 +551,4 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) {
|
||||
END_HANDLE_TH_ERRORS_RET(-1)
|
||||
}
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
} // namespace torch::autograd
|
||||
|
||||
@ -22133,9 +22133,9 @@ python_ref_db = [
|
||||
torch_opinfo_name="roll",
|
||||
validate_view_consistency=False,
|
||||
skips=(
|
||||
# RuntimeError: no _refs support for torch.Tensor.__getitem__
|
||||
# Leaving it as a ref because fftshift uses it
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
|
||||
# # RuntimeError: no _refs support for torch.Tensor.__getitem__
|
||||
# # Leaving it as a ref because fftshift uses it
|
||||
# DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
|
||||
),
|
||||
),
|
||||
PythonRefInfo(
|
||||
|
||||
@ -767,18 +767,10 @@ python_ref_db: List[OpInfo] = [
|
||||
"_refs.fft.fftshift",
|
||||
op_db=op_db,
|
||||
torch_opinfo_name="fft.fftshift",
|
||||
skips=(
|
||||
# TODO Move fftshift to decomps
|
||||
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref"),
|
||||
),
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.fft.ifftshift",
|
||||
op_db=op_db,
|
||||
torch_opinfo_name="fft.ifftshift",
|
||||
skips=(
|
||||
# TODO Move ifftshift to decomps
|
||||
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref"),
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
@ -2389,8 +2389,6 @@ python_ref_db: List[OpInfo] = [
|
||||
supports_out=True,
|
||||
op_db=op_db,
|
||||
skips=(
|
||||
# no _refs support for Tensor.__getitem__
|
||||
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref"),
|
||||
# TODO: is this really needed?
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure, "TestCommon", "test_python_ref_errors"
|
||||
|
||||
Reference in New Issue
Block a user