[JIT] Partially support ForwardRef type annotations for NamedTuple attributes (#96933)

**Summary** NamedTuple attributes can be annotated to declare their type:
```python
class MyNamedTuple(NamedTuple):
    x: int
    y: torch.Tensor
    z: MyOtherType
```
Normally in python you can also declare your types as strings, `x: 'int'`. But NamedTuples previously didn't support this, because their annotation evaluation process was slightly different. This PR updates the NamedTuple attribute type annotation evaluation method to support ForwardRef declarations (i.e. declaring as strings).

**Details**

Below I repeat the comment I left in _jit_internal.py:

NamedTuple types are slightly different from normal types.

Normally, annotations are evaluted like this (during jit.script):
1. Load strings of python code into c++ and parse.
2. Get annotations as strings
3. Use the PythonResolver's resolution callback (rcb) to convert the string into a python object
4. We call into annotations.py:ann_to_type to convert python obj from step 3 into a type that torchscript understands.

NamedTuples are more complicated, because they have sub-types. Normally, once we have the NamedTuple type object from #3, we can just look at the annotation literal values and use ann_to_type directly on them.

But sometimes, users will annotate with string literals, e.g.
```
   x: 'int'
```
This also happens with PEP563 (from __forward__ import annotations)

These annotations appear in the annotation dict as ForwardRef('int').

Then, we need to convert the string into a python object. This requires having local context for custom objects or imported types. rcb() is what gives us this. So, we plumb rcb through the stack so it can be used in this context for the if block below.

FAQ:
- Why do we need this special handling for NamedTuple but string annotations work fine for normal types? Normally, we parse the string directly and then call rcb() directly from C++.
- Why not use ForwardRef._evaluate? For that, we need globals() and locals() for the local context where the NamedTuple was defined. rcb is what lets us look up into these. So, basically rcb does the hard work for us.
- What is rcb? rcb is a ResolutionCallback - python callable that takes a string and returns a type. It's generated by `createResolutionCallback.*` in _jit_internal.py.

**Why is this only partial support**:

This only plumbs the rcb through some paths. In particular, the `toSugaredValue` path uses a fake rcb.

**Alternatives**:

We could also treat this the way we treat non-nn.Module classes: we evaluate them separately, ahead of time. That solution is probably better, but probably requires a more risky refactor for the way NamedTuples are handled.

Fixes #95858

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96933
Approved by: https://github.com/qihqi
This commit is contained in:
David Berard
2023-03-22 15:20:38 +00:00
committed by PyTorch MergeBot
parent d850c33bfe
commit a133b5081c
8 changed files with 160 additions and 18 deletions

View File

@ -17,7 +17,7 @@ from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.jit_utils import JitTestCase, make_global
from torch.testing._internal.common_utils import skipIfTorchDynamo
if __name__ == '__main__':
@ -2084,6 +2084,62 @@ class TestNamedTuple(JitTestCase):
for name in ['a', 'b', 'c']:
self.assertEqual(getattr(out_loaded, name), getattr(out, name))
def test_namedtuple_inside_forwardref(self):
class FeatureVector(NamedTuple):
float_features: 'float'
sequence_features: 'List[float]'
time_since_first: 'float'
@torch.jit.script
def foo(x) -> float:
fv = FeatureVector(3.0, [3.0], 3.0)
rv = fv.float_features
for val in fv.sequence_features:
rv += val
rv *= fv.time_since_first
return rv
self.assertEqual(foo(torch.rand(3, 4)), 18.0)
def test_namedtuple_input_forwardref(self):
class MyNamedTuple(NamedTuple):
a : int
b : float
c : torch.Tensor
make_global(MyNamedTuple)
nt = MyNamedTuple(4, 2.5, torch.rand((2, 2)))
def fn(obj: MyNamedTuple):
return ((obj.c + obj.b) ** obj.a).sin()
expected = fn(nt)
fn_s = torch.jit.script(fn)
actual = fn_s(nt)
self.assertEqual(expected, actual)
# see #95858
@unittest.expectedFailure
def test_namedtuple_resolution_forwardref(self):
class TheType(NamedTuple):
t: 'int'
class MyModule(types.ModuleType):
def __init__(self):
super().__init__('MyModule')
def __getattr__(self, attr):
return TheType
some_module = MyModule()
def fn() -> some_module.Type:
return some_module.Type(1)
self.checkScript(fn, [])
class TestScriptDict(JitTestCase):
"""
This class contains a suite of tests for torch.jit.script, a

View File

@ -433,6 +433,24 @@ class TestSaveLoad(JitTestCase):
output = m_loaded(FooTuple(a=5))
self.assertEqual(output, torch.tensor(3))
def test_save_namedtuple_input_only_forwardref(self):
"""
Even if a NamedTuple is only used as an input argument, saving and
loading should work correctly.
"""
global FooTuple # see [local resolution in python]
class FooTuple(NamedTuple):
a: 'int'
class MyModule(torch.nn.Module):
def forward(self, x: FooTuple) -> torch.Tensor:
return torch.tensor(3)
m_loaded = self.getExportImportCopy(torch.jit.script(MyModule()))
output = m_loaded(FooTuple(a=5))
self.assertEqual(output, torch.tensor(3))
def test_save_namedtuple_output_only(self):
"""
Even if a NamedTuple is only used as an output argument, saving and

View File

@ -23,6 +23,7 @@ from typing import ( # noqa: F401
Callable,
Dict,
Final,
ForwardRef,
Generic,
List,
Optional,
@ -1199,7 +1200,7 @@ def _try_get_dispatched_fn(fn):
def _get_named_tuple_properties(
obj, loc: Optional[torch._C._jit_tree_views.SourceRange] = None
obj, loc: Optional[torch._C._jit_tree_views.SourceRange] = None, rcb=None
):
if loc is None:
loc = fake_range()
@ -1225,7 +1226,53 @@ def _get_named_tuple_properties(
annotations = []
for field in obj._fields:
if field in obj_annotations:
the_type = torch.jit.annotations.ann_to_type(obj_annotations[field], loc)
field_type = obj_annotations[field]
# [Note: ForwardRef annotations in NamedTuple attributes]
# NamedTuple types are slightly different from normal types.
#
# Normally, annotations are evaluted like this (during jit.script):
# 1. Load strings of python code into c++ and parse.
# 2. Get annotations as strings
# 3. Use the PythonResolver's resolution callback (rcb) to convert
# the string into a python object
# 4. We call into annotations.py:ann_to_type to convert python obj
# from step 3 into a type that torchscript understands.
#
# NamedTuples are more complicated, because it has sub-types.
# Normally, once we have the NamedTuple type object from #3,
# we can just look at the annotation literal values and use
# ann_to_type directly on them.
#
# But sometimes, users will annotate with string literals, e.g.
# x: 'int'
# This also happens with PEP563 (from __forward__ import annotations)
#
# These annotations appear in the annotation dict as ForwardRef('int').
#
# Then, we need to convert the string into a python object. This
# requires having local context for custom objects or imported types.
# rcb() is what gives us this. So, we plumb rcb through the stack so
# it can be used in this context for the if block below.
#
# FAQ:
# - Why do we need this special handling for NamedTuple but string
# annotations work fine for normal types? Normally, we parse the
# string directly and then call rcb() directly from C++.
# - Why not use ForwardRef._evaluate? For that, we need globals()
# and locals() for the local context where the NamedTuple was defined.
# rcb is what lets us look up into these. So, basically rcb does the
# hard work for us.
if isinstance(field_type, ForwardRef) and rcb is not None:
rcb_type = rcb(field_type.__forward_arg__)
# rcb returns None if it can't find anything.
if rcb_type is None:
raise ValueError(
f"Unknown type annotation: '{field_type}' in NamedTuple {obj.__name__}."
f" Likely due to partial support for ForwardRef parameters in NamedTuples, see #95858."
f" Issue occurred at {loc.highlight()}"
)
field_type = rcb_type
the_type = torch.jit.annotations.ann_to_type(field_type, loc, rcb)
annotations.append(the_type)
else:
annotations.append(torch._C.TensorType.getInferred())

View File

@ -57,6 +57,8 @@
namespace torch {
namespace jit {
using ResolutionCallback = std::function<py::object(std::string)>;
void clear_registered_instances(void* ptr);
TORCH_PYTHON_API IValue toIValue(

View File

@ -1006,13 +1006,19 @@ bool isNamedTupleClass(const py::object& obj) {
return is_tuple_class == 1 && py::hasattr(obj, "_fields");
}
TypePtr registerNamedTuple(const py::object& obj, const SourceRange& loc) {
TypePtr registerNamedTuple(
const py::object& obj,
const SourceRange& loc,
const ResolutionCallback& rcb) {
TORCH_INTERNAL_ASSERT(isNamedTupleClass(obj));
auto qualifiedName = c10::QualifiedName(py::cast<std::string>(
py::module::import("torch._jit_internal").attr("_qualified_name")(obj)));
py::object props = py::module::import("torch._jit_internal")
.attr("_get_named_tuple_properties")(obj, loc);
// Note: we need to pass rcb to resolve ForwardRef annotations. See
// [Note: ForwardRef annotations in NamedTuple attributes]
py::object props =
py::module::import("torch._jit_internal")
.attr("_get_named_tuple_properties")(obj, loc, py::cpp_function(rcb));
std::string unqualName;
std::vector<std::string> field_names;
@ -1290,7 +1296,14 @@ std::shared_ptr<SugaredValue> toSugaredValue(
}
if (isNamedTupleClass(obj)) {
auto tuple_type = registerNamedTuple(obj, loc)->expect<TupleType>();
// The use of fakeRcb here prevents us from correctly resolving ForwardRef
// annotations on NamedTuple attributes for instances whose types are
// inferred. See #95858 for more details, as well as
// [Note: ForwardRef annotations in NamedTuple attributes]
auto fakeRcb =
py::module::import("torch.jit.annotations").attr("_fake_rcb");
auto tuple_type =
registerNamedTuple(obj, loc, fakeRcb)->expect<TupleType>();
return std::make_shared<NamedTupleConstructor>(tuple_type);
}

View File

@ -242,7 +242,10 @@ struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue {
};
bool isNamedTupleClass(const py::object& obj);
TypePtr registerNamedTuple(const py::object& obj, const SourceRange& loc);
TypePtr registerNamedTuple(
const py::object& obj,
const SourceRange& loc,
const ResolutionCallback& rcb);
void recurseThroughNestedModules(
const SourceRange& loc,

View File

@ -76,7 +76,6 @@ namespace torch::jit {
using ::c10::Argument;
using ::c10::FunctionSchema;
using ResolutionCallback = std::function<py::object(std::string)>;
using FunctionDefaults = std::unordered_map<std::string, py::object>;
using ClassMethodDefaults = std::unordered_map<std::string, FunctionDefaults>;
@ -136,7 +135,7 @@ struct PythonResolver : public Resolver {
}
if (isNamedTupleClass(obj)) {
return registerNamedTuple(obj, loc);
return registerNamedTuple(obj, loc, rcb_);
}
auto qualifiedName = c10::QualifiedName(
@ -157,8 +156,9 @@ struct PythonResolver : public Resolver {
return nullptr;
}
auto annotation_type = py::module::import("torch.jit.annotations")
.attr("try_ann_to_type")(obj, loc);
auto annotation_type =
py::module::import("torch.jit.annotations")
.attr("try_ann_to_type")(obj, loc, py::cpp_function(rcb_));
if (!annotation_type.is_none()) {
return py::cast<TypePtr>(annotation_type);
}

View File

@ -315,8 +315,11 @@ def is_tensor(ann):
return False
def _fake_rcb(inp):
return None
def try_ann_to_type(ann, loc):
def try_ann_to_type(ann, loc, rcb=None):
if ann is inspect.Signature.empty:
return TensorType.getInferred()
if ann is None:
@ -410,13 +413,13 @@ def try_ann_to_type(ann, loc):
return torch.jit._script._recursive_compile_class(ann, loc)
# Maybe resolve a NamedTuple to a Tuple Type
def fake_rcb(key):
return None
return torch._C._resolve_type_from_object(ann, loc, fake_rcb)
if rcb is None:
rcb = _fake_rcb
return torch._C._resolve_type_from_object(ann, loc, rcb)
def ann_to_type(ann, loc):
the_type = try_ann_to_type(ann, loc)
def ann_to_type(ann, loc, rcb=None):
the_type = try_ann_to_type(ann, loc, rcb)
if the_type is not None:
return the_type
raise ValueError(f"Unknown type annotation: '{ann}' at {loc.highlight()}")