diff --git a/test/jit/test_dataclasses.py b/test/jit/test_dataclasses.py deleted file mode 100644 index 0aebe8aba0ec..000000000000 --- a/test/jit/test_dataclasses.py +++ /dev/null @@ -1,165 +0,0 @@ -# Owner(s): ["oncall: jit"] -# flake8: noqa - -from dataclasses import dataclass, field, InitVar -from hypothesis import given, settings, strategies as st -from torch.testing._internal.jit_utils import JitTestCase -from typing import List, Optional -import sys -import torch -import unittest -from enum import Enum - -# Example jittable dataclass -@torch.jit.script -@dataclass(order=True) -class Point: - x: float - y: float - norm: Optional[torch.Tensor] = None - - def __post_init__(self): - self.norm = (torch.tensor(self.x) ** 2 + torch.tensor(self.y) ** 2) ** 0.5 - -class MixupScheme(Enum): - - INPUT = ["input"] - - MANIFOLD = [ - "input", - "before_fusion_projection", - "after_fusion_projection", - "after_classifier_projection", - ] - - -@dataclass -class MixupParams: - def __init__(self, alpha: float = 0.125, scheme: MixupScheme = MixupScheme.INPUT): - self.alpha = alpha - self.scheme = scheme - -class MixupScheme2(Enum): - A = 1 - B = 2 - - -@dataclass -class MixupParams2: - def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A): - self.alpha = alpha - self.scheme = scheme - -@dataclass -class MixupParams3: - def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A): - self.alpha = alpha - self.scheme = scheme - - -# Make sure the Meta internal tooling doesn't raise an overflow error -NonHugeFloats = st.floats(min_value=-1e4, max_value=1e4, allow_nan=False) - -class TestDataclasses(JitTestCase): - - @classmethod - def tearDownClass(cls): - torch._C._jit_clear_class_registry() - # We only support InitVar in JIT dataclasses for Python 3.8+ because it would be very hard - # to support without the `type` attribute on InitVar (see comment in _dataclass_impls.py). - @unittest.skipIf(sys.version_info < (3, 8), "InitVar not supported in Python < 3.8") - def test_init_vars(self): - @torch.jit.script - @dataclass(order=True) - class Point2: - x: float - y: float - norm_p: InitVar[int] = 2 - norm: Optional[torch.Tensor] = None - - def __post_init__(self, norm_p: int): - self.norm = (torch.tensor(self.x) ** norm_p + torch.tensor(self.y) ** norm_p) ** (1 / norm_p) - - def fn(x: float, y: float, p: int): - pt = Point2(x, y, p) - return pt.norm - - self.checkScript(fn, (1.0, 2.0, 3)) - - # Sort of tests both __post_init__ and optional fields - @settings(deadline=None) - @given(NonHugeFloats, NonHugeFloats) - def test__post_init__(self, x, y): - def fn(x: float, y: float): - pt = Point(x, y) - return pt.norm - - self.checkScript(fn, [x, y]) - - @settings(deadline=None) - @given(st.tuples(NonHugeFloats, NonHugeFloats), st.tuples(NonHugeFloats, NonHugeFloats)) - def test_comparators(self, pt1, pt2): - x1, y1 = pt1 - x2, y2 = pt2 - - def compare(x1: float, y1: float, x2: float, y2: float): - pt1 = Point(x1, y1) - pt2 = Point(x2, y2) - return ( - pt1 == pt2, - # pt1 != pt2, # TODO: Modify interpreter to auto-resolve (a != b) to not (a == b) when there's no __ne__ - pt1 < pt2, - pt1 <= pt2, - pt1 > pt2, - pt1 >= pt2, - ) - - self.checkScript(compare, [x1, y1, x2, y2]) - - def test_default_factories(self): - @dataclass - class Foo(object): - x: List[int] = field(default_factory=list) - - with self.assertRaises(NotImplementedError): - torch.jit.script(Foo) - def fn(): - foo = Foo() - return foo.x - - torch.jit.script(fn)() - - # The user should be able to write their own __eq__ implementation - # without us overriding it. - def test_custom__eq__(self): - @torch.jit.script - @dataclass - class CustomEq: - a: int - b: int - - def __eq__(self, other: 'CustomEq') -> bool: - return self.a == other.a # ignore the b field - - def fn(a: int, b1: int, b2: int): - pt1 = CustomEq(a, b1) - pt2 = CustomEq(a, b2) - return pt1 == pt2 - - self.checkScript(fn, [1, 2, 3]) - - def test_no_source(self): - with self.assertRaises(RuntimeError): - # uses list in Enum is not supported - torch.jit.script(MixupParams) - - torch.jit.script(MixupParams2) # don't throw - - - def test_use_unregistered_dataclass_raises(self): - - def f(a: MixupParams3): - return 0 - - with self.assertRaises(OSError): - torch.jit.script(f) diff --git a/test/jit/test_misc.py b/test/jit/test_misc.py index f720154bd7ce..bf3c3c3e71c1 100644 --- a/test/jit/test_misc.py +++ b/test/jit/test_misc.py @@ -6,6 +6,7 @@ from torch.testing._internal.jit_utils import JitTestCase, make_global from torch.testing import FileCheck from torch import jit from jit.test_module_interface import TestModuleInterface # noqa: F401 +import unittest import os import sys import torch @@ -46,6 +47,24 @@ class TestMisc(JitTestCase): self.assertEqual(out, out_script) self.assertEqual(captured, captured_script) + @unittest.skipIf(sys.version_info[:2] < (3, 7), "`dataclasses` module not present on < 3.7") + def test_dataclass_error(self): + from dataclasses import dataclass + + @dataclass + class NormalizationInfo(object): + mean: float = 0.0 + + def compute(self, total_rows): + return self.mean + + def fn(): + return NormalizationInfo(1, 2, 3, 4, 5) + + with self.assertRaisesRegex(OSError, "could not get source code"): + torch.jit.script(fn) + + def test_kwarg_support(self): with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "variable number of arguments"): class M(torch.nn.Module): diff --git a/test/test_jit.py b/test/test_jit.py index adbda1a575e9..c23cb17df62e 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -78,7 +78,6 @@ from jit.test_device_analysis import TestDeviceAnalysis # noqa: F401 from jit.test_dce import TestDCE # noqa: F401 from jit.test_sparse import TestSparse # noqa: F401 from jit.test_tensor_methods import TestTensorMethods # noqa: F401 -from jit.test_dataclasses import TestDataclasses # noqa: F401 # Torch from torch import Tensor diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index b3c61edaf83a..3c067d5c1c53 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -47,24 +47,6 @@ except ImportError: boolean_dispatched: 'weakref.WeakKeyDictionary[Callable, Dict[str, Callable]]' = weakref.WeakKeyDictionary() # noqa: T484 -FAKE_FILENAME_PREFIX = '__torch_jit_dataclass' - - -class SourceLoader: - - def __init__(self): - self.content = {} - - def cache(self, fn, source): - self.content[fn] = source - - def get_source(self, fn): - return self.content.get(fn) - - -loader = SourceLoader() - - def createResolutionCallbackFromEnv(lookup_base): """ Creates a resolution callback that will look up qualified names in an @@ -341,14 +323,6 @@ def get_type_hint_captures(fn): A Dict[str, Any] containing a mapping from the literal annotations used on fn to the Python objects they refer to. """ - # First, try to get the source of the function. We'll need to parse it to find the actual string names - # that were used to annotate the types, since inspect.signature() will only return the class object that - # the annotation refers to, not the string name. If we can't get the source, simply return an empty dict. - # This may happen in cases where the function is synthesized dynamically at runtime. - src = loader.get_source(fn) - if src is None: - src = inspect.getsource(fn) - # Gather a dictionary of parameter name -> type, skipping any parameters whose annotated # types are strings. These are only understood by TorchScript in the context of a type annotation # that refers to a class in its own definition, but trying to include a mapping for this in the result @@ -364,6 +338,8 @@ def get_type_hint_captures(fn): # Then, get the literal type annotations from the function declaration # by source inspection. This accounts for the case in which aliases are used # to annotate the arguments (e.g device_t = torch.device, and then d: device_t). + src = inspect.getsource(fn) + # frontend.py cannot be used here because it includes _jit_internal, so use ast instead. a = ast.parse(dedent(src)) if len(a.body) != 1 or not isinstance(a.body[0], ast.FunctionDef): @@ -929,7 +905,7 @@ def is_optional(ann): def is_union_as_optional(ann): ann_args = ann.__args__ - return len(ann_args) == 2 and (None in ann_args or type(None) in ann_args) + return len(ann_args) == 2 and None in ann_args return is_optional_as_optional(ann) or (is_union(ann) and is_union_as_optional(ann)) diff --git a/torch/_sources.py b/torch/_sources.py index a7a87481d3e5..e3d9064b38bc 100644 --- a/torch/_sources.py +++ b/torch/_sources.py @@ -48,20 +48,13 @@ def normalize_source_lines(sourcelines: List[str]) -> List[str]: return text[text.startswith(prefix) and len(prefix):] # Find the line and line number containing the function definition - idx = None for i, l in enumerate(sourcelines): if l.lstrip().startswith("def"): idx = i break - - # This will happen when the function is a lambda- we won't find "def" anywhere in the source - # lines in that case. Currently trying to JIT compile a lambda will throw an error up in - # `parse_def()`, but we might want to handle this case in the future. - if idx is None: - return sourcelines + fn_def = sourcelines[idx] # Get a string representing the amount of leading whitespace - fn_def = sourcelines[idx] whitespace = fn_def.split("def")[0] # Add this leading whitespace to all lines before and after the `def` diff --git a/torch/csrc/jit/api/compilation_unit.h b/torch/csrc/jit/api/compilation_unit.h index 6313923e2deb..658307bf4506 100644 --- a/torch/csrc/jit/api/compilation_unit.h +++ b/torch/csrc/jit/api/compilation_unit.h @@ -227,11 +227,10 @@ struct TORCH_API CompilationUnit { // Tombstone the method in the compilation unit. // Don't erase because the dict_ auto it = dict_.find(method->qualname()); - if (it != dict_.end()) { - functions_[it->second] = nullptr; - // Erase in our big lookup table - dict_.erase(it); - } + TORCH_INTERNAL_ASSERT(it != dict_.end()); + functions_[it->second] = nullptr; + // Erase in our big lookup table + dict_.erase(it); } // Classes can have multiple pointers to the same hook, // need to make sure to not delete it twice diff --git a/torch/jit/_dataclass_impls.py b/torch/jit/_dataclass_impls.py deleted file mode 100644 index c58198d240a9..000000000000 --- a/torch/jit/_dataclass_impls.py +++ /dev/null @@ -1,152 +0,0 @@ -# Functions for synthesizing magic methods for JIT-compiled dataclasses -import os -from functools import partial -from torch._jit_internal import is_optional, FAKE_FILENAME_PREFIX -from torch._sources import ParsedDef, SourceContext -from typing import Callable, Dict, List -import ast -import dataclasses -import inspect -import sys - -def _get_fake_filename(cls, method_name): - return os.path.join(FAKE_FILENAME_PREFIX, cls.__name__, method_name) - - -def compose_fn(cls, name: str, body_lines: List[str], signature: str) -> ParsedDef: - body = '\n'.join(f' {b}' for b in body_lines) - decl = f'def {name}{signature}:\n{body}' - - # Parse the function declaration - try: - py_ast = ast.parse(decl) - except SyntaxError: - # This should only happen if there's some unforeseeable change - # in the dataclasses module that makes our synthesized code fail - raise RuntimeError( - f"TorchScript failed to synthesize dataclass method '{name}' for class '{cls.__name__}'. " - "Please file a bug report at " - ) - fake_filename = _get_fake_filename(cls, name) - # Parse the function - return ParsedDef( - py_ast, - ctx=SourceContext( - source=decl, - filename=fake_filename, - file_lineno=0, - leading_whitespace_len=0 - ), - source=decl, - filename=fake_filename, - file_lineno=0 - ) - - -def synthesize__init__(cls) -> ParsedDef: - # Supporting default factories in the way that people expect would sort of require us to - # allow compiling lambda functions, which is not currently supported. - if any(field.default_factory is not dataclasses.MISSING for field in dataclasses.fields(cls)): - raise NotImplementedError("Default factory initializers are not supported in TorchScript dataclasses") - - # Simply read off the generated __init__ signature from CPython's implementation. It'll be - # almost correct except for InitVar annotations, which we need to handle specially. - signature = inspect.signature(cls.__init__) - - # Handle InitVars if needed (only works on Python 3.8+, when a `type` attribute was added to InitVar); - # see CPython commit here https://github.com/python/cpython/commit/01ee12ba35a333e8a6a25c4153c4a21838e9585c - init_vars: List[str] = [] - if sys.version_info >= (3, 8): - params = [] - for name, param in signature.parameters.items(): - ann = param.annotation - - if isinstance(ann, dataclasses.InitVar): - # The TorchScript interpreter can't handle InitVar annotations, so we unwrap the underlying type here - init_vars.append(name) - params.append(param.replace(annotation=ann.type)) # type: ignore[attr-defined] - else: - params.append(param) - - signature = signature.replace(parameters=params) - - body = [ - # Assign all attributes to self - f'self.{field.name} = {field.name}' - for field in dataclasses.fields(cls) - if field.init and field.name not in init_vars - ] - # Call user's impl of __post_init__ if it exists - if hasattr(cls, '__post_init__'): - body.append('self.__post_init__(' + ', '.join(init_vars) + ')') - - return compose_fn(cls, '__init__', body or ['pass'], signature=str(signature)) - -# This is a placeholder at the moment since the TorchScript interpreter doesn't call __repr__ -def synthesize__repr__(cls) -> ParsedDef: - return compose_fn( - cls, '__repr__', - [f"return '{cls.__name__}(" + ", ".join([ - f"{field.name}=self.{field.name}" - for field in dataclasses.fields(cls) if field.repr - ]) + ")'"], - signature='(self) -> str' - ) - -def synthesize__hash__(cls) -> ParsedDef: - return compose_fn( - cls, '__hash__', - [ - # This is just a placeholder to prevent compilation from failing; this won't even get called at - # all right now because the TorchScript interpreter doesn't call custom __hash__ implementations - "raise NotImplementedError('__hash__ is not supported for dataclasses in TorchScript')" - ], - signature='(self) -> int' - ) - -# Implementation for __eq__ and __ne__ -def synthesize_equality(cls, name: str, converse: str) -> ParsedDef: - return synthesize_comparison(cls, name, allow_eq=True, raise_on_none=False, inner=[ - f"if val1 {converse} val2: return False" - ]) - -def synthesize_inequality(cls, name: str, op: str, allow_eq: bool) -> ParsedDef: - return synthesize_comparison(cls, name, allow_eq, raise_on_none=True, inner=[ - f"if val1 {op} val2: return True", - f"elif val2 {op} val1: return False", - ]) - -def synthesize_comparison(cls, name: str, allow_eq: bool, raise_on_none: bool, inner: List[str]) -> ParsedDef: - body = [] - for field in dataclasses.fields(cls): - if not field.compare: - continue - - body.extend([ - f"val1 = self.{field.name}", - f"val2 = other.{field.name}", - ]) - body.extend( - inner if not is_optional(field.type) else [ - # Type refinement for optional fields; we need this to avoid type errors from the interpreter - "if val1 is not None and val2 is not None:", - *[' ' + line for line in inner], - "elif (val1 is None) != (val2 is None):", - f" raise TypeError('Cannot compare {cls.__name__} with None')" if raise_on_none else " return False" - ] - ) - - body.append(f"return {allow_eq}") - return compose_fn(cls, name, body, signature=f'(self, other: {cls.__name__}) -> bool') - -DATACLASS_MAGIC_METHODS: Dict[str, Callable] = { - "__init__": synthesize__init__, - "__repr__": synthesize__repr__, - "__hash__": synthesize__hash__, - "__eq__": partial(synthesize_equality, name="__eq__", converse="!="), - "__ne__": partial(synthesize_equality, name="__ne__", converse="=="), - "__lt__": partial(synthesize_inequality, name="__lt__", op="<", allow_eq=False), - "__le__": partial(synthesize_inequality, name="__le__", op="<", allow_eq=True), - "__gt__": partial(synthesize_inequality, name="__gt__", op=">", allow_eq=False), - "__ge__": partial(synthesize_inequality, name="__ge__", op=">", allow_eq=True), -} diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py index a4a36ce36a5e..45d708c6e68f 100644 --- a/torch/jit/annotations.py +++ b/torch/jit/annotations.py @@ -284,10 +284,7 @@ def get_enum_value_type(e: Type[enum.Enum], loc): # Even though Python supports this case, we chose to not implement it to # avoid overcomplicate logic here for a rare use case. Please report a # feature request if you find it necessary. - res = torch._C.unify_type_list(ir_types) - if not res: - return AnyType.get() - return res + return torch._C.unify_type_list(ir_types) def is_tensor(ann): if issubclass(ann, torch.Tensor): diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index 0b21844f43c3..fbbe962d40b7 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -1,7 +1,6 @@ import torch import sys import ast -import dataclasses import inspect import string from collections import namedtuple @@ -18,11 +17,9 @@ from torch._C._jit_tree_views import ( SliceExpr, Subscript, TernaryIf, With, WithItem, Property, DictComp, ) -from torch._sources import get_source_lines_and_file, ParsedDef, parse_def, make_source_context -from torch.jit._dataclass_impls import DATACLASS_MAGIC_METHODS +from torch._sources import get_source_lines_and_file, parse_def, make_source_context from torch.jit._monkeytype_config import monkeytype_trace, get_qualified_name from torch._jit_internal import should_drop, is_static_fn, FunctionModifiers # noqa: F401 -from torch import _jit_internal import torch.jit.annotations _IS_ASTUNPARSE_INSTALLED = False @@ -198,48 +195,24 @@ def get_jit_class_def(cls, self_name): def is_classmethod(fn): return inspect.ismethod(fn) and getattr(fn, "__self__", None) == cls - # Get and parse the source code for this class - sourcelines, file_lineno, filename = get_source_lines_and_file(cls, torch._C.ErrorReport.call_stack()) - source = ''.join(sourcelines) + methods = [get_jit_def(obj, + name, + self_name=self_name, + is_classmethod=is_classmethod(obj)) for (name, obj) in methods] - dedent_src = dedent(source) - py_ast = ast.parse(dedent_src) - - class_ast = py_ast.body[0] - assert isinstance(class_ast, ast.ClassDef) - - # Special case for dataclasses. In general we need access to the source code for - # an object in order to JIT compile it. But the dataclasses module dynamically synthesizes - # magic methods for classes, and we can't get the source code for these methods. As a - # workaround, we synthesize TorchScript-friendly implementations ourselves. - if dataclasses.is_dataclass(cls): - # Detect whether the user manually implemented any of the magic methods. If they did, - # we don't want to synthesize/override them. - overrides = { - method.name - for method in class_ast.body - if isinstance(method, ast.FunctionDef) and method.name in DATACLASS_MAGIC_METHODS - } - for i, (name, _) in enumerate(methods): - # Is this a magic method we can synthesize? - synthesizer_fn = DATACLASS_MAGIC_METHODS.get(name) - if synthesizer_fn and name not in overrides: - parsed_def = synthesizer_fn(cls) - methods[i] = name, parsed_def - func = getattr(cls, name) - _jit_internal.loader.cache(func, parsed_def.source) - - method_defs = [ - get_jit_def(obj, name, self_name=self_name, is_classmethod=is_classmethod(obj)) - for (name, obj) in methods - ] properties = get_class_properties(cls, self_name) + sourcelines, file_lineno, filename = get_source_lines_and_file(cls, torch._C.ErrorReport.call_stack()) + source = ''.join(sourcelines) + dedent_src = dedent(source) + py_ast = ast.parse(dedent_src) leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0]) ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, False) + class_ast = py_ast.body[0] + assert isinstance(class_ast, ast.ClassDef) assigns = get_class_assigns(ctx, class_ast) - return build_class_def(ctx, class_ast, method_defs, properties, self_name, assigns) + return build_class_def(ctx, class_ast, methods, properties, self_name, assigns) def get_jit_def(fn, def_name, self_name=None, is_classmethod=False): @@ -247,7 +220,7 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False): Build a JIT AST (TreeView) from the given function. Args: - fn: A function object to compile or a pre-parsed ParsedDef object + fn: A function object to compile def_name: The name to give to the resulting AST object. This is not always the same as `fn.__name__`, for example: def _forward(self): @@ -257,7 +230,7 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False): but we want the result AST to have the name "forward". self_name: If this function is a method, what the type name of `self` is. """ - parsed_def = parse_def(fn) if not isinstance(fn, ParsedDef) else fn + parsed_def = parse_def(fn) type_line = torch.jit.annotations.get_type_line(parsed_def.source) fn_def = parsed_def.ast.body[0] @@ -284,7 +257,7 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False): # for the arguments from type_trace_db type_trace_db = torch.jit._script._get_type_trace_db() pdt_arg_types = None - if monkeytype_trace and not isinstance(fn, ParsedDef): + if monkeytype_trace: qualname = get_qualified_name(fn) pdt_arg_types = type_trace_db.get_args_types(qualname)