Change aot_module_simplified to take take arguments directly (#89669)

This is extracted from voz's #89392

Previously, the implementation did some half-assed caching where it
returned a callable, that when invoked for the first time, actually
performed the compilation.  Delaying the compilation like this...
seems totally unnecessary?  To make matters worse, this has cost
(we have to check if we hit the cache) and unsound (because the
compiled function may not be valid for other arguments.)

So instead, we ask user to provide arguments, and compile everything
immediately.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89669
Approved by: https://github.com/voznesenskym, https://github.com/Chillee
This commit is contained in:
Edward Z. Yang
2022-11-28 14:57:42 +00:00
committed by PyTorch MergeBot
parent b589e726d9
commit abf91562bd
4 changed files with 37 additions and 30 deletions

View File

@ -1854,6 +1854,7 @@ def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module:
def aot_module_simplified(
mod: nn.Module,
args,
fw_compiler: Callable,
bw_compiler: Optional[Callable] = None,
partition_fn: Callable = default_partition,
@ -1934,27 +1935,30 @@ def aot_module_simplified(
aot_id=next(AOT_COUNTER),
)
compiled_fn = None
full_args = []
full_args.extend(params_flat)
full_args.extend(args)
@wraps(functional_call)
def compiled_f(*args):
nonlocal compiled_fn
if compiled_fn is None:
compiled_fn = create_aot_dispatcher_function(
functional_call,
args,
aot_config,
)
return compiled_fn(args)
compiled_fn = create_aot_dispatcher_function(
functional_call,
full_args,
aot_config,
)
def forward(*args):
return compiled_f(
*params_flat,
*args,
)
# TODO: There is something deeply wrong here; compiled_fn running with
# the boxed calling convention, but aot_module_simplified somehow
# historically returned a function that was not the boxed calling
# convention. This should get fixed...
def forward(*runtime_args):
full_args = []
full_args.extend(params_flat)
full_args.extend(runtime_args)
return compiled_fn(full_args)
# Just for convenience
forward.zero_grad = mod.zero_grad
forward.named_parameters = mod.named_parameters
return forward

View File

@ -380,6 +380,7 @@ def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_
return aot_module_simplified(
gm,
example_inputs,
fw_compiler=graph_saver_forward,
bw_compiler=graph_saver_backward,
partition_fn=graph_saver_joint,
@ -387,6 +388,7 @@ def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_
)
# WARNING: This isn't tested anywhere!!
def graph_dumper_aot(current_name, folder_name, dump_example_input=False):
"""
Dump the forward, backward, and joint computation graph.

View File

@ -995,11 +995,11 @@ def forward(self, primals_1, primals_2):
x = torch.randn(3, 3, requires_grad=True)
y = torch.randn(3, 3, requires_grad=True)
fxy = aot_module_simplified(F(), nop)
fxy = aot_module_simplified(F(), (x, y), nop)
fxy(x, y)
fxy(x, x) # is ok!
fxx = aot_module_simplified(F(), nop)
fxx = aot_module_simplified(F(), (x, x), nop)
fxx(x, x)
self.assertExpectedRaisesInline(
AssertionError, lambda: fxx(x, y),
@ -1024,11 +1024,11 @@ def forward(self, primals_1, primals_2):
self.assertEqual(r1, r2)
self.assertEqual(g1, g2)
fxy = aot_module_simplified(F(), nop)
fxy = aot_module_simplified(F(), (x, y), nop)
compare(F(), fxy, (x, y))
compare(F(), fxy, (x, z))
fxz = aot_module_simplified(F(), nop)
fxz = aot_module_simplified(F(), (x, z), nop)
compare(F(), fxz, (x, z))
self.assertExpectedRaisesInline(
AssertionError, lambda: fxz(x, y),
@ -1537,9 +1537,9 @@ class TestAOTModuleSimplified(AOTTestCase):
ref = mod(*inputs)
ref[0].sum().backward()
aot_mod = aot_module_simplified(mod, nop)
aot_mod.zero_grad()
res = aot_mod(*cloned_inputs)
compiled_f = aot_module_simplified(mod, cloned_inputs, nop)
mod.zero_grad()
res = compiled_f(*cloned_inputs)
res[0].sum().backward()
assert torch.allclose(ref[0], res[0])
@ -1577,12 +1577,12 @@ class TestAOTModuleSimplified(AOTTestCase):
assert 'test_aotdispatch.py' in node.stack_trace
return gm.forward # return a python callable
aot_mod = aot_module_simplified(mod, fw_compiler=assert_compiler, bw_compiler=assert_compiler)
x = torch.randn(128, 20, requires_grad=True)
y = torch.randn(128, 30, requires_grad=True)
inputs = [x, y]
res = aot_mod(*inputs)
compiled_f = aot_module_simplified(mod, inputs, fw_compiler=assert_compiler, bw_compiler=assert_compiler)
res = compiled_f(*inputs)
res[0].sum().backward()
def test_aot_module_simplified_fake_tensor_gm_raises(self):
@ -1615,14 +1615,14 @@ class TestAOTModuleSimplified(AOTTestCase):
mod_fake = torch.fx.GraphModule(tracer.root, graph)
self.assertExpectedRaisesInline(
AssertionError, lambda: aot_module_simplified(mod_fake, nop),
AssertionError, lambda: aot_module_simplified(mod_fake, (real_x,), nop),
"""Unexpected fake buffer y"""
)
# Counterfactual to ensure that the raise is only due to real vs fake
# Run the same exact thing except with a real buffer.
graph = tracer.trace(MockModule(real_y))
mod_real = torch.fx.GraphModule(tracer.root, graph)
aot_module_simplified(MockModule(real_y), nop)
aot_module_simplified(MockModule(real_y), (real_x,), nop)
def test_aot_module_deepcopy_fake_tensor_gm_raises(self):
class MockModule(torch.nn.Module):
@ -1644,7 +1644,7 @@ class TestAOTModuleSimplified(AOTTestCase):
mod_fake = torch._dynamo.utils.deepcopy_to_fake_tensor(MockModule(real_y), fake_mode)
self.assertExpectedRaisesInline(
AssertionError, lambda: aot_module_simplified(mod_fake, nop),
AssertionError, lambda: aot_module_simplified(mod_fake, (real_x,), nop),
"""Unexpected fake param linear.weight"""
)

View File

@ -77,7 +77,8 @@ def aot_autograd(**kwargs):
kwargs["bw_compiler"] = _wrapped_bw_compiler
try:
cg = aot_module_simplified(gm, **kwargs)
# NB: NOT cloned!
cg = aot_module_simplified(gm, example_inputs, **kwargs)
counters["aot_autograd"]["ok"] += 1
return eval_frame.disable(cg)
except Exception: