mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[JIT] Partially support ForwardRef type annotations for NamedTuple attributes (#96933)
**Summary** NamedTuple attributes can be annotated to declare their type: ```python class MyNamedTuple(NamedTuple): x: int y: torch.Tensor z: MyOtherType ``` Normally in python you can also declare your types as strings, `x: 'int'`. But NamedTuples previously didn't support this, because their annotation evaluation process was slightly different. This PR updates the NamedTuple attribute type annotation evaluation method to support ForwardRef declarations (i.e. declaring as strings). **Details** Below I repeat the comment I left in _jit_internal.py: NamedTuple types are slightly different from normal types. Normally, annotations are evaluted like this (during jit.script): 1. Load strings of python code into c++ and parse. 2. Get annotations as strings 3. Use the PythonResolver's resolution callback (rcb) to convert the string into a python object 4. We call into annotations.py:ann_to_type to convert python obj from step 3 into a type that torchscript understands. NamedTuples are more complicated, because they have sub-types. Normally, once we have the NamedTuple type object from #3, we can just look at the annotation literal values and use ann_to_type directly on them. But sometimes, users will annotate with string literals, e.g. ``` x: 'int' ``` This also happens with PEP563 (from __forward__ import annotations) These annotations appear in the annotation dict as ForwardRef('int'). Then, we need to convert the string into a python object. This requires having local context for custom objects or imported types. rcb() is what gives us this. So, we plumb rcb through the stack so it can be used in this context for the if block below. FAQ: - Why do we need this special handling for NamedTuple but string annotations work fine for normal types? Normally, we parse the string directly and then call rcb() directly from C++. - Why not use ForwardRef._evaluate? For that, we need globals() and locals() for the local context where the NamedTuple was defined. rcb is what lets us look up into these. So, basically rcb does the hard work for us. - What is rcb? rcb is a ResolutionCallback - python callable that takes a string and returns a type. It's generated by `createResolutionCallback.*` in _jit_internal.py. **Why is this only partial support**: This only plumbs the rcb through some paths. In particular, the `toSugaredValue` path uses a fake rcb. **Alternatives**: We could also treat this the way we treat non-nn.Module classes: we evaluate them separately, ahead of time. That solution is probably better, but probably requires a more risky refactor for the way NamedTuples are handled. Fixes #95858 Pull Request resolved: https://github.com/pytorch/pytorch/pull/96933 Approved by: https://github.com/qihqi
This commit is contained in:
committed by
PyTorch MergeBot
parent
d850c33bfe
commit
a133b5081c
@ -17,7 +17,7 @@ from torch.testing import FileCheck
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -2084,6 +2084,62 @@ class TestNamedTuple(JitTestCase):
|
||||
for name in ['a', 'b', 'c']:
|
||||
self.assertEqual(getattr(out_loaded, name), getattr(out, name))
|
||||
|
||||
def test_namedtuple_inside_forwardref(self):
|
||||
class FeatureVector(NamedTuple):
|
||||
float_features: 'float'
|
||||
sequence_features: 'List[float]'
|
||||
time_since_first: 'float'
|
||||
|
||||
@torch.jit.script
|
||||
def foo(x) -> float:
|
||||
fv = FeatureVector(3.0, [3.0], 3.0)
|
||||
rv = fv.float_features
|
||||
for val in fv.sequence_features:
|
||||
rv += val
|
||||
rv *= fv.time_since_first
|
||||
return rv
|
||||
|
||||
self.assertEqual(foo(torch.rand(3, 4)), 18.0)
|
||||
|
||||
def test_namedtuple_input_forwardref(self):
|
||||
class MyNamedTuple(NamedTuple):
|
||||
a : int
|
||||
b : float
|
||||
c : torch.Tensor
|
||||
|
||||
make_global(MyNamedTuple)
|
||||
|
||||
nt = MyNamedTuple(4, 2.5, torch.rand((2, 2)))
|
||||
|
||||
def fn(obj: MyNamedTuple):
|
||||
return ((obj.c + obj.b) ** obj.a).sin()
|
||||
|
||||
expected = fn(nt)
|
||||
fn_s = torch.jit.script(fn)
|
||||
actual = fn_s(nt)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
# see #95858
|
||||
@unittest.expectedFailure
|
||||
def test_namedtuple_resolution_forwardref(self):
|
||||
class TheType(NamedTuple):
|
||||
t: 'int'
|
||||
|
||||
class MyModule(types.ModuleType):
|
||||
def __init__(self):
|
||||
super().__init__('MyModule')
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return TheType
|
||||
|
||||
some_module = MyModule()
|
||||
|
||||
def fn() -> some_module.Type:
|
||||
return some_module.Type(1)
|
||||
|
||||
self.checkScript(fn, [])
|
||||
|
||||
|
||||
class TestScriptDict(JitTestCase):
|
||||
"""
|
||||
This class contains a suite of tests for torch.jit.script, a
|
||||
|
@ -433,6 +433,24 @@ class TestSaveLoad(JitTestCase):
|
||||
output = m_loaded(FooTuple(a=5))
|
||||
self.assertEqual(output, torch.tensor(3))
|
||||
|
||||
def test_save_namedtuple_input_only_forwardref(self):
|
||||
"""
|
||||
Even if a NamedTuple is only used as an input argument, saving and
|
||||
loading should work correctly.
|
||||
"""
|
||||
global FooTuple # see [local resolution in python]
|
||||
|
||||
class FooTuple(NamedTuple):
|
||||
a: 'int'
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def forward(self, x: FooTuple) -> torch.Tensor:
|
||||
return torch.tensor(3)
|
||||
|
||||
m_loaded = self.getExportImportCopy(torch.jit.script(MyModule()))
|
||||
output = m_loaded(FooTuple(a=5))
|
||||
self.assertEqual(output, torch.tensor(3))
|
||||
|
||||
def test_save_namedtuple_output_only(self):
|
||||
"""
|
||||
Even if a NamedTuple is only used as an output argument, saving and
|
||||
|
@ -23,6 +23,7 @@ from typing import ( # noqa: F401
|
||||
Callable,
|
||||
Dict,
|
||||
Final,
|
||||
ForwardRef,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
@ -1199,7 +1200,7 @@ def _try_get_dispatched_fn(fn):
|
||||
|
||||
|
||||
def _get_named_tuple_properties(
|
||||
obj, loc: Optional[torch._C._jit_tree_views.SourceRange] = None
|
||||
obj, loc: Optional[torch._C._jit_tree_views.SourceRange] = None, rcb=None
|
||||
):
|
||||
if loc is None:
|
||||
loc = fake_range()
|
||||
@ -1225,7 +1226,53 @@ def _get_named_tuple_properties(
|
||||
annotations = []
|
||||
for field in obj._fields:
|
||||
if field in obj_annotations:
|
||||
the_type = torch.jit.annotations.ann_to_type(obj_annotations[field], loc)
|
||||
field_type = obj_annotations[field]
|
||||
# [Note: ForwardRef annotations in NamedTuple attributes]
|
||||
# NamedTuple types are slightly different from normal types.
|
||||
#
|
||||
# Normally, annotations are evaluted like this (during jit.script):
|
||||
# 1. Load strings of python code into c++ and parse.
|
||||
# 2. Get annotations as strings
|
||||
# 3. Use the PythonResolver's resolution callback (rcb) to convert
|
||||
# the string into a python object
|
||||
# 4. We call into annotations.py:ann_to_type to convert python obj
|
||||
# from step 3 into a type that torchscript understands.
|
||||
#
|
||||
# NamedTuples are more complicated, because it has sub-types.
|
||||
# Normally, once we have the NamedTuple type object from #3,
|
||||
# we can just look at the annotation literal values and use
|
||||
# ann_to_type directly on them.
|
||||
#
|
||||
# But sometimes, users will annotate with string literals, e.g.
|
||||
# x: 'int'
|
||||
# This also happens with PEP563 (from __forward__ import annotations)
|
||||
#
|
||||
# These annotations appear in the annotation dict as ForwardRef('int').
|
||||
#
|
||||
# Then, we need to convert the string into a python object. This
|
||||
# requires having local context for custom objects or imported types.
|
||||
# rcb() is what gives us this. So, we plumb rcb through the stack so
|
||||
# it can be used in this context for the if block below.
|
||||
#
|
||||
# FAQ:
|
||||
# - Why do we need this special handling for NamedTuple but string
|
||||
# annotations work fine for normal types? Normally, we parse the
|
||||
# string directly and then call rcb() directly from C++.
|
||||
# - Why not use ForwardRef._evaluate? For that, we need globals()
|
||||
# and locals() for the local context where the NamedTuple was defined.
|
||||
# rcb is what lets us look up into these. So, basically rcb does the
|
||||
# hard work for us.
|
||||
if isinstance(field_type, ForwardRef) and rcb is not None:
|
||||
rcb_type = rcb(field_type.__forward_arg__)
|
||||
# rcb returns None if it can't find anything.
|
||||
if rcb_type is None:
|
||||
raise ValueError(
|
||||
f"Unknown type annotation: '{field_type}' in NamedTuple {obj.__name__}."
|
||||
f" Likely due to partial support for ForwardRef parameters in NamedTuples, see #95858."
|
||||
f" Issue occurred at {loc.highlight()}"
|
||||
)
|
||||
field_type = rcb_type
|
||||
the_type = torch.jit.annotations.ann_to_type(field_type, loc, rcb)
|
||||
annotations.append(the_type)
|
||||
else:
|
||||
annotations.append(torch._C.TensorType.getInferred())
|
||||
|
@ -57,6 +57,8 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using ResolutionCallback = std::function<py::object(std::string)>;
|
||||
|
||||
void clear_registered_instances(void* ptr);
|
||||
|
||||
TORCH_PYTHON_API IValue toIValue(
|
||||
|
@ -1006,13 +1006,19 @@ bool isNamedTupleClass(const py::object& obj) {
|
||||
return is_tuple_class == 1 && py::hasattr(obj, "_fields");
|
||||
}
|
||||
|
||||
TypePtr registerNamedTuple(const py::object& obj, const SourceRange& loc) {
|
||||
TypePtr registerNamedTuple(
|
||||
const py::object& obj,
|
||||
const SourceRange& loc,
|
||||
const ResolutionCallback& rcb) {
|
||||
TORCH_INTERNAL_ASSERT(isNamedTupleClass(obj));
|
||||
auto qualifiedName = c10::QualifiedName(py::cast<std::string>(
|
||||
py::module::import("torch._jit_internal").attr("_qualified_name")(obj)));
|
||||
|
||||
py::object props = py::module::import("torch._jit_internal")
|
||||
.attr("_get_named_tuple_properties")(obj, loc);
|
||||
// Note: we need to pass rcb to resolve ForwardRef annotations. See
|
||||
// [Note: ForwardRef annotations in NamedTuple attributes]
|
||||
py::object props =
|
||||
py::module::import("torch._jit_internal")
|
||||
.attr("_get_named_tuple_properties")(obj, loc, py::cpp_function(rcb));
|
||||
|
||||
std::string unqualName;
|
||||
std::vector<std::string> field_names;
|
||||
@ -1290,7 +1296,14 @@ std::shared_ptr<SugaredValue> toSugaredValue(
|
||||
}
|
||||
|
||||
if (isNamedTupleClass(obj)) {
|
||||
auto tuple_type = registerNamedTuple(obj, loc)->expect<TupleType>();
|
||||
// The use of fakeRcb here prevents us from correctly resolving ForwardRef
|
||||
// annotations on NamedTuple attributes for instances whose types are
|
||||
// inferred. See #95858 for more details, as well as
|
||||
// [Note: ForwardRef annotations in NamedTuple attributes]
|
||||
auto fakeRcb =
|
||||
py::module::import("torch.jit.annotations").attr("_fake_rcb");
|
||||
auto tuple_type =
|
||||
registerNamedTuple(obj, loc, fakeRcb)->expect<TupleType>();
|
||||
return std::make_shared<NamedTupleConstructor>(tuple_type);
|
||||
}
|
||||
|
||||
|
@ -242,7 +242,10 @@ struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue {
|
||||
};
|
||||
|
||||
bool isNamedTupleClass(const py::object& obj);
|
||||
TypePtr registerNamedTuple(const py::object& obj, const SourceRange& loc);
|
||||
TypePtr registerNamedTuple(
|
||||
const py::object& obj,
|
||||
const SourceRange& loc,
|
||||
const ResolutionCallback& rcb);
|
||||
|
||||
void recurseThroughNestedModules(
|
||||
const SourceRange& loc,
|
||||
|
@ -76,7 +76,6 @@ namespace torch::jit {
|
||||
using ::c10::Argument;
|
||||
using ::c10::FunctionSchema;
|
||||
|
||||
using ResolutionCallback = std::function<py::object(std::string)>;
|
||||
using FunctionDefaults = std::unordered_map<std::string, py::object>;
|
||||
using ClassMethodDefaults = std::unordered_map<std::string, FunctionDefaults>;
|
||||
|
||||
@ -136,7 +135,7 @@ struct PythonResolver : public Resolver {
|
||||
}
|
||||
|
||||
if (isNamedTupleClass(obj)) {
|
||||
return registerNamedTuple(obj, loc);
|
||||
return registerNamedTuple(obj, loc, rcb_);
|
||||
}
|
||||
|
||||
auto qualifiedName = c10::QualifiedName(
|
||||
@ -157,8 +156,9 @@ struct PythonResolver : public Resolver {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto annotation_type = py::module::import("torch.jit.annotations")
|
||||
.attr("try_ann_to_type")(obj, loc);
|
||||
auto annotation_type =
|
||||
py::module::import("torch.jit.annotations")
|
||||
.attr("try_ann_to_type")(obj, loc, py::cpp_function(rcb_));
|
||||
if (!annotation_type.is_none()) {
|
||||
return py::cast<TypePtr>(annotation_type);
|
||||
}
|
||||
|
@ -315,8 +315,11 @@ def is_tensor(ann):
|
||||
return False
|
||||
|
||||
|
||||
def _fake_rcb(inp):
|
||||
return None
|
||||
|
||||
def try_ann_to_type(ann, loc):
|
||||
|
||||
def try_ann_to_type(ann, loc, rcb=None):
|
||||
if ann is inspect.Signature.empty:
|
||||
return TensorType.getInferred()
|
||||
if ann is None:
|
||||
@ -410,13 +413,13 @@ def try_ann_to_type(ann, loc):
|
||||
return torch.jit._script._recursive_compile_class(ann, loc)
|
||||
|
||||
# Maybe resolve a NamedTuple to a Tuple Type
|
||||
def fake_rcb(key):
|
||||
return None
|
||||
return torch._C._resolve_type_from_object(ann, loc, fake_rcb)
|
||||
if rcb is None:
|
||||
rcb = _fake_rcb
|
||||
return torch._C._resolve_type_from_object(ann, loc, rcb)
|
||||
|
||||
|
||||
def ann_to_type(ann, loc):
|
||||
the_type = try_ann_to_type(ann, loc)
|
||||
def ann_to_type(ann, loc, rcb=None):
|
||||
the_type = try_ann_to_type(ann, loc, rcb)
|
||||
if the_type is not None:
|
||||
return the_type
|
||||
raise ValueError(f"Unknown type annotation: '{ann}' at {loc.highlight()}")
|
||||
|
Reference in New Issue
Block a user