mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[JIT] Partially support ForwardRef type annotations for NamedTuple attributes (#96933)
**Summary** NamedTuple attributes can be annotated to declare their type: ```python class MyNamedTuple(NamedTuple): x: int y: torch.Tensor z: MyOtherType ``` Normally in python you can also declare your types as strings, `x: 'int'`. But NamedTuples previously didn't support this, because their annotation evaluation process was slightly different. This PR updates the NamedTuple attribute type annotation evaluation method to support ForwardRef declarations (i.e. declaring as strings). **Details** Below I repeat the comment I left in _jit_internal.py: NamedTuple types are slightly different from normal types. Normally, annotations are evaluted like this (during jit.script): 1. Load strings of python code into c++ and parse. 2. Get annotations as strings 3. Use the PythonResolver's resolution callback (rcb) to convert the string into a python object 4. We call into annotations.py:ann_to_type to convert python obj from step 3 into a type that torchscript understands. NamedTuples are more complicated, because they have sub-types. Normally, once we have the NamedTuple type object from #3, we can just look at the annotation literal values and use ann_to_type directly on them. But sometimes, users will annotate with string literals, e.g. ``` x: 'int' ``` This also happens with PEP563 (from __forward__ import annotations) These annotations appear in the annotation dict as ForwardRef('int'). Then, we need to convert the string into a python object. This requires having local context for custom objects or imported types. rcb() is what gives us this. So, we plumb rcb through the stack so it can be used in this context for the if block below. FAQ: - Why do we need this special handling for NamedTuple but string annotations work fine for normal types? Normally, we parse the string directly and then call rcb() directly from C++. - Why not use ForwardRef._evaluate? For that, we need globals() and locals() for the local context where the NamedTuple was defined. rcb is what lets us look up into these. So, basically rcb does the hard work for us. - What is rcb? rcb is a ResolutionCallback - python callable that takes a string and returns a type. It's generated by `createResolutionCallback.*` in _jit_internal.py. **Why is this only partial support**: This only plumbs the rcb through some paths. In particular, the `toSugaredValue` path uses a fake rcb. **Alternatives**: We could also treat this the way we treat non-nn.Module classes: we evaluate them separately, ahead of time. That solution is probably better, but probably requires a more risky refactor for the way NamedTuples are handled. Fixes #95858 Pull Request resolved: https://github.com/pytorch/pytorch/pull/96933 Approved by: https://github.com/qihqi
This commit is contained in:
committed by
PyTorch MergeBot
parent
d850c33bfe
commit
a133b5081c
@ -23,6 +23,7 @@ from typing import ( # noqa: F401
|
||||
Callable,
|
||||
Dict,
|
||||
Final,
|
||||
ForwardRef,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
@ -1199,7 +1200,7 @@ def _try_get_dispatched_fn(fn):
|
||||
|
||||
|
||||
def _get_named_tuple_properties(
|
||||
obj, loc: Optional[torch._C._jit_tree_views.SourceRange] = None
|
||||
obj, loc: Optional[torch._C._jit_tree_views.SourceRange] = None, rcb=None
|
||||
):
|
||||
if loc is None:
|
||||
loc = fake_range()
|
||||
@ -1225,7 +1226,53 @@ def _get_named_tuple_properties(
|
||||
annotations = []
|
||||
for field in obj._fields:
|
||||
if field in obj_annotations:
|
||||
the_type = torch.jit.annotations.ann_to_type(obj_annotations[field], loc)
|
||||
field_type = obj_annotations[field]
|
||||
# [Note: ForwardRef annotations in NamedTuple attributes]
|
||||
# NamedTuple types are slightly different from normal types.
|
||||
#
|
||||
# Normally, annotations are evaluted like this (during jit.script):
|
||||
# 1. Load strings of python code into c++ and parse.
|
||||
# 2. Get annotations as strings
|
||||
# 3. Use the PythonResolver's resolution callback (rcb) to convert
|
||||
# the string into a python object
|
||||
# 4. We call into annotations.py:ann_to_type to convert python obj
|
||||
# from step 3 into a type that torchscript understands.
|
||||
#
|
||||
# NamedTuples are more complicated, because it has sub-types.
|
||||
# Normally, once we have the NamedTuple type object from #3,
|
||||
# we can just look at the annotation literal values and use
|
||||
# ann_to_type directly on them.
|
||||
#
|
||||
# But sometimes, users will annotate with string literals, e.g.
|
||||
# x: 'int'
|
||||
# This also happens with PEP563 (from __forward__ import annotations)
|
||||
#
|
||||
# These annotations appear in the annotation dict as ForwardRef('int').
|
||||
#
|
||||
# Then, we need to convert the string into a python object. This
|
||||
# requires having local context for custom objects or imported types.
|
||||
# rcb() is what gives us this. So, we plumb rcb through the stack so
|
||||
# it can be used in this context for the if block below.
|
||||
#
|
||||
# FAQ:
|
||||
# - Why do we need this special handling for NamedTuple but string
|
||||
# annotations work fine for normal types? Normally, we parse the
|
||||
# string directly and then call rcb() directly from C++.
|
||||
# - Why not use ForwardRef._evaluate? For that, we need globals()
|
||||
# and locals() for the local context where the NamedTuple was defined.
|
||||
# rcb is what lets us look up into these. So, basically rcb does the
|
||||
# hard work for us.
|
||||
if isinstance(field_type, ForwardRef) and rcb is not None:
|
||||
rcb_type = rcb(field_type.__forward_arg__)
|
||||
# rcb returns None if it can't find anything.
|
||||
if rcb_type is None:
|
||||
raise ValueError(
|
||||
f"Unknown type annotation: '{field_type}' in NamedTuple {obj.__name__}."
|
||||
f" Likely due to partial support for ForwardRef parameters in NamedTuples, see #95858."
|
||||
f" Issue occurred at {loc.highlight()}"
|
||||
)
|
||||
field_type = rcb_type
|
||||
the_type = torch.jit.annotations.ann_to_type(field_type, loc, rcb)
|
||||
annotations.append(the_type)
|
||||
else:
|
||||
annotations.append(torch._C.TensorType.getInferred())
|
||||
|
Reference in New Issue
Block a user