mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145178 Approved by: https://github.com/bobrenjc93
		
			
				
	
	
		
			139 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			139 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# mypy: allow-untyped-defs
 | 
						|
import ast
 | 
						|
import functools
 | 
						|
import inspect
 | 
						|
from textwrap import dedent
 | 
						|
from typing import Any, NamedTuple, Optional
 | 
						|
 | 
						|
from torch._C import ErrorReport
 | 
						|
from torch._C._jit_tree_views import SourceRangeFactory
 | 
						|
 | 
						|
 | 
						|
def get_source_lines_and_file(
 | 
						|
    obj: Any,
 | 
						|
    error_msg: Optional[str] = None,
 | 
						|
) -> tuple[list[str], int, Optional[str]]:
 | 
						|
    """
 | 
						|
    Wrapper around inspect.getsourcelines and inspect.getsourcefile.
 | 
						|
 | 
						|
    Returns: (sourcelines, file_lino, filename)
 | 
						|
    """
 | 
						|
    filename = None  # in case getsourcefile throws
 | 
						|
    try:
 | 
						|
        filename = inspect.getsourcefile(obj)
 | 
						|
        sourcelines, file_lineno = inspect.getsourcelines(obj)
 | 
						|
    except OSError as e:
 | 
						|
        msg = (
 | 
						|
            f"Can't get source for {obj}. TorchScript requires source access in "
 | 
						|
            "order to carry out compilation, make sure original .py files are "
 | 
						|
            "available."
 | 
						|
        )
 | 
						|
        if error_msg:
 | 
						|
            msg += "\n" + error_msg
 | 
						|
        raise OSError(msg) from e
 | 
						|
 | 
						|
    return sourcelines, file_lineno, filename
 | 
						|
 | 
						|
 | 
						|
def normalize_source_lines(sourcelines: list[str]) -> list[str]:
 | 
						|
    """
 | 
						|
    This helper function accepts a list of source lines. It finds the
 | 
						|
    indentation level of the function definition (`def`), then it indents
 | 
						|
    all lines in the function body to a point at or greater than that
 | 
						|
    level. This allows for comments and continued string literals that
 | 
						|
    are at a lower indentation than the rest of the code.
 | 
						|
    Args:
 | 
						|
        sourcelines: function source code, separated into lines by
 | 
						|
                        the '\n' character
 | 
						|
    Returns:
 | 
						|
        A list of source lines that have been correctly aligned
 | 
						|
    """
 | 
						|
 | 
						|
    def remove_prefix(text, prefix):
 | 
						|
        return text[text.startswith(prefix) and len(prefix) :]
 | 
						|
 | 
						|
    # Find the line and line number containing the function definition
 | 
						|
    idx = None
 | 
						|
    for i, l in enumerate(sourcelines):
 | 
						|
        if l.lstrip().startswith("def"):
 | 
						|
            idx = i
 | 
						|
            break
 | 
						|
 | 
						|
    # This will happen when the function is a lambda- we won't find "def" anywhere in the source
 | 
						|
    # lines in that case. Currently trying to JIT compile a lambda will throw an error up in
 | 
						|
    # `parse_def()`, but we might want to handle this case in the future.
 | 
						|
    if idx is None:
 | 
						|
        return sourcelines
 | 
						|
 | 
						|
    # Get a string representing the amount of leading whitespace
 | 
						|
    fn_def = sourcelines[idx]
 | 
						|
    whitespace = fn_def.split("def")[0]
 | 
						|
 | 
						|
    # Add this leading whitespace to all lines before and after the `def`
 | 
						|
    aligned_prefix = [
 | 
						|
        whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]
 | 
						|
    ]
 | 
						|
    aligned_suffix = [
 | 
						|
        whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1 :]
 | 
						|
    ]
 | 
						|
 | 
						|
    # Put it together again
 | 
						|
    aligned_prefix.append(fn_def)
 | 
						|
    return aligned_prefix + aligned_suffix
 | 
						|
 | 
						|
 | 
						|
# Thin wrapper around SourceRangeFactory to store extra metadata
 | 
						|
# about the function-to-be-compiled.
 | 
						|
class SourceContext(SourceRangeFactory):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        source,
 | 
						|
        filename,
 | 
						|
        file_lineno,
 | 
						|
        leading_whitespace_len,
 | 
						|
        uses_true_division=True,
 | 
						|
        funcname=None,
 | 
						|
    ):
 | 
						|
        super().__init__(source, filename, file_lineno, leading_whitespace_len)
 | 
						|
        self.uses_true_division = uses_true_division
 | 
						|
        self.filename = filename
 | 
						|
        self.funcname = funcname
 | 
						|
 | 
						|
 | 
						|
@functools.cache
 | 
						|
def make_source_context(*args):
 | 
						|
    return SourceContext(*args)
 | 
						|
 | 
						|
 | 
						|
def fake_range():
 | 
						|
    return SourceContext("", None, 0, 0).make_raw_range(0, 1)
 | 
						|
 | 
						|
 | 
						|
class ParsedDef(NamedTuple):
 | 
						|
    ast: ast.Module
 | 
						|
    ctx: SourceContext
 | 
						|
    source: str
 | 
						|
    filename: Optional[str]
 | 
						|
    file_lineno: int
 | 
						|
 | 
						|
 | 
						|
def parse_def(fn):
 | 
						|
    sourcelines, file_lineno, filename = get_source_lines_and_file(
 | 
						|
        fn, ErrorReport.call_stack()
 | 
						|
    )
 | 
						|
    sourcelines = normalize_source_lines(sourcelines)
 | 
						|
    source = "".join(sourcelines)
 | 
						|
    dedent_src = dedent(source)
 | 
						|
    py_ast = ast.parse(dedent_src)
 | 
						|
    if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
 | 
						|
        raise RuntimeError(
 | 
						|
            f"Expected a single top-level function: {filename}:{file_lineno}"
 | 
						|
        )
 | 
						|
    leading_whitespace_len = len(source.split("\n", 1)[0]) - len(
 | 
						|
        dedent_src.split("\n", 1)[0]
 | 
						|
    )
 | 
						|
    ctx = make_source_context(
 | 
						|
        source, filename, file_lineno, leading_whitespace_len, True, fn.__name__
 | 
						|
    )
 | 
						|
    return ParsedDef(py_ast, ctx, source, filename, file_lineno)
 |