mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Make OptimizedModule
more robust in attribute reads and writes (#153637)
Fixes #138157. Differential Revision: [D74834872](https://our.internmc.facebook.com/intern/diff/D74834872) Pull Request resolved: https://github.com/pytorch/pytorch/pull/153637 Approved by: https://github.com/williamwen42
This commit is contained in:
committed by
PyTorch MergeBot
parent
1748fa529a
commit
e4a636df80
@ -5853,6 +5853,40 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
|
||||
mod.eval()
|
||||
self.assertFalse(opt_mod.training)
|
||||
|
||||
def test_optimized_module_patched_init(self):
|
||||
# A regression test for #138157, and the pattern acame from deepspeed.
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x.mul(5.0)
|
||||
|
||||
def patch_init(init):
|
||||
@functools.wraps(init)
|
||||
def wrapper(module, *args, **kwargs):
|
||||
if not hasattr(module, "_ds_child_entered"):
|
||||
# child's __init__ was called, since parents all see the same object they can now skip post_init
|
||||
module._ds_child_entered = True
|
||||
init(module, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
def patch_init_for_class(cls):
|
||||
if "__init__" in cls.__dict__:
|
||||
cls._old_init = cls.__init__
|
||||
cls.__init__ = patch_init(cls.__init__)
|
||||
|
||||
patch_init_for_class(MyModule)
|
||||
mod = MyModule()
|
||||
opt_mod = torch.compile(mod)
|
||||
|
||||
x = torch.rand(10)
|
||||
ref = mod(x)
|
||||
res = opt_mod(x)
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_os_fspath(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(x):
|
||||
|
@ -313,12 +313,23 @@ class OptimizedModule(torch.nn.Module):
|
||||
"_forward",
|
||||
"__dict__",
|
||||
"named_children_walk",
|
||||
"_super_module_initialized",
|
||||
}
|
||||
|
||||
def __init__(self, mod: torch.nn.Module, dynamo_ctx) -> None:
|
||||
# NOTE: this must go first, because attribute reads/writes of `self`
|
||||
# uses `_orig_mod`, and sometimes users override `Module.__init__` to
|
||||
# do attribute reads/writes on `self`.
|
||||
#
|
||||
# We also can't use regular setattr because `super().__setattr__` will
|
||||
# complain for module value before `super().__init__()`
|
||||
object.__setattr__(self, "_orig_mod", mod)
|
||||
self._super_module_initialized = False
|
||||
super().__init__()
|
||||
self._super_module_initialized = True
|
||||
|
||||
# Installs the params/buffer
|
||||
self._orig_mod = mod
|
||||
self._orig_mod = mod # `super().__setattr__` will register this module
|
||||
self.dynamo_ctx = dynamo_ctx
|
||||
self._initialize()
|
||||
self.training = self._orig_mod.training
|
||||
@ -379,12 +390,11 @@ class OptimizedModule(torch.nn.Module):
|
||||
|
||||
@training.setter
|
||||
def training(self, value):
|
||||
try:
|
||||
super().__getattr__("_orig_mod")
|
||||
# Ignore the `training` mutation in `super().__init__()`, since that's
|
||||
# setting the default on `nn.Module`, but we are mirroring the
|
||||
# `training` attr in `self._orig_mod`.
|
||||
if self._super_module_initialized:
|
||||
self._orig_mod.training = value
|
||||
except AttributeError:
|
||||
# still initializing
|
||||
pass
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name == "_orig_mod":
|
||||
|
Reference in New Issue
Block a user