mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support Union in TorchScript (#64234)
Summary: This PR is created to replace https://github.com/pytorch/pytorch/pull/53180 PR stack, which has all the review discussions. Reason for needing a replacement is due to a messy Sandcastle issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/64234 Reviewed By: gmagogsfm Differential Revision: D30656444 Pulled By: ansley fbshipit-source-id: 77536c8bcc88162e2c72636026ca3c16891d669a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
91b926fab3
commit
6831d8e379
@ -885,33 +885,28 @@ def is_dict(ann) -> bool:
|
||||
(getattr(ann, '__origin__', None) is Dict or
|
||||
getattr(ann, '__origin__', None) is dict)
|
||||
|
||||
def is_optional(ann) -> bool:
|
||||
def is_union(ann):
|
||||
if ann is Union:
|
||||
raise_error_container_parameter_missing("Union")
|
||||
|
||||
return (hasattr(ann, '__module__') and
|
||||
ann.__module__ == 'typing' and
|
||||
(getattr(ann, '__origin__', None) is Union))
|
||||
|
||||
def is_optional(ann):
|
||||
if ann is Optional:
|
||||
raise_error_container_parameter_missing("Optional")
|
||||
|
||||
# Optional[T] is just shorthand for Union[T, None], so check for both
|
||||
def safe_is_subclass(the_type, super_type):
|
||||
# Don't throw if `the_type` isn't a class type (e.g. if it is
|
||||
# another type annotation instance)
|
||||
if not inspect.isclass(the_type):
|
||||
return False
|
||||
return issubclass(the_type, super_type)
|
||||
def is_optional_as_optional(ann):
|
||||
return (hasattr(ann, '__module__') and
|
||||
ann.__module__ == 'typing' and
|
||||
(getattr(ann, '__origin__', None) is Optional))
|
||||
|
||||
if not hasattr(ann, '__module__'):
|
||||
return False
|
||||
def is_union_as_optional(ann):
|
||||
ann_args = ann.__args__
|
||||
return len(ann_args) == 2 and None in ann_args
|
||||
|
||||
union_optional = False
|
||||
if ann.__module__ == 'typing' and \
|
||||
(getattr(ann, '__origin__', None) is Union):
|
||||
args = getattr(ann, '__args__', ())
|
||||
if len(args) == 2:
|
||||
union_optional = (safe_is_subclass(args[1], type(None)) and not safe_is_subclass(args[0], type(None))) \
|
||||
or (safe_is_subclass(args[0], type(None)) and not safe_is_subclass(args[1], type(None)))
|
||||
|
||||
optional = ann.__module__ == 'typing' and \
|
||||
(getattr(ann, '__origin__', None) is Optional)
|
||||
|
||||
return optional or union_optional
|
||||
return is_optional_as_optional(ann) or (is_union(ann) and is_union_as_optional(ann))
|
||||
|
||||
def is_future(ann) -> bool:
|
||||
if ann is Future:
|
||||
@ -1192,15 +1187,16 @@ def container_checker(obj, target_type) -> bool:
|
||||
elif not isinstance(el, el_type):
|
||||
return False
|
||||
return True
|
||||
elif origin_type is Union: # actually handles Optional Case
|
||||
elif origin_type is Union: # also handles Optional
|
||||
if obj is None: # check before recursion because None is always fine
|
||||
return True
|
||||
optional_type = get_args(target_type)[0]
|
||||
optional_origin = get_origin(optional_type)
|
||||
if optional_origin:
|
||||
return container_checker(obj, optional_type)
|
||||
elif isinstance(obj, optional_type):
|
||||
return True
|
||||
inner_types = get_args(target_type)
|
||||
for t in inner_types:
|
||||
t_origin = get_origin(t)
|
||||
if (t_origin):
|
||||
return container_checker(obj, t)
|
||||
elif isinstance(obj, t):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user