Made FlexAttention rewrite getitem calls to use aten.index in score_mod (#124799)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124799
Approved by: https://github.com/drisspg
ghstack dependencies: #124444
This commit is contained in:
chilli
2024-04-25 10:47:54 -07:00
committed by PyTorch MergeBot
parent 7321005dd8
commit 9bccafc31c
11 changed files with 129 additions and 58 deletions

View File

@ -12,7 +12,7 @@ const char* get_cuda_check_suffix() noexcept {
} else {
return "\nCUDA kernel errors might be asynchronously reported at some"
" other API call, so the stacktrace below might be incorrect."
"\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.";
"\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1";
}
}
std::mutex* getFreeMutex() {

View File

@ -4,7 +4,7 @@ import functools
from collections import namedtuple
from typing import Callable
from unittest import skip, skipUnless
from unittest import expectedFailure, skip, skipUnless
from unittest.mock import patch
import torch
@ -125,7 +125,7 @@ class TestTemplatedSDPA(InductorTestCase):
head_offset = torch.rand(H, device="cuda", dtype=dtype)
def score_mod(score, b, h, m, n):
return score + index(head_offset, [h])
return score + head_offset[h]
self.run_test(score_mod, dtype)
@ -136,9 +136,7 @@ class TestTemplatedSDPA(InductorTestCase):
seq_idx[S // 2 :] = 1
def seq_mask_mod(score, b, h, q, kv):
return torch.where(
index(seq_idx, [q]) == index(seq_idx, [kv]), score, float("-inf")
)
return torch.where(seq_idx[q] == seq_idx[kv], score, float("-inf"))
self.run_test(seq_mask_mod, dtype)
@ -148,7 +146,7 @@ class TestTemplatedSDPA(InductorTestCase):
bias = torch.randn(S, S, device="cuda", dtype=dtype)
def bias_mod(score, b, h, q, kv):
return score + index(bias, [q, kv])
return score + bias[q, kv]
self.run_test(bias_mod, dtype)
@ -158,7 +156,7 @@ class TestTemplatedSDPA(InductorTestCase):
bias = torch.randn(B, S, S, device="cuda", dtype=dtype)
def bias_mod(score, b, h, q, kv):
return score + index(bias, [b, q, kv])
return score + bias[b, q, kv]
self.run_test(bias_mod, dtype)
@ -168,7 +166,7 @@ class TestTemplatedSDPA(InductorTestCase):
bias = torch.randn(B, H, S, S, device="cuda", dtype=dtype)
def bias_mod(score, b, h, q, kv):
return score + index(bias, [b, h, q, kv])
return score + bias[b, h, q, kv]
self.run_test(bias_mod, dtype)
@ -178,7 +176,7 @@ class TestTemplatedSDPA(InductorTestCase):
rel_bias = torch.randn(2 * S, device="cuda", dtype=dtype)
def bias_mod(score, b, h, q, kv):
return score + index(rel_bias, [(q - kv) + S])
return score + rel_bias[(q - kv) + S]
self.run_test(bias_mod, dtype)
@ -189,7 +187,7 @@ class TestTemplatedSDPA(InductorTestCase):
def bias_mod(score, b, h, q, kv):
causal_attention = q >= kv
cur_num_bidirectional = index(num_bidirectional, (b,))
cur_num_bidirectional = num_bidirectional[b]
bidirectional_attention_on_video = (q <= cur_num_bidirectional) & (
kv <= cur_num_bidirectional
)
@ -201,6 +199,38 @@ class TestTemplatedSDPA(InductorTestCase):
self.run_test(bias_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_natten_2d(self, dtype):
H = 32
W = S // H
WINDOW = 3
assert W * H == S
def get_x_y(idx):
# This should be a floor divide, but we don't support that properly
return idx / W, idx % W
def natten_mask(score, b, h, q, kv):
q_x, q_y = get_x_y(q)
kv_x, kv_y = get_x_y(kv)
return torch.where(
((q_x - kv_x).abs() <= WINDOW) | ((q_y - kv_y).abs() <= WINDOW),
score,
float("-inf"),
)
self.run_test(natten_mask, dtype)
@supported_platform
@expectedFailure
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_silu_on_score(self, dtype):
def silu_score(score, b, h, q, kv):
return torch.nn.functional.silu(score)
self.run_test(silu_score, dtype)
@supported_platform
@skip("Triton bug ") # https://github.com/pytorch/pytorch/issues/124571
@common_utils.parametrize("dtype", test_dtypes)
@ -214,8 +244,8 @@ class TestTemplatedSDPA(InductorTestCase):
def create_njt_wrapper(orig_score_mod, offsets, seq_idx):
def njt_score_mod(qk, b, h, q, kv):
q_nested = q - index(offsets, [index(seq_idx, [q])])
kv_nested = kv - index(offsets, [index(seq_idx, [kv])])
q_nested = q - offsets[seq_idx[q]]
kv_nested = kv - offsets[seq_idx[kv]]
return orig_score_mod(qk, b, h, q_nested, kv_nested)
return njt_score_mod
@ -274,9 +304,9 @@ class TestTemplatedSDPA(InductorTestCase):
tok_scale = torch.randn(S, device="cuda")
def bias_mod(score, batch, head, token_q, token_kv):
score = score + index(tok_scale, [token_q])
score = score + index(batch_scale, [batch])
score = score + index(head_scale, [head])
score = score + tok_scale[token_q]
score = score + batch_scale[batch]
score = score + head_scale[head]
return score
self.run_test(bias_mod)

View File

@ -1387,6 +1387,28 @@ class TestTorchFunctionMode(TestCase):
self.assertTrue(called)
def test_getitem_call(self):
# This failed because the parser thinks the function is called to()
# but it's actually called _parse_to()
called = False
class A(TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
nonlocal called
if kwargs is None:
kwargs = {}
called = True
return func(*args, **kwargs)
a = torch.zeros(5)
b = torch.tensor(0)
with A():
a[b]
self.assertTrue(called)
def test_distributions_bernoulli(self):
# This failed because improper use of has_torch_function when
# is_tensor_like should have been used instead, inside the

View File

@ -1475,6 +1475,7 @@ class TemplatedAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
self, tx, query: "VariableTracker", score_function: "VariableTracker"
):
from torch._dynamo.symbolic_convert import InstructionTranslator
from torch._higher_order_ops.templated_attention import TransformGetItemToIndex
from .builder import SourcelessBuilder
tx: InstructionTranslator = tx
@ -1499,19 +1500,21 @@ class TemplatedAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
bhmn = [create_scalar() for _ in range(4)]
new_args = [score, *bhmn]
(
(body_output, body_treespec),
body_graph,
body_lifted_freevars,
) = speculate_subgraph(
tx,
score_function,
new_args,
{}, # expect only args no kwargs for now
description="templated_attention",
source_target=self.value,
set_subgraph_inputs="flatten_manual",
)
with TransformGetItemToIndex():
(
(body_output, body_treespec),
body_graph,
body_lifted_freevars,
) = speculate_subgraph(
tx,
score_function,
new_args,
{}, # expect only args no kwargs for now
description="templated_attention",
source_target=self.value,
set_subgraph_inputs="flatten_manual",
)
body_name = add_subgraph(
tx,

View File

@ -178,7 +178,7 @@ def _maybe_remove_batch_dim(name, batched_output, vmap_level, batch_size, out_di
raise ValueError(
f"vmap({name}, ...): `{name}` must only return "
f"Tensors, got type {type(batched_output)}. "
"Did you mean to set out_dim= to None for output?"
"Did you mean to set out_dims= to None for output?"
)
return _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim)

View File

@ -1,4 +1,4 @@
from typing import Callable, Tuple
from typing import Any, Callable, Tuple
import torch
import torch.utils._pytree as pytree
@ -16,6 +16,29 @@ from torch.fx.experimental.proxy_tensor import (
track_tensor_tree,
)
from torch.overrides import TorchFunctionMode
def transform_getitem_args(x: torch.Tensor, index_args) -> Tuple[Any, ...]:
if isinstance(index_args, tuple):
return (x, list(index_args))
elif not isinstance(index_args, (list, tuple)):
return (x, [index_args])
return (x, index_args)
class TransformGetItemToIndex(TorchFunctionMode):
# This is needed since we want to support calling
# A[q_idx], where q_idx is a scalar tensor in score_mod.
# Today, when q_idx is a scalar tensor, we implicitly convert it to a python
# scalar and create a view. We do not want that behavior in this case, so we
# use this torchfunctionmode to override that behavior for score_mod
# wherever we're running it.
def __torch_function__(self, func, types, args, kwargs=None):
if func == torch.Tensor.__getitem__:
return torch.ops.aten.index(*transform_getitem_args(*args))
return func(*args, **(kwargs or {}))
class TemplatedAttentionHOP(HigherOrderOperator):
def __init__(self):
@ -73,7 +96,10 @@ def math_attention(
score_mod = torch.vmap(score_mod, in_dims=(0, None, 0, None, None) + in_dim_buffers)
score_mod = torch.vmap(score_mod, in_dims=(0, 0, None, None, None) + in_dim_buffers)
scores = score_mod(scores, b, h, m, n, *other_buffers).to(torch.float32)
# todo: We wouldn't need these overrides in this file if Dynamo always did the
# rewriting.
with TransformGetItemToIndex():
scores = score_mod(scores, b, h, m, n, *other_buffers).to(torch.float32)
# TODO Unconditionally return logsumexp for backwards
# if any(t.requires_grad for t in (query, key, value)):
@ -122,7 +148,8 @@ def trace_templated_attention(
example_vals = [
torch.zeros((), dtype=query.dtype, requires_grad=query.requires_grad)
] + [torch.zeros((), dtype=torch.int) for _ in range(4)]
score_graph = make_fx(score_mod)(*example_vals, *other_buffers)
with TransformGetItemToIndex():
score_graph = make_fx(score_mod)(*example_vals, *other_buffers)
proxy_mode.tracer.root.register_module("sdpa_score", score_graph)
node_args = (query, key, value, score_graph, *other_buffers)
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
@ -187,9 +214,10 @@ def templated_attention_functionalize(
with ctx.redispatch_to_next() as m:
functional_score_mod = ctx.functionalize(score_mod)
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
mutates = _has_potential_branch_input_mutation(
functional_score_mod, example_vals, pre_dispatch
)
with TransformGetItemToIndex():
mutates = _has_potential_branch_input_mutation(
functional_score_mod, example_vals, pre_dispatch
)
# The only care about mutations of existing buffers since we can't replay these.
# However, we can just error if anything is detected
if mutates:

View File

@ -85,6 +85,7 @@ torch_function_passthrough = {
torch.Tensor.__format__,
torch.Tensor.__repr__,
torch.Tensor.requires_grad.__get__, # type: ignore[attr-defined]
torch.Tensor.__getitem__,
}

View File

@ -32,8 +32,7 @@
using namespace at;
using namespace torch::autograd::utils;
namespace torch {
namespace autograd {
namespace torch::autograd {
Py_ssize_t THPVariable_length(PyObject* self) {
HANDLE_TH_ERRORS
@ -69,7 +68,7 @@ static inline int64_t count_specified_dimensions(PyObject* index) {
for (Py_ssize_t i = 0; i < size; i++) {
PyObject* obj = PyTuple_GET_ITEM(
index, i); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
if (!THPVariable_CheckExact(obj) && check_has_torch_function(obj))
if (check_has_torch_function(obj))
return -1;
if (THPVariable_Check(obj)) {
const auto& var = THPVariable_Unpack(obj);
@ -341,7 +340,7 @@ static inline THPObjectPtr wrapTuple(PyObject* index) {
// indexing is needed, it calls C++ `at::indexing::dispatch_index`.
PyObject* THPVariable_getitem(PyObject* self, PyObject* index) {
HANDLE_TH_ERRORS
if (!THPVariable_CheckExact(self) && check_has_torch_function(self)) {
if (check_has_torch_function(self)) {
return handle_torch_function_indexing(self, index);
}
const auto& self_ = THPVariable_Unpack(self);
@ -438,9 +437,8 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) {
if (py_value == nullptr) {
throw TypeError("Tensor does not support deleting items");
}
if ((!THPVariable_CheckExact(self) && check_has_torch_function(self)) ||
(!THPVariable_CheckExact(py_value) &&
check_has_torch_function(py_value))) {
if ((check_has_torch_function(self)) ||
(check_has_torch_function(py_value))) {
py::object ret = py::reinterpret_steal<py::object>(
handle_torch_function_indexing(self, index, py_value));
return 0;
@ -553,5 +551,4 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) {
END_HANDLE_TH_ERRORS_RET(-1)
}
} // namespace autograd
} // namespace torch
} // namespace torch::autograd

View File

@ -22133,9 +22133,9 @@ python_ref_db = [
torch_opinfo_name="roll",
validate_view_consistency=False,
skips=(
# RuntimeError: no _refs support for torch.Tensor.__getitem__
# Leaving it as a ref because fftshift uses it
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
# # RuntimeError: no _refs support for torch.Tensor.__getitem__
# # Leaving it as a ref because fftshift uses it
# DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
),
),
PythonRefInfo(

View File

@ -767,18 +767,10 @@ python_ref_db: List[OpInfo] = [
"_refs.fft.fftshift",
op_db=op_db,
torch_opinfo_name="fft.fftshift",
skips=(
# TODO Move fftshift to decomps
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref"),
),
),
PythonRefInfo(
"_refs.fft.ifftshift",
op_db=op_db,
torch_opinfo_name="fft.ifftshift",
skips=(
# TODO Move ifftshift to decomps
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref"),
),
),
]

View File

@ -2389,8 +2389,6 @@ python_ref_db: List[OpInfo] = [
supports_out=True,
op_db=op_db,
skips=(
# no _refs support for Tensor.__getitem__
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref"),
# TODO: is this really needed?
DecorateInfo(
unittest.expectedFailure, "TestCommon", "test_python_ref_errors"