mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Several improvements for skipfiles: * Add ```FUNC_INLINELIST``` to support function level skip/inline check. * Use ```fn.__code__``` to match function since we can't get the function object sometimes. * Use python module string name for ```FILE_INLINELIST``` and ```SUBMODULE_INLINELIST```. * Use filename to match file and python module, which can fundamentally resolved the circular import issues introduced by skipfiles. * Use ```TYPE_CHECKING``` to ensure the python module string name is correct. * Add unit tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/110835 Approved by: https://github.com/ezyang
419 lines
12 KiB
Python
419 lines
12 KiB
Python
import _collections_abc
|
|
import _weakrefset
|
|
import abc
|
|
import collections
|
|
import contextlib
|
|
import copy
|
|
import copyreg
|
|
import dataclasses
|
|
import enum
|
|
import functools
|
|
import importlib
|
|
import inspect
|
|
import linecache
|
|
import logging
|
|
import multiprocessing
|
|
import operator
|
|
import os
|
|
import posixpath
|
|
import random
|
|
import re
|
|
import selectors
|
|
import signal
|
|
import tempfile
|
|
import threading
|
|
import tokenize
|
|
import traceback
|
|
import types
|
|
import typing
|
|
import unittest
|
|
import weakref
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch._inductor.test_operators
|
|
import torch.distributed
|
|
import torch.utils._content_store
|
|
from .utils import getfile
|
|
|
|
from .variables.functions import (
|
|
NestedUserFunctionVariable,
|
|
UserFunctionVariable,
|
|
UserMethodVariable,
|
|
)
|
|
|
|
|
|
"""
|
|
A note on skipfiles:
|
|
|
|
Dynamo consults this file to determine whether function should be inlined or skipped.
|
|
|
|
A skip applies at the frame boundary, meaning dynamo either triggers a graph break
|
|
at the beginning of the frame or attempts to trace/inline the whole frame. When skipping
|
|
a frame, recursively called frames are still traced by dynamo unless also skipped.
|
|
|
|
Skipfiles (skipped at the file level instead of function level) still apply on a
|
|
frame-by-frame boundary as dynamo traces, but apply to all functions in that file.
|
|
|
|
@skip is a helper decorator that can be applied to your function to cause it to be
|
|
included here.
|
|
|
|
Dynamo skip/inline rules & priorities are defined as follows:
|
|
* Inline is the default behavior and will be used unless explicitly skipped.
|
|
* Dynamo has two SKIPLIST: BUILTIN_SKIPLIST and THIRDPARTY_SKIPLIST.
|
|
* BUILTIN_SKIPLIST contains builtin python modules, such as abc, collections, etc.
|
|
* THIRDPARTY_SKIPLIST contains common third party libraries, such as numpy, pandas, etc.
|
|
* Functions in these two SKIPLISTs are always skipped, except when they are explicitly
|
|
put into the three INLINELIST: FUNC_INLINELIST, FILE_INLINELIST and SUBMODULE_INLINELIST.
|
|
* PyTorch(torch) is in the BUILTIN_SKIPLIST by default, but there are many cases
|
|
where we want inline the functions under torch namespace. We should add them
|
|
into one of the three *_INLINELIST to make dynamo inline those functions.
|
|
* If you call functions under skipped modules/files, Dynamo will wrap these functions
|
|
as SkipFilesVariable. There are a few functions(e.g, collections.OrderedDict) that
|
|
we have special handling at SkipFilesVariable.call_function.
|
|
|
|
Overall: *_INLINELIST has precedence over *_SKIPLIST has precedence over DEFAULT (inline)
|
|
|
|
To figure out what the behavior is, check the following list in order:
|
|
* FUNC_INLINELIST (Inline if YES)
|
|
* FILE_INLINELIST (Inline if YES)
|
|
* SUBMODULE_INLINELIST (Inline if YES)
|
|
* BUILTIN_SKIPLIST & THIRDPARTY_SKIPLIST (Skip if YES)
|
|
* Inline by default
|
|
|
|
In general, if you want to force inline a function or module, please consider adding
|
|
the function's file or python module to FILE_INLINELIST first.
|
|
Use the FUNC_INLINELIST only when there are other functions under the same file that
|
|
you don't want to inline.
|
|
In the future, we will consolidate FILE_INLINELIST and SUBMODULE_INLINELIST into one list
|
|
as we use the same logic (filename.startswith) to determine if a file or module is skipped.
|
|
"""
|
|
|
|
|
|
BUILTIN_SKIPLIST = (
|
|
abc,
|
|
collections,
|
|
contextlib,
|
|
copy,
|
|
copyreg,
|
|
dataclasses,
|
|
enum,
|
|
functools,
|
|
importlib,
|
|
inspect,
|
|
linecache,
|
|
logging,
|
|
multiprocessing,
|
|
operator,
|
|
os,
|
|
posixpath,
|
|
random,
|
|
re,
|
|
selectors,
|
|
signal,
|
|
tempfile,
|
|
threading,
|
|
tokenize,
|
|
torch, # torch/* is skipped by default unless specified in FILE_INLINELIST or SUBMODULE_INLINELIST
|
|
traceback,
|
|
types,
|
|
typing,
|
|
unittest,
|
|
weakref,
|
|
_collections_abc,
|
|
_weakrefset,
|
|
)
|
|
|
|
# third party libraries skiplist is defined by str, because users may not use these libraries.
|
|
# we should use lazy import & skip in the future.
|
|
THIRDPARTY_SKIPLIST = (
|
|
"functorch",
|
|
"fx2trt_oss",
|
|
"intel_extension_for_pytorch",
|
|
"networkx",
|
|
"numpy",
|
|
"omegaconf",
|
|
"onnx",
|
|
"onnxruntime",
|
|
"onnx_tf",
|
|
"pandas",
|
|
"sklearn",
|
|
"tabulate",
|
|
"tensorflow",
|
|
"tensorrt",
|
|
"torch2trt",
|
|
"tqdm",
|
|
"tree",
|
|
"tvm",
|
|
"xarray",
|
|
)
|
|
|
|
|
|
def _strip_init_py(s):
|
|
return re.sub(r"__init__.py$", "", s)
|
|
|
|
|
|
def _module_dir(m: types.ModuleType):
|
|
return _strip_init_py(m.__file__)
|
|
|
|
|
|
# TODO: Add a decoractor for easily adding functions to FUNC_INLINELIST
|
|
# after resolving all circular import issues.
|
|
FUNC_INLINELIST = {
|
|
"torch._constrain_as_size",
|
|
"torch._constrain_as_value",
|
|
}
|
|
|
|
|
|
# Force inline functions in these files or directories, even they are in *_SKIPLIST.
|
|
# We are using python module name instead of file or directory object to avoid circular dependency.
|
|
# Please keep this sorted alphabetically.
|
|
# TODO: Merge FILE_INLINELIST into SUBMODULE_INLINELIST.
|
|
FILE_INLINELIST = {
|
|
"torch._dynamo._trace_wrapped_higher_order_op",
|
|
"torch._dynamo.comptime",
|
|
"torch._dynamo.external_utils",
|
|
"torch._dynamo.polyfill",
|
|
"torch._export.db.examples",
|
|
"torch._export.wrappers",
|
|
"torch._functorch.apis",
|
|
"torch._functorch.deprecated",
|
|
"torch._higher_order_ops.cond",
|
|
"torch._inductor.test_operators",
|
|
"torch.ao.quantization.pt2e.eval_utils",
|
|
"torch.ao.quantization.pt2e.qat_utils",
|
|
"torch.ao.quantization.pt2e.representation.rewrite",
|
|
"torch.ao.quantization.pt2e.utils",
|
|
"torch.ao.quantization.quantizer.xnnpack_quantizer",
|
|
"torch.nn.modules.container",
|
|
"torch.optim._functional",
|
|
"torch.random",
|
|
"torch.utils._content_store",
|
|
"torch.utils._foreach_utils",
|
|
}
|
|
|
|
|
|
if torch.distributed.is_available():
|
|
FILE_INLINELIST |= {
|
|
"torch.distributed._tensor.api",
|
|
"torch.distributed._tensor.device_mesh",
|
|
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper",
|
|
"torch.distributed.tensor.parallel._data_parallel_utils",
|
|
"torch.distributed.tensor.parallel._utils",
|
|
"torch.distributed.tensor.parallel.style",
|
|
}
|
|
|
|
# Include optimizer code for tracing
|
|
FILE_INLINELIST |= {
|
|
str(obj.__module__) for obj in torch.optim.__dict__.values() if inspect.isclass(obj)
|
|
}
|
|
|
|
|
|
# TODO: consolidate SUBMODULE_INLINELIST and FILE_INLINELIST into one list
|
|
# Force inline functions under these modules, even the modules is in *_SKIPLIST.
|
|
SUBMODULE_INLINELIST = {
|
|
"torch._refs",
|
|
"torch._prims",
|
|
"torch._decomp",
|
|
"torch.ao.nn",
|
|
"torch.distributions",
|
|
"torch.fx._pytree",
|
|
"torch.nn",
|
|
"torch.sparse",
|
|
"torch.testing",
|
|
"torch.utils._contextlib",
|
|
"torch.utils._pytree",
|
|
}
|
|
|
|
|
|
if torch.distributed.is_available():
|
|
SUBMODULE_INLINELIST.add("torch.distributed._functional_collectives")
|
|
|
|
|
|
# TODO: support adding bound method into this list
|
|
@functools.lru_cache(None)
|
|
def get_func_inlinelist():
|
|
inlinelist = set()
|
|
for f in FUNC_INLINELIST:
|
|
module_name, fn_name = f.rsplit(".", 1)
|
|
m = importlib.import_module(module_name)
|
|
fn = getattr(m, fn_name)
|
|
inlinelist.add(fn.__code__)
|
|
return inlinelist
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def get_file_inlinelist():
|
|
inlinelist = set()
|
|
for f in FILE_INLINELIST:
|
|
inlinelist.add(_module_dir(torch) + f[len("torch.") :].replace(".", "/"))
|
|
return inlinelist
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def get_submodule_inlinelist():
|
|
inlinelist = set()
|
|
for m in SUBMODULE_INLINELIST:
|
|
inlinelist.add(_module_dir(torch) + m[len("torch.") :].replace(".", "/"))
|
|
return inlinelist
|
|
|
|
|
|
# skip some standard python builtin libs
|
|
SKIP_DIRS = [
|
|
"<frozen importlib",
|
|
"<__array_function__ internals>",
|
|
] + [_module_dir(m) for m in BUILTIN_SKIPLIST]
|
|
|
|
SKIP_DIRS_RE = None
|
|
|
|
is_fbcode = importlib.import_module("torch._inductor.config").is_fbcode()
|
|
# Skip fbcode paths(including torch.package paths) containing
|
|
# one of the following strings.
|
|
FBCODE_SKIP_DIRS = {
|
|
"torchrec/distributed",
|
|
"torchrec/fb/distributed",
|
|
"caffe2/torch/fb/sparsenn/pooled_embeddings_modules.py",
|
|
}
|
|
FBCODE_SKIP_DIRS_RE = re.compile(f".*({'|'.join(map(re.escape, FBCODE_SKIP_DIRS))})")
|
|
|
|
|
|
def _recompile_re():
|
|
global SKIP_DIRS_RE
|
|
SKIP_DIRS_RE = re.compile(f"^({'|'.join(map(re.escape, SKIP_DIRS))})")
|
|
|
|
|
|
def add(import_name: str):
|
|
if isinstance(import_name, types.ModuleType):
|
|
return add(import_name.__name__)
|
|
assert isinstance(import_name, str)
|
|
module_spec = importlib.util.find_spec(import_name)
|
|
if not module_spec:
|
|
return
|
|
origin = module_spec.origin
|
|
if origin is None:
|
|
return
|
|
global SKIP_DIRS_RE
|
|
SKIP_DIRS.append(_strip_init_py(origin))
|
|
_recompile_re()
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class SkipResult:
|
|
skipped: bool
|
|
reason: Optional[str]
|
|
|
|
|
|
# TODO(ybliang): This is a temp function, we should consolidate this with check_file.
|
|
def _check_file_inner(filename, allow_torch=False):
|
|
"""Should skip this file?"""
|
|
if filename is None:
|
|
return SkipResult(True, "filename is None")
|
|
if any(filename.startswith(d) for d in get_file_inlinelist()):
|
|
return SkipResult(
|
|
False,
|
|
"inlined according skipfiles.FILE_INLINELIST",
|
|
)
|
|
# TODO(ybliang): the is_torch check should be consolidate with is_torch_inline_allowed
|
|
if allow_torch and is_torch(filename):
|
|
return SkipResult(
|
|
False,
|
|
"inlined according skipfiles.is_torch",
|
|
)
|
|
if is_fbcode and bool(FBCODE_SKIP_DIRS_RE.match(filename)):
|
|
return SkipResult(
|
|
True,
|
|
"skipped according skipfiles.FBCODE_SKIP_DIRS",
|
|
)
|
|
if bool(SKIP_DIRS_RE.match(filename)):
|
|
return SkipResult(True, "skipped according skipfiles.SKIP_DIRS")
|
|
else:
|
|
return SkipResult(False, "inlined by default")
|
|
|
|
|
|
def check_file(filename, allow_torch=False, extra_check=False):
|
|
result = _check_file_inner(filename, allow_torch)
|
|
if extra_check and result.skipped and is_torch_inline_allowed(filename):
|
|
return SkipResult(
|
|
False,
|
|
"inlined according skipfiles.is_torch_inline_allowed returning True",
|
|
)
|
|
else:
|
|
return result
|
|
|
|
|
|
"""
|
|
This is the main entry point to determine whether an object (function) should be inlined or skipped.
|
|
Let's illustrate the logic with an example:
|
|
@torch.compile
|
|
def f1(x, y):
|
|
......
|
|
f2(x, y)
|
|
......
|
|
|
|
def f2(x, y):
|
|
......
|
|
f3(x, y)
|
|
......
|
|
|
|
def f3(x, y):
|
|
......
|
|
|
|
There are mainly three call sites of check/check_verbose:
|
|
* The compile region entrance (like function f1), the correspoinding code is located at eval_frame.py.
|
|
* When tracing the recursively called functions (like function f2 and f3).
|
|
* Dynamo decides inline/skip everytime it encounters a new recursively function call, and the call site
|
|
is in InliningInstructionTranslator.check_inlineable of symbolic_convert.py.
|
|
* If f2 is skipped by Dynamo, when evaluating the frame of f3, Dynamo need the inline/skip check again
|
|
and the call site is in catch_errors_wrapper.catch_errors of eval_frame.py.
|
|
* For global variables and function arguments, Dynamo needs to decide if they are wrapped as SkipFilesVariable in builder.py.
|
|
"""
|
|
|
|
|
|
def check_verbose(obj, allow_torch=False, extra_check=False):
|
|
if isinstance(
|
|
obj, (UserFunctionVariable, UserMethodVariable, NestedUserFunctionVariable)
|
|
):
|
|
filename = obj.get_filename()
|
|
obj = obj.get_code()
|
|
elif isinstance(obj, types.CodeType):
|
|
filename = obj.co_filename
|
|
elif isinstance(obj, (types.FunctionType, types.MethodType)):
|
|
filename = getfile(obj)
|
|
obj = obj.__code__
|
|
else:
|
|
filename = getfile(obj)
|
|
if obj in get_func_inlinelist():
|
|
return SkipResult(
|
|
False,
|
|
"inlined according skipfiles.FUNC_INLINELIST",
|
|
)
|
|
return check_file(filename, allow_torch, extra_check)
|
|
|
|
|
|
def check(obj, allow_torch=False, extra_check=False):
|
|
return check_verbose(obj, allow_torch, extra_check).skipped
|
|
|
|
|
|
# skip common third party libs
|
|
for _name in THIRDPARTY_SKIPLIST:
|
|
add(_name)
|
|
|
|
_recompile_re()
|
|
|
|
|
|
def is_torch_inline_allowed(filename):
|
|
return any(filename.startswith(d) for d in get_submodule_inlinelist())
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def dynamo_dir():
|
|
import torch._dynamo
|
|
|
|
return _module_dir(torch._dynamo)
|
|
|
|
|
|
def is_torch(filename):
|
|
if filename.startswith(dynamo_dir()):
|
|
return False
|
|
return filename.startswith(_module_dir(torch))
|