[JIT] Fix annotation extraction for named tuple (#81506)

Followup similar to https://github.com/pytorch/pytorch/pull/81334

In Python-3.10 type annotations from base class are not inherited to the
child class which leads to the following error:
```
% python -c "import torch;torch.jit.script(torch.nn.utils.rnn.pack_padded_sequence)"
...
PackedSequence(Tensor data, Tensor batch_sizes, Tensor sorted_indices, Tensor unsorted_indices) -> ():
Expected a value of type 'Tensor (inferred)' for argument 'sorted_indices' but instead found type 'Optional[Tensor]'.
Inferred 'sorted_indices' to be of type 'Tensor' because it was not annotated with an explicit type.
:
  File "/Users/nshulga/git/pytorch-worktree/torch/nn/utils/rnn.py", line 197
    data, batch_sizes, sorted_indices, unsorted_indices = _packed_sequence_init_args(
        data, batch_sizes, sorted_indices, unsorted_indices)
    return PackedSequence(data, batch_sizes, sorted_indices, unsorted_indices)
           ~~~~~~~~~~~~~~ <--- HERE
```

Which stems from the fact that  `torch.nn.utils.rnn.PackedSequence.__annotations__` returns empty list for python-3.10 as seen below:
```
% conda run -n py_39 python3 -c "import sys;import torch;print(torch.nn.utils.rnn.PackedSequence.__annotations__, sys.version_info[:2])"
{'data': <class 'torch.Tensor'>, 'batch_sizes': <class 'torch.Tensor'>, 'sorted_indices': typing.Optional[torch.Tensor], 'unsorted_indices': typing.Optional[torch.Tensor]} (3, 9)

 % conda run -n py_310 python3 -c "import sys;import torch;print(torch.nn.utils.rnn.PackedSequence.__annotations__, sys.version_info[:2])"
{} (3, 10)
```

Fix by checking annotations of parent class if base one does not have any.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81506
Approved by: https://github.com/suo
This commit is contained in:
Nikita Shulga
2022-07-15 02:47:15 +00:00
committed by PyTorch MergeBot
parent 938643b8bc
commit 75fdebde62

View File

@ -1083,11 +1083,19 @@ def _get_named_tuple_properties(obj):
if field in obj._field_defaults]
else:
defaults = []
# In 3.10 recommended way to get annotations is to call `inspect.get_annotations` function
# Also, annotations from base class are not inherited so they need to be queried explicitly
if sys.version_info[:2] < (3, 10):
obj_annotations = getattr(obj, '__annotations__', {})
else:
obj_annotations = inspect.get_annotations(obj)
if len(obj_annotations) == 0 and hasattr(obj, "__base__"):
obj_annotations = inspect.get_annotations(obj.__base__)
annotations = []
has_annotations = hasattr(obj, '__annotations__')
for field in obj._fields:
if has_annotations and field in obj.__annotations__:
the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], fake_range())
if field in obj_annotations:
the_type = torch.jit.annotations.ann_to_type(obj_annotations[field], fake_range())
annotations.append(the_type)
else:
annotations.append(torch._C.TensorType.getInferred())