mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Reland "[pytorch][PR] Support dataclasses in TorchScript" take 2 (#74353)"
This reverts commit 5547741960a01fbd3a97d1ddd5ae9b43d8f1169c. Reverted https://github.com/pytorch/pytorch/pull/74889 on behalf of https://github.com/malfet
This commit is contained in:
@ -1,165 +0,0 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
# flake8: noqa
|
||||
|
||||
from dataclasses import dataclass, field, InitVar
|
||||
from hypothesis import given, settings, strategies as st
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from typing import List, Optional
|
||||
import sys
|
||||
import torch
|
||||
import unittest
|
||||
from enum import Enum
|
||||
|
||||
# Example jittable dataclass
|
||||
@torch.jit.script
|
||||
@dataclass(order=True)
|
||||
class Point:
|
||||
x: float
|
||||
y: float
|
||||
norm: Optional[torch.Tensor] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.norm = (torch.tensor(self.x) ** 2 + torch.tensor(self.y) ** 2) ** 0.5
|
||||
|
||||
class MixupScheme(Enum):
|
||||
|
||||
INPUT = ["input"]
|
||||
|
||||
MANIFOLD = [
|
||||
"input",
|
||||
"before_fusion_projection",
|
||||
"after_fusion_projection",
|
||||
"after_classifier_projection",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MixupParams:
|
||||
def __init__(self, alpha: float = 0.125, scheme: MixupScheme = MixupScheme.INPUT):
|
||||
self.alpha = alpha
|
||||
self.scheme = scheme
|
||||
|
||||
class MixupScheme2(Enum):
|
||||
A = 1
|
||||
B = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class MixupParams2:
|
||||
def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A):
|
||||
self.alpha = alpha
|
||||
self.scheme = scheme
|
||||
|
||||
@dataclass
|
||||
class MixupParams3:
|
||||
def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A):
|
||||
self.alpha = alpha
|
||||
self.scheme = scheme
|
||||
|
||||
|
||||
# Make sure the Meta internal tooling doesn't raise an overflow error
|
||||
NonHugeFloats = st.floats(min_value=-1e4, max_value=1e4, allow_nan=False)
|
||||
|
||||
class TestDataclasses(JitTestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
torch._C._jit_clear_class_registry()
|
||||
# We only support InitVar in JIT dataclasses for Python 3.8+ because it would be very hard
|
||||
# to support without the `type` attribute on InitVar (see comment in _dataclass_impls.py).
|
||||
@unittest.skipIf(sys.version_info < (3, 8), "InitVar not supported in Python < 3.8")
|
||||
def test_init_vars(self):
|
||||
@torch.jit.script
|
||||
@dataclass(order=True)
|
||||
class Point2:
|
||||
x: float
|
||||
y: float
|
||||
norm_p: InitVar[int] = 2
|
||||
norm: Optional[torch.Tensor] = None
|
||||
|
||||
def __post_init__(self, norm_p: int):
|
||||
self.norm = (torch.tensor(self.x) ** norm_p + torch.tensor(self.y) ** norm_p) ** (1 / norm_p)
|
||||
|
||||
def fn(x: float, y: float, p: int):
|
||||
pt = Point2(x, y, p)
|
||||
return pt.norm
|
||||
|
||||
self.checkScript(fn, (1.0, 2.0, 3))
|
||||
|
||||
# Sort of tests both __post_init__ and optional fields
|
||||
@settings(deadline=None)
|
||||
@given(NonHugeFloats, NonHugeFloats)
|
||||
def test__post_init__(self, x, y):
|
||||
def fn(x: float, y: float):
|
||||
pt = Point(x, y)
|
||||
return pt.norm
|
||||
|
||||
self.checkScript(fn, [x, y])
|
||||
|
||||
@settings(deadline=None)
|
||||
@given(st.tuples(NonHugeFloats, NonHugeFloats), st.tuples(NonHugeFloats, NonHugeFloats))
|
||||
def test_comparators(self, pt1, pt2):
|
||||
x1, y1 = pt1
|
||||
x2, y2 = pt2
|
||||
|
||||
def compare(x1: float, y1: float, x2: float, y2: float):
|
||||
pt1 = Point(x1, y1)
|
||||
pt2 = Point(x2, y2)
|
||||
return (
|
||||
pt1 == pt2,
|
||||
# pt1 != pt2, # TODO: Modify interpreter to auto-resolve (a != b) to not (a == b) when there's no __ne__
|
||||
pt1 < pt2,
|
||||
pt1 <= pt2,
|
||||
pt1 > pt2,
|
||||
pt1 >= pt2,
|
||||
)
|
||||
|
||||
self.checkScript(compare, [x1, y1, x2, y2])
|
||||
|
||||
def test_default_factories(self):
|
||||
@dataclass
|
||||
class Foo(object):
|
||||
x: List[int] = field(default_factory=list)
|
||||
|
||||
with self.assertRaises(NotImplementedError):
|
||||
torch.jit.script(Foo)
|
||||
def fn():
|
||||
foo = Foo()
|
||||
return foo.x
|
||||
|
||||
torch.jit.script(fn)()
|
||||
|
||||
# The user should be able to write their own __eq__ implementation
|
||||
# without us overriding it.
|
||||
def test_custom__eq__(self):
|
||||
@torch.jit.script
|
||||
@dataclass
|
||||
class CustomEq:
|
||||
a: int
|
||||
b: int
|
||||
|
||||
def __eq__(self, other: 'CustomEq') -> bool:
|
||||
return self.a == other.a # ignore the b field
|
||||
|
||||
def fn(a: int, b1: int, b2: int):
|
||||
pt1 = CustomEq(a, b1)
|
||||
pt2 = CustomEq(a, b2)
|
||||
return pt1 == pt2
|
||||
|
||||
self.checkScript(fn, [1, 2, 3])
|
||||
|
||||
def test_no_source(self):
|
||||
with self.assertRaises(RuntimeError):
|
||||
# uses list in Enum is not supported
|
||||
torch.jit.script(MixupParams)
|
||||
|
||||
torch.jit.script(MixupParams2) # don't throw
|
||||
|
||||
|
||||
def test_use_unregistered_dataclass_raises(self):
|
||||
|
||||
def f(a: MixupParams3):
|
||||
return 0
|
||||
|
||||
with self.assertRaises(OSError):
|
||||
torch.jit.script(f)
|
@ -6,6 +6,7 @@ from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
from torch.testing import FileCheck
|
||||
from torch import jit
|
||||
from jit.test_module_interface import TestModuleInterface # noqa: F401
|
||||
import unittest
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
@ -46,6 +47,24 @@ class TestMisc(JitTestCase):
|
||||
self.assertEqual(out, out_script)
|
||||
self.assertEqual(captured, captured_script)
|
||||
|
||||
@unittest.skipIf(sys.version_info[:2] < (3, 7), "`dataclasses` module not present on < 3.7")
|
||||
def test_dataclass_error(self):
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class NormalizationInfo(object):
|
||||
mean: float = 0.0
|
||||
|
||||
def compute(self, total_rows):
|
||||
return self.mean
|
||||
|
||||
def fn():
|
||||
return NormalizationInfo(1, 2, 3, 4, 5)
|
||||
|
||||
with self.assertRaisesRegex(OSError, "could not get source code"):
|
||||
torch.jit.script(fn)
|
||||
|
||||
|
||||
def test_kwarg_support(self):
|
||||
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "variable number of arguments"):
|
||||
class M(torch.nn.Module):
|
||||
|
@ -78,7 +78,6 @@ from jit.test_device_analysis import TestDeviceAnalysis # noqa: F401
|
||||
from jit.test_dce import TestDCE # noqa: F401
|
||||
from jit.test_sparse import TestSparse # noqa: F401
|
||||
from jit.test_tensor_methods import TestTensorMethods # noqa: F401
|
||||
from jit.test_dataclasses import TestDataclasses # noqa: F401
|
||||
|
||||
# Torch
|
||||
from torch import Tensor
|
||||
|
@ -47,24 +47,6 @@ except ImportError:
|
||||
boolean_dispatched: 'weakref.WeakKeyDictionary[Callable, Dict[str, Callable]]' = weakref.WeakKeyDictionary() # noqa: T484
|
||||
|
||||
|
||||
FAKE_FILENAME_PREFIX = '__torch_jit_dataclass'
|
||||
|
||||
|
||||
class SourceLoader:
|
||||
|
||||
def __init__(self):
|
||||
self.content = {}
|
||||
|
||||
def cache(self, fn, source):
|
||||
self.content[fn] = source
|
||||
|
||||
def get_source(self, fn):
|
||||
return self.content.get(fn)
|
||||
|
||||
|
||||
loader = SourceLoader()
|
||||
|
||||
|
||||
def createResolutionCallbackFromEnv(lookup_base):
|
||||
"""
|
||||
Creates a resolution callback that will look up qualified names in an
|
||||
@ -341,14 +323,6 @@ def get_type_hint_captures(fn):
|
||||
A Dict[str, Any] containing a mapping from the literal annotations used on
|
||||
fn to the Python objects they refer to.
|
||||
"""
|
||||
# First, try to get the source of the function. We'll need to parse it to find the actual string names
|
||||
# that were used to annotate the types, since inspect.signature() will only return the class object that
|
||||
# the annotation refers to, not the string name. If we can't get the source, simply return an empty dict.
|
||||
# This may happen in cases where the function is synthesized dynamically at runtime.
|
||||
src = loader.get_source(fn)
|
||||
if src is None:
|
||||
src = inspect.getsource(fn)
|
||||
|
||||
# Gather a dictionary of parameter name -> type, skipping any parameters whose annotated
|
||||
# types are strings. These are only understood by TorchScript in the context of a type annotation
|
||||
# that refers to a class in its own definition, but trying to include a mapping for this in the result
|
||||
@ -364,6 +338,8 @@ def get_type_hint_captures(fn):
|
||||
# Then, get the literal type annotations from the function declaration
|
||||
# by source inspection. This accounts for the case in which aliases are used
|
||||
# to annotate the arguments (e.g device_t = torch.device, and then d: device_t).
|
||||
src = inspect.getsource(fn)
|
||||
|
||||
# frontend.py cannot be used here because it includes _jit_internal, so use ast instead.
|
||||
a = ast.parse(dedent(src))
|
||||
if len(a.body) != 1 or not isinstance(a.body[0], ast.FunctionDef):
|
||||
@ -929,7 +905,7 @@ def is_optional(ann):
|
||||
|
||||
def is_union_as_optional(ann):
|
||||
ann_args = ann.__args__
|
||||
return len(ann_args) == 2 and (None in ann_args or type(None) in ann_args)
|
||||
return len(ann_args) == 2 and None in ann_args
|
||||
|
||||
return is_optional_as_optional(ann) or (is_union(ann) and is_union_as_optional(ann))
|
||||
|
||||
|
@ -48,20 +48,13 @@ def normalize_source_lines(sourcelines: List[str]) -> List[str]:
|
||||
return text[text.startswith(prefix) and len(prefix):]
|
||||
|
||||
# Find the line and line number containing the function definition
|
||||
idx = None
|
||||
for i, l in enumerate(sourcelines):
|
||||
if l.lstrip().startswith("def"):
|
||||
idx = i
|
||||
break
|
||||
|
||||
# This will happen when the function is a lambda- we won't find "def" anywhere in the source
|
||||
# lines in that case. Currently trying to JIT compile a lambda will throw an error up in
|
||||
# `parse_def()`, but we might want to handle this case in the future.
|
||||
if idx is None:
|
||||
return sourcelines
|
||||
fn_def = sourcelines[idx]
|
||||
|
||||
# Get a string representing the amount of leading whitespace
|
||||
fn_def = sourcelines[idx]
|
||||
whitespace = fn_def.split("def")[0]
|
||||
|
||||
# Add this leading whitespace to all lines before and after the `def`
|
||||
|
@ -227,11 +227,10 @@ struct TORCH_API CompilationUnit {
|
||||
// Tombstone the method in the compilation unit.
|
||||
// Don't erase because the dict_
|
||||
auto it = dict_.find(method->qualname());
|
||||
if (it != dict_.end()) {
|
||||
functions_[it->second] = nullptr;
|
||||
// Erase in our big lookup table
|
||||
dict_.erase(it);
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(it != dict_.end());
|
||||
functions_[it->second] = nullptr;
|
||||
// Erase in our big lookup table
|
||||
dict_.erase(it);
|
||||
}
|
||||
// Classes can have multiple pointers to the same hook,
|
||||
// need to make sure to not delete it twice
|
||||
|
@ -1,152 +0,0 @@
|
||||
# Functions for synthesizing magic methods for JIT-compiled dataclasses
|
||||
import os
|
||||
from functools import partial
|
||||
from torch._jit_internal import is_optional, FAKE_FILENAME_PREFIX
|
||||
from torch._sources import ParsedDef, SourceContext
|
||||
from typing import Callable, Dict, List
|
||||
import ast
|
||||
import dataclasses
|
||||
import inspect
|
||||
import sys
|
||||
|
||||
def _get_fake_filename(cls, method_name):
|
||||
return os.path.join(FAKE_FILENAME_PREFIX, cls.__name__, method_name)
|
||||
|
||||
|
||||
def compose_fn(cls, name: str, body_lines: List[str], signature: str) -> ParsedDef:
|
||||
body = '\n'.join(f' {b}' for b in body_lines)
|
||||
decl = f'def {name}{signature}:\n{body}'
|
||||
|
||||
# Parse the function declaration
|
||||
try:
|
||||
py_ast = ast.parse(decl)
|
||||
except SyntaxError:
|
||||
# This should only happen if there's some unforeseeable change
|
||||
# in the dataclasses module that makes our synthesized code fail
|
||||
raise RuntimeError(
|
||||
f"TorchScript failed to synthesize dataclass method '{name}' for class '{cls.__name__}'. "
|
||||
"Please file a bug report at <https://github.com/pytorch/pytorch/issues>"
|
||||
)
|
||||
fake_filename = _get_fake_filename(cls, name)
|
||||
# Parse the function
|
||||
return ParsedDef(
|
||||
py_ast,
|
||||
ctx=SourceContext(
|
||||
source=decl,
|
||||
filename=fake_filename,
|
||||
file_lineno=0,
|
||||
leading_whitespace_len=0
|
||||
),
|
||||
source=decl,
|
||||
filename=fake_filename,
|
||||
file_lineno=0
|
||||
)
|
||||
|
||||
|
||||
def synthesize__init__(cls) -> ParsedDef:
|
||||
# Supporting default factories in the way that people expect would sort of require us to
|
||||
# allow compiling lambda functions, which is not currently supported.
|
||||
if any(field.default_factory is not dataclasses.MISSING for field in dataclasses.fields(cls)):
|
||||
raise NotImplementedError("Default factory initializers are not supported in TorchScript dataclasses")
|
||||
|
||||
# Simply read off the generated __init__ signature from CPython's implementation. It'll be
|
||||
# almost correct except for InitVar annotations, which we need to handle specially.
|
||||
signature = inspect.signature(cls.__init__)
|
||||
|
||||
# Handle InitVars if needed (only works on Python 3.8+, when a `type` attribute was added to InitVar);
|
||||
# see CPython commit here https://github.com/python/cpython/commit/01ee12ba35a333e8a6a25c4153c4a21838e9585c
|
||||
init_vars: List[str] = []
|
||||
if sys.version_info >= (3, 8):
|
||||
params = []
|
||||
for name, param in signature.parameters.items():
|
||||
ann = param.annotation
|
||||
|
||||
if isinstance(ann, dataclasses.InitVar):
|
||||
# The TorchScript interpreter can't handle InitVar annotations, so we unwrap the underlying type here
|
||||
init_vars.append(name)
|
||||
params.append(param.replace(annotation=ann.type)) # type: ignore[attr-defined]
|
||||
else:
|
||||
params.append(param)
|
||||
|
||||
signature = signature.replace(parameters=params)
|
||||
|
||||
body = [
|
||||
# Assign all attributes to self
|
||||
f'self.{field.name} = {field.name}'
|
||||
for field in dataclasses.fields(cls)
|
||||
if field.init and field.name not in init_vars
|
||||
]
|
||||
# Call user's impl of __post_init__ if it exists
|
||||
if hasattr(cls, '__post_init__'):
|
||||
body.append('self.__post_init__(' + ', '.join(init_vars) + ')')
|
||||
|
||||
return compose_fn(cls, '__init__', body or ['pass'], signature=str(signature))
|
||||
|
||||
# This is a placeholder at the moment since the TorchScript interpreter doesn't call __repr__
|
||||
def synthesize__repr__(cls) -> ParsedDef:
|
||||
return compose_fn(
|
||||
cls, '__repr__',
|
||||
[f"return '{cls.__name__}(" + ", ".join([
|
||||
f"{field.name}=self.{field.name}"
|
||||
for field in dataclasses.fields(cls) if field.repr
|
||||
]) + ")'"],
|
||||
signature='(self) -> str'
|
||||
)
|
||||
|
||||
def synthesize__hash__(cls) -> ParsedDef:
|
||||
return compose_fn(
|
||||
cls, '__hash__',
|
||||
[
|
||||
# This is just a placeholder to prevent compilation from failing; this won't even get called at
|
||||
# all right now because the TorchScript interpreter doesn't call custom __hash__ implementations
|
||||
"raise NotImplementedError('__hash__ is not supported for dataclasses in TorchScript')"
|
||||
],
|
||||
signature='(self) -> int'
|
||||
)
|
||||
|
||||
# Implementation for __eq__ and __ne__
|
||||
def synthesize_equality(cls, name: str, converse: str) -> ParsedDef:
|
||||
return synthesize_comparison(cls, name, allow_eq=True, raise_on_none=False, inner=[
|
||||
f"if val1 {converse} val2: return False"
|
||||
])
|
||||
|
||||
def synthesize_inequality(cls, name: str, op: str, allow_eq: bool) -> ParsedDef:
|
||||
return synthesize_comparison(cls, name, allow_eq, raise_on_none=True, inner=[
|
||||
f"if val1 {op} val2: return True",
|
||||
f"elif val2 {op} val1: return False",
|
||||
])
|
||||
|
||||
def synthesize_comparison(cls, name: str, allow_eq: bool, raise_on_none: bool, inner: List[str]) -> ParsedDef:
|
||||
body = []
|
||||
for field in dataclasses.fields(cls):
|
||||
if not field.compare:
|
||||
continue
|
||||
|
||||
body.extend([
|
||||
f"val1 = self.{field.name}",
|
||||
f"val2 = other.{field.name}",
|
||||
])
|
||||
body.extend(
|
||||
inner if not is_optional(field.type) else [
|
||||
# Type refinement for optional fields; we need this to avoid type errors from the interpreter
|
||||
"if val1 is not None and val2 is not None:",
|
||||
*[' ' + line for line in inner],
|
||||
"elif (val1 is None) != (val2 is None):",
|
||||
f" raise TypeError('Cannot compare {cls.__name__} with None')" if raise_on_none else " return False"
|
||||
]
|
||||
)
|
||||
|
||||
body.append(f"return {allow_eq}")
|
||||
return compose_fn(cls, name, body, signature=f'(self, other: {cls.__name__}) -> bool')
|
||||
|
||||
DATACLASS_MAGIC_METHODS: Dict[str, Callable] = {
|
||||
"__init__": synthesize__init__,
|
||||
"__repr__": synthesize__repr__,
|
||||
"__hash__": synthesize__hash__,
|
||||
"__eq__": partial(synthesize_equality, name="__eq__", converse="!="),
|
||||
"__ne__": partial(synthesize_equality, name="__ne__", converse="=="),
|
||||
"__lt__": partial(synthesize_inequality, name="__lt__", op="<", allow_eq=False),
|
||||
"__le__": partial(synthesize_inequality, name="__le__", op="<", allow_eq=True),
|
||||
"__gt__": partial(synthesize_inequality, name="__gt__", op=">", allow_eq=False),
|
||||
"__ge__": partial(synthesize_inequality, name="__ge__", op=">", allow_eq=True),
|
||||
}
|
@ -284,10 +284,7 @@ def get_enum_value_type(e: Type[enum.Enum], loc):
|
||||
# Even though Python supports this case, we chose to not implement it to
|
||||
# avoid overcomplicate logic here for a rare use case. Please report a
|
||||
# feature request if you find it necessary.
|
||||
res = torch._C.unify_type_list(ir_types)
|
||||
if not res:
|
||||
return AnyType.get()
|
||||
return res
|
||||
return torch._C.unify_type_list(ir_types)
|
||||
|
||||
def is_tensor(ann):
|
||||
if issubclass(ann, torch.Tensor):
|
||||
|
@ -1,7 +1,6 @@
|
||||
import torch
|
||||
import sys
|
||||
import ast
|
||||
import dataclasses
|
||||
import inspect
|
||||
import string
|
||||
from collections import namedtuple
|
||||
@ -18,11 +17,9 @@ from torch._C._jit_tree_views import (
|
||||
SliceExpr, Subscript, TernaryIf, With, WithItem, Property,
|
||||
DictComp,
|
||||
)
|
||||
from torch._sources import get_source_lines_and_file, ParsedDef, parse_def, make_source_context
|
||||
from torch.jit._dataclass_impls import DATACLASS_MAGIC_METHODS
|
||||
from torch._sources import get_source_lines_and_file, parse_def, make_source_context
|
||||
from torch.jit._monkeytype_config import monkeytype_trace, get_qualified_name
|
||||
from torch._jit_internal import should_drop, is_static_fn, FunctionModifiers # noqa: F401
|
||||
from torch import _jit_internal
|
||||
import torch.jit.annotations
|
||||
|
||||
_IS_ASTUNPARSE_INSTALLED = False
|
||||
@ -198,48 +195,24 @@ def get_jit_class_def(cls, self_name):
|
||||
def is_classmethod(fn):
|
||||
return inspect.ismethod(fn) and getattr(fn, "__self__", None) == cls
|
||||
|
||||
# Get and parse the source code for this class
|
||||
sourcelines, file_lineno, filename = get_source_lines_and_file(cls, torch._C.ErrorReport.call_stack())
|
||||
source = ''.join(sourcelines)
|
||||
methods = [get_jit_def(obj,
|
||||
name,
|
||||
self_name=self_name,
|
||||
is_classmethod=is_classmethod(obj)) for (name, obj) in methods]
|
||||
|
||||
dedent_src = dedent(source)
|
||||
py_ast = ast.parse(dedent_src)
|
||||
|
||||
class_ast = py_ast.body[0]
|
||||
assert isinstance(class_ast, ast.ClassDef)
|
||||
|
||||
# Special case for dataclasses. In general we need access to the source code for
|
||||
# an object in order to JIT compile it. But the dataclasses module dynamically synthesizes
|
||||
# magic methods for classes, and we can't get the source code for these methods. As a
|
||||
# workaround, we synthesize TorchScript-friendly implementations ourselves.
|
||||
if dataclasses.is_dataclass(cls):
|
||||
# Detect whether the user manually implemented any of the magic methods. If they did,
|
||||
# we don't want to synthesize/override them.
|
||||
overrides = {
|
||||
method.name
|
||||
for method in class_ast.body
|
||||
if isinstance(method, ast.FunctionDef) and method.name in DATACLASS_MAGIC_METHODS
|
||||
}
|
||||
for i, (name, _) in enumerate(methods):
|
||||
# Is this a magic method we can synthesize?
|
||||
synthesizer_fn = DATACLASS_MAGIC_METHODS.get(name)
|
||||
if synthesizer_fn and name not in overrides:
|
||||
parsed_def = synthesizer_fn(cls)
|
||||
methods[i] = name, parsed_def
|
||||
func = getattr(cls, name)
|
||||
_jit_internal.loader.cache(func, parsed_def.source)
|
||||
|
||||
method_defs = [
|
||||
get_jit_def(obj, name, self_name=self_name, is_classmethod=is_classmethod(obj))
|
||||
for (name, obj) in methods
|
||||
]
|
||||
properties = get_class_properties(cls, self_name)
|
||||
|
||||
sourcelines, file_lineno, filename = get_source_lines_and_file(cls, torch._C.ErrorReport.call_stack())
|
||||
source = ''.join(sourcelines)
|
||||
dedent_src = dedent(source)
|
||||
py_ast = ast.parse(dedent_src)
|
||||
leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0])
|
||||
ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, False)
|
||||
class_ast = py_ast.body[0]
|
||||
assert isinstance(class_ast, ast.ClassDef)
|
||||
assigns = get_class_assigns(ctx, class_ast)
|
||||
|
||||
return build_class_def(ctx, class_ast, method_defs, properties, self_name, assigns)
|
||||
return build_class_def(ctx, class_ast, methods, properties, self_name, assigns)
|
||||
|
||||
|
||||
def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
|
||||
@ -247,7 +220,7 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
|
||||
Build a JIT AST (TreeView) from the given function.
|
||||
|
||||
Args:
|
||||
fn: A function object to compile or a pre-parsed ParsedDef object
|
||||
fn: A function object to compile
|
||||
def_name: The name to give to the resulting AST object. This is not
|
||||
always the same as `fn.__name__`, for example:
|
||||
def _forward(self):
|
||||
@ -257,7 +230,7 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
|
||||
but we want the result AST to have the name "forward".
|
||||
self_name: If this function is a method, what the type name of `self` is.
|
||||
"""
|
||||
parsed_def = parse_def(fn) if not isinstance(fn, ParsedDef) else fn
|
||||
parsed_def = parse_def(fn)
|
||||
type_line = torch.jit.annotations.get_type_line(parsed_def.source)
|
||||
fn_def = parsed_def.ast.body[0]
|
||||
|
||||
@ -284,7 +257,7 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
|
||||
# for the arguments from type_trace_db
|
||||
type_trace_db = torch.jit._script._get_type_trace_db()
|
||||
pdt_arg_types = None
|
||||
if monkeytype_trace and not isinstance(fn, ParsedDef):
|
||||
if monkeytype_trace:
|
||||
qualname = get_qualified_name(fn)
|
||||
pdt_arg_types = type_trace_db.get_args_types(qualname)
|
||||
|
||||
|
Reference in New Issue
Block a user