Support Union in TorchScript (#64234)

Summary:
This PR is created to replace https://github.com/pytorch/pytorch/pull/53180 PR stack, which has all the review discussions. Reason for needing a replacement is due to a messy Sandcastle issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64234

Reviewed By: gmagogsfm

Differential Revision: D30656444

Pulled By: ansley

fbshipit-source-id: 77536c8bcc88162e2c72636026ca3c16891d669a
This commit is contained in:
Ansley Ussery
2021-09-03 06:10:37 -07:00
committed by Facebook GitHub Bot
parent 91b926fab3
commit 6831d8e379
50 changed files with 2137 additions and 467 deletions

View File

@ -885,33 +885,28 @@ def is_dict(ann) -> bool:
(getattr(ann, '__origin__', None) is Dict or
getattr(ann, '__origin__', None) is dict)
def is_optional(ann) -> bool:
def is_union(ann):
if ann is Union:
raise_error_container_parameter_missing("Union")
return (hasattr(ann, '__module__') and
ann.__module__ == 'typing' and
(getattr(ann, '__origin__', None) is Union))
def is_optional(ann):
if ann is Optional:
raise_error_container_parameter_missing("Optional")
# Optional[T] is just shorthand for Union[T, None], so check for both
def safe_is_subclass(the_type, super_type):
# Don't throw if `the_type` isn't a class type (e.g. if it is
# another type annotation instance)
if not inspect.isclass(the_type):
return False
return issubclass(the_type, super_type)
def is_optional_as_optional(ann):
return (hasattr(ann, '__module__') and
ann.__module__ == 'typing' and
(getattr(ann, '__origin__', None) is Optional))
if not hasattr(ann, '__module__'):
return False
def is_union_as_optional(ann):
ann_args = ann.__args__
return len(ann_args) == 2 and None in ann_args
union_optional = False
if ann.__module__ == 'typing' and \
(getattr(ann, '__origin__', None) is Union):
args = getattr(ann, '__args__', ())
if len(args) == 2:
union_optional = (safe_is_subclass(args[1], type(None)) and not safe_is_subclass(args[0], type(None))) \
or (safe_is_subclass(args[0], type(None)) and not safe_is_subclass(args[1], type(None)))
optional = ann.__module__ == 'typing' and \
(getattr(ann, '__origin__', None) is Optional)
return optional or union_optional
return is_optional_as_optional(ann) or (is_union(ann) and is_union_as_optional(ann))
def is_future(ann) -> bool:
if ann is Future:
@ -1192,15 +1187,16 @@ def container_checker(obj, target_type) -> bool:
elif not isinstance(el, el_type):
return False
return True
elif origin_type is Union: # actually handles Optional Case
elif origin_type is Union: # also handles Optional
if obj is None: # check before recursion because None is always fine
return True
optional_type = get_args(target_type)[0]
optional_origin = get_origin(optional_type)
if optional_origin:
return container_checker(obj, optional_type)
elif isinstance(obj, optional_type):
return True
inner_types = get_args(target_type)
for t in inner_types:
t_origin = get_origin(t)
if (t_origin):
return container_checker(obj, t)
elif isinstance(obj, t):
return True
return False