Compare commits

...

2 Commits

Author SHA1 Message Date
46f2a96d21 [WIP] flex_attention x functorch.grad 2025-02-07 07:45:43 -08:00
da56b7e5e7 Add torch.func.debug_unwrap
Use it to unwrap any functorch-wrapped tensor. I don't recommend using
the output in a program since it breaks the semantics of the transforms,
but it seems useful for debugging.

I will note that some people have wanted to get intermediate values out
of an e.g. grad transform, so this might be a way to do that...

Test Plan:
- tests

ghstack-source-id: 982b82ea13d39a09e7ac245f677c5eea15d2cd14
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146528
2025-02-06 07:42:55 -08:00
9 changed files with 169 additions and 54 deletions

View File

@ -76,3 +76,12 @@ guidance here
:maxdepth: 1
func.batch_norm
Debug utilities
---------------
.. autosummary::
:toctree: generated
:nosignatures:
debug_unwrap

View File

@ -3261,6 +3261,18 @@ class TestHelpers(TestCase):
out = A.apply(x, y)
out.backward()
def test_debug_unwrap(self):
stuff = []
def f(x):
stuff.append(torch.func.debug_unwrap(x))
return x.sin()
x = torch.randn(2, 3)
_ = vmap(vmap(f))(x)
self.assertEqual(stuff[0], x)
self.assertTrue(stuff[0] is x)
def test_reductify_leaf(self, device):
reductify_leaf = torch._functorch.autograd_function.reductify_leaf
B = 2

View File

@ -2282,6 +2282,26 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
)
)
@common_utils.parametrize(
"score_mod", [_identity, _causal, _times_two, _squared, _trig, _trig2]
)
def test_functorch_grad(self, score_mod):
make_tensor = functools.partial(
torch.randn,
(2, 2, 11, 4),
device="cpu",
dtype=torch.float64,
requires_grad=True,
)
query, key, value = make_tensor(), make_tensor(), make_tensor()
def fn(query, key, value):
return flex_attention(query, key, value, score_mod)[0].sum()
expected = torch.autograd.grad(fn(query, key, value), (query, key, value))
result = torch.func.grad(fn, argnums=(0, 1, 2))(query, key, value)
self.assertEqual(result, expected)
@supported_platform
def test_eager_backward_strides(self):
class Repro(torch.nn.Module):

View File

@ -26,6 +26,8 @@ from torch._C._functorch import (
_wrap_for_grad,
_wrap_functional_tensor,
get_inplace_requires_grad_allowed,
get_unwrapped,
is_functorch_wrapped_tensor,
set_inplace_requires_grad_allowed,
)
from torch._functorch.utils import argnums_t, exposed_in
@ -1795,3 +1797,19 @@ def linearize(func: Callable, *primals) -> tuple[Any, Callable]:
return tree_unflatten(flat_output, output_spec)
return output, jvp_fn
@exposed_in("torch.func")
def debug_unwrap(tensor: torch.Tensor, *, recurse=True) -> torch.Tensor:
"""Unwraps a functorch tensor (e.g. BatchedTensor, GradTrackingTensor) to its underlying tensor.
This function should only be used in a debug setting (e.g. trying to print the
value of a Tensor in a debugger). Otherwise, using the result of function
inside of a function being transformed will lead to undefined behavior.
"""
if not is_functorch_wrapped_tensor(tensor):
return tensor
result = get_unwrapped(tensor)
if recurse:
return debug_unwrap(result)
return result

View File

