mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
make nanogpt work with both compiled autograd and _LazyGraphModule (#118981)
@xmfan and @fegin reported that _LazyGraphModule ( https://github.com/pytorch/pytorch/pull/117911 ) makes nanogpt training fail with compiled autograd.
We have a repro: ``` python benchmarks/dynamo/torchbench.py --training --backend=inductor --disable-cudagraphs --accuracy --only nanogpt --repeat 1 --compiled-autograd ```
but it's still mysterious how to trigger the issue with a toy model.
The error message for the failure is https://gist.github.com/shunting314/6402a6388b3539956090b6bc098952fb . In compile_fx we will call `detect_fake_mode`. This function will look for an active FakeTensorMode from both TracingContext and example inputs. The error is triggered because we find different FakeTensorMode from these 2 sources.
Although I don't know what really causes the discrepancy of FakeTensorMode above, the fix here is to force _LazyGraphModule recompilation if we have compiled autograd enabled. This does not hurt compilation time most of the time because we anyway will call the graph module here in the backward pass when compiled autograd is enabled: 855d5f144e/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py (L705)
Let me know if we can have a better fix.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118981
Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
d670dfb7ae
commit
a72190fd51
@ -42,7 +42,9 @@ def hook3(gI, gO):
|
||||
|
||||
|
||||
class TestCompiledAutograd(TestCase):
|
||||
def check_output_and_recompiles(self, fn, count=1, compiler_fn=compiler_fn):
|
||||
def check_output_and_recompiles(
|
||||
self, fn, count=1, compiler_fn=compiler_fn, compile_fn=False
|
||||
):
|
||||
with torch.autograd.set_multithreading_enabled(False):
|
||||
torch._dynamo.reset()
|
||||
counters["compiled_autograd"].clear()
|
||||
@ -50,7 +52,8 @@ class TestCompiledAutograd(TestCase):
|
||||
expected = list(fn())
|
||||
torch.manual_seed(123)
|
||||
with compiled_autograd.enable(compiler_fn):
|
||||
actual = list(fn())
|
||||
opt_fn = torch.compile(fn) if compile_fn else fn
|
||||
actual = list(opt_fn())
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertEqual(counters["compiled_autograd"]["captures"], count)
|
||||
self.assertEqual(counters["compiled_autograd"]["compiles"], count)
|
||||
@ -633,6 +636,25 @@ class TestCompiledAutograd(TestCase):
|
||||
|
||||
self.check_output_and_recompiles(fn, 3)
|
||||
|
||||
def test_mismatch_fake_tensor_mode(self):
|
||||
"""
|
||||
Repro the failure of training nanogpt with both compiled-autograd
|
||||
and _LazyGraphModule. Check https://github.com/pytorch/pytorch/pull/118981
|
||||
for more context.
|
||||
"""
|
||||
x = torch.rand(2, 16)
|
||||
y = nn.Parameter(torch.rand(2, 16))
|
||||
|
||||
def f():
|
||||
out = x + y
|
||||
|
||||
# make sure the backward call does not trigger any error when
|
||||
# compiling the backward graph
|
||||
out.sum().backward()
|
||||
return out, y.grad
|
||||
|
||||
self.check_output_and_recompiles(f, compile_fn=True)
|
||||
|
||||
|
||||
def load_test_module(name):
|
||||
testdir = Path(__file__).absolute().parent.parent
|
||||
|
@ -222,18 +222,35 @@ class AutogradCompilerInstance:
|
||||
|
||||
compiled_autograd_enabled = False
|
||||
|
||||
# We may have code like:
|
||||
# with enable(compiler_fn):
|
||||
# ...
|
||||
# with disable():
|
||||
# ...
|
||||
# ...
|
||||
# The disable() call just want to disable compiled autograd temporarily.
|
||||
# But overall the feature is enabled.
|
||||
#
|
||||
# The code covered by the disable context manager has no way to know if
|
||||
# compiled autograd is overall eanbled. Use another variable
|
||||
# compiled_autograd_enabled_count to indicate how many times compiled
|
||||
# autograd has been enabled in the call stack for this purpose.
|
||||
compiled_autograd_enabled_count = 0
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def enable(compiler_fn):
|
||||
prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(
|
||||
functools.partial(AutogradCompilerInstance, compiler_fn)
|
||||
)
|
||||
global compiled_autograd_enabled
|
||||
global compiled_autograd_enabled, compiled_autograd_enabled_count
|
||||
compiled_autograd_enabled = True
|
||||
compiled_autograd_enabled_count += 1
|
||||
try:
|
||||
with torch.autograd.set_multithreading_enabled(False):
|
||||
yield
|
||||
finally:
|
||||
compiled_autograd_enabled_count -= 1
|
||||
if not prior:
|
||||
compiled_autograd_enabled = False
|
||||
torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)
|
||||
|
@ -182,6 +182,19 @@ def aot_dispatch_autograd(
|
||||
fw_module, bw_module = aot_config.partition_fn(
|
||||
fx_g, joint_inputs, num_fwd_outputs=num_inner_fwd_outputs
|
||||
)
|
||||
|
||||
# Compiled autograd will run the bw_module in the backward pass,
|
||||
# so recompilation need happen anyway if the backward pass is ever
|
||||
# called.
|
||||
#
|
||||
# The reason we do the GraphModule recompilation here is because
|
||||
# the lazy recompilation will cause issue in the backward pass
|
||||
# with compiled autograd.
|
||||
if torch._dynamo.compiled_autograd.compiled_autograd_enabled_count:
|
||||
from torch.fx._lazy_graph_module import _LazyGraphModule
|
||||
|
||||
_LazyGraphModule.force_recompile(bw_module)
|
||||
|
||||
fw_outs = next(n for n in fw_module.graph.nodes if n.op == "output").args[0]
|
||||
# we only need to bookkeep the symints that are saved for bw, not any symints
|
||||
# the user forward might have returned in its own output
|
||||
|
Reference in New Issue
Block a user