mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
ghstack-source-id: e44b3941530be83a630ec93f1478eec741ffca2e Pull-Request-resolved: https://github.com/pytorch/pytorch/pull/162815 Fixes #ISSUE_NUMBER Relanding due to internal weirdness. Separate PR to codev w/o ghstack. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164751 Approved by: https://github.com/albanD
321 lines
12 KiB
Python
321 lines
12 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
# flake8: noqa: B950
|
|
|
|
import functools
|
|
import itertools
|
|
from unittest import mock
|
|
|
|
import torch
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
import torch._dynamo.utils
|
|
from torch import _inductor as inductor
|
|
from torch._dynamo import compiled_autograd
|
|
from torch._dynamo._trace_wrapped_higher_order_op import trace_wrapped
|
|
from torch._dynamo.testing import normalize_gm
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
|
|
|
|
def _multiply(x):
|
|
return x * x
|
|
|
|
|
|
def _multiply_invoke(grad):
|
|
return trace_wrapped(grad, fn=_multiply)
|
|
|
|
|
|
class BackwardHigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
|
def test_invoke_in_eager(self):
|
|
x = torch.tensor([0.5, 0.5], requires_grad=True)
|
|
y = torch.tensor([0.5, 0.5], requires_grad=True)
|
|
|
|
def fn(x, y):
|
|
x.register_hook(_multiply_invoke)
|
|
return x * y
|
|
|
|
out = fn(x, y)
|
|
grad_out = torch.tensor([2.0, 2.0])
|
|
out.backward(grad_out)
|
|
self.assertEqual(x.grad, y * grad_out)
|
|
|
|
def test_invoke_in_pt2(self):
|
|
for backend in ["eager", "aot_eager", "inductor"]:
|
|
torch._dynamo.reset()
|
|
x = torch.tensor([0.5, 0.5], requires_grad=True)
|
|
y = torch.tensor([0.5, 0.5], requires_grad=True)
|
|
|
|
def fn(x, y):
|
|
x.register_hook(_multiply_invoke)
|
|
return x * y
|
|
|
|
fn = torch.compile(fn, backend=backend)
|
|
out = fn(x, y)
|
|
grad_out = torch.tensor([2.0, 2.0])
|
|
out.backward(grad_out)
|
|
self.assertEqual(x.grad, grad_out * y)
|
|
|
|
def test_invoke_make_fx_forward_contrived(self):
|
|
x = torch.tensor([0.5, 0.5], requires_grad=True)
|
|
out = make_fx(_multiply_invoke)(x)
|
|
self.assertEqual(out(x), torch.tensor([0.25, 0.25]))
|
|
actual = normalize_gm(out.print_readable(False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class _multiply_invoke(torch.nn.Module):
|
|
def forward(self, grad_1: "f32[2]"):
|
|
trace_wrapped: "f32[2]" = torch__dynamo__trace_wrapped_higher_order_op_self_invoke(grad_1); grad_1 = None
|
|
return trace_wrapped
|
|
""",
|
|
)
|
|
|
|
def test_invoke_make_bw(self):
|
|
x = torch.tensor([0.5, 0.5], requires_grad=True)
|
|
|
|
def fwd(x):
|
|
z = x * x
|
|
return z + z
|
|
|
|
res = fwd(x)
|
|
res.backward(torch.tensor([1.0, 1.0]))
|
|
out = make_fx(_multiply_invoke)(x.grad)
|
|
self.assertEqual(out(x.grad), torch.tensor([4.0, 4.0]))
|
|
actual = normalize_gm(out.print_readable(False))
|
|
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class _multiply_invoke(torch.nn.Module):
|
|
def forward(self, grad_1: "f32[2]"):
|
|
trace_wrapped: "f32[2]" = torch__dynamo__trace_wrapped_higher_order_op_self_invoke(grad_1); grad_1 = None
|
|
return trace_wrapped
|
|
""",
|
|
)
|
|
|
|
@mock.patch(
|
|
"torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count
|
|
)
|
|
def test_invoke_in_pt2_compiled_autograd(self, _):
|
|
graph = None
|
|
|
|
def compiler_fn(gm):
|
|
def inner_compiler(gm_, example_inputs_):
|
|
nonlocal graph
|
|
self.assertEqual(graph, None)
|
|
graph = gm_
|
|
return inductor.compile(gm_, example_inputs_)
|
|
|
|
return torch.compile(
|
|
gm, backend=inner_compiler, fullgraph=True, dynamic=True
|
|
)
|
|
|
|
for backend in ["eager", "aot_eager", "inductor"]:
|
|
torch._dynamo.reset()
|
|
x = torch.tensor([0.5, 0.5], requires_grad=True)
|
|
y = torch.tensor([0.5, 0.5], requires_grad=True)
|
|
|
|
def fn(x, y):
|
|
x.register_hook(_multiply_invoke)
|
|
return x + y
|
|
|
|
fn = torch.compile(fn, backend=backend)
|
|
out = fn(x, y)
|
|
grad_out = torch.tensor([2.0, 2.0])
|
|
with compiled_autograd._enable(compiler_fn):
|
|
out.backward(grad_out)
|
|
actual = normalize_gm(graph.print_readable(False))
|
|
self.assertEqual(x.grad, grad_out * grad_out)
|
|
if backend == "aot_eager":
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_inputs_ : list, s69: "Sym(s21)", L_sizes_0_: "f32[0, s21]"):
|
|
l_inputs_ = L_inputs_
|
|
l_sizes_0_ = L_sizes_0_
|
|
|
|
getitem: "f32[s21]" = l_inputs_[0]
|
|
getitem_1: "f32[s21]" = l_inputs_[1]
|
|
getitem_2: "f32[s21]" = l_inputs_[2]; l_inputs_ = None
|
|
|
|
size: "Sym(s21)" = l_sizes_0_.size(1); l_sizes_0_ = None
|
|
|
|
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False, 6)]); getitem = size = None
|
|
getitem_9: "f32[s21]" = validate_outputs[0]; validate_outputs = None
|
|
|
|
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
|
|
aot1_tangents_1: "f32[s21]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
|
|
|
|
accumulate_grad = torch__dynamo_compiled_autograd_ops_AccumulateGrad([aot1_tangents_1], getitem_1, None, False); getitem_1 = None
|
|
getitem_11: "f32[s21]" = accumulate_grad[0]; accumulate_grad = None
|
|
|
|
result: "f32[s21]" = aot1_tangents_1 * aot1_tangents_1; aot1_tangents_1 = None
|
|
|
|
accumulate_grad_1 = torch__dynamo_compiled_autograd_ops_AccumulateGrad([result], getitem_2, None, False); result = getitem_2 = None
|
|
getitem_12: "f32[s21]" = accumulate_grad_1[0]; accumulate_grad_1 = None
|
|
return (getitem_11, getitem_12)
|
|
""",
|
|
)
|
|
elif backend == "inductor":
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_inputs_ : list, s69: "Sym(s21)", L_sizes_0_: "f32[0, s21]"):
|
|
l_inputs_ = L_inputs_
|
|
l_sizes_0_ = L_sizes_0_
|
|
|
|
getitem: "f32[s21]" = l_inputs_[0]
|
|
getitem_1: "f32[s21]" = l_inputs_[1]
|
|
getitem_2: "f32[s21]" = l_inputs_[2]; l_inputs_ = None
|
|
|
|
size: "Sym(s21)" = l_sizes_0_.size(1); l_sizes_0_ = None
|
|
|
|
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False, 6)]); getitem = size = None
|
|
getitem_9: "f32[s21]" = validate_outputs[0]; validate_outputs = None
|
|
|
|
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
|
|
aot3_tangents_1: "f32[s21]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
|
|
|
|
accumulate_grad = torch__dynamo_compiled_autograd_ops_AccumulateGrad([aot3_tangents_1], getitem_1, None, False); getitem_1 = None
|
|
getitem_11: "f32[s21]" = accumulate_grad[0]; accumulate_grad = None
|
|
|
|
result: "f32[s21]" = aot3_tangents_1 * aot3_tangents_1; aot3_tangents_1 = None
|
|
|
|
accumulate_grad_1 = torch__dynamo_compiled_autograd_ops_AccumulateGrad([result], getitem_2, None, False); result = getitem_2 = None
|
|
getitem_12: "f32[s21]" = accumulate_grad_1[0]; accumulate_grad_1 = None
|
|
return (getitem_11, getitem_12)
|
|
""",
|
|
)
|
|
|
|
graph = None
|
|
|
|
@mock.patch(
|
|
"torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count
|
|
)
|
|
def test_invoke_in_pt2_compiled_autograd_side_effect(self, _):
|
|
def _side_effect_stateful_fn2(x, obj):
|
|
obj.counter = obj.counter + 1
|
|
return _multiply(x)
|
|
|
|
def _side_effectful_invoke2(grad, fn):
|
|
return trace_wrapped(grad, fn=fn)
|
|
|
|
graph = None
|
|
|
|
def compiler_fn(gm):
|
|
def inner_compiler(gm_, example_inputs_):
|
|
nonlocal graph
|
|
self.assertEqual(graph, None)
|
|
graph = gm_
|
|
return inductor.compile(gm_, example_inputs_)
|
|
|
|
return torch.compile(
|
|
gm, backend=inner_compiler, fullgraph=True, dynamic=True
|
|
)
|
|
|
|
for backend in ["inductor"]:
|
|
torch._dynamo.reset()
|
|
x = torch.tensor([0.5, 0.5], requires_grad=True)
|
|
y = torch.tensor([0.5, 0.5], requires_grad=True)
|
|
|
|
class MyObj:
|
|
def __init__(self) -> None:
|
|
self.counter = 0
|
|
|
|
obj = MyObj()
|
|
inner_fn = functools.partial(_side_effect_stateful_fn2, obj=obj)
|
|
hook_fn = functools.partial(_side_effectful_invoke2, fn=inner_fn)
|
|
x.register_hook(hook_fn)
|
|
|
|
def fn(x, y):
|
|
return x + y
|
|
|
|
fn = torch.compile(fn, backend=backend, fullgraph=True)
|
|
out = fn(x, y)
|
|
grad_out = torch.tensor([2.0, 2.0])
|
|
with compiled_autograd._enable(compiler_fn):
|
|
out.backward(grad_out)
|
|
actual = normalize_gm(graph.print_readable(False))
|
|
self.assertEqual(obj.counter, 1)
|
|
self.assertEqual(x.grad, grad_out + grad_out)
|
|
if backend in ["aot_eager", "inductor"]:
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_inputs_ : list, s69: "Sym(s21)", L_sizes_0_: "f32[0, s21]", L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s45)"):
|
|
l_inputs_ = L_inputs_
|
|
l_sizes_0_ = L_sizes_0_
|
|
l_hooks_1_keywords_fn_keywords_obj_counter = L_hooks_1_keywords_fn_keywords_obj_counter
|
|
|
|
getitem: "f32[s21]" = l_inputs_[0]
|
|
getitem_1: "f32[s21]" = l_inputs_[1]
|
|
getitem_2: "f32[s21]" = l_inputs_[2]; l_inputs_ = None
|
|
|
|
size: "Sym(s21)" = l_sizes_0_.size(1); l_sizes_0_ = None
|
|
|
|
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False, 6)]); getitem = size = None
|
|
getitem_9: "f32[s21]" = validate_outputs[0]; validate_outputs = None
|
|
|
|
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
|
|
aot0_tangents_1: "f32[s21]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
|
|
|
|
accumulate_grad = torch__dynamo_compiled_autograd_ops_AccumulateGrad([aot0_tangents_1], getitem_1, None, False); getitem_1 = None
|
|
getitem_11: "f32[s21]" = accumulate_grad[0]; accumulate_grad = None
|
|
|
|
add: "Sym(s45 + 1)" = l_hooks_1_keywords_fn_keywords_obj_counter + 1; l_hooks_1_keywords_fn_keywords_obj_counter = None
|
|
|
|
result: "f32[s21]" = aot0_tangents_1 * aot0_tangents_1; aot0_tangents_1 = None
|
|
|
|
accumulate_grad_1 = torch__dynamo_compiled_autograd_ops_AccumulateGrad([result], getitem_2, None, False); result = getitem_2 = None
|
|
getitem_12: "f32[s21]" = accumulate_grad_1[0]; accumulate_grad_1 = None
|
|
return (getitem_11, getitem_12, add)
|
|
""",
|
|
)
|
|
|
|
out = fn(x, y)
|
|
out.backward(grad_out)
|
|
self.assertEqual(obj.counter, 2)
|
|
|
|
out = fn(x, y)
|
|
out.backward(grad_out)
|
|
self.assertEqual(obj.counter, 3)
|
|
graph = None
|
|
|
|
def test_invoke_in_pt2_compiled_autograd_graph_breaks(self):
|
|
def _graph_breaking_fn(x):
|
|
print("Boo!")
|
|
return _multiply(x)
|
|
|
|
def _graph_break_invoke(grad):
|
|
return trace_wrapped(grad, fn=_graph_breaking_fn)
|
|
|
|
def compiler_fn(gm):
|
|
return torch.compile(gm, backend="inductor", fullgraph=True, dynamic=True)
|
|
|
|
for backend in ["eager", "aot_eager", "inductor"]:
|
|
torch._dynamo.reset()
|
|
x = torch.tensor([0.5, 0.5], requires_grad=True)
|
|
y = torch.tensor([0.5, 0.5], requires_grad=True)
|
|
|
|
def fn(x, y):
|
|
x.register_hook(_graph_break_invoke)
|
|
return x + y
|
|
|
|
fn = torch.compile(fn, backend=backend, fullgraph=True)
|
|
out = fn(x, y)
|
|
grad_out = torch.tensor([2.0, 2.0])
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.Unsupported,
|
|
"print",
|
|
):
|
|
with compiled_autograd._enable(compiler_fn):
|
|
out.backward(grad_out)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|