@ -6,6 +6,8 @@ from typing import Any
import torch
import torch.utils._pytree as pytree
from torch._C._functorch import (
_unwrap_for_grad,
_wrap_for_grad,
CFunctionalizeInterpreterPtr,
CGradInterpreterPtr,
CInterpreter,
@ -252,6 +254,28 @@ def coerce_cinterpreter(cinterpreter: CInterpreter) -> FuncTorchInterpreter:
raise RuntimeError(f"NYI: PyDispatcher has not implemented support for {key}")
def HOPAutoDispatchBelowAutograd(func, *args, **kwargs):
"""Wrapper around torch._C._AutoDispatchBelowAutograd that works with functorch.grad
Please use this when you're implementing an autograd py_impl for a HOP.
"""
interpreter = torch._C._functorch.peek_interpreter_stack()
if interpreter is None:
with torch._C._AutoDispatchBelowAutograd():
return func(*args, **kwargs)
interpreter = retrieve_current_functorch_interpreter()
assert isinstance(interpreter, GradInterpreter)
level = interpreter.level()
args, kwargs = pytree.tree_map_only(
torch.Tensor, lambda x: _unwrap_for_grad(x, level), (args, kwargs)
)
with interpreter.lower():
result = func(*args, **kwargs)
result = pytree.tree_map_only(
torch.Tensor, lambda x: _wrap_for_grad(x, level), result
)
return result
def retrieve_current_functorch_interpreter() -> FuncTorchInterpreter:
interpreter = torch._C._functorch.peek_interpreter_stack()
assert interpreter is not None

View File

@ -1,11 +1,17 @@
import math
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Literal, Optional, Union
import torch
import torch.utils._pytree as pytree
from torch import Tensor
from torch._C import DispatchKey
from torch._C._functorch import TransformType
from torch._functorch.pyfunctorch import (
FuncTorchInterpreter,
HOPAutoDispatchBelowAutograd,
)
from torch._functorch.utils import enable_single_level_autograd_function
from torch._higher_order_ops.utils import (
_has_potential_branch_input_mutation,
_maybe_reenter_make_fx,
@ -489,7 +495,7 @@ def create_fw_bw_graph(
# All of these imports need to be here in order to avoid circular dependencies
from torch._dispatch.python import suspend_functionalization
from torch._functorch.aot_autograd import AOTConfig, create_joint
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._subclasses.functional_tensor import disable_functional_mode
from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing
@ -532,10 +538,10 @@ def create_fw_bw_graph(
unwrapped_score_mod_indexes = pytree.tree_map(_from_fun, index_values)
unwrapped_other_buffers = pytree.tree_map(_from_fun, other_buffers)
assert all(
isinstance(t, (FakeTensor, int, torch.SymInt))
for t in unwrapped_score_mod_indexes + unwrapped_other_buffers
)
# assert all(
# isinstance(t, (FakeTensor, int, torch.SymInt))
# for t in unwrapped_score_mod_indexes + unwrapped_other_buffers
# )
example_flat_out = pytree.tree_map(
_from_fun,
@ -577,7 +583,7 @@ def create_fw_bw_graph(
return score_mod, joint_graph
class FlexAttentionAutogradOp(torch.autograd.Function):
class FlexAttentionAutogradOp(torch.autograd.function._SingleLevelFunction):
@staticmethod
def forward(
ctx: Any,
@ -606,18 +612,19 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
ctx.scale = scale
ctx.kernel_options = kernel_options
ctx._score_mod_other_buffers_len = len(score_mod_other_buffers)
with torch._C._AutoDispatchBelowAutograd():
out, logsumexp = flex_attention(
query,
key,
value,
fw_graph,
block_mask,
scale,
kernel_options,
score_mod_other_buffers,
mask_mod_other_buffers,
)
out, logsumexp = HOPAutoDispatchBelowAutograd(
flex_attention,
query,
key,
value,
fw_graph,
block_mask,
scale,
kernel_options,
score_mod_other_buffers,
mask_mod_other_buffers,
)
save_tensors_and_symints_for_backward(
ctx,
@ -710,7 +717,17 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
@flex_attention.py_impl(DispatchKey.Autograd)
def _(*args: Any, **kwargs: Any) -> Any:
return flex_attention_autograd(None, *args, **kwargs)
@flex_attention.py_impl(TransformType.Grad)
def _(_ignored: FuncTorchInterpreter, *args: Any, **kwargs: Any) -> Any:
return flex_attention_autograd(None, *args, **kwargs)
def flex_attention_autograd(
_ignored: Literal[None],
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
@ -741,18 +758,19 @@ def flex_attention_autograd(
)
else:
fw_graph, bw_graph = score_mod, None
out, logsumexp = FlexAttentionAutogradOp.apply(
query,
key,
value,
fw_graph,
bw_graph,
block_mask,
scale,
kernel_options,
mask_mod_other_buffers,
*score_mod_other_buffers,
)
with enable_single_level_autograd_function():
out, logsumexp = FlexAttentionAutogradOp.apply( # type: ignore[attr-defined]
query,
key,
value,
fw_graph,
bw_graph,
block_mask,
scale,
kernel_options,
mask_mod_other_buffers,
*score_mod_other_buffers,
)
return out, logsumexp
@ -1163,3 +1181,9 @@ def flex_attention_backward_fake_tensor_mode(
flex_attention_backward.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(flex_attention_backward, deferred_error=True)
)
flex_attention_backward.py_impl(TransformType.Grad)(
autograd_not_implemented(
flex_attention_backward, deferred_error=True, functorch=True
)
)

View File

@ -7,6 +7,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union
import torch
import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree
from torch._functorch.pyfunctorch import HOPAutoDispatchBelowAutograd
from torch._guards import detect_fake_mode
from torch._ops import OperatorBase
from torch._subclasses.fake_tensor import FakeTensor
@ -36,34 +37,38 @@ def autograd_not_implemented_inner(
Raises:
RuntimeError: If autograd is enabled and any of the arguments to the Operator
"""
with torch._C._AutoDispatchBelowAutograd():
result = operator(*args, **kwargs)
flat_operands = pytree.arg_tree_leaves(*args)
if torch.is_grad_enabled() and any(
f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)
):
if delayed_error:
err_fn = torch._C._functions.DelayedError(
f"Autograd not implemented for {str(operator)}",
1,
)
result = HOPAutoDispatchBelowAutograd(operator, *args, **kwargs)
flat_operands = pytree.arg_tree_leaves(*args)
if torch.is_grad_enabled() and any(
f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)
):
if delayed_error:
err_fn = torch._C._functions.DelayedError(
f"Autograd not implemented for {str(operator)}",
1,
)
def fake_requires_grad(tensor):
if torch.is_floating_point(tensor) or torch.is_complex(tensor):
tensor = tensor.detach()
tensor.requires_grad = True
return tensor
def fake_requires_grad(tensor):
if torch.is_floating_point(tensor) or torch.is_complex(tensor):
tensor = tensor.detach()
tensor.requires_grad = True
return tensor
return pytree.tree_map_only(
torch.Tensor, lambda x: err_fn(fake_requires_grad(x)), result
)
else:
raise RuntimeError(f"Autograd not implemented for {str(operator)}")
return result
return pytree.tree_map_only(
torch.Tensor, lambda x: err_fn(fake_requires_grad(x)), result
)
else:
raise RuntimeError(f"Autograd not implemented for {str(operator)}")
return result
def autograd_not_implemented(op: OperatorBase, deferred_error: bool) -> Callable:
def autograd_not_implemented(
op: OperatorBase, deferred_error: bool, *, functorch: bool = False
) -> Callable:
def inner(*args, **kwargs):
if functorch:
# ignore first arg, which is the functorch interpreter
args = args[1:]
return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs)
return inner

View File

@ -141,6 +141,7 @@ class OperatorBase:
raise RuntimeError(
f"Trying to override a python impl for {k} on operator {self.name()}"
)
self.py_kernels[k] = fn
self._dispatch_cache.clear()
return fn

View File

@ -1,6 +1,7 @@
from torch._functorch.apis import grad, grad_and_value, vmap
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
from torch._functorch.eager_transforms import (
debug_unwrap,
functionalize,
hessian,
jacfwd,
@ -26,4 +27,5 @@ __all__ = [
"vjp",
"functional_call",
"stack_module_state",
"debug_unwrap",
]