mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Reland] Fix inlining module-scoped store global (#132439)
Reland https://github.com/pytorch/pytorch/pull/132224 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132439 Approved by: https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
a4ea776881
commit
d2e9a8bf6d
11
test/dynamo/mock_store_global_crossfile_inline.py
Normal file
11
test/dynamo/mock_store_global_crossfile_inline.py
Normal file
@ -0,0 +1,11 @@
|
||||
global_flag = False
|
||||
|
||||
|
||||
def set_flag_true():
|
||||
global global_flag
|
||||
global_flag = True
|
||||
|
||||
|
||||
def set_flag_false():
|
||||
global global_flag
|
||||
global_flag = False
|
@ -224,6 +224,28 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(s0, "v0v1")
|
||||
reset_name()
|
||||
|
||||
def test_store_global_crossfile_inline(self):
|
||||
try:
|
||||
from . import mock_store_global_crossfile_inline
|
||||
except ImportError:
|
||||
import mock_store_global_crossfile_inline
|
||||
|
||||
@torch.compile()
|
||||
def fn(x):
|
||||
mock_store_global_crossfile_inline.set_flag_true()
|
||||
mock_store_global_crossfile_inline.set_flag_false()
|
||||
return x + 1
|
||||
|
||||
@torch.compile()
|
||||
def fn_set_true(x):
|
||||
mock_store_global_crossfile_inline.set_flag_true()
|
||||
return x + 1
|
||||
|
||||
fn_set_true(torch.ones(2, 2))
|
||||
self.assertTrue(mock_store_global_crossfile_inline.global_flag)
|
||||
fn(torch.ones(2, 2))
|
||||
self.assertFalse(mock_store_global_crossfile_inline.global_flag)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -3258,9 +3258,6 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
unimplemented("Storing handles in globals - NYI")
|
||||
name = inst.argval
|
||||
fglobals_value, fglobals_vt, _ = self.get_globals_source_and_value(name)
|
||||
fglobals_vt = self.output.side_effects.track_object_existing(
|
||||
fglobals_value, fglobals_vt
|
||||
)
|
||||
self.output.side_effects.store_attr(fglobals_vt, name, value)
|
||||
|
||||
|
||||
|
@ -932,10 +932,12 @@ class VariableBuilder:
|
||||
# type(torch.backends.cudnn) -> <class 'torch.backends.cudnn.CudnnModule'>
|
||||
elif isinstance(value, (types.ModuleType, replay_record.DummyModule)):
|
||||
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||
return PythonModuleVariable(
|
||||
result = PythonModuleVariable(
|
||||
value,
|
||||
source=self.source,
|
||||
)
|
||||
self.tx.output.side_effects.track_object_existing(value, result)
|
||||
return result
|
||||
elif isinstance(value, types.MethodType) and isinstance(
|
||||
value.__self__, (torch.nn.Module, torch.utils._pytree.TreeSpec)
|
||||
):
|
||||
|
Reference in New Issue
Block a user