mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Currently, every time we construct a GLOBAL_STATE guard, we always create a fresh guard based on the current global state. For precompile, we want to create a GLOBAL_STATE guard always based on some external sources, e.g. serialized global states. This can also be applied with the normal case where we just pass in the global state guard from Python. Differential Revision: [D77400988](https://our.internmc.facebook.com/intern/diff/D77400988/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/157285 Approved by: https://github.com/jansel
1286 lines
46 KiB
Python
1286 lines
46 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import dataclasses
|
|
import importlib
|
|
import pickle
|
|
import sys
|
|
import types
|
|
import unittest
|
|
from collections.abc import Iterator
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
import torch._dynamo.testing
|
|
import torch._inductor.config
|
|
import torch._inductor.test_case
|
|
import torch.onnx.operators
|
|
import torch.utils.cpp_extension
|
|
from torch._dynamo.bytecode_transformation import transform_code_object
|
|
from torch._dynamo.exc import PackageError
|
|
from torch._dynamo.guards import CheckFunctionManager, CompileId
|
|
from torch._dynamo.symbolic_convert import (
|
|
ExceptionStack,
|
|
InstructionTranslator,
|
|
SpeculationLog,
|
|
)
|
|
from torch._dynamo.utils import dynamo_timed, get_metrics_context
|
|
from torch._guards import compile_context, CompileContext, tracing
|
|
from torch.overrides import TorchFunctionMode
|
|
from torch.testing._internal.inductor_utils import HAS_GPU
|
|
from torch.utils import _pytree as pytree
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class _FrameState:
|
|
f_locals: dict
|
|
f_globals: dict
|
|
f_code: types.CodeType
|
|
f_builtins: dict
|
|
|
|
|
|
class GlobalModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + 1
|
|
|
|
|
|
def global_func(x):
|
|
return x + 1
|
|
|
|
|
|
class GlobalTorchFunctionMode(TorchFunctionMode):
|
|
def __torch_function__(self, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
class SubclassWithMeta(torch.Tensor):
|
|
@staticmethod
|
|
def __new__(cls, a, extra, outer_size=None, outer_stride=None):
|
|
if outer_size is None:
|
|
outer_size = a.size()
|
|
if outer_stride is None:
|
|
outer_stride = a.stride()
|
|
|
|
shape = outer_size
|
|
kwargs = {}
|
|
kwargs["strides"] = outer_stride
|
|
kwargs["storage_offset"] = a.storage_offset()
|
|
kwargs["device"] = a.device
|
|
kwargs["layout"] = a.layout
|
|
kwargs["requires_grad"] = a.requires_grad
|
|
kwargs["dtype"] = a.dtype
|
|
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
|
|
|
|
def __init__(self, a, extra, outer_size=None, outer_stride=None):
|
|
self.a = a
|
|
self.extra = extra
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
args_a = pytree.tree_map_only(SubclassWithMeta, lambda x: x.a, args)
|
|
kwargs_a = pytree.tree_map_only(SubclassWithMeta, lambda x: x.a, kwargs)
|
|
out_a = func(*args_a, **kwargs_a)
|
|
if isinstance(out_a, torch.Tensor):
|
|
assert isinstance(args[0], SubclassWithMeta)
|
|
return SubclassWithMeta(out_a, extra=args[0].extra)
|
|
return out_a
|
|
|
|
def __tensor_flatten__(self):
|
|
# store extra in meta
|
|
return ["a"], {"extra": self.extra}
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
|
|
assert isinstance(meta, dict)
|
|
a = inner_tensors["a"]
|
|
# pull out extra from meta
|
|
extra = meta["extra"]
|
|
if type(a) is torch.Tensor:
|
|
assert outer_size is not None
|
|
assert outer_stride is not None
|
|
return SubclassWithMeta(a, extra, outer_size, outer_stride)
|
|
|
|
|
|
class SubclassWithCustomMetadataGuard(torch.Tensor):
|
|
@staticmethod
|
|
def __new__(cls, a, extra, outer_size=None, outer_stride=None):
|
|
if outer_size is None:
|
|
outer_size = a.size()
|
|
if outer_stride is None:
|
|
outer_stride = a.stride()
|
|
|
|
shape = outer_size
|
|
kwargs = {}
|
|
kwargs["strides"] = outer_stride
|
|
kwargs["storage_offset"] = a.storage_offset()
|
|
kwargs["device"] = a.device
|
|
kwargs["layout"] = a.layout
|
|
kwargs["requires_grad"] = a.requires_grad
|
|
kwargs["dtype"] = a.dtype
|
|
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
|
|
|
|
def __init__(self, a, extra, outer_size=None, outer_stride=None):
|
|
self.a = a
|
|
self.extra = extra
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
args_a = pytree.tree_map_only(
|
|
SubclassWithCustomMetadataGuard, lambda x: x.a, args
|
|
)
|
|
kwargs_a = pytree.tree_map_only(
|
|
SubclassWithCustomMetadataGuard, lambda x: x.a, kwargs
|
|
)
|
|
out_a = func(*args_a, **kwargs_a)
|
|
if isinstance(out_a, torch.Tensor):
|
|
assert isinstance(args[0], SubclassWithCustomMetadataGuard)
|
|
return SubclassWithCustomMetadataGuard(out_a, extra=args[0].extra)
|
|
return out_a
|
|
|
|
@classmethod
|
|
def __metadata_guard__(cls, meta1, meta2):
|
|
# Define custom metadata guard logic that only looks at "bar" to determine
|
|
# metadata equivalence. This is more purposefully more lax than the default
|
|
# guard behavior.
|
|
return meta1["extra"]["bar"] == meta2["extra"]["bar"]
|
|
|
|
def __tensor_flatten__(self):
|
|
# store extra in meta
|
|
return ["a"], {"extra": self.extra}
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
|
|
assert isinstance(meta, dict)
|
|
a = inner_tensors["a"]
|
|
# pull out extra from meta
|
|
extra = meta["extra"]
|
|
if type(a) is torch.Tensor:
|
|
assert outer_size is not None
|
|
assert outer_stride is not None
|
|
return SubclassWithCustomMetadataGuard(a, extra, outer_size, outer_stride)
|
|
|
|
|
|
class SubclassWithSubclassInnerTensor(torch.Tensor):
|
|
@staticmethod
|
|
def __new__(cls, a, extra, outer_size=None, outer_stride=None):
|
|
if outer_size is None:
|
|
outer_size = a.size()
|
|
if outer_stride is None:
|
|
outer_stride = a.stride()
|
|
|
|
shape = outer_size
|
|
kwargs = {}
|
|
kwargs["strides"] = outer_stride
|
|
kwargs["storage_offset"] = a.storage_offset()
|
|
kwargs["device"] = a.device
|
|
kwargs["layout"] = a.layout
|
|
kwargs["requires_grad"] = a.requires_grad
|
|
kwargs["dtype"] = a.dtype
|
|
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
|
|
|
|
def __init__(self, a, extra, outer_size=None, outer_stride=None):
|
|
self.a = a
|
|
self.inner_sub = SubclassWithMeta(a + 1, extra=extra)
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
args_a = pytree.tree_map_only(
|
|
SubclassWithSubclassInnerTensor, lambda x: x.a, args
|
|
)
|
|
kwargs_a = pytree.tree_map_only(
|
|
SubclassWithSubclassInnerTensor, lambda x: x.a, kwargs
|
|
)
|
|
out_a = func(*args_a, **kwargs_a)
|
|
if isinstance(out_a, torch.Tensor):
|
|
assert isinstance(args[0], SubclassWithSubclassInnerTensor)
|
|
return SubclassWithSubclassInnerTensor(out_a, extra=args[0].inner_sub.extra)
|
|
return out_a
|
|
|
|
def __tensor_flatten__(self):
|
|
return ["a", "inner_sub"], None
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
|
|
assert meta is None
|
|
a = inner_tensors["a"]
|
|
extra = inner_tensors["inner_sub"].extra
|
|
if type(a) is torch.Tensor:
|
|
assert outer_size is not None
|
|
assert outer_stride is not None
|
|
return SubclassWithSubclassInnerTensor(a, extra, outer_size, outer_stride)
|
|
|
|
|
|
# defines a custom __eq__() / __hash__() to be registered as a pytree constant type
|
|
class CustomConstantType:
|
|
def __init__(self, a, b):
|
|
self.a = a
|
|
self.b = b
|
|
|
|
def __eq__(self, other):
|
|
# custom eq ignores b
|
|
return self.a == other.a
|
|
|
|
def __hash__(self):
|
|
# custom hash ignores b
|
|
return hash(self.a)
|
|
|
|
|
|
pytree.register_constant(CustomConstantType)
|
|
|
|
|
|
class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
|
def test_function_locals(self):
|
|
def foo(x):
|
|
return x + 1
|
|
|
|
def fn(x, g):
|
|
return g(x) + 1
|
|
|
|
self._test_serialization("TENSOR_MATCH", fn, torch.randn(3), foo)
|
|
|
|
def _tracefunc(self, frame, event, arg):
|
|
if event != "call":
|
|
return
|
|
|
|
if self._frame_state is not None:
|
|
return
|
|
|
|
self._frame_state = _FrameState(
|
|
f_locals=dict(frame.f_locals),
|
|
f_globals=dict(frame.f_globals),
|
|
f_code=frame.f_code,
|
|
f_builtins=frame.f_builtins,
|
|
)
|
|
|
|
def _test_serialization(self, guard_type, fn, *args, **kwargs):
|
|
# kwargs might contain a callable that generates kwargs
|
|
kwarg_gen_fn = kwargs.get("_gen_fn", None)
|
|
if kwarg_gen_fn is not None:
|
|
kwargs = kwarg_gen_fn()
|
|
|
|
self._frame_state = None
|
|
sys.settrace(self._tracefunc)
|
|
if isinstance(fn, torch.nn.Module):
|
|
fn = fn.forward
|
|
try:
|
|
fn(*args, **kwargs)
|
|
finally:
|
|
sys.settrace(None)
|
|
|
|
assert self._frame_state is not None
|
|
|
|
# Set f_locals from regenerated kwargs to handle exhausted input iterators
|
|
# NB: This is super janky and might cause unforeseen problems
|
|
if kwarg_gen_fn is not None:
|
|
kwargs = kwarg_gen_fn()
|
|
for key in self._frame_state.f_locals.keys():
|
|
if key in kwargs and isinstance(kwargs[key], Iterator):
|
|
self._frame_state.f_locals[key] = kwargs[key]
|
|
|
|
def guard_filter_fn(guards):
|
|
ret = [
|
|
g.guard_type == guard_type or guard_type in g.derived_guard_types
|
|
for g in guards
|
|
]
|
|
self.assertTrue(any(ret))
|
|
return ret
|
|
|
|
ref_gm = None
|
|
loaded_gm = None
|
|
|
|
def transform(instructions: list, code_options: dict[str, object]):
|
|
"""
|
|
The goal is here is not to reimplement dynamo, but just to have a
|
|
simplified version to extract the state from symbolic convert.
|
|
Should not work on all cases, but should work on simple functions
|
|
in this test file.
|
|
"""
|
|
nonlocal ref_gm
|
|
nonlocal loaded_gm
|
|
|
|
torch._dynamo.convert_frame.initial_global_state = (
|
|
torch._C._dynamo.guards.GlobalStateGuard()
|
|
)
|
|
tracer = InstructionTranslator(
|
|
instructions,
|
|
self._frame_state.f_code,
|
|
self._frame_state.f_locals,
|
|
self._frame_state.f_globals,
|
|
self._frame_state.f_builtins,
|
|
fn.__closure__ or (),
|
|
torch.overrides._get_current_function_mode_stack(),
|
|
code_options,
|
|
torch._dynamo.lookup_backend("eager"),
|
|
one_graph=False,
|
|
export=False,
|
|
export_constraints=None,
|
|
frame_state=None,
|
|
speculation_log=SpeculationLog(),
|
|
exn_vt_stack=ExceptionStack(),
|
|
distributed_state=None,
|
|
package=None,
|
|
)
|
|
with (
|
|
compile_context(CompileContext(CompileId(0, 0))),
|
|
tracing(tracer.output.tracing_context),
|
|
tracer.set_current_tx(),
|
|
get_metrics_context(),
|
|
dynamo_timed(""),
|
|
):
|
|
tracer.run()
|
|
|
|
check_fn_manager = CheckFunctionManager(
|
|
self._frame_state.f_code,
|
|
tracer.output,
|
|
guard_filter_fn=guard_filter_fn,
|
|
guards_serialization_mode="save",
|
|
)
|
|
ref_gm = check_fn_manager.guard_manager
|
|
guards_state = check_fn_manager.guards_state
|
|
self._cached_guards_state = guards_state
|
|
self._cached_f_code = self._frame_state.f_code
|
|
self.assertIsNotNone(guards_state)
|
|
guards_state = pickle.loads(guards_state)
|
|
|
|
check_fn_manager = CheckFunctionManager(
|
|
self._frame_state.f_code,
|
|
guards_state.output_graph,
|
|
guards_serialization_mode="load",
|
|
shape_code_parts=guards_state.shape_code_parts,
|
|
)
|
|
loaded_gm = check_fn_manager.guard_manager
|
|
|
|
try:
|
|
transform_code_object(self._frame_state.f_code, transform)
|
|
finally:
|
|
torch._dynamo.convert_frame.initial_global_state = None
|
|
self._frame_state = None
|
|
|
|
self.assertIsNotNone(ref_gm)
|
|
self.assertIsNotNone(loaded_gm)
|
|
return ref_gm, loaded_gm
|
|
|
|
def _test_check_fn(self, ref, loaded, inputs, expected):
|
|
self.assertIsInstance(inputs, dict)
|
|
self.assertEqual(ref.check(inputs), expected)
|
|
self.assertEqual(ref.check(inputs), loaded.check(inputs))
|
|
|
|
def test_tensor_match(self):
|
|
def f(x: torch.Tensor):
|
|
return x + 1
|
|
|
|
ref, loaded = self._test_serialization(
|
|
"TENSOR_MATCH", f, torch.ones(2, dtype=torch.float32)
|
|
)
|
|
self._test_check_fn(
|
|
ref, loaded, {"x": torch.randn(2, dtype=torch.float32)}, True
|
|
)
|
|
self._test_check_fn(
|
|
ref, loaded, {"x": torch.randn(3, dtype=torch.float32)}, False
|
|
)
|
|
self._test_check_fn(
|
|
ref, loaded, {"x": torch.randn(2, dtype=torch.float64)}, False
|
|
)
|
|
self._test_check_fn(ref, loaded, {"x": None}, False)
|
|
|
|
def test_not_present_in_generic_dict(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor):
|
|
return x + 1
|
|
|
|
m = Module()
|
|
|
|
def fn(x):
|
|
return m(x)
|
|
|
|
ref, loaded = self._test_serialization(
|
|
"NOT_PRESENT_IN_GENERIC_DICT", fn, torch.ones(2, dtype=torch.float32)
|
|
)
|
|
self._test_check_fn(ref, loaded, {"m": m}, True)
|
|
|
|
m.forward = types.MethodType(lambda x: x + 2, m)
|
|
self._test_check_fn(ref, loaded, {"m": m}, False)
|
|
|
|
def test_hasattr_serialization(self):
|
|
class Module(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.a = 1
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
if hasattr(self, "a"):
|
|
return x + self.a
|
|
else:
|
|
return x + 2
|
|
|
|
m = Module()
|
|
|
|
def fn(x):
|
|
return m(x)
|
|
|
|
ref, loaded = self._test_serialization("HASATTR", fn, torch.randn(3))
|
|
self._test_check_fn(ref, loaded, {"m": m}, True)
|
|
delattr(m, "a")
|
|
self._test_check_fn(ref, loaded, {"m": m}, False)
|
|
|
|
def test_type_match(self):
|
|
class LocalModule(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor):
|
|
return x + 1
|
|
|
|
m = LocalModule()
|
|
|
|
def fn(m, x):
|
|
return m(x)
|
|
|
|
with self.assertRaisesRegex(
|
|
TypeError, "Please define the class at global scope"
|
|
):
|
|
self._test_serialization("TYPE_MATCH", fn, m, torch.randn(3))
|
|
|
|
m = GlobalModule()
|
|
ref, loaded = self._test_serialization("TYPE_MATCH", fn, m, torch.randn(3))
|
|
self._test_check_fn(ref, loaded, {"m": m}, True)
|
|
self._test_check_fn(ref, loaded, {"m": GlobalModule()}, True)
|
|
self._test_check_fn(ref, loaded, {"m": torch.nn.Module()}, False)
|
|
|
|
def test_tensor_subclass_metadata_match(self):
|
|
class LocalSubclass(torch.Tensor):
|
|
@staticmethod
|
|
def __new__(cls, a, outer_size=None, outer_stride=None):
|
|
if outer_size is None:
|
|
outer_size = a.size()
|
|
if outer_stride is None:
|
|
outer_stride = a.stride()
|
|
|
|
shape = outer_size
|
|
kwargs = {}
|
|
kwargs["strides"] = outer_stride
|
|
kwargs["storage_offset"] = a.storage_offset()
|
|
kwargs["device"] = a.device
|
|
kwargs["layout"] = a.layout
|
|
kwargs["requires_grad"] = a.requires_grad
|
|
kwargs["dtype"] = a.dtype
|
|
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
|
|
|
|
def __init__(self, a, outer_size=None, outer_stride=None):
|
|
self.a = a
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
args_a = pytree.tree_map_only(LocalSubclass, lambda x: x.a, args)
|
|
kwargs_a = pytree.tree_map_only(LocalSubclass, lambda x: x.a, kwargs)
|
|
out_a = func(*args_a, **kwargs_a)
|
|
if isinstance(out_a, torch.Tensor):
|
|
return LocalSubclass(out_a)
|
|
return out_a
|
|
|
|
def __tensor_flatten__(self):
|
|
return ["a"], None
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
|
|
assert meta is None
|
|
a = inner_tensors["a"]
|
|
if type(a) is torch.Tensor:
|
|
assert outer_size is not None
|
|
assert outer_stride is not None
|
|
return LocalSubclass(a, outer_size, outer_stride)
|
|
|
|
def fn(x):
|
|
return x * 2
|
|
|
|
# === example subclass defined locally (error) ===
|
|
local_sub = LocalSubclass(torch.randn(3))
|
|
with self.assertRaisesRegex(
|
|
PackageError, "Please define the class at global scope"
|
|
):
|
|
self._test_serialization("TENSOR_SUBCLASS_METADATA_MATCH", fn, local_sub)
|
|
|
|
# === example subclass with None extra metadata ===
|
|
from torch.testing._internal.two_tensor import TwoTensor
|
|
|
|
tt = TwoTensor(torch.randn(3), torch.randn(3))
|
|
ref, loaded = self._test_serialization("TENSOR_SUBCLASS_METADATA_MATCH", fn, tt)
|
|
self._test_check_fn(ref, loaded, {"x": tt}, True)
|
|
self._test_check_fn(ref, loaded, {"x": torch.ones_like(tt)}, True)
|
|
|
|
# used below for convenience; returned func accepts some metadata and whether the
|
|
# guard is expected to pass for the given subclass type
|
|
def _get_meta_test_check_fn(ref, loaded, subclass_type):
|
|
def _f(meta, expected, ref=ref, loaded=loaded, subclass_type=subclass_type):
|
|
self._test_check_fn(
|
|
ref,
|
|
loaded,
|
|
{"x": subclass_type(torch.randn(3), extra=meta)},
|
|
expected,
|
|
)
|
|
|
|
return _f
|
|
|
|
# === example subclass with extra metadata ===
|
|
extra_meta = {
|
|
"foo": 5,
|
|
"bar": "hello",
|
|
}
|
|
sub = SubclassWithMeta(torch.randn(3), extra=extra_meta)
|
|
ref, loaded = self._test_serialization(
|
|
"TENSOR_SUBCLASS_METADATA_MATCH", fn, sub
|
|
)
|
|
self._test_check_fn(ref, loaded, {"x": sub}, True)
|
|
check_with_meta = _get_meta_test_check_fn(ref, loaded, SubclassWithMeta)
|
|
check_with_meta(dict(extra_meta), True)
|
|
# different "foo"
|
|
check_with_meta({"foo": 6, "bar": "hello"}, False)
|
|
# different "bar"
|
|
check_with_meta({"foo": 5, "bar": "world"}, False)
|
|
|
|
# === example subclass with custom metadata guard logic ===
|
|
sub = SubclassWithCustomMetadataGuard(torch.randn(3), extra=extra_meta)
|
|
ref, loaded = self._test_serialization(
|
|
"TENSOR_SUBCLASS_METADATA_MATCH", fn, sub
|
|
)
|
|
self._test_check_fn(ref, loaded, {"x": sub}, True)
|
|
check_with_meta = _get_meta_test_check_fn(
|
|
ref, loaded, SubclassWithCustomMetadataGuard
|
|
)
|
|
check_with_meta(dict(extra_meta), True)
|
|
# different "foo"; custom logic says this is okay
|
|
check_with_meta({"foo": 6, "bar": "hello"}, True)
|
|
# different "bar"
|
|
check_with_meta({"foo": 5, "bar": "world"}, False)
|
|
|
|
# === example subclass with subclass inner tensor ===
|
|
sub = SubclassWithSubclassInnerTensor(torch.randn(3), extra=extra_meta)
|
|
ref, loaded = self._test_serialization(
|
|
"TENSOR_SUBCLASS_METADATA_MATCH", fn, sub
|
|
)
|
|
self._test_check_fn(ref, loaded, {"x": sub}, True)
|
|
check_with_meta = _get_meta_test_check_fn(
|
|
ref, loaded, SubclassWithSubclassInnerTensor
|
|
)
|
|
check_with_meta(dict(extra_meta), True)
|
|
# different "foo"
|
|
check_with_meta({"foo": 6, "bar": "hello"}, False)
|
|
# different "bar"
|
|
check_with_meta({"foo": 5, "bar": "world"}, False)
|
|
|
|
def test_equals_match(self):
|
|
def fn(x, y):
|
|
# CustomConstantType is registered as a pytree constant so this should
|
|
# result in an EQUALS_MATCH guard.
|
|
if x in y:
|
|
return torch.zeros(3)
|
|
return torch.ones(3)
|
|
|
|
x = CustomConstantType(4, 5)
|
|
y = [CustomConstantType(2, 3), CustomConstantType(4, 5)]
|
|
ref, loaded = self._test_serialization("EQUALS_MATCH", fn, x, y)
|
|
self._test_check_fn(ref, loaded, {"x": x, "y": y}, True)
|
|
# custom __eq__ says that CustomConstantType(4, 5) == CustomConstantType(4, 9)
|
|
self._test_check_fn(
|
|
ref,
|
|
loaded,
|
|
{
|
|
"x": CustomConstantType(4, 5),
|
|
"y": [CustomConstantType(2, 3), CustomConstantType(4, 9)],
|
|
},
|
|
True,
|
|
)
|
|
self._test_check_fn(ref, loaded, {"x": x, "y": []}, False)
|
|
self._test_check_fn(
|
|
ref,
|
|
loaded,
|
|
{
|
|
"x": x,
|
|
"y": [CustomConstantType(2, 3), CustomConstantType(6, 7)],
|
|
},
|
|
False,
|
|
)
|
|
|
|
def test_constant_match(self):
|
|
# === bool constant ===
|
|
def fn(x, y):
|
|
if y:
|
|
return x + 1
|
|
return x + 2
|
|
|
|
x = torch.randn(3)
|
|
y = True
|
|
|
|
ref, loaded = self._test_serialization("CONSTANT_MATCH", fn, x, y)
|
|
self._test_check_fn(ref, loaded, {"x": x, "y": y}, True)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "y": True}, True)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(4), "y": True}, True)
|
|
# guard should fail for different y value
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "y": False}, False)
|
|
|
|
# === None constant ===
|
|
def fn(x, y):
|
|
if y is None:
|
|
return x + 1
|
|
return x + 2
|
|
|
|
x = torch.randn(3)
|
|
y = None
|
|
|
|
ref, loaded = self._test_serialization("CONSTANT_MATCH", fn, x, y)
|
|
self._test_check_fn(ref, loaded, {"x": x, "y": y}, True)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "y": None}, True)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(4), "y": None}, True)
|
|
# guard should fail for non-None y value
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "y": 5}, False)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "y": True}, False)
|
|
|
|
# === int constant ===
|
|
def fn(x, y):
|
|
return x + y
|
|
|
|
x = torch.randn(3)
|
|
y = 5
|
|
|
|
ref, loaded = self._test_serialization("CONSTANT_MATCH", fn, x, y)
|
|
self._test_check_fn(ref, loaded, {"x": x, "y": y}, True)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "y": 5}, True)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(4), "y": 5}, True)
|
|
# guard should fail for different y value
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "y": 6}, False)
|
|
|
|
def test_nn_module(self):
|
|
def fn(m, x):
|
|
return m(x)
|
|
|
|
m = GlobalModule()
|
|
x = torch.randn(3)
|
|
|
|
# config setting controls whether the NN_MODULE guard is installed
|
|
with patch("torch._dynamo.config.inline_inbuilt_nn_modules", False):
|
|
# we don't support NN_MODULE because it adds an ID_MATCH guard, and we don't
|
|
# support that in serialization
|
|
with self.assertRaisesRegex(
|
|
PackageError, "NN_MODULE guard cannot be serialized."
|
|
):
|
|
self._test_serialization("NN_MODULE", fn, m, x)
|
|
|
|
def test_function_match(self):
|
|
def fn(x):
|
|
# usage of this context manager installs a FUNCTION_MATCH guard
|
|
with torch.no_grad():
|
|
y = x * 2
|
|
return y
|
|
|
|
x = torch.randn(3)
|
|
|
|
# we don't support FUNCTION_MATCH because it adds an ID_MATCH guard, and we don't
|
|
# support that in serialization
|
|
with self.assertRaisesRegex(
|
|
PackageError, "FUNCTION_MATCH guard cannot be serialized."
|
|
):
|
|
self._test_serialization("FUNCTION_MATCH", fn, x)
|
|
|
|
def test_closure_match(self):
|
|
def fn(x):
|
|
# usage of this global function installs a CLOSURE_MATCH guard
|
|
return global_func(x)
|
|
|
|
x = torch.randn(3)
|
|
|
|
# we don't support CLOSURE_MATCH because it adds a FUNCTION_MATCH guard, and we don't
|
|
# support that in serialization
|
|
with self.assertRaisesRegex(
|
|
PackageError, "CLOSURE_MATCH guard cannot be serialized."
|
|
):
|
|
self._test_serialization("CLOSURE_MATCH", fn, x)
|
|
|
|
def test_sequence_length(self):
|
|
# tuple input installs a SEQUENCE_LENGTH guard
|
|
def fn(t, x):
|
|
return t[1] + x
|
|
|
|
t = tuple(torch.randn(3) for _ in range(3))
|
|
x = torch.randn(3)
|
|
|
|
ref, loaded = self._test_serialization("SEQUENCE_LENGTH", fn, t, x)
|
|
self._test_check_fn(ref, loaded, {"x": x, "t": t}, True)
|
|
self._test_check_fn(
|
|
ref,
|
|
loaded,
|
|
{
|
|
"x": torch.randn(3),
|
|
"t": tuple(torch.randn(3) for _ in range(3)),
|
|
},
|
|
True,
|
|
)
|
|
# different types in tuple of same length shouldn't fail SEQUENCE_LENGTH guard
|
|
# (it should fail the separate TYPE_MATCH guard but that isn't tested here)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "t": (0, 1, 2)}, True)
|
|
# different length tuple
|
|
self._test_check_fn(
|
|
ref,
|
|
loaded,
|
|
{
|
|
"x": torch.randn(3),
|
|
"t": tuple(torch.randn(3) for _ in range(4)),
|
|
},
|
|
False,
|
|
)
|
|
|
|
def test_tuple_iterator_len(self):
|
|
def fn(t, x):
|
|
if len(list(t)) > 2:
|
|
return x * 2
|
|
return x + 1
|
|
|
|
tup = (1, 2, 3)
|
|
x = torch.randn(3)
|
|
|
|
# func to generate kwargs; useful for avoiding iterator exhaustion issues
|
|
def _gen_kwargs(tup=tup, x=x):
|
|
return {"t": iter(tup), "x": x}
|
|
|
|
ref, loaded = self._test_serialization(
|
|
"TUPLE_ITERATOR_LEN", fn, _gen_fn=_gen_kwargs
|
|
)
|
|
|
|
# same tuple
|
|
self._test_check_fn(ref, loaded, {"t": iter(tup), "x": x}, True)
|
|
self._test_check_fn(ref, loaded, {"t": iter(tup), "x": torch.randn(4)}, True)
|
|
# same length tuple, different contents
|
|
self._test_check_fn(ref, loaded, {"t": iter((3, 2, 1)), "x": x}, True)
|
|
self._test_check_fn(
|
|
ref, loaded, {"t": iter((3, 2, 1)), "x": torch.randn(4)}, True
|
|
)
|
|
# different tuple lengths
|
|
self._test_check_fn(ref, loaded, {"t": iter((1, 2)), "x": x}, False)
|
|
self._test_check_fn(
|
|
ref, loaded, {"t": iter((1, 2)), "x": torch.randn(4)}, False
|
|
)
|
|
self._test_check_fn(ref, loaded, {"t": iter((1, 2, 3, 4)), "x": x}, False)
|
|
self._test_check_fn(
|
|
ref, loaded, {"t": iter((1, 2, 3, 4)), "x": torch.randn(4)}, False
|
|
)
|
|
|
|
def test_range_iterator_match(self):
|
|
def fn(x, r):
|
|
y = x
|
|
for val in r:
|
|
y = x + val
|
|
return y
|
|
|
|
x = torch.randn(3)
|
|
|
|
def _gen_kwargs(x=x):
|
|
return {"x": x, "r": iter(range(2, 15, 3))}
|
|
|
|
ref, loaded = self._test_serialization(
|
|
"RANGE_ITERATOR_MATCH", fn, _gen_fn=_gen_kwargs
|
|
)
|
|
|
|
# same range
|
|
self._test_check_fn(ref, loaded, {"x": x, "r": iter(range(2, 15, 3))}, True)
|
|
self._test_check_fn(
|
|
ref, loaded, {"x": torch.randn(4), "r": iter(range(2, 15, 3))}, True
|
|
)
|
|
# equivalent even with different end
|
|
self._test_check_fn(ref, loaded, {"x": x, "r": iter(range(2, 16, 3))}, True)
|
|
self._test_check_fn(
|
|
ref, loaded, {"x": torch.randn(4), "r": iter(range(2, 16, 3))}, True
|
|
)
|
|
# different start
|
|
self._test_check_fn(ref, loaded, {"x": x, "r": iter(range(1, 15, 3))}, False)
|
|
self._test_check_fn(
|
|
ref, loaded, {"x": torch.randn(4), "r": iter(range(1, 15, 3))}, False
|
|
)
|
|
# different end resulting in different values
|
|
self._test_check_fn(ref, loaded, {"x": x, "r": iter(range(2, 18, 3))}, False)
|
|
self._test_check_fn(
|
|
ref, loaded, {"x": torch.randn(4), "r": iter(range(2, 18, 3))}, False
|
|
)
|
|
# different step
|
|
self._test_check_fn(ref, loaded, {"x": x, "r": iter(range(2, 15, 4))}, False)
|
|
self._test_check_fn(
|
|
ref, loaded, {"x": torch.randn(4), "r": iter(range(2, 15, 4))}, False
|
|
)
|
|
|
|
def test_dict_version(self):
|
|
def fn(x):
|
|
return pytree.tree_leaves(x)[0] + 1
|
|
|
|
with self.assertRaisesRegex(
|
|
PackageError, "DICT_VERSION guard cannot be serialized."
|
|
):
|
|
self._test_serialization("DICT_VERSION", fn, {"t": torch.randn(3)})
|
|
|
|
def test_dict_contains(self):
|
|
def fn(x):
|
|
if x.__contains__("t"):
|
|
return x["t"] + 1
|
|
else:
|
|
return torch.ones(3)
|
|
|
|
ref, loaded = self._test_serialization(
|
|
"DICT_CONTAINS", fn, {"t": torch.randn(3)}
|
|
)
|
|
|
|
self._test_check_fn(ref, loaded, {"x": {"t": torch.randn(3)}}, True)
|
|
self._test_check_fn(ref, loaded, {"x": {}}, False)
|
|
self._test_check_fn(
|
|
ref, loaded, {"x": {"t": torch.randn(3), "d": torch.randn(3)}}, True
|
|
)
|
|
|
|
def test_bool_match(self):
|
|
def fn(x, b):
|
|
if b:
|
|
return x + 1
|
|
else:
|
|
return x + 2
|
|
|
|
ref, loaded = self._test_serialization("BOOL_MATCH", fn, torch.randn(3), True)
|
|
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "b": True}, True)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "b": False}, False)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "b": None}, False)
|
|
|
|
def test_none_match(self):
|
|
def fn(x, b):
|
|
if b is None:
|
|
return x + 1
|
|
else:
|
|
return x + 2
|
|
|
|
ref, loaded = self._test_serialization("NONE_MATCH", fn, torch.randn(3), None)
|
|
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "b": None}, True)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "b": False}, False)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3), "b": True}, False)
|
|
|
|
def test_id_match(self):
|
|
def fn(x):
|
|
return x + id(x)
|
|
|
|
with self.assertRaisesRegex(
|
|
PackageError, "ID_MATCH guard cannot be serialized."
|
|
):
|
|
self._test_serialization("ID_MATCH", fn, torch.randn(3))
|
|
|
|
def test_dispatch_key_set_match(self):
|
|
def fn(x, dks):
|
|
if dks.has("CPU"):
|
|
return torch.sin(x + 1)
|
|
else:
|
|
return torch.sin(x - 1)
|
|
|
|
x = torch.randn(3)
|
|
dks = torch._C._dispatch_keys(x)
|
|
ref, loaded = self._test_serialization("DISPATCH_KEY_SET_MATCH", fn, x, dks)
|
|
|
|
self._test_check_fn(ref, loaded, {"x": x, "dks": dks}, True)
|
|
|
|
x = torch.randn(3, device="meta")
|
|
dks = torch._C._dispatch_keys(x)
|
|
self._test_check_fn(ref, loaded, {"x": x, "dks": dks}, False)
|
|
|
|
def test_name_match(self):
|
|
def fn(x, y):
|
|
return torch.cond(x, lambda x: y + 1, lambda x: y - 1, (y,))
|
|
|
|
x = torch.tensor(True)
|
|
y = torch.randn(3)
|
|
ref, loaded = self._test_serialization("NAME_MATCH", fn, x, y)
|
|
|
|
self._test_check_fn(ref, loaded, {"x": x, "y": y}, True)
|
|
|
|
op = importlib.import_module("torch._higher_order_ops.cond").cond_op
|
|
prev, op.__name__ = op.__name__, ""
|
|
try:
|
|
self._test_check_fn(ref, loaded, {"x": x, "y": y}, False)
|
|
finally:
|
|
op.__name__ = prev
|
|
|
|
def test_dual_level(self):
|
|
def fn(x):
|
|
with torch.autograd.forward_ad.dual_level():
|
|
return x + 1
|
|
|
|
x = torch.randn(3)
|
|
ref, loaded = self._test_serialization("DUAL_LEVEL", fn, x)
|
|
|
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
|
with torch.autograd.forward_ad.dual_level():
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
def test_functorch_stack_match(self):
|
|
# Test when functorch stack is empty.
|
|
def fn(x):
|
|
return torch.func.jvp(torch.sin, (x,), (x,))
|
|
|
|
x = torch.randn(3, 4)
|
|
ref, loaded = self._test_serialization("FUNCTORCH_STACK_MATCH", fn, x)
|
|
|
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
|
with torch._functorch.vmap.vmap_increment_nesting(2, "error"):
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
def fn(x):
|
|
def g(x):
|
|
return torch.vmap(torch.func.grad(torch.sin))(x)
|
|
|
|
return torch.vmap(g)(x)
|
|
|
|
x = torch.randn(4, 5)
|
|
ref, loaded = self._test_serialization("FUNCTORCH_STACK_MATCH", fn, x)
|
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
|
with torch._functorch.eager_transforms.grad_increment_nesting():
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
# Test when there are more than 0 functorch layers.
|
|
# Simulate the case where torch.compile is nested inside eager transforms.
|
|
|
|
# Case 1: vmap
|
|
def fn(x):
|
|
return x.sum()
|
|
|
|
ref = loaded = None
|
|
|
|
def run(x):
|
|
nonlocal ref, loaded
|
|
# Turn off automatic dynamic shape to so that functionalization
|
|
# doesn't produce extra SymInt to serialize.
|
|
with torch._dynamo.config.patch(automatic_dynamic_shapes=False):
|
|
ref, loaded = self._test_serialization("FUNCTORCH_STACK_MATCH", fn, x)
|
|
return fn(x)
|
|
|
|
torch.vmap(run)(x)
|
|
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
|
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
with torch._functorch.eager_transforms.grad_increment_nesting():
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
# Case 2: grad
|
|
x = torch.randn(3, 2)
|
|
ref = loaded = None
|
|
torch.func.grad(run)(x)
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
with torch._functorch.eager_transforms.grad_increment_nesting():
|
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
|
with torch._functorch.eager_transforms.grad_increment_nesting():
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
# Case 3: jvp + vmap
|
|
x = torch.randn(3, 4)
|
|
ref = loaded = None
|
|
|
|
def fn(x):
|
|
return torch.func.jvp(torch.sin, (x,), (x,))
|
|
|
|
torch.func.jvp(torch.vmap(run), (x,), (x,))
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
with torch._functorch.eager_transforms.jvp_increment_nesting():
|
|
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
|
|
|
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
|
with torch._functorch.eager_transforms.jvp_increment_nesting():
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
# Case 4: functionalize
|
|
x = torch.randn(3, 2)
|
|
ref = loaded = None
|
|
torch.func.functionalize(run)(x)
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
torch._C._functorch._func_increment_nesting(True)
|
|
try:
|
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
|
finally:
|
|
torch._C._functorch._func_decrement_nesting()
|
|
|
|
with torch._functorch.eager_transforms.jvp_increment_nesting():
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
# Case 5: vmap + grad
|
|
def fn(x):
|
|
return x.sum()
|
|
|
|
x = torch.randn(3, 2)
|
|
ref = loaded = None
|
|
torch.vmap(torch.func.grad(run))(x)
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
|
with torch._functorch.eager_transforms.grad_increment_nesting():
|
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
|
|
|
with torch._functorch.eager_transforms.grad_increment_nesting():
|
|
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
with torch._functorch.eager_transforms.grad_increment_nesting():
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
def test_duplicate_input(self):
|
|
def fn(x, x_):
|
|
return x + x_
|
|
|
|
x = torch.randn(3, 2)
|
|
with self.assertRaisesRegex(
|
|
PackageError, "DUPLICATE_INPUT guard cannot be serialized"
|
|
):
|
|
self._test_serialization("DUPLICATE_INPUT", fn, x, x)
|
|
|
|
def test_weakref_alive(self):
|
|
mod = torch.nn.Linear(10, 10, bias=False)
|
|
for p in mod.parameters():
|
|
p.grad = torch.rand_like(p)
|
|
|
|
opt = torch.optim.SGD(mod.parameters(), lr=0.1)
|
|
|
|
def fn():
|
|
params = []
|
|
opt._init_group(opt.param_groups[0], params, [], [])
|
|
return params[0].sum()
|
|
|
|
with self.assertRaisesRegex(
|
|
PackageError, "WEAKREF_ALIVE guard cannot be serialized"
|
|
):
|
|
with torch.set_grad_enabled(False):
|
|
self._test_serialization("WEAKREF_ALIVE", fn)
|
|
|
|
def test_mapping_keys_check(self):
|
|
def fn(mp):
|
|
return mp["a"] + 1
|
|
|
|
mp = types.MappingProxyType({"a": torch.randn(3, 2), "b": torch.randn(3, 2)})
|
|
ref, loaded = self._test_serialization("MAPPING_KEYS_CHECK", fn, mp)
|
|
self._test_check_fn(ref, loaded, {"mp": mp}, True)
|
|
self._test_check_fn(
|
|
ref,
|
|
loaded,
|
|
{
|
|
"mp": types.MappingProxyType(
|
|
{"b": torch.randn(3, 2), "a": torch.randn(3, 2)}
|
|
)
|
|
},
|
|
False,
|
|
)
|
|
self._test_check_fn(
|
|
ref, loaded, {"mp": types.MappingProxyType({"a": torch.randn(3, 2)})}, False
|
|
)
|
|
|
|
def test_dict_keys_match(self):
|
|
def fn(x):
|
|
ret = 1
|
|
for k in x:
|
|
ret += x[k]
|
|
return ret
|
|
|
|
x = {"a": torch.randn(3, 2), "b": torch.randn(3, 2)}
|
|
ref, loaded = self._test_serialization("DICT_KEYS_MATCH", fn, x)
|
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
|
self._test_check_fn(
|
|
ref,
|
|
loaded,
|
|
{"x": {"b": torch.randn(3, 2), "a": torch.randn(3, 2)}},
|
|
False,
|
|
)
|
|
self._test_check_fn(ref, loaded, {"x": {"a": torch.randn(3, 2)}}, False)
|
|
|
|
@torch._dynamo.config.patch("skip_nnmodule_hook_guards", False)
|
|
def test_empty_nn_module_hooks_dict(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor):
|
|
return x + 1
|
|
|
|
m = Module()
|
|
|
|
def fn(x):
|
|
return m(x)
|
|
|
|
x = torch.ones(2, dtype=torch.float32)
|
|
ref, loaded = self._test_serialization("EMPTY_NN_MODULE_HOOKS_DICT", fn, x)
|
|
self._test_check_fn(ref, loaded, {"m": m, "x": x}, True)
|
|
|
|
h = m.register_forward_hook(lambda *args, **kwargs: None)
|
|
self._test_check_fn(ref, loaded, {"m": m, "x": x}, False)
|
|
h.remove()
|
|
|
|
h = m.register_forward_pre_hook(lambda *args, **kwargs: None)
|
|
self._test_check_fn(ref, loaded, {"m": m, "x": x}, False)
|
|
h.remove()
|
|
|
|
h = m.register_backward_hook(lambda *args, **kwargs: None)
|
|
self._test_check_fn(ref, loaded, {"m": m, "x": x}, False)
|
|
h.remove()
|
|
|
|
def test_grad_mode(self):
|
|
def fn(x):
|
|
return x + 1
|
|
|
|
x = torch.randn(3, 2)
|
|
with torch.enable_grad():
|
|
ref, loaded = self._test_serialization("GRAD_MODE", fn, x)
|
|
with torch.no_grad():
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
with torch.enable_grad():
|
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
|
|
|
def test_grad_mode_loading(self):
|
|
def fn(x):
|
|
return x + 1
|
|
|
|
x = torch.randn(3, 2)
|
|
with torch.enable_grad():
|
|
ref, _ = self._test_serialization("GRAD_MODE", fn, x)
|
|
with torch.no_grad():
|
|
# Ensure guards state loading is not affected by the current global grad mode.
|
|
guards_state = pickle.loads(self._cached_guards_state)
|
|
check_fn_manager = CheckFunctionManager(
|
|
self._cached_f_code,
|
|
guards_state.output_graph,
|
|
guards_serialization_mode="load",
|
|
shape_code_parts=guards_state.shape_code_parts,
|
|
)
|
|
loaded = check_fn_manager.guard_manager
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
|
|
def test_deterministic_algorithms(self):
|
|
def fn(x):
|
|
return x + 1
|
|
|
|
deterministic_restore = torch.are_deterministic_algorithms_enabled()
|
|
try:
|
|
x = torch.randn(3, 2)
|
|
torch.use_deterministic_algorithms(True)
|
|
ref, loaded = self._test_serialization("DETERMINISTIC_ALGORITHMS", fn, x)
|
|
torch.use_deterministic_algorithms(False)
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
torch.use_deterministic_algorithms(True)
|
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
|
finally:
|
|
torch.use_deterministic_algorithms(deterministic_restore)
|
|
|
|
def test_torch_function_state(self):
|
|
def fn(x):
|
|
return x + 1
|
|
|
|
x = torch.randn(3, 2)
|
|
|
|
class LocalTorchFunctionMode(TorchFunctionMode):
|
|
def __torch_function__(self, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
return func(*args, **kwargs)
|
|
|
|
with GlobalTorchFunctionMode():
|
|
ref, loaded = self._test_serialization("TORCH_FUNCTION_STATE", fn, x)
|
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
with GlobalTorchFunctionMode():
|
|
with torch._C.DisableTorchFunction():
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
with self.assertRaisesRegex(
|
|
PackageError,
|
|
"defined in local scope. Please define the class at global scope",
|
|
):
|
|
with LocalTorchFunctionMode():
|
|
ref, loaded = self._test_serialization("TORCH_FUNCTION_STATE", fn, x)
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
def test_fsdp_training_state(self):
|
|
from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState
|
|
from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup
|
|
|
|
param_group = FSDPParamGroup(
|
|
[], # params: List[nn.Parameter],
|
|
(torch.nn.Linear(1, 1),), # module: nn.Module,
|
|
None, # mesh_info: FSDPMeshInfo,
|
|
None, # post_forward_mesh_info: Optional[FSDPMeshInfo],
|
|
torch.device("cpu"), # device: torch.device,
|
|
None, # shard_placement_fn: Optional[Callable],
|
|
None, # mp_policy: MixedPrecisionPolicy,
|
|
None, # offload_policy: OffloadPolicy,
|
|
)
|
|
|
|
def fn(x):
|
|
with param_group.use_training_state(TrainingState.FORWARD):
|
|
if param_group._training_state == TrainingState.FORWARD:
|
|
return x + 1
|
|
else:
|
|
return x - 1
|
|
|
|
x = torch.randn(3, 2)
|
|
|
|
with torch.enable_grad():
|
|
ref, loaded = self._test_serialization("FSDP_TRAINING_STATE", fn, x)
|
|
with torch.no_grad():
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
with torch.enable_grad():
|
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
|
|
|
def test_default_device(self):
|
|
device = torch.get_default_device()
|
|
|
|
def fn(x):
|
|
return x + 1
|
|
|
|
x = torch.randn(3, 2)
|
|
try:
|
|
torch.set_default_device("cpu")
|
|
ref, loaded = self._test_serialization("DEFAULT_DEVICE", fn, x)
|
|
torch.set_default_device("meta")
|
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
|
torch.set_default_device("cpu")
|
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
|
finally:
|
|
torch.set_default_device(device)
|
|
|
|
def test_shape_env(self):
|
|
def fn(x):
|
|
return x + 1
|
|
|
|
x = torch.randn(3, 2)
|
|
ref, loaded = self._test_serialization("SHAPE_ENV", fn, x)
|
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
|
|
|
x = torch.randn(3, 2)
|
|
torch._dynamo.mark_dynamic(x, 0, min=3, max=10)
|
|
ref, loaded = self._test_serialization("SHAPE_ENV", fn, x)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(4, 2)}, True)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(10, 2)}, True)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(11, 2)}, False)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(2, 2)}, False)
|
|
|
|
x = torch.randn(3, 3, 2)
|
|
torch._dynamo.mark_dynamic(x, 1, min=3, max=10)
|
|
ref, loaded = self._test_serialization("SHAPE_ENV", fn, x)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3, 4, 2)}, True)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3, 10, 2)}, True)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3, 11, 2)}, False)
|
|
self._test_check_fn(ref, loaded, {"x": torch.randn(3, 2, 2)}, False)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|