[compiled autograd] support .backward(inputs=) (#128252)

autograd already marks nodes as needed or not before calling calling compiled autograd. so our worklist already skips nodes not specified in the `inputs` kwarg.

For the .backward(inputs=) case, I'm keeping the grads as outputs, just like for .grad(inputs=), this is to still guard on graph_output when we collect the nodes. This does not get DCE'd rn, and is ignored in the post graph bytecode.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128252
Approved by: https://github.com/jansel
This commit is contained in:
Simon Fan
2024-06-07 16:26:47 -07:00
committed by PyTorch MergeBot
parent 583a56d5a8
commit 2176ef7dfa
3 changed files with 142 additions and 28 deletions

View File

@ -581,7 +581,7 @@ main()
self.check_output_and_recompiles(fn)
def test_output_nodes(self):
def test_output_nodes_all_leaves(self):
def fn():
y = torch.randn(1, 4, requires_grad=True)
z = torch.randn(1, 4, requires_grad=True)
@ -593,7 +593,7 @@ main()
x = torch.randn([1, 4])
result = model(x).sum()
gy, gz = torch.autograd.grad(result, [y, z])
gy, gz = torch.autograd.grad(result, inputs=[y, z])
assert y.grad is None
assert z.grad is None
yield gy
@ -601,6 +601,111 @@ main()
self.check_output_and_recompiles(fn)
def test_output_nodes_some_leaves(self):
def fn():
class UnreachableBwd(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x
@staticmethod
def backward(ctx, gO):
raise RuntimeError
y = torch.randn(1, 4, requires_grad=True)
z = torch.randn(1, 4, requires_grad=True)
def model(x):
return torch.sigmoid(UnreachableBwd.apply(y) * z)
for _ in range(3):
x = torch.randn([1, 4])
result = model(x).sum()
gz = torch.autograd.grad(result, inputs=[z])
assert y.grad is None
assert z.grad is None
yield gz
self.check_output_and_recompiles(fn)
def test_no_output_nodes_all_leaves(self):
def fn():
y = torch.randn(1, 4, requires_grad=True)
z = torch.randn(1, 4, requires_grad=True)
def model(x):
return torch.sigmoid(x * z + torch.sin(y) + torch.cos(y))
for _ in range(3):
x = torch.randn([1, 4])
result = model(x).sum()
out = result.backward()
assert out is None
assert y.grad is not None
assert z.grad is not None
yield y.grad
yield z.grad
y.grad = None
z.grad = None
self.check_output_and_recompiles(fn)
def test_no_output_nodes_some_leaves(self):
def fn():
class UnreachableBwd(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x
@staticmethod
def backward(ctx, gO):
raise RuntimeError
y = torch.randn(1, 4, requires_grad=True)
z = torch.randn(1, 4, requires_grad=True)
a = torch.randn(1, 4, requires_grad=True)
def model(x):
return torch.sigmoid(x * y * z * UnreachableBwd.apply(a))
for _ in range(3):
x = torch.randn([1, 4])
result = model(x).sum()
out = result.backward(inputs=[y, z])
assert out is None
assert y.grad is not None
assert z.grad is not None
assert a.grad is None
yield y.grad
yield z.grad
y.grad = None
z.grad = None
self.check_output_and_recompiles(fn)
def test_no_output_nodes_different_leaves_will_recompile(self):
def fn():
def fwd(x, y, z):
out = x * y # MulBackward0
out2 = out * z # MulBackward0
return out2.sum() # SumBackward0
x = torch.randn(5, requires_grad=True)
y = torch.randn(5, requires_grad=True)
z = torch.randn(5, requires_grad=True)
loss = fwd(x, y, z)
torch.compile(lambda: torch.autograd.backward(loss, inputs=[x]))()
yield x.grad
x.grad = None
loss = fwd(x, y, z)
torch.compile(lambda: torch.autograd.backward(loss, inputs=[y]))()
yield y.grad
# Guarded by TensorArg id, mismatch on last MulBackward0
self.check_output_and_recompiles(fn, 2)
def test_dynamic_shapes(self):
def fn():
model = torch.nn.Sequential(
@ -1986,7 +2091,18 @@ def wrap_test_class(orig_cls):
return cls
known_graph_breaks_tests = {}
known_graph_breaks_tests = {
"test_hook_none", # uses assert in hook
"test_post_accumulate_grad_hook_e2e", # optim.Adam manually graph breaks
"test_tensor_hooks_inplace", # uses assert in hook
"test_tensor_hooks_inplace_over_view", # uses assert in hook
"test_grad_fn_prehooks", # uses assert in hook
"test_grad_fn_prehooks_multiple_outputs", # uses assert in hook
"test_grad_fn_prehooks_remove_hooks", # uses handle.remove() in hook
"test_tensor_hooks_inplace_multiple_outputs", # uses assert in hook
"test_hooks", # uses assert in hook
"test_accumulate_grad_posthooks_can_observe_tensor_prehook", # allclose
}
# These groups of tests aren't supported yet
known_failures_re = re.compile(
@ -2004,23 +2120,14 @@ known_failing_tests = {
"test_saved_variable_saved_original_inplace_detach", # AssertionError: RuntimeError not raised
"test_saving_variable_to_disk", # Cannot call numel() on tensor with symbolic sizes/strides
"test_setitem_mask", # torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: It appears that you're
"test_tensor_hooks_inplace_over_view", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {}
"test_tensor_hooks_inplace", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {}
"test_wrapped_number_saved_variable_hooks", # RuntimeError: this hook should not be called
"test_accumulate_grad_posthooks_can_observe_tensor_prehook", # data dependent operator: aten.allclose.default
"test_accumulate_grad_tensor_reference", # backend='inner_compiler' raised:
"test_anomaly_grad_warnings", # "one of the variables needed for gradient computation has been modified by an...
"test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd
"test_backward_with_inputs", # specifying inputs= with .backward() not yet implemented for compiled autograd
"test_current_node", # TorchDispatchMode not yet implemented for compiled autograd
"test_custom_function_exception", # "Simulate error on backward pass" does not match "type object 'SimulateBackwa...
"test_grad_batched_grad", # Cannot access storage of BatchedTensorImpl
"test_grad_unreachable_discovery", # specifying inputs= with .backward() not yet implemented for compiled autograd
"test_index_backward_does_not_save_tensor", # dynamic shape operator: aten.nonzero.default
"test_post_accumulate_grad_hook_e2e", # tensor_post_acc_grad_hooks not implemented for compiled autograd
"test_post_accumulate_grad_hook_gets_cleaned_up", # tensor_post_acc_grad_hooks not implemented for compiled autograd
"test_post_accumulate_grad_hook_multiple_hooks", # tensor_post_acc_grad_hooks not implemented for compiled autograd
"test_post_accumulate_grad_hook_multiple_tensors", # tensor_post_acc_grad_hooks not implemented for compiled autograd
"test_post_accumulate_grad_hook_ordering", # tensor_post_acc_grad_hooks not implemented for compiled autograd
"test_post_accumulate_grad_hook_returns_not_None", # "hooks should return None." does not match
"test_reentrant_child_error", # "Simulate error" does not match "type object 'ReentrantFunc' has no attribute...
@ -2052,21 +2159,20 @@ known_failing_tests = {
"test_hessian_vector", # RuntimeError: compiled_autograd does not support create_graph
"test_hook_closure_cycle_use_custom_function_True_use_tensor_hook_False", # AttributeError: type object
"test_hook_closure_cycle_use_custom_function_True_use_tensor_hook_True", # AttributeError: type object
"test_hook_edge_case_when_called_with_grad", # RuntimeError: specifying inputs= with .backward() not yet
"test_hooks", # torch._dynamo.exc.Unsupported: inline in skipfiles
"test_hook_edge_case_when_called_with_grad", # retains_grad_hooks NYI
"test_inplace_on_view_backward", # RuntimeError: compiled_autograd does not support create_graph
"test_multi_grad_any_hooks", # RuntimeError: specifying inputs= with .backward() not yet implemented for compiled autograd
"test_multi_grad_all_hooks", # RuntimeError: specifying inputs= with .backward() not yet implemented for compiled autograd
"test_multi_grad_any_hooks", # register_multi_grad_hook NYI
"test_multi_grad_all_hooks", # retains_grad_hooks NYI
"test_nested_anomaly_detect_nan", # RuntimeError: compiled_autograd does not support create_graph
"test_nested_anomaly_printstack_cleanup", # RuntimeError: compiled_autograd does not support create_graph
"test_once_differentiable", # RuntimeError: compiled_autograd does not support create_graph
"test_prehook_ordering", # RuntimeError: specifying inputs= with .backward() not yet implemented for compiled autograd
"test_prehook_ordering", # retains_grad_hooks NYI
"test_retain_grad", # RuntimeError: retains_grad_hooks not implemented for compiled autograd
"test_saved_variable_packing_unpacking_saved_original_with_hooks", # RuntimeError: compiled_autograd
"test_select_sum", # torch.autograd.gradcheck.GradcheckError: While computing batched gradients
"test_unrelated_inputs", # torch.autograd.gradcheck.GradcheckError: While computing batched gradients
"test_will_engine_execute_node", # RuntimeError: specifying inputs= with .backward() not yet implemented for compiled autograd
"test_backward_to_node", # RuntimeError: specifying inputs= with .backward() not yet implemented for compiled autograd
"test_will_engine_execute_node", # retains_grad_hooks NYI
"test_backward_to_node", # retains_grad_hooks NYI
"test_anomaly_detect_nan", # torch._dynamo.exc.TorchRuntimeError: Failed running call_function aten.add.Tensor(
"test_autograd_multiple_views_python", # torch._dynamo.exc.Unsupported: call_function args: TensorVariable(
"test_autograd_node_isinstance", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertIsInstance
@ -2083,11 +2189,7 @@ known_failing_tests = {
"test_deep_reentrant", # torch._dynamo.exc.InternalTorchDynamoError: '<' not supported between instances of
"test_dont_materialize_grads", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertIsNone
"test_function_returns_undefined_tensor", # torch._dynamo.exc.TorchRuntimeError: Failed running call_function
"test_grad_fn_prehooks", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {}
"test_grad_fn_prehooks_multiple_outputs", # torch._dynamo.exc.Unsupported: 'inline in skipfiles:
"test_grad_fn_prehooks_remove_hooks", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: RemovableHandle.remove
"test_grad_mode_restored_reentrant", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertTrue
"test_hook_none", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertIsNotNone
"test_invalid_gradients", # AssertionError: "expected shape" does not match "The size of tensor a (5) must match
"test_mark_non_differentiable_mixed", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertTrue
"test_materialize_grads", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {}
@ -2107,7 +2209,6 @@ known_failing_tests = {
"test_set_materialize_non_diff_grads", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertIsNone
"test_setup_context_when_forward_has_default_args", # torch._dynamo.exc.Unsupported: call_function args
"test_simple_reentrant", # torch._dynamo.exc.Unsupported: call_method SkipFunctionVariable() sum [] {}
"test_tensor_hooks_inplace_multiple_outputs", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {}
"test_lobpcg", # torch._dynamo.exc.Unsupported: 'call_function LOBPCGAutogradFunction.backward in skip_files
"test_backward_dict_grad_for_nontensor", # AssertionError: "non-Tensor-like types" does not match "'skip function
"test_backward_dict_invalid_keys", # AssertionError: "to have keys {'x'}" does not match "'skip function
@ -2120,7 +2221,6 @@ known_failing_tests = {
"test_backward_tensorlist_input_requires_list_grads_none_or_Tensor", # AssertionError: "None or Tensor"
"test_backward_tensorlist_input_requires_list_grads_with_same_numel", # AssertionError: "3 gradients
"test_save_for_backward_inputs_are_namedtuple", # torch._dynamo.exc.Unsupported: 'skip function
"test_autograd_function_backed_op", # RuntimeError: compiled_args not implemented
"test_setitem", # AssertionError: Tensor-likes are not close!
"test_grad_nonleaf_register_hook", # IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors)
"test_scalar_grad_mixed_device", # Fake Tensors aren't propagating device properly for 0-dim grads

View File

@ -1342,6 +1342,23 @@ class TestAutograd(TestCase):
b.backward()
def test_accumulate_grad_posthooks_should_not_execute(self):
def tensor_prehook(g):
raise RuntimeError
def posthook(gO, gI):
raise RuntimeError
a = torch.tensor(1.0, requires_grad=True)
a.register_hook(tensor_prehook)
b = torch.tensor(1.0, requires_grad=True)
c = a.clone()
acc = c.grad_fn.next_functions[0][0]
acc.register_hook(posthook)
out = a + b + c
out.sum().backward(inputs=[b])
def test_hook_edge_case_when_called_with_grad(self):
# grad executes the tensor hooks of the next node but not
# grad_fn pre hooks or the post hooks

View File

@ -630,9 +630,6 @@ variable_list compiled_autograd(
GraphTask& graph_task,
bool accumulate_grad,
const edge_list& output_edges) {
TORCH_CHECK(
output_edges.empty() || !accumulate_grad,
"specifying inputs= with .backward() not yet implemented for compiled autograd")
TORCH_CHECK(
c10::impl::TorchDispatchModeTLS::stack_len() == 0,
"TorchDispatchMode not yet implemented for compiled autograd")