mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit 5547741960a01fbd3a97d1ddd5ae9b43d8f1169c. Reverted https://github.com/pytorch/pytorch/pull/74889 on behalf of https://github.com/malfet
106 lines
3.9 KiB
Python
106 lines
3.9 KiB
Python
import ast
|
|
import functools
|
|
import inspect
|
|
from textwrap import dedent
|
|
from typing import Any, Optional, Tuple, List, NamedTuple
|
|
from torch._C import ErrorReport
|
|
from torch._C._jit_tree_views import SourceRangeFactory
|
|
|
|
def get_source_lines_and_file(
|
|
obj: Any,
|
|
error_msg: Optional[str] = None,
|
|
) -> Tuple[List[str], int, Optional[str]]:
|
|
"""
|
|
Wrapper around inspect.getsourcelines and inspect.getsourcefile.
|
|
|
|
Returns: (sourcelines, file_lino, filename)
|
|
"""
|
|
filename = None # in case getsourcefile throws
|
|
try:
|
|
filename = inspect.getsourcefile(obj)
|
|
sourcelines, file_lineno = inspect.getsourcelines(obj)
|
|
except OSError as e:
|
|
msg = (f"Can't get source for {obj}. TorchScript requires source access in "
|
|
"order to carry out compilation, make sure original .py files are "
|
|
"available.")
|
|
if error_msg:
|
|
msg += '\n' + error_msg
|
|
raise OSError(msg) from e
|
|
|
|
return sourcelines, file_lineno, filename
|
|
|
|
|
|
def normalize_source_lines(sourcelines: List[str]) -> List[str]:
|
|
"""
|
|
This helper function accepts a list of source lines. It finds the
|
|
indentation level of the function definition (`def`), then it indents
|
|
all lines in the function body to a point at or greater than that
|
|
level. This allows for comments and continued string literals that
|
|
are at a lower indentation than the rest of the code.
|
|
Args:
|
|
sourcelines: function source code, separated into lines by
|
|
the '\n' character
|
|
Returns:
|
|
A list of source lines that have been correctly aligned
|
|
"""
|
|
|
|
def remove_prefix(text, prefix):
|
|
return text[text.startswith(prefix) and len(prefix):]
|
|
|
|
# Find the line and line number containing the function definition
|
|
for i, l in enumerate(sourcelines):
|
|
if l.lstrip().startswith("def"):
|
|
idx = i
|
|
break
|
|
fn_def = sourcelines[idx]
|
|
|
|
# Get a string representing the amount of leading whitespace
|
|
whitespace = fn_def.split("def")[0]
|
|
|
|
# Add this leading whitespace to all lines before and after the `def`
|
|
aligned_prefix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]]
|
|
aligned_suffix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1:]]
|
|
|
|
# Put it together again
|
|
aligned_prefix.append(fn_def)
|
|
return aligned_prefix + aligned_suffix
|
|
|
|
|
|
# Thin wrapper around SourceRangeFactory to store extra metadata
|
|
# about the function-to-be-compiled.
|
|
class SourceContext(SourceRangeFactory):
|
|
def __init__(self, source, filename, file_lineno, leading_whitespace_len, uses_true_division=True, funcname=None):
|
|
super(SourceContext, self).__init__(source, filename, file_lineno, leading_whitespace_len)
|
|
self.uses_true_division = uses_true_division
|
|
self.filename = filename
|
|
self.funcname = funcname
|
|
|
|
|
|
@functools.lru_cache(maxsize=None)
|
|
def make_source_context(*args):
|
|
return SourceContext(*args)
|
|
|
|
|
|
def fake_range():
|
|
return SourceContext('', None, 0, 0).make_raw_range(0, 1)
|
|
|
|
|
|
class ParsedDef(NamedTuple):
|
|
ast: ast.Module
|
|
ctx: SourceContext
|
|
source: str
|
|
filename: Optional[str]
|
|
file_lineno: int
|
|
|
|
def parse_def(fn):
|
|
sourcelines, file_lineno, filename = get_source_lines_and_file(fn, ErrorReport.call_stack())
|
|
sourcelines = normalize_source_lines(sourcelines)
|
|
source = ''.join(sourcelines)
|
|
dedent_src = dedent(source)
|
|
py_ast = ast.parse(dedent_src)
|
|
if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
|
|
raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}")
|
|
leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0])
|
|
ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, True, fn.__name__)
|
|
return ParsedDef(py_ast, ctx, source, filename, file_lineno)
|