Revert "Reland "[pytorch][PR] Support dataclasses in TorchScript" take 2 (#74353)"

This reverts commit 5547741960a01fbd3a97d1ddd5ae9b43d8f1169c.

Reverted https://github.com/pytorch/pytorch/pull/74889 on behalf of https://github.com/malfet
This commit is contained in:
Nikita Shulga
2022-03-31 04:16:45 -07:00
parent 0b845bb645
commit fa1a41ca71
9 changed files with 43 additions and 404 deletions

View File

@ -1,165 +0,0 @@
# Owner(s): ["oncall: jit"]
# flake8: noqa
from dataclasses import dataclass, field, InitVar
from hypothesis import given, settings, strategies as st
from torch.testing._internal.jit_utils import JitTestCase
from typing import List, Optional
import sys
import torch
import unittest
from enum import Enum
# Example jittable dataclass
@torch.jit.script
@dataclass(order=True)
class Point:
x: float
y: float
norm: Optional[torch.Tensor] = None
def __post_init__(self):
self.norm = (torch.tensor(self.x) ** 2 + torch.tensor(self.y) ** 2) ** 0.5
class MixupScheme(Enum):
INPUT = ["input"]
MANIFOLD = [
"input",
"before_fusion_projection",
"after_fusion_projection",
"after_classifier_projection",
]
@dataclass
class MixupParams:
def __init__(self, alpha: float = 0.125, scheme: MixupScheme = MixupScheme.INPUT):
self.alpha = alpha
self.scheme = scheme
class MixupScheme2(Enum):
A = 1
B = 2
@dataclass
class MixupParams2:
def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A):
self.alpha = alpha
self.scheme = scheme
@dataclass
class MixupParams3:
def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A):
self.alpha = alpha
self.scheme = scheme
# Make sure the Meta internal tooling doesn't raise an overflow error
NonHugeFloats = st.floats(min_value=-1e4, max_value=1e4, allow_nan=False)
class TestDataclasses(JitTestCase):
@classmethod
def tearDownClass(cls):
torch._C._jit_clear_class_registry()
# We only support InitVar in JIT dataclasses for Python 3.8+ because it would be very hard
# to support without the `type` attribute on InitVar (see comment in _dataclass_impls.py).
@unittest.skipIf(sys.version_info < (3, 8), "InitVar not supported in Python < 3.8")
def test_init_vars(self):
@torch.jit.script
@dataclass(order=True)
class Point2:
x: float
y: float
norm_p: InitVar[int] = 2
norm: Optional[torch.Tensor] = None
def __post_init__(self, norm_p: int):
self.norm = (torch.tensor(self.x) ** norm_p + torch.tensor(self.y) ** norm_p) ** (1 / norm_p)
def fn(x: float, y: float, p: int):
pt = Point2(x, y, p)
return pt.norm
self.checkScript(fn, (1.0, 2.0, 3))
# Sort of tests both __post_init__ and optional fields
@settings(deadline=None)
@given(NonHugeFloats, NonHugeFloats)
def test__post_init__(self, x, y):
def fn(x: float, y: float):
pt = Point(x, y)
return pt.norm
self.checkScript(fn, [x, y])
@settings(deadline=None)
@given(st.tuples(NonHugeFloats, NonHugeFloats), st.tuples(NonHugeFloats, NonHugeFloats))
def test_comparators(self, pt1, pt2):
x1, y1 = pt1
x2, y2 = pt2
def compare(x1: float, y1: float, x2: float, y2: float):
pt1 = Point(x1, y1)
pt2 = Point(x2, y2)
return (
pt1 == pt2,
# pt1 != pt2, # TODO: Modify interpreter to auto-resolve (a != b) to not (a == b) when there's no __ne__
pt1 < pt2,
pt1 <= pt2,
pt1 > pt2,
pt1 >= pt2,
)
self.checkScript(compare, [x1, y1, x2, y2])
def test_default_factories(self):
@dataclass
class Foo(object):
x: List[int] = field(default_factory=list)
with self.assertRaises(NotImplementedError):
torch.jit.script(Foo)
def fn():
foo = Foo()
return foo.x
torch.jit.script(fn)()
# The user should be able to write their own __eq__ implementation
# without us overriding it.
def test_custom__eq__(self):
@torch.jit.script
@dataclass
class CustomEq:
a: int
b: int
def __eq__(self, other: 'CustomEq') -> bool:
return self.a == other.a # ignore the b field
def fn(a: int, b1: int, b2: int):
pt1 = CustomEq(a, b1)
pt2 = CustomEq(a, b2)
return pt1 == pt2
self.checkScript(fn, [1, 2, 3])
def test_no_source(self):
with self.assertRaises(RuntimeError):
# uses list in Enum is not supported
torch.jit.script(MixupParams)
torch.jit.script(MixupParams2) # don't throw
def test_use_unregistered_dataclass_raises(self):
def f(a: MixupParams3):
return 0
with self.assertRaises(OSError):
torch.jit.script(f)

View File

@ -6,6 +6,7 @@ from torch.testing._internal.jit_utils import JitTestCase, make_global
from torch.testing import FileCheck
from torch import jit
from jit.test_module_interface import TestModuleInterface # noqa: F401
import unittest
import os
import sys
import torch
@ -46,6 +47,24 @@ class TestMisc(JitTestCase):
self.assertEqual(out, out_script)
self.assertEqual(captured, captured_script)
@unittest.skipIf(sys.version_info[:2] < (3, 7), "`dataclasses` module not present on < 3.7")
def test_dataclass_error(self):
from dataclasses import dataclass
@dataclass
class NormalizationInfo(object):
mean: float = 0.0
def compute(self, total_rows):
return self.mean
def fn():
return NormalizationInfo(1, 2, 3, 4, 5)
with self.assertRaisesRegex(OSError, "could not get source code"):
torch.jit.script(fn)
def test_kwarg_support(self):
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "variable number of arguments"):
class M(torch.nn.Module):

View File

@ -78,7 +78,6 @@ from jit.test_device_analysis import TestDeviceAnalysis # noqa: F401
from jit.test_dce import TestDCE # noqa: F401
from jit.test_sparse import TestSparse # noqa: F401
from jit.test_tensor_methods import TestTensorMethods # noqa: F401
from jit.test_dataclasses import TestDataclasses # noqa: F401
# Torch
from torch import Tensor

View File

@ -47,24 +47,6 @@ except ImportError:
boolean_dispatched: 'weakref.WeakKeyDictionary[Callable, Dict[str, Callable]]' = weakref.WeakKeyDictionary() # noqa: T484
FAKE_FILENAME_PREFIX = '__torch_jit_dataclass'
class SourceLoader:
def __init__(self):
self.content = {}
def cache(self, fn, source):
self.content[fn] = source
def get_source(self, fn):
return self.content.get(fn)
loader = SourceLoader()
def createResolutionCallbackFromEnv(lookup_base):
"""
Creates a resolution callback that will look up qualified names in an
@ -341,14 +323,6 @@ def get_type_hint_captures(fn):
A Dict[str, Any] containing a mapping from the literal annotations used on
fn to the Python objects they refer to.
"""
# First, try to get the source of the function. We'll need to parse it to find the actual string names
# that were used to annotate the types, since inspect.signature() will only return the class object that
# the annotation refers to, not the string name. If we can't get the source, simply return an empty dict.
# This may happen in cases where the function is synthesized dynamically at runtime.
src = loader.get_source(fn)
if src is None:
src = inspect.getsource(fn)
# Gather a dictionary of parameter name -> type, skipping any parameters whose annotated
# types are strings. These are only understood by TorchScript in the context of a type annotation
# that refers to a class in its own definition, but trying to include a mapping for this in the result
@ -364,6 +338,8 @@ def get_type_hint_captures(fn):
# Then, get the literal type annotations from the function declaration
# by source inspection. This accounts for the case in which aliases are used
# to annotate the arguments (e.g device_t = torch.device, and then d: device_t).
src = inspect.getsource(fn)
# frontend.py cannot be used here because it includes _jit_internal, so use ast instead.
a = ast.parse(dedent(src))
if len(a.body) != 1 or not isinstance(a.body[0], ast.FunctionDef):
@ -929,7 +905,7 @@ def is_optional(ann):
def is_union_as_optional(ann):
ann_args = ann.__args__
return len(ann_args) == 2 and (None in ann_args or type(None) in ann_args)
return len(ann_args) == 2 and None in ann_args
return is_optional_as_optional(ann) or (is_union(ann) and is_union_as_optional(ann))

View File

@ -48,20 +48,13 @@ def normalize_source_lines(sourcelines: List[str]) -> List[str]:
return text[text.startswith(prefix) and len(prefix):]
# Find the line and line number containing the function definition
idx = None
for i, l in enumerate(sourcelines):
if l.lstrip().startswith("def"):
idx = i
break
# This will happen when the function is a lambda- we won't find "def" anywhere in the source
# lines in that case. Currently trying to JIT compile a lambda will throw an error up in
# `parse_def()`, but we might want to handle this case in the future.
if idx is None:
return sourcelines
fn_def = sourcelines[idx]
# Get a string representing the amount of leading whitespace
fn_def = sourcelines[idx]
whitespace = fn_def.split("def")[0]
# Add this leading whitespace to all lines before and after the `def`

View File

@ -227,11 +227,10 @@ struct TORCH_API CompilationUnit {
// Tombstone the method in the compilation unit.
// Don't erase because the dict_
auto it = dict_.find(method->qualname());
if (it != dict_.end()) {
functions_[it->second] = nullptr;
// Erase in our big lookup table
dict_.erase(it);
}
TORCH_INTERNAL_ASSERT(it != dict_.end());
functions_[it->second] = nullptr;
// Erase in our big lookup table
dict_.erase(it);
}
// Classes can have multiple pointers to the same hook,
// need to make sure to not delete it twice

View File

@ -1,152 +0,0 @@
# Functions for synthesizing magic methods for JIT-compiled dataclasses
import os
from functools import partial
from torch._jit_internal import is_optional, FAKE_FILENAME_PREFIX
from torch._sources import ParsedDef, SourceContext
from typing import Callable, Dict, List
import ast
import dataclasses
import inspect
import sys
def _get_fake_filename(cls, method_name):
return os.path.join(FAKE_FILENAME_PREFIX, cls.__name__, method_name)
def compose_fn(cls, name: str, body_lines: List[str], signature: str) -> ParsedDef:
body = '\n'.join(f' {b}' for b in body_lines)
decl = f'def {name}{signature}:\n{body}'
# Parse the function declaration
try:
py_ast = ast.parse(decl)
except SyntaxError:
# This should only happen if there's some unforeseeable change
# in the dataclasses module that makes our synthesized code fail
raise RuntimeError(
f"TorchScript failed to synthesize dataclass method '{name}' for class '{cls.__name__}'. "
"Please file a bug report at <https://github.com/pytorch/pytorch/issues>"
)
fake_filename = _get_fake_filename(cls, name)
# Parse the function
return ParsedDef(
py_ast,
ctx=SourceContext(
source=decl,
filename=fake_filename,
file_lineno=0,
leading_whitespace_len=0
),
source=decl,
filename=fake_filename,
file_lineno=0
)
def synthesize__init__(cls) -> ParsedDef:
# Supporting default factories in the way that people expect would sort of require us to
# allow compiling lambda functions, which is not currently supported.
if any(field.default_factory is not dataclasses.MISSING for field in dataclasses.fields(cls)):
raise NotImplementedError("Default factory initializers are not supported in TorchScript dataclasses")
# Simply read off the generated __init__ signature from CPython's implementation. It'll be
# almost correct except for InitVar annotations, which we need to handle specially.
signature = inspect.signature(cls.__init__)
# Handle InitVars if needed (only works on Python 3.8+, when a `type` attribute was added to InitVar);
# see CPython commit here https://github.com/python/cpython/commit/01ee12ba35a333e8a6a25c4153c4a21838e9585c
init_vars: List[str] = []
if sys.version_info >= (3, 8):
params = []
for name, param in signature.parameters.items():
ann = param.annotation
if isinstance(ann, dataclasses.InitVar):
# The TorchScript interpreter can't handle InitVar annotations, so we unwrap the underlying type here
init_vars.append(name)
params.append(param.replace(annotation=ann.type)) # type: ignore[attr-defined]
else:
params.append(param)
signature = signature.replace(parameters=params)
body = [
# Assign all attributes to self
f'self.{field.name} = {field.name}'
for field in dataclasses.fields(cls)
if field.init and field.name not in init_vars
]
# Call user's impl of __post_init__ if it exists
if hasattr(cls, '__post_init__'):
body.append('self.__post_init__(' + ', '.join(init_vars) + ')')
return compose_fn(cls, '__init__', body or ['pass'], signature=str(signature))
# This is a placeholder at the moment since the TorchScript interpreter doesn't call __repr__
def synthesize__repr__(cls) -> ParsedDef:
return compose_fn(
cls, '__repr__',
[f"return '{cls.__name__}(" + ", ".join([
f"{field.name}=self.{field.name}"
for field in dataclasses.fields(cls) if field.repr
]) + ")'"],
signature='(self) -> str'
)
def synthesize__hash__(cls) -> ParsedDef:
return compose_fn(
cls, '__hash__',
[
# This is just a placeholder to prevent compilation from failing; this won't even get called at
# all right now because the TorchScript interpreter doesn't call custom __hash__ implementations
"raise NotImplementedError('__hash__ is not supported for dataclasses in TorchScript')"
],
signature='(self) -> int'
)
# Implementation for __eq__ and __ne__
def synthesize_equality(cls, name: str, converse: str) -> ParsedDef:
return synthesize_comparison(cls, name, allow_eq=True, raise_on_none=False, inner=[
f"if val1 {converse} val2: return False"
])
def synthesize_inequality(cls, name: str, op: str, allow_eq: bool) -> ParsedDef:
return synthesize_comparison(cls, name, allow_eq, raise_on_none=True, inner=[
f"if val1 {op} val2: return True",
f"elif val2 {op} val1: return False",
])
def synthesize_comparison(cls, name: str, allow_eq: bool, raise_on_none: bool, inner: List[str]) -> ParsedDef:
body = []
for field in dataclasses.fields(cls):
if not field.compare:
continue
body.extend([
f"val1 = self.{field.name}",
f"val2 = other.{field.name}",
])
body.extend(
inner if not is_optional(field.type) else [
# Type refinement for optional fields; we need this to avoid type errors from the interpreter
"if val1 is not None and val2 is not None:",
*[' ' + line for line in inner],
"elif (val1 is None) != (val2 is None):",
f" raise TypeError('Cannot compare {cls.__name__} with None')" if raise_on_none else " return False"
]
)
body.append(f"return {allow_eq}")
return compose_fn(cls, name, body, signature=f'(self, other: {cls.__name__}) -> bool')
DATACLASS_MAGIC_METHODS: Dict[str, Callable] = {
"__init__": synthesize__init__,
"__repr__": synthesize__repr__,
"__hash__": synthesize__hash__,
"__eq__": partial(synthesize_equality, name="__eq__", converse="!="),
"__ne__": partial(synthesize_equality, name="__ne__", converse="=="),
"__lt__": partial(synthesize_inequality, name="__lt__", op="<", allow_eq=False),
"__le__": partial(synthesize_inequality, name="__le__", op="<", allow_eq=True),
"__gt__": partial(synthesize_inequality, name="__gt__", op=">", allow_eq=False),
"__ge__": partial(synthesize_inequality, name="__ge__", op=">", allow_eq=True),
}

View File

@ -284,10 +284,7 @@ def get_enum_value_type(e: Type[enum.Enum], loc):
# Even though Python supports this case, we chose to not implement it to
# avoid overcomplicate logic here for a rare use case. Please report a
# feature request if you find it necessary.
res = torch._C.unify_type_list(ir_types)
if not res:
return AnyType.get()
return res
return torch._C.unify_type_list(ir_types)
def is_tensor(ann):
if issubclass(ann, torch.Tensor):

View File

@ -1,7 +1,6 @@
import torch
import sys
import ast
import dataclasses
import inspect
import string
from collections import namedtuple
@ -18,11 +17,9 @@ from torch._C._jit_tree_views import (
SliceExpr, Subscript, TernaryIf, With, WithItem, Property,
DictComp,
)
from torch._sources import get_source_lines_and_file, ParsedDef, parse_def, make_source_context
from torch.jit._dataclass_impls import DATACLASS_MAGIC_METHODS
from torch._sources import get_source_lines_and_file, parse_def, make_source_context
from torch.jit._monkeytype_config import monkeytype_trace, get_qualified_name
from torch._jit_internal import should_drop, is_static_fn, FunctionModifiers # noqa: F401
from torch import _jit_internal
import torch.jit.annotations
_IS_ASTUNPARSE_INSTALLED = False
@ -198,48 +195,24 @@ def get_jit_class_def(cls, self_name):
def is_classmethod(fn):
return inspect.ismethod(fn) and getattr(fn, "__self__", None) == cls
# Get and parse the source code for this class
sourcelines, file_lineno, filename = get_source_lines_and_file(cls, torch._C.ErrorReport.call_stack())
source = ''.join(sourcelines)
methods = [get_jit_def(obj,
name,
self_name=self_name,
is_classmethod=is_classmethod(obj)) for (name, obj) in methods]
dedent_src = dedent(source)
py_ast = ast.parse(dedent_src)
class_ast = py_ast.body[0]
assert isinstance(class_ast, ast.ClassDef)
# Special case for dataclasses. In general we need access to the source code for
# an object in order to JIT compile it. But the dataclasses module dynamically synthesizes
# magic methods for classes, and we can't get the source code for these methods. As a
# workaround, we synthesize TorchScript-friendly implementations ourselves.
if dataclasses.is_dataclass(cls):
# Detect whether the user manually implemented any of the magic methods. If they did,
# we don't want to synthesize/override them.
overrides = {
method.name
for method in class_ast.body
if isinstance(method, ast.FunctionDef) and method.name in DATACLASS_MAGIC_METHODS
}
for i, (name, _) in enumerate(methods):
# Is this a magic method we can synthesize?
synthesizer_fn = DATACLASS_MAGIC_METHODS.get(name)
if synthesizer_fn and name not in overrides:
parsed_def = synthesizer_fn(cls)
methods[i] = name, parsed_def
func = getattr(cls, name)
_jit_internal.loader.cache(func, parsed_def.source)
method_defs = [
get_jit_def(obj, name, self_name=self_name, is_classmethod=is_classmethod(obj))
for (name, obj) in methods
]
properties = get_class_properties(cls, self_name)
sourcelines, file_lineno, filename = get_source_lines_and_file(cls, torch._C.ErrorReport.call_stack())
source = ''.join(sourcelines)
dedent_src = dedent(source)
py_ast = ast.parse(dedent_src)
leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0])
ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, False)
class_ast = py_ast.body[0]
assert isinstance(class_ast, ast.ClassDef)
assigns = get_class_assigns(ctx, class_ast)
return build_class_def(ctx, class_ast, method_defs, properties, self_name, assigns)
return build_class_def(ctx, class_ast, methods, properties, self_name, assigns)
def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
@ -247,7 +220,7 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
Build a JIT AST (TreeView) from the given function.
Args:
fn: A function object to compile or a pre-parsed ParsedDef object
fn: A function object to compile
def_name: The name to give to the resulting AST object. This is not
always the same as `fn.__name__`, for example:
def _forward(self):
@ -257,7 +230,7 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
but we want the result AST to have the name "forward".
self_name: If this function is a method, what the type name of `self` is.
"""
parsed_def = parse_def(fn) if not isinstance(fn, ParsedDef) else fn
parsed_def = parse_def(fn)
type_line = torch.jit.annotations.get_type_line(parsed_def.source)
fn_def = parsed_def.ast.body[0]
@ -284,7 +257,7 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
# for the arguments from type_trace_db
type_trace_db = torch.jit._script._get_type_trace_db()
pdt_arg_types = None
if monkeytype_trace and not isinstance(fn, ParsedDef):
if monkeytype_trace:
qualname = get_qualified_name(fn)
pdt_arg_types = type_trace_db.get_args_types(qualname)