[HOP] Reworked DispatchKey.Autograd (#151107)

This PR intends to rework the dispatching of the autograd key.
I.e., currently the DispatchKey.Autograd of the HOPs was triggered, even if non of the operands of the HOP have `requires_grad=True`. With this rework, the autograd is bypassed if non of the operands require gradients and only invoked if any of the operands require gradients.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151107
Approved by: https://github.com/ydwu4
This commit is contained in:
Thomas Bohnstingl
2025-04-15 19:55:42 +00:00
committed by PyTorch MergeBot
parent 19a33b20c2
commit a2632d5241
15 changed files with 34 additions and 36 deletions

View File

@ -401,7 +401,7 @@ def associative_scan_op_dense(combine_fn, xs, additional_inputs):
return generic_associative_scan(combine_fn, xs, additional_inputs=additional_inputs)
associative_scan_op.py_impl(DispatchKey.Autograd)(
associative_scan_op.py_autograd_impl(
autograd_not_implemented(associative_scan_op, deferred_error=True)
)

View File

@ -56,7 +56,7 @@ class BaseHOP(HigherOrderOperator, abc.ABC):
# Set up the registrations
# If you want to override any of these, override them in your subclass.
self.py_impl(DispatchKey.Autograd)(self._call_Autograd)
self.py_autograd_impl(self._call_Autograd)
self.py_functionalize_impl(self._call_Functionalize)
self.py_impl(ProxyTorchDispatchMode)(self._call_ProxyTorchDispatchMode)
self.py_impl(FakeTensorMode)(self._call_FakeTensorMode)
@ -76,13 +76,6 @@ class BaseHOP(HigherOrderOperator, abc.ABC):
def _call_Autograd(self, subgraph, *operands, **kwargs):
if isinstance(subgraph, torch.fx.GraphModule):
pass
if not torch.is_grad_enabled() or pytree.tree_all_only(
torch.Tensor,
lambda t: not t.requires_grad, # type: ignore[union-attr]
operands,
):
with torch._C._AutoDispatchBelowAutograd():
return self(subgraph, *operands, **kwargs)
# We assume the subgraph doesn't mutate inputs and there is no aliasing.
# In the PT2 stack, this is Dynamo's responsibility to figure out.

View File

@ -391,7 +391,7 @@ class CondAutogradOp(torch.autograd.Function):
# As long as one of the tensors in pred or operands requires grad,
# all the output would require grad with backward fn set to be the CondAutogradOp.
# This is consistent with autograd.Function's semantic.
@cond_op.py_impl(DispatchKey.Autograd)
@cond_op.py_autograd_impl
def cond_autograd(pred, true_fn, false_fn, operands):
return CondAutogradOp.apply(
pred,

View File

@ -87,7 +87,7 @@ def call_delegate_cpu(lowered_module, *args):
return lowered_module.original_module.module()(*new_args)
@executorch_call_delegate.py_impl(torch._C.DispatchKey.Autograd)
@executorch_call_delegate.py_autograd_impl
# pyre-ignore
def call_delegate_autograd(lowered_module, *args):
# TODO: support autograd

View File

@ -709,6 +709,7 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
return grad_query, grad_key, grad_value, *none_grads, *grad_score_mod_captured
# TODO: Rework DispatchKey.Autograd to py_autograd_impl
@flex_attention.py_impl(DispatchKey.Autograd)
def flex_attention_autograd(
query: torch.Tensor,
@ -1190,6 +1191,6 @@ def flex_attention_backward_fake_tensor_mode(
return grad_query, grad_key, grad_value, grad_score_mod_captured
flex_attention_backward.py_impl(DispatchKey.Autograd)(
flex_attention_backward.py_autograd_impl(
autograd_not_implemented(flex_attention_backward, deferred_error=True)
)

View File

@ -77,7 +77,7 @@ def hints_wrapper_dense(body_fn, args, kwargs, hints):
return body_fn(*args, **kwargs)
hints_wrapper.py_impl(DispatchKey.Autograd)(
hints_wrapper.py_autograd_impl(
autograd_not_implemented(hints_wrapper, deferred_error=True)
)

View File

@ -447,22 +447,8 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
return None, None, None, *grads
@invoke_subgraph.py_impl(DispatchKey.Autograd)
@invoke_subgraph.py_autograd_impl
def _(subgraph, identifier, operands):
if not torch.is_grad_enabled():
with torch._C._AutoDispatchBelowAutograd():
return invoke_subgraph(subgraph, identifier, operands)
# A shortcut for the case where all inputs don't require gradient,
# we skip tracing the forward and backward graph.
if pytree.tree_all_only(
torch.Tensor,
lambda t: not t.requires_grad, # type: ignore[union-attr]
operands,
):
with torch._C._AutoDispatchBelowAutograd():
return invoke_subgraph(subgraph, identifier, operands)
# Check if we have already traced the subgraph.
invoke_subgraph_cache = get_invoke_subgraph_cache()
if invoke_subgraph_cache:

View File

@ -221,6 +221,7 @@ def map_dense(f, xs, pos_args):
return _stack_pytree(pytrees)
# TODO: Rework DispatchKey.Autograd to py_autograd_impl
@map_impl.py_impl(DispatchKey.Autograd)
def map_autograd(f, xs, pos_args):
num_mapped_args = len(xs)

View File

@ -130,9 +130,7 @@ def out_dtype_fallback(op, output_dtype, *args):
return res
out_dtype.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(out_dtype, deferred_error=True)
)
out_dtype.py_autograd_impl(autograd_not_implemented(out_dtype, deferred_error=True))
@out_dtype.py_impl(ProxyTorchDispatchMode)

View File

@ -42,7 +42,7 @@ def run_const_graph_functional(ctx, graph, args):
return ctx.wrap_tensors(out)
run_const_graph.py_impl(DispatchKey.Autograd)(
run_const_graph.py_autograd_impl(
autograd_not_implemented(run_const_graph, deferred_error=True)
)

View File

@ -804,7 +804,7 @@ class ScanAutogradOp(torch.autograd.Function):
return *[None] * 4, *g_init, *g_xs, *g_additional_inputs
@scan_op.py_impl(DispatchKey.Autograd)
@scan_op.py_autograd_impl
def scan_autograd(combine_fn, init, xs, additional_inputs):
if not any(
el.requires_grad

View File

@ -60,7 +60,7 @@ def strict_mode_op_dense(callable, operands):
return callable(*operands)
strict_mode_op.py_impl(DispatchKey.Autograd)(
strict_mode_op.py_autograd_impl(
autograd_not_implemented(strict_mode_op, deferred_error=True)
)

View File

@ -150,7 +150,7 @@ def call_torchbind_fake(mode, *args, **kwargs):
)
call_torchbind.py_impl(DispatchKey.Autograd)(
call_torchbind.py_autograd_impl(
autograd_not_implemented(call_torchbind, deferred_error=True)
)

View File

@ -215,7 +215,7 @@ def while_loop_dense(cond_fn, body_fn, carried_inputs, additional_inputs):
return carried_vals
while_loop_op.py_impl(DispatchKey.Autograd)(
while_loop_op.py_autograd_impl(
autograd_not_implemented(while_loop_op, deferred_error=True)
)

View File

@ -306,6 +306,25 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
self.non_fallthrough_keys = self.non_fallthrough_keys.add(k)
return super().py_impl(k)
def py_autograd_impl(
self,
fn: Callable[_P, _T],
) -> Callable[_P, _T]:
def maybe_run_autograd(*args: _P.args, **kwargs: _P.kwargs) -> _T:
if not torch.is_grad_enabled() or pytree.tree_all_only(
torch.Tensor,
lambda t: not t.requires_grad, # type: ignore[union-attr]
(*args, kwargs),
):
with torch._C._AutoDispatchBelowAutograd():
return self(*args, **kwargs)
return fn(*args, **kwargs)
self.py_impl(DispatchKey.Autograd)(maybe_run_autograd)
return fn
@property
def namespace(self):
return self._ns