mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[JIT] Fix annotation extraction for named tuple (#81506)
Followup similar to https://github.com/pytorch/pytorch/pull/81334 In Python-3.10 type annotations from base class are not inherited to the child class which leads to the following error: ``` % python -c "import torch;torch.jit.script(torch.nn.utils.rnn.pack_padded_sequence)" ... PackedSequence(Tensor data, Tensor batch_sizes, Tensor sorted_indices, Tensor unsorted_indices) -> (): Expected a value of type 'Tensor (inferred)' for argument 'sorted_indices' but instead found type 'Optional[Tensor]'. Inferred 'sorted_indices' to be of type 'Tensor' because it was not annotated with an explicit type. : File "/Users/nshulga/git/pytorch-worktree/torch/nn/utils/rnn.py", line 197 data, batch_sizes, sorted_indices, unsorted_indices = _packed_sequence_init_args( data, batch_sizes, sorted_indices, unsorted_indices) return PackedSequence(data, batch_sizes, sorted_indices, unsorted_indices) ~~~~~~~~~~~~~~ <--- HERE ``` Which stems from the fact that `torch.nn.utils.rnn.PackedSequence.__annotations__` returns empty list for python-3.10 as seen below: ``` % conda run -n py_39 python3 -c "import sys;import torch;print(torch.nn.utils.rnn.PackedSequence.__annotations__, sys.version_info[:2])" {'data': <class 'torch.Tensor'>, 'batch_sizes': <class 'torch.Tensor'>, 'sorted_indices': typing.Optional[torch.Tensor], 'unsorted_indices': typing.Optional[torch.Tensor]} (3, 9) % conda run -n py_310 python3 -c "import sys;import torch;print(torch.nn.utils.rnn.PackedSequence.__annotations__, sys.version_info[:2])" {} (3, 10) ``` Fix by checking annotations of parent class if base one does not have any. Pull Request resolved: https://github.com/pytorch/pytorch/pull/81506 Approved by: https://github.com/suo
This commit is contained in:
committed by
PyTorch MergeBot
parent
938643b8bc
commit
75fdebde62
@ -1083,11 +1083,19 @@ def _get_named_tuple_properties(obj):
|
||||
if field in obj._field_defaults]
|
||||
else:
|
||||
defaults = []
|
||||
# In 3.10 recommended way to get annotations is to call `inspect.get_annotations` function
|
||||
# Also, annotations from base class are not inherited so they need to be queried explicitly
|
||||
if sys.version_info[:2] < (3, 10):
|
||||
obj_annotations = getattr(obj, '__annotations__', {})
|
||||
else:
|
||||
obj_annotations = inspect.get_annotations(obj)
|
||||
if len(obj_annotations) == 0 and hasattr(obj, "__base__"):
|
||||
obj_annotations = inspect.get_annotations(obj.__base__)
|
||||
|
||||
annotations = []
|
||||
has_annotations = hasattr(obj, '__annotations__')
|
||||
for field in obj._fields:
|
||||
if has_annotations and field in obj.__annotations__:
|
||||
the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], fake_range())
|
||||
if field in obj_annotations:
|
||||
the_type = torch.jit.annotations.ann_to_type(obj_annotations[field], fake_range())
|
||||
annotations.append(the_type)
|
||||
else:
|
||||
annotations.append(torch._C.TensorType.getInferred())
|
||||
|
Reference in New Issue
Block a user