mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[HOP] Reworked DispatchKey.Autograd (#151107)
This PR intends to rework the dispatching of the autograd key. I.e., currently the DispatchKey.Autograd of the HOPs was triggered, even if non of the operands of the HOP have `requires_grad=True`. With this rework, the autograd is bypassed if non of the operands require gradients and only invoked if any of the operands require gradients. Pull Request resolved: https://github.com/pytorch/pytorch/pull/151107 Approved by: https://github.com/ydwu4
This commit is contained in:
committed by
PyTorch MergeBot
parent
19a33b20c2
commit
a2632d5241
@ -401,7 +401,7 @@ def associative_scan_op_dense(combine_fn, xs, additional_inputs):
|
||||
return generic_associative_scan(combine_fn, xs, additional_inputs=additional_inputs)
|
||||
|
||||
|
||||
associative_scan_op.py_impl(DispatchKey.Autograd)(
|
||||
associative_scan_op.py_autograd_impl(
|
||||
autograd_not_implemented(associative_scan_op, deferred_error=True)
|
||||
)
|
||||
|
||||
|
||||
@ -56,7 +56,7 @@ class BaseHOP(HigherOrderOperator, abc.ABC):
|
||||
|
||||
# Set up the registrations
|
||||
# If you want to override any of these, override them in your subclass.
|
||||
self.py_impl(DispatchKey.Autograd)(self._call_Autograd)
|
||||
self.py_autograd_impl(self._call_Autograd)
|
||||
self.py_functionalize_impl(self._call_Functionalize)
|
||||
self.py_impl(ProxyTorchDispatchMode)(self._call_ProxyTorchDispatchMode)
|
||||
self.py_impl(FakeTensorMode)(self._call_FakeTensorMode)
|
||||
@ -76,13 +76,6 @@ class BaseHOP(HigherOrderOperator, abc.ABC):
|
||||
def _call_Autograd(self, subgraph, *operands, **kwargs):
|
||||
if isinstance(subgraph, torch.fx.GraphModule):
|
||||
pass
|
||||
if not torch.is_grad_enabled() or pytree.tree_all_only(
|
||||
torch.Tensor,
|
||||
lambda t: not t.requires_grad, # type: ignore[union-attr]
|
||||
operands,
|
||||
):
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
return self(subgraph, *operands, **kwargs)
|
||||
|
||||
# We assume the subgraph doesn't mutate inputs and there is no aliasing.
|
||||
# In the PT2 stack, this is Dynamo's responsibility to figure out.
|
||||
|
||||
@ -391,7 +391,7 @@ class CondAutogradOp(torch.autograd.Function):
|
||||
# As long as one of the tensors in pred or operands requires grad,
|
||||
# all the output would require grad with backward fn set to be the CondAutogradOp.
|
||||
# This is consistent with autograd.Function's semantic.
|
||||
@cond_op.py_impl(DispatchKey.Autograd)
|
||||
@cond_op.py_autograd_impl
|
||||
def cond_autograd(pred, true_fn, false_fn, operands):
|
||||
return CondAutogradOp.apply(
|
||||
pred,
|
||||
|
||||
@ -87,7 +87,7 @@ def call_delegate_cpu(lowered_module, *args):
|
||||
return lowered_module.original_module.module()(*new_args)
|
||||
|
||||
|
||||
@executorch_call_delegate.py_impl(torch._C.DispatchKey.Autograd)
|
||||
@executorch_call_delegate.py_autograd_impl
|
||||
# pyre-ignore
|
||||
def call_delegate_autograd(lowered_module, *args):
|
||||
# TODO: support autograd
|
||||
|
||||
@ -709,6 +709,7 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
|
||||
return grad_query, grad_key, grad_value, *none_grads, *grad_score_mod_captured
|
||||
|
||||
|
||||
# TODO: Rework DispatchKey.Autograd to py_autograd_impl
|
||||
@flex_attention.py_impl(DispatchKey.Autograd)
|
||||
def flex_attention_autograd(
|
||||
query: torch.Tensor,
|
||||
@ -1190,6 +1191,6 @@ def flex_attention_backward_fake_tensor_mode(
|
||||
return grad_query, grad_key, grad_value, grad_score_mod_captured
|
||||
|
||||
|
||||
flex_attention_backward.py_impl(DispatchKey.Autograd)(
|
||||
flex_attention_backward.py_autograd_impl(
|
||||
autograd_not_implemented(flex_attention_backward, deferred_error=True)
|
||||
)
|
||||
|
||||
@ -77,7 +77,7 @@ def hints_wrapper_dense(body_fn, args, kwargs, hints):
|
||||
return body_fn(*args, **kwargs)
|
||||
|
||||
|
||||
hints_wrapper.py_impl(DispatchKey.Autograd)(
|
||||
hints_wrapper.py_autograd_impl(
|
||||
autograd_not_implemented(hints_wrapper, deferred_error=True)
|
||||
)
|
||||
|
||||
|
||||
@ -447,22 +447,8 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
|
||||
return None, None, None, *grads
|
||||
|
||||
|
||||
@invoke_subgraph.py_impl(DispatchKey.Autograd)
|
||||
@invoke_subgraph.py_autograd_impl
|
||||
def _(subgraph, identifier, operands):
|
||||
if not torch.is_grad_enabled():
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
return invoke_subgraph(subgraph, identifier, operands)
|
||||
|
||||
# A shortcut for the case where all inputs don't require gradient,
|
||||
# we skip tracing the forward and backward graph.
|
||||
if pytree.tree_all_only(
|
||||
torch.Tensor,
|
||||
lambda t: not t.requires_grad, # type: ignore[union-attr]
|
||||
operands,
|
||||
):
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
return invoke_subgraph(subgraph, identifier, operands)
|
||||
|
||||
# Check if we have already traced the subgraph.
|
||||
invoke_subgraph_cache = get_invoke_subgraph_cache()
|
||||
if invoke_subgraph_cache:
|
||||
|
||||
@ -221,6 +221,7 @@ def map_dense(f, xs, pos_args):
|
||||
return _stack_pytree(pytrees)
|
||||
|
||||
|
||||
# TODO: Rework DispatchKey.Autograd to py_autograd_impl
|
||||
@map_impl.py_impl(DispatchKey.Autograd)
|
||||
def map_autograd(f, xs, pos_args):
|
||||
num_mapped_args = len(xs)
|
||||
|
||||
@ -130,9 +130,7 @@ def out_dtype_fallback(op, output_dtype, *args):
|
||||
return res
|
||||
|
||||
|
||||
out_dtype.py_impl(DispatchKey.Autograd)(
|
||||
autograd_not_implemented(out_dtype, deferred_error=True)
|
||||
)
|
||||
out_dtype.py_autograd_impl(autograd_not_implemented(out_dtype, deferred_error=True))
|
||||
|
||||
|
||||
@out_dtype.py_impl(ProxyTorchDispatchMode)
|
||||
|
||||
@ -42,7 +42,7 @@ def run_const_graph_functional(ctx, graph, args):
|
||||
return ctx.wrap_tensors(out)
|
||||
|
||||
|
||||
run_const_graph.py_impl(DispatchKey.Autograd)(
|
||||
run_const_graph.py_autograd_impl(
|
||||
autograd_not_implemented(run_const_graph, deferred_error=True)
|
||||
)
|
||||
|
||||
|
||||
@ -804,7 +804,7 @@ class ScanAutogradOp(torch.autograd.Function):
|
||||
return *[None] * 4, *g_init, *g_xs, *g_additional_inputs
|
||||
|
||||
|
||||
@scan_op.py_impl(DispatchKey.Autograd)
|
||||
@scan_op.py_autograd_impl
|
||||
def scan_autograd(combine_fn, init, xs, additional_inputs):
|
||||
if not any(
|
||||
el.requires_grad
|
||||
|
||||
@ -60,7 +60,7 @@ def strict_mode_op_dense(callable, operands):
|
||||
return callable(*operands)
|
||||
|
||||
|
||||
strict_mode_op.py_impl(DispatchKey.Autograd)(
|
||||
strict_mode_op.py_autograd_impl(
|
||||
autograd_not_implemented(strict_mode_op, deferred_error=True)
|
||||
)
|
||||
|
||||
|
||||
@ -150,7 +150,7 @@ def call_torchbind_fake(mode, *args, **kwargs):
|
||||
)
|
||||
|
||||
|
||||
call_torchbind.py_impl(DispatchKey.Autograd)(
|
||||
call_torchbind.py_autograd_impl(
|
||||
autograd_not_implemented(call_torchbind, deferred_error=True)
|
||||
)
|
||||
|
||||
|
||||
@ -215,7 +215,7 @@ def while_loop_dense(cond_fn, body_fn, carried_inputs, additional_inputs):
|
||||
return carried_vals
|
||||
|
||||
|
||||
while_loop_op.py_impl(DispatchKey.Autograd)(
|
||||
while_loop_op.py_autograd_impl(
|
||||
autograd_not_implemented(while_loop_op, deferred_error=True)
|
||||
)
|
||||
|
||||
|
||||
@ -306,6 +306,25 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
|
||||
self.non_fallthrough_keys = self.non_fallthrough_keys.add(k)
|
||||
return super().py_impl(k)
|
||||
|
||||
def py_autograd_impl(
|
||||
self,
|
||||
fn: Callable[_P, _T],
|
||||
) -> Callable[_P, _T]:
|
||||
def maybe_run_autograd(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
||||
if not torch.is_grad_enabled() or pytree.tree_all_only(
|
||||
torch.Tensor,
|
||||
lambda t: not t.requires_grad, # type: ignore[union-attr]
|
||||
(*args, kwargs),
|
||||
):
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
return self(*args, **kwargs)
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
self.py_impl(DispatchKey.Autograd)(maybe_run_autograd)
|
||||
|
||||
return fn
|
||||
|
||||
@property
|
||||
def namespace(self):
|
||||
return self._ns
|
||||
|
||||
Reference in New Issue
Block a user