make nanogpt work with both compiled autograd and _LazyGraphModule (#118981)

@xmfan and @fegin reported that _LazyGraphModule ( https://github.com/pytorch/pytorch/pull/117911 ) makes nanogpt training fail with compiled autograd.

We have a repro:  ``` python benchmarks/dynamo/torchbench.py --training --backend=inductor --disable-cudagraphs --accuracy --only nanogpt --repeat 1 --compiled-autograd ```
but it's still mysterious how to trigger the issue with a toy model.

The error message for the failure is https://gist.github.com/shunting314/6402a6388b3539956090b6bc098952fb . In compile_fx we will call `detect_fake_mode`. This function will look for an active FakeTensorMode from both TracingContext and example inputs. The error is triggered because we find different FakeTensorMode from these 2 sources.

Although I don't know what really causes the discrepancy of FakeTensorMode above, the fix here is to force _LazyGraphModule recompilation if we have compiled autograd enabled. This does not hurt compilation time most of the time because we anyway will call the graph module here in the backward pass when compiled autograd is enabled: 855d5f144e/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py (L705)

Let me know if we can have a better fix.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118981
Approved by: https://github.com/jansel
This commit is contained in:
Shunting Zhang
2024-02-02 16:07:56 -08:00
committed by PyTorch MergeBot
parent d670dfb7ae
commit a72190fd51
3 changed files with 55 additions and 3 deletions

View File

@ -42,7 +42,9 @@ def hook3(gI, gO):
class TestCompiledAutograd(TestCase):
def check_output_and_recompiles(self, fn, count=1, compiler_fn=compiler_fn):
def check_output_and_recompiles(
self, fn, count=1, compiler_fn=compiler_fn, compile_fn=False
):
with torch.autograd.set_multithreading_enabled(False):
torch._dynamo.reset()
counters["compiled_autograd"].clear()
@ -50,7 +52,8 @@ class TestCompiledAutograd(TestCase):
expected = list(fn())
torch.manual_seed(123)
with compiled_autograd.enable(compiler_fn):
actual = list(fn())
opt_fn = torch.compile(fn) if compile_fn else fn
actual = list(opt_fn())
self.assertEqual(expected, actual)
self.assertEqual(counters["compiled_autograd"]["captures"], count)
self.assertEqual(counters["compiled_autograd"]["compiles"], count)
@ -633,6 +636,25 @@ class TestCompiledAutograd(TestCase):
self.check_output_and_recompiles(fn, 3)
def test_mismatch_fake_tensor_mode(self):
"""
Repro the failure of training nanogpt with both compiled-autograd
and _LazyGraphModule. Check https://github.com/pytorch/pytorch/pull/118981
for more context.
"""
x = torch.rand(2, 16)
y = nn.Parameter(torch.rand(2, 16))
def f():
out = x + y
# make sure the backward call does not trigger any error when
# compiling the backward graph
out.sum().backward()
return out, y.grad
self.check_output_and_recompiles(f, compile_fn=True)
def load_test_module(name):
testdir = Path(__file__).absolute().parent.parent

View File

@ -222,18 +222,35 @@ class AutogradCompilerInstance:
compiled_autograd_enabled = False
# We may have code like:
# with enable(compiler_fn):
# ...
# with disable():
# ...
# ...
# The disable() call just want to disable compiled autograd temporarily.
# But overall the feature is enabled.
#
# The code covered by the disable context manager has no way to know if
# compiled autograd is overall eanbled. Use another variable
# compiled_autograd_enabled_count to indicate how many times compiled
# autograd has been enabled in the call stack for this purpose.
compiled_autograd_enabled_count = 0
@contextlib.contextmanager
def enable(compiler_fn):
prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(
functools.partial(AutogradCompilerInstance, compiler_fn)
)
global compiled_autograd_enabled
global compiled_autograd_enabled, compiled_autograd_enabled_count
compiled_autograd_enabled = True
compiled_autograd_enabled_count += 1
try:
with torch.autograd.set_multithreading_enabled(False):
yield
finally:
compiled_autograd_enabled_count -= 1
if not prior:
compiled_autograd_enabled = False
torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)

View File

@ -182,6 +182,19 @@ def aot_dispatch_autograd(
fw_module, bw_module = aot_config.partition_fn(
fx_g, joint_inputs, num_fwd_outputs=num_inner_fwd_outputs
)
# Compiled autograd will run the bw_module in the backward pass,
# so recompilation need happen anyway if the backward pass is ever
# called.
#
# The reason we do the GraphModule recompilation here is because
# the lazy recompilation will cause issue in the backward pass
# with compiled autograd.
if torch._dynamo.compiled_autograd.compiled_autograd_enabled_count:
from torch.fx._lazy_graph_module import _LazyGraphModule
_LazyGraphModule.force_recompile(bw_module)
fw_outs = next(n for n in fw_module.graph.nodes if n.op == "output").args[0]
# we only need to bookkeep the symints that are saved for bw, not any symints
# the user forward might have returned in its own output