mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Change aot_module_simplified to take take arguments directly (#89669)
This is extracted from voz's #89392 Previously, the implementation did some half-assed caching where it returned a callable, that when invoked for the first time, actually performed the compilation. Delaying the compilation like this... seems totally unnecessary? To make matters worse, this has cost (we have to check if we hit the cache) and unsound (because the compiled function may not be valid for other arguments.) So instead, we ask user to provide arguments, and compile everything immediately. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/89669 Approved by: https://github.com/voznesenskym, https://github.com/Chillee
This commit is contained in:
committed by
PyTorch MergeBot
parent
b589e726d9
commit
abf91562bd
@ -1854,6 +1854,7 @@ def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module:
|
||||
|
||||
def aot_module_simplified(
|
||||
mod: nn.Module,
|
||||
args,
|
||||
fw_compiler: Callable,
|
||||
bw_compiler: Optional[Callable] = None,
|
||||
partition_fn: Callable = default_partition,
|
||||
@ -1934,27 +1935,30 @@ def aot_module_simplified(
|
||||
aot_id=next(AOT_COUNTER),
|
||||
)
|
||||
|
||||
compiled_fn = None
|
||||
full_args = []
|
||||
full_args.extend(params_flat)
|
||||
full_args.extend(args)
|
||||
|
||||
@wraps(functional_call)
|
||||
def compiled_f(*args):
|
||||
nonlocal compiled_fn
|
||||
if compiled_fn is None:
|
||||
compiled_fn = create_aot_dispatcher_function(
|
||||
functional_call,
|
||||
args,
|
||||
aot_config,
|
||||
)
|
||||
return compiled_fn(args)
|
||||
compiled_fn = create_aot_dispatcher_function(
|
||||
functional_call,
|
||||
full_args,
|
||||
aot_config,
|
||||
)
|
||||
|
||||
def forward(*args):
|
||||
return compiled_f(
|
||||
*params_flat,
|
||||
*args,
|
||||
)
|
||||
# TODO: There is something deeply wrong here; compiled_fn running with
|
||||
# the boxed calling convention, but aot_module_simplified somehow
|
||||
# historically returned a function that was not the boxed calling
|
||||
# convention. This should get fixed...
|
||||
def forward(*runtime_args):
|
||||
full_args = []
|
||||
full_args.extend(params_flat)
|
||||
full_args.extend(runtime_args)
|
||||
return compiled_fn(full_args)
|
||||
|
||||
# Just for convenience
|
||||
forward.zero_grad = mod.zero_grad
|
||||
forward.named_parameters = mod.named_parameters
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
|
@ -380,6 +380,7 @@ def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_
|
||||
|
||||
return aot_module_simplified(
|
||||
gm,
|
||||
example_inputs,
|
||||
fw_compiler=graph_saver_forward,
|
||||
bw_compiler=graph_saver_backward,
|
||||
partition_fn=graph_saver_joint,
|
||||
@ -387,6 +388,7 @@ def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_
|
||||
)
|
||||
|
||||
|
||||
# WARNING: This isn't tested anywhere!!
|
||||
def graph_dumper_aot(current_name, folder_name, dump_example_input=False):
|
||||
"""
|
||||
Dump the forward, backward, and joint computation graph.
|
||||
|
@ -995,11 +995,11 @@ def forward(self, primals_1, primals_2):
|
||||
x = torch.randn(3, 3, requires_grad=True)
|
||||
y = torch.randn(3, 3, requires_grad=True)
|
||||
|
||||
fxy = aot_module_simplified(F(), nop)
|
||||
fxy = aot_module_simplified(F(), (x, y), nop)
|
||||
fxy(x, y)
|
||||
fxy(x, x) # is ok!
|
||||
|
||||
fxx = aot_module_simplified(F(), nop)
|
||||
fxx = aot_module_simplified(F(), (x, x), nop)
|
||||
fxx(x, x)
|
||||
self.assertExpectedRaisesInline(
|
||||
AssertionError, lambda: fxx(x, y),
|
||||
@ -1024,11 +1024,11 @@ def forward(self, primals_1, primals_2):
|
||||
self.assertEqual(r1, r2)
|
||||
self.assertEqual(g1, g2)
|
||||
|
||||
fxy = aot_module_simplified(F(), nop)
|
||||
fxy = aot_module_simplified(F(), (x, y), nop)
|
||||
compare(F(), fxy, (x, y))
|
||||
compare(F(), fxy, (x, z))
|
||||
|
||||
fxz = aot_module_simplified(F(), nop)
|
||||
fxz = aot_module_simplified(F(), (x, z), nop)
|
||||
compare(F(), fxz, (x, z))
|
||||
self.assertExpectedRaisesInline(
|
||||
AssertionError, lambda: fxz(x, y),
|
||||
@ -1537,9 +1537,9 @@ class TestAOTModuleSimplified(AOTTestCase):
|
||||
ref = mod(*inputs)
|
||||
ref[0].sum().backward()
|
||||
|
||||
aot_mod = aot_module_simplified(mod, nop)
|
||||
aot_mod.zero_grad()
|
||||
res = aot_mod(*cloned_inputs)
|
||||
compiled_f = aot_module_simplified(mod, cloned_inputs, nop)
|
||||
mod.zero_grad()
|
||||
res = compiled_f(*cloned_inputs)
|
||||
res[0].sum().backward()
|
||||
|
||||
assert torch.allclose(ref[0], res[0])
|
||||
@ -1577,12 +1577,12 @@ class TestAOTModuleSimplified(AOTTestCase):
|
||||
assert 'test_aotdispatch.py' in node.stack_trace
|
||||
return gm.forward # return a python callable
|
||||
|
||||
aot_mod = aot_module_simplified(mod, fw_compiler=assert_compiler, bw_compiler=assert_compiler)
|
||||
|
||||
x = torch.randn(128, 20, requires_grad=True)
|
||||
y = torch.randn(128, 30, requires_grad=True)
|
||||
inputs = [x, y]
|
||||
res = aot_mod(*inputs)
|
||||
|
||||
compiled_f = aot_module_simplified(mod, inputs, fw_compiler=assert_compiler, bw_compiler=assert_compiler)
|
||||
res = compiled_f(*inputs)
|
||||
res[0].sum().backward()
|
||||
|
||||
def test_aot_module_simplified_fake_tensor_gm_raises(self):
|
||||
@ -1615,14 +1615,14 @@ class TestAOTModuleSimplified(AOTTestCase):
|
||||
mod_fake = torch.fx.GraphModule(tracer.root, graph)
|
||||
|
||||
self.assertExpectedRaisesInline(
|
||||
AssertionError, lambda: aot_module_simplified(mod_fake, nop),
|
||||
AssertionError, lambda: aot_module_simplified(mod_fake, (real_x,), nop),
|
||||
"""Unexpected fake buffer y"""
|
||||
)
|
||||
# Counterfactual to ensure that the raise is only due to real vs fake
|
||||
# Run the same exact thing except with a real buffer.
|
||||
graph = tracer.trace(MockModule(real_y))
|
||||
mod_real = torch.fx.GraphModule(tracer.root, graph)
|
||||
aot_module_simplified(MockModule(real_y), nop)
|
||||
aot_module_simplified(MockModule(real_y), (real_x,), nop)
|
||||
|
||||
def test_aot_module_deepcopy_fake_tensor_gm_raises(self):
|
||||
class MockModule(torch.nn.Module):
|
||||
@ -1644,7 +1644,7 @@ class TestAOTModuleSimplified(AOTTestCase):
|
||||
mod_fake = torch._dynamo.utils.deepcopy_to_fake_tensor(MockModule(real_y), fake_mode)
|
||||
|
||||
self.assertExpectedRaisesInline(
|
||||
AssertionError, lambda: aot_module_simplified(mod_fake, nop),
|
||||
AssertionError, lambda: aot_module_simplified(mod_fake, (real_x,), nop),
|
||||
"""Unexpected fake param linear.weight"""
|
||||
)
|
||||
|
||||
|
@ -77,7 +77,8 @@ def aot_autograd(**kwargs):
|
||||
kwargs["bw_compiler"] = _wrapped_bw_compiler
|
||||
|
||||
try:
|
||||
cg = aot_module_simplified(gm, **kwargs)
|
||||
# NB: NOT cloned!
|
||||
cg = aot_module_simplified(gm, example_inputs, **kwargs)
|
||||
counters["aot_autograd"]["ok"] += 1
|
||||
return eval_frame.disable(cg)
|
||||
except Exception:
|
||||
|
Reference in New Issue
Block a user