Compare commits

...

2 Commits

Author SHA1 Message Date
46f2a96d21 [WIP] flex_attention x functorch.grad 2025-02-07 07:45:43 -08:00
da56b7e5e7 Add torch.func.debug_unwrap
Use it to unwrap any functorch-wrapped tensor. I don't recommend using
the output in a program since it breaks the semantics of the transforms,
but it seems useful for debugging.

I will note that some people have wanted to get intermediate values out
of an e.g. grad transform, so this might be a way to do that...

Test Plan:
- tests

ghstack-source-id: 982b82ea13d39a09e7ac245f677c5eea15d2cd14
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146528
2025-02-06 07:42:55 -08:00
9 changed files with 169 additions and 54 deletions

View File

@ -76,3 +76,12 @@ guidance here
:maxdepth: 1 :maxdepth: 1
func.batch_norm func.batch_norm
Debug utilities
---------------
.. autosummary::
:toctree: generated
:nosignatures:
debug_unwrap

View File

@ -3261,6 +3261,18 @@ class TestHelpers(TestCase):
out = A.apply(x, y) out = A.apply(x, y)
out.backward() out.backward()
def test_debug_unwrap(self):
stuff = []
def f(x):
stuff.append(torch.func.debug_unwrap(x))
return x.sin()
x = torch.randn(2, 3)
_ = vmap(vmap(f))(x)
self.assertEqual(stuff[0], x)
self.assertTrue(stuff[0] is x)
def test_reductify_leaf(self, device): def test_reductify_leaf(self, device):
reductify_leaf = torch._functorch.autograd_function.reductify_leaf reductify_leaf = torch._functorch.autograd_function.reductify_leaf
B = 2 B = 2

View File

@ -2282,6 +2282,26 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
) )
) )
@common_utils.parametrize(
"score_mod", [_identity, _causal, _times_two, _squared, _trig, _trig2]
)
def test_functorch_grad(self, score_mod):
make_tensor = functools.partial(
torch.randn,
(2, 2, 11, 4),
device="cpu",
dtype=torch.float64,
requires_grad=True,
)
query, key, value = make_tensor(), make_tensor(), make_tensor()
def fn(query, key, value):
return flex_attention(query, key, value, score_mod)[0].sum()
expected = torch.autograd.grad(fn(query, key, value), (query, key, value))
result = torch.func.grad(fn, argnums=(0, 1, 2))(query, key, value)
self.assertEqual(result, expected)
@supported_platform @supported_platform
def test_eager_backward_strides(self): def test_eager_backward_strides(self):
class Repro(torch.nn.Module): class Repro(torch.nn.Module):

View File

@ -26,6 +26,8 @@ from torch._C._functorch import (
_wrap_for_grad, _wrap_for_grad,
_wrap_functional_tensor, _wrap_functional_tensor,
get_inplace_requires_grad_allowed, get_inplace_requires_grad_allowed,
get_unwrapped,
is_functorch_wrapped_tensor,
set_inplace_requires_grad_allowed, set_inplace_requires_grad_allowed,
) )
from torch._functorch.utils import argnums_t, exposed_in from torch._functorch.utils import argnums_t, exposed_in
@ -1795,3 +1797,19 @@ def linearize(func: Callable, *primals) -> tuple[Any, Callable]:
return tree_unflatten(flat_output, output_spec) return tree_unflatten(flat_output, output_spec)
return output, jvp_fn return output, jvp_fn
@exposed_in("torch.func")
def debug_unwrap(tensor: torch.Tensor, *, recurse=True) -> torch.Tensor:
"""Unwraps a functorch tensor (e.g. BatchedTensor, GradTrackingTensor) to its underlying tensor.
This function should only be used in a debug setting (e.g. trying to print the
value of a Tensor in a debugger). Otherwise, using the result of function
inside of a function being transformed will lead to undefined behavior.
"""
if not is_functorch_wrapped_tensor(tensor):
return tensor
result = get_unwrapped(tensor)
if recurse:
return debug_unwrap(result)
return result

View File

@ -6,6 +6,8 @@ from typing import Any
import torch import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
from torch._C._functorch import ( from torch._C._functorch import (
_unwrap_for_grad,
_wrap_for_grad,
CFunctionalizeInterpreterPtr, CFunctionalizeInterpreterPtr,
CGradInterpreterPtr, CGradInterpreterPtr,
CInterpreter, CInterpreter,
@ -252,6 +254,28 @@ def coerce_cinterpreter(cinterpreter: CInterpreter) -> FuncTorchInterpreter:
raise RuntimeError(f"NYI: PyDispatcher has not implemented support for {key}") raise RuntimeError(f"NYI: PyDispatcher has not implemented support for {key}")
def HOPAutoDispatchBelowAutograd(func, *args, **kwargs):
"""Wrapper around torch._C._AutoDispatchBelowAutograd that works with functorch.grad
Please use this when you're implementing an autograd py_impl for a HOP.
"""
interpreter = torch._C._functorch.peek_interpreter_stack()
if interpreter is None:
with torch._C._AutoDispatchBelowAutograd():
return func(*args, **kwargs)
interpreter = retrieve_current_functorch_interpreter()
assert isinstance(interpreter, GradInterpreter)
level = interpreter.level()
args, kwargs = pytree.tree_map_only(
torch.Tensor, lambda x: _unwrap_for_grad(x, level), (args, kwargs)
)
with interpreter.lower():
result = func(*args, **kwargs)
result = pytree.tree_map_only(
torch.Tensor, lambda x: _wrap_for_grad(x, level), result
)
return result
def retrieve_current_functorch_interpreter() -> FuncTorchInterpreter: def retrieve_current_functorch_interpreter() -> FuncTorchInterpreter:
interpreter = torch._C._functorch.peek_interpreter_stack() interpreter = torch._C._functorch.peek_interpreter_stack()
assert interpreter is not None assert interpreter is not None

View File

@ -1,11 +1,17 @@
import math import math
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Literal, Optional, Union
import torch import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
from torch import Tensor from torch import Tensor
from torch._C import DispatchKey from torch._C import DispatchKey
from torch._C._functorch import TransformType
from torch._functorch.pyfunctorch import (
FuncTorchInterpreter,
HOPAutoDispatchBelowAutograd,
)
from torch._functorch.utils import enable_single_level_autograd_function
from torch._higher_order_ops.utils import ( from torch._higher_order_ops.utils import (
_has_potential_branch_input_mutation, _has_potential_branch_input_mutation,
_maybe_reenter_make_fx, _maybe_reenter_make_fx,
@ -489,7 +495,7 @@ def create_fw_bw_graph(
# All of these imports need to be here in order to avoid circular dependencies # All of these imports need to be here in order to avoid circular dependencies
from torch._dispatch.python import suspend_functionalization from torch._dispatch.python import suspend_functionalization
from torch._functorch.aot_autograd import AOTConfig, create_joint from torch._functorch.aot_autograd import AOTConfig, create_joint
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode from torch._subclasses.fake_tensor import FakeTensorMode
from torch._subclasses.functional_tensor import disable_functional_mode from torch._subclasses.functional_tensor import disable_functional_mode
from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing
@ -532,10 +538,10 @@ def create_fw_bw_graph(
unwrapped_score_mod_indexes = pytree.tree_map(_from_fun, index_values) unwrapped_score_mod_indexes = pytree.tree_map(_from_fun, index_values)
unwrapped_other_buffers = pytree.tree_map(_from_fun, other_buffers) unwrapped_other_buffers = pytree.tree_map(_from_fun, other_buffers)
assert all( # assert all(
isinstance(t, (FakeTensor, int, torch.SymInt)) # isinstance(t, (FakeTensor, int, torch.SymInt))
for t in unwrapped_score_mod_indexes + unwrapped_other_buffers # for t in unwrapped_score_mod_indexes + unwrapped_other_buffers
) # )
example_flat_out = pytree.tree_map( example_flat_out = pytree.tree_map(
_from_fun, _from_fun,
@ -577,7 +583,7 @@ def create_fw_bw_graph(
return score_mod, joint_graph return score_mod, joint_graph
class FlexAttentionAutogradOp(torch.autograd.Function): class FlexAttentionAutogradOp(torch.autograd.function._SingleLevelFunction):
@staticmethod @staticmethod
def forward( def forward(
ctx: Any, ctx: Any,
@ -606,18 +612,19 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
ctx.scale = scale ctx.scale = scale
ctx.kernel_options = kernel_options ctx.kernel_options = kernel_options
ctx._score_mod_other_buffers_len = len(score_mod_other_buffers) ctx._score_mod_other_buffers_len = len(score_mod_other_buffers)
with torch._C._AutoDispatchBelowAutograd():
out, logsumexp = flex_attention( out, logsumexp = HOPAutoDispatchBelowAutograd(
query, flex_attention,
key, query,
value, key,
fw_graph, value,
block_mask, fw_graph,
scale, block_mask,
kernel_options, scale,
score_mod_other_buffers, kernel_options,
mask_mod_other_buffers, score_mod_other_buffers,
) mask_mod_other_buffers,
)
save_tensors_and_symints_for_backward( save_tensors_and_symints_for_backward(
ctx, ctx,
@ -710,7 +717,17 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
@flex_attention.py_impl(DispatchKey.Autograd) @flex_attention.py_impl(DispatchKey.Autograd)
def _(*args: Any, **kwargs: Any) -> Any:
return flex_attention_autograd(None, *args, **kwargs)
@flex_attention.py_impl(TransformType.Grad)
def _(_ignored: FuncTorchInterpreter, *args: Any, **kwargs: Any) -> Any:
return flex_attention_autograd(None, *args, **kwargs)
def flex_attention_autograd( def flex_attention_autograd(
_ignored: Literal[None],
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
@ -741,18 +758,19 @@ def flex_attention_autograd(
) )
else: else:
fw_graph, bw_graph = score_mod, None fw_graph, bw_graph = score_mod, None
out, logsumexp = FlexAttentionAutogradOp.apply( with enable_single_level_autograd_function():
query, out, logsumexp = FlexAttentionAutogradOp.apply( # type: ignore[attr-defined]
key, query,
value, key,
fw_graph, value,
bw_graph, fw_graph,
block_mask, bw_graph,
scale, block_mask,
kernel_options, scale,
mask_mod_other_buffers, kernel_options,
*score_mod_other_buffers, mask_mod_other_buffers,
) *score_mod_other_buffers,
)
return out, logsumexp return out, logsumexp
@ -1163,3 +1181,9 @@ def flex_attention_backward_fake_tensor_mode(
flex_attention_backward.py_impl(DispatchKey.Autograd)( flex_attention_backward.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(flex_attention_backward, deferred_error=True) autograd_not_implemented(flex_attention_backward, deferred_error=True)
) )
flex_attention_backward.py_impl(TransformType.Grad)(
autograd_not_implemented(
flex_attention_backward, deferred_error=True, functorch=True
)
)

View File

@ -7,6 +7,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union
import torch import torch
import torch.fx.traceback as fx_traceback import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
from torch._functorch.pyfunctorch import HOPAutoDispatchBelowAutograd
from torch._guards import detect_fake_mode from torch._guards import detect_fake_mode
from torch._ops import OperatorBase from torch._ops import OperatorBase
from torch._subclasses.fake_tensor import FakeTensor from torch._subclasses.fake_tensor import FakeTensor
@ -36,34 +37,38 @@ def autograd_not_implemented_inner(
Raises: Raises:
RuntimeError: If autograd is enabled and any of the arguments to the Operator RuntimeError: If autograd is enabled and any of the arguments to the Operator
""" """
with torch._C._AutoDispatchBelowAutograd(): result = HOPAutoDispatchBelowAutograd(operator, *args, **kwargs)
result = operator(*args, **kwargs) flat_operands = pytree.arg_tree_leaves(*args)
flat_operands = pytree.arg_tree_leaves(*args) if torch.is_grad_enabled() and any(
if torch.is_grad_enabled() and any( f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)
f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor) ):
): if delayed_error:
if delayed_error: err_fn = torch._C._functions.DelayedError(
err_fn = torch._C._functions.DelayedError( f"Autograd not implemented for {str(operator)}",
f"Autograd not implemented for {str(operator)}", 1,
1, )
)
def fake_requires_grad(tensor): def fake_requires_grad(tensor):
if torch.is_floating_point(tensor) or torch.is_complex(tensor): if torch.is_floating_point(tensor) or torch.is_complex(tensor):
tensor = tensor.detach() tensor = tensor.detach()
tensor.requires_grad = True tensor.requires_grad = True
return tensor return tensor
return pytree.tree_map_only( return pytree.tree_map_only(
torch.Tensor, lambda x: err_fn(fake_requires_grad(x)), result torch.Tensor, lambda x: err_fn(fake_requires_grad(x)), result
) )
else: else:
raise RuntimeError(f"Autograd not implemented for {str(operator)}") raise RuntimeError(f"Autograd not implemented for {str(operator)}")
return result return result
def autograd_not_implemented(op: OperatorBase, deferred_error: bool) -> Callable: def autograd_not_implemented(
op: OperatorBase, deferred_error: bool, *, functorch: bool = False
) -> Callable:
def inner(*args, **kwargs): def inner(*args, **kwargs):
if functorch:
# ignore first arg, which is the functorch interpreter
args = args[1:]
return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs) return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs)
return inner return inner

View File

@ -141,6 +141,7 @@ class OperatorBase:
raise RuntimeError( raise RuntimeError(
f"Trying to override a python impl for {k} on operator {self.name()}" f"Trying to override a python impl for {k} on operator {self.name()}"
) )
self.py_kernels[k] = fn self.py_kernels[k] = fn
self._dispatch_cache.clear() self._dispatch_cache.clear()
return fn return fn

View File

@ -1,6 +1,7 @@
from torch._functorch.apis import grad, grad_and_value, vmap from torch._functorch.apis import grad, grad_and_value, vmap
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_ from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
from torch._functorch.eager_transforms import ( from torch._functorch.eager_transforms import (
debug_unwrap,
functionalize, functionalize,
hessian, hessian,
jacfwd, jacfwd,
@ -26,4 +27,5 @@ __all__ = [
"vjp", "vjp",
"functional_call", "functional_call",
"stack_module_state", "stack_module_state",
"debug_unwrap",
] ]