[Reland] Fix inlining module-scoped store global (#132439)

Reland https://github.com/pytorch/pytorch/pull/132224

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132439
Approved by: https://github.com/anijain2305
This commit is contained in:
Michael Lazos
2024-08-02 09:13:52 +00:00
committed by PyTorch MergeBot
parent a4ea776881
commit d2e9a8bf6d
4 changed files with 36 additions and 4 deletions

View File

@ -0,0 +1,11 @@
global_flag = False
def set_flag_true():
global global_flag
global_flag = True
def set_flag_false():
global global_flag
global_flag = False

View File

@ -224,6 +224,28 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
self.assertEqual(s0, "v0v1")
reset_name()
def test_store_global_crossfile_inline(self):
try:
from . import mock_store_global_crossfile_inline
except ImportError:
import mock_store_global_crossfile_inline
@torch.compile()
def fn(x):
mock_store_global_crossfile_inline.set_flag_true()
mock_store_global_crossfile_inline.set_flag_false()
return x + 1
@torch.compile()
def fn_set_true(x):
mock_store_global_crossfile_inline.set_flag_true()
return x + 1
fn_set_true(torch.ones(2, 2))
self.assertTrue(mock_store_global_crossfile_inline.global_flag)
fn(torch.ones(2, 2))
self.assertFalse(mock_store_global_crossfile_inline.global_flag)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -3258,9 +3258,6 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
unimplemented("Storing handles in globals - NYI")
name = inst.argval
fglobals_value, fglobals_vt, _ = self.get_globals_source_and_value(name)
fglobals_vt = self.output.side_effects.track_object_existing(
fglobals_value, fglobals_vt
)
self.output.side_effects.store_attr(fglobals_vt, name, value)

View File

@ -932,10 +932,12 @@ class VariableBuilder:
# type(torch.backends.cudnn) -> <class 'torch.backends.cudnn.CudnnModule'>
elif isinstance(value, (types.ModuleType, replay_record.DummyModule)):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return PythonModuleVariable(
result = PythonModuleVariable(
value,
source=self.source,
)
self.tx.output.side_effects.track_object_existing(value, result)
return result
elif isinstance(value, types.MethodType) and isinstance(
value.__self__, (torch.nn.Module, torch.utils._pytree.TreeSpec)
):