[dynamo] Make OptimizedModule more robust in attribute reads and writes (#153637)

Fixes #138157.

Differential Revision: [D74834872](https://our.internmc.facebook.com/intern/diff/D74834872)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153637
Approved by: https://github.com/williamwen42
This commit is contained in:
Ryan Guo
2025-05-15 11:02:47 -07:00
committed by PyTorch MergeBot
parent 1748fa529a
commit e4a636df80
2 changed files with 50 additions and 6 deletions

View File

@ -5853,6 +5853,40 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
mod.eval()
self.assertFalse(opt_mod.training)
def test_optimized_module_patched_init(self):
# A regression test for #138157, and the pattern acame from deepspeed.
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.mul(5.0)
def patch_init(init):
@functools.wraps(init)
def wrapper(module, *args, **kwargs):
if not hasattr(module, "_ds_child_entered"):
# child's __init__ was called, since parents all see the same object they can now skip post_init
module._ds_child_entered = True
init(module, *args, **kwargs)
return wrapper
def patch_init_for_class(cls):
if "__init__" in cls.__dict__:
cls._old_init = cls.__init__
cls.__init__ = patch_init(cls.__init__)
patch_init_for_class(MyModule)
mod = MyModule()
opt_mod = torch.compile(mod)
x = torch.rand(10)
ref = mod(x)
res = opt_mod(x)
self.assertEqual(ref, res)
def test_os_fspath(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(x):

View File

@ -313,12 +313,23 @@ class OptimizedModule(torch.nn.Module):
"_forward",
"__dict__",
"named_children_walk",
"_super_module_initialized",
}
def __init__(self, mod: torch.nn.Module, dynamo_ctx) -> None:
# NOTE: this must go first, because attribute reads/writes of `self`
# uses `_orig_mod`, and sometimes users override `Module.__init__` to
# do attribute reads/writes on `self`.
#
# We also can't use regular setattr because `super().__setattr__` will
# complain for module value before `super().__init__()`
object.__setattr__(self, "_orig_mod", mod)
self._super_module_initialized = False
super().__init__()
self._super_module_initialized = True
# Installs the params/buffer
self._orig_mod = mod
self._orig_mod = mod # `super().__setattr__` will register this module
self.dynamo_ctx = dynamo_ctx
self._initialize()
self.training = self._orig_mod.training
@ -379,12 +390,11 @@ class OptimizedModule(torch.nn.Module):
@training.setter
def training(self, value):
try:
super().__getattr__("_orig_mod")
# Ignore the `training` mutation in `super().__init__()`, since that's
# setting the default on `nn.Module`, but we are mirroring the
# `training` attr in `self._orig_mod`.
if self._super_module_initialized:
self._orig_mod.training = value
except AttributeError:
# still initializing
pass
def __getattr__(self, name):
if name == "_orig_mod":