[2/N] Dynamo supports skip by function & removes skipfiles circular import (#110835)

Several improvements for skipfiles:
* Add ```FUNC_INLINELIST``` to support function level skip/inline check.
  * Use ```fn.__code__``` to match function since we can't get the function object sometimes.
* Use python module string name for ```FILE_INLINELIST``` and ```SUBMODULE_INLINELIST```.
  * Use filename to match file and python module, which can fundamentally resolved the circular import issues introduced by skipfiles.
  * Use ```TYPE_CHECKING``` to ensure the python module string name is correct.
* Add unit tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110835
Approved by: https://github.com/ezyang
This commit is contained in:
Yanbo Liang
2023-10-12 00:44:41 +00:00
committed by PyTorch MergeBot
parent a6b452dfdc
commit 986ad3bfa6
6 changed files with 299 additions and 82 deletions

View File

@ -0,0 +1,101 @@
# Owner(s): ["module: dynamo"]
import importlib
import types
import unittest
import torch
import torch._dynamo.test_case
from torch._dynamo.skipfiles import (
FILE_INLINELIST,
FUNC_INLINELIST,
SUBMODULE_INLINELIST,
)
from torch._dynamo.utils import istype
try:
from .utils import create_dummy_module_and_function
except ImportError:
from utils import create_dummy_module_and_function
def gen_get_func_inlinelist(dummy_func_inlinelist):
def get_func_inlinelist():
inlinelist = set()
for f in dummy_func_inlinelist:
module_name, fn_name = f.rsplit(".", 1)
m = importlib.import_module(module_name)
fn = getattr(m, fn_name)
inlinelist.add(fn.__code__)
return inlinelist
return get_func_inlinelist
class AllowInlineSkipTests(torch._dynamo.test_case.TestCase):
# We are using python function and module string names for these inlinelist,
# this unit test is to make sure the functions/modules can be correctly imported
# or loaded in case there is typo in the strings.
def test_skipfiles_inlinelist_correctness(self):
for m in FILE_INLINELIST.union(SUBMODULE_INLINELIST):
self.assertTrue(isinstance(importlib.import_module(m), types.ModuleType))
for f in FUNC_INLINELIST:
module_name, fn_name = f.rsplit(".", 1)
m = importlib.import_module(module_name)
self.assertTrue(isinstance(getattr(m, fn_name), types.FunctionType))
def test_func_inlinelist_torch_function(self):
def fn(x):
if istype(x, torch.Tensor):
return x + 1
else:
return x - 1
func_inlinelist = torch._dynamo.skipfiles.FUNC_INLINELIST.copy()
func_inlinelist.add("torch._dynamo.utils.istype")
self.assertTrue(
"torch._dynamo.utils" not in torch._dynamo.skipfiles.FILE_INLINELIST
)
self.assertTrue(
"torch._dynamo" not in torch._dynamo.skipfiles.SUBMODULE_INLINELIST
)
with unittest.mock.patch(
"torch._dynamo.skipfiles.get_func_inlinelist",
gen_get_func_inlinelist(func_inlinelist),
):
x = torch.rand(3)
opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_func_inlinelist_third_party_function(self):
mod, func = create_dummy_module_and_function()
def fn(x):
return func(x)
func_inlinelist = torch._dynamo.skipfiles.FUNC_INLINELIST.copy()
func_inlinelist.add(f"{mod.__name__}.{func.__name__}")
with unittest.mock.patch(
"torch._dynamo.skipfiles.get_func_inlinelist",
gen_get_func_inlinelist(func_inlinelist),
), unittest.mock.patch(
"torch._dynamo.skipfiles.SKIP_DIRS",
torch._dynamo.skipfiles.SKIP_DIRS.copy(),
):
# First adding the module to SKIP_DIRS so that it will be skipped.
torch._dynamo.skipfiles.add(mod.__name__)
x = torch.rand(3)
opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -1,4 +1,8 @@
# Owner(s): ["module: dynamo"]
import importlib
import os
import sys
import types
import torch
import torch._dynamo
@ -20,3 +24,27 @@ def outer_func(func):
return torch.sin(a + 1), inner_func()
return wrapped
# Create a dummy python module and function to test skipfiles rules.
module_code = """
def add(x):
return x + 1
"""
def add(x):
return x + 1
def create_dummy_module_and_function():
module = types.ModuleType("dummy_module")
module.__spec__ = importlib.machinery.ModuleSpec(
"dummy_module", None, origin=os.path.abspath(__file__)
)
exec(module_code, module.__dict__)
sys.modules["dummy_module"] = module
# Need to override the original function since its __code__.co_filename is not a regular python file name,
# and the skipfiles rules use filename when checking SKIP_DIRS.
module.add = add
return module, module.add

View File

@ -182,7 +182,7 @@ class OptimizedModule(torch.nn.Module):
def _initialize(self):
# Do this stuff in constructor to lower overhead slightly
if isinstance(self._orig_mod.forward, types.MethodType) and skipfiles.check(
inspect.getsourcefile(self._orig_mod.forward)
self._orig_mod.forward
):
# This may be a torch.nn.* instance in skipfiles.py which
# won't trigger a frame evaluation workaround to add an extra
@ -362,7 +362,7 @@ class _TorchDynamoContext:
except TypeError:
filename = None
if (
(filename is None or skipfiles.check(filename))
(filename is None or skipfiles.check(fn))
and (
getattr(fn, "__name__", "") not in ["_call_impl", "_wrapped_call_impl"]
)
@ -519,7 +519,7 @@ def catch_errors_wrapper(callback, hooks: Hooks):
if (
# TODO: the first condition is not covered by any test
frame.f_lasti >= first_real_inst_idx(frame.f_code)
or skipfiles.check(frame.f_code.co_filename)
or skipfiles.check(frame.f_code)
or config.disable
):
log.debug("skipping %s %s", frame.f_code.co_name, frame.f_code.co_filename)
@ -1218,7 +1218,7 @@ def export(
if (
(shape_env := getattr(fake_mode, "shape_env", None)) is not None
and (dim_constraints := shape_env.dim_constraints) is not None
and not skipfiles.check(inspect.getsourcefile(call_to_inspect))
and not skipfiles.check(call_to_inspect)
):
dim_constraints.solve()
dim_constraints.remove_redundant_dynamic_results()

View File

@ -8,7 +8,6 @@ import copyreg
import dataclasses
import enum
import functools
import glob
import importlib
import inspect
import linecache
@ -35,8 +34,14 @@ import torch
import torch._inductor.test_operators
import torch.distributed
import torch.utils._content_store
from .utils import getfile
from .variables.functions import (
NestedUserFunctionVariable,
UserFunctionVariable,
UserMethodVariable,
)
from . import comptime, external_utils, polyfill
"""
A note on skipfiles:
@ -59,10 +64,10 @@ Dynamo skip/inline rules & priorities are defined as follows:
* BUILTIN_SKIPLIST contains builtin python modules, such as abc, collections, etc.
* THIRDPARTY_SKIPLIST contains common third party libraries, such as numpy, pandas, etc.
* Functions in these two SKIPLISTs are always skipped, except when they are explicitly
put into the two INLINELIST: FILENAME_INLINELIST and SUBMODULE_INLINELIST.
put into the three INLINELIST: FUNC_INLINELIST, FILE_INLINELIST and SUBMODULE_INLINELIST.
* PyTorch(torch) is in the BUILTIN_SKIPLIST by default, but there are many cases
where we want inline the functions under torch namespace. We should add them
into FILENAME_INLINELIST or SUBMODULE_INLINELIST to make dynamo inline those functions.
into one of the three *_INLINELIST to make dynamo inline those functions.
* If you call functions under skipped modules/files, Dynamo will wrap these functions
as SkipFilesVariable. There are a few functions(e.g, collections.OrderedDict) that
we have special handling at SkipFilesVariable.call_function.
@ -70,11 +75,18 @@ Dynamo skip/inline rules & priorities are defined as follows:
Overall: *_INLINELIST has precedence over *_SKIPLIST has precedence over DEFAULT (inline)
To figure out what the behavior is, check the following list in order:
* FILENAME_INLINELIST (Inline if YES)
* FUNC_INLINELIST (Inline if YES)
* FILE_INLINELIST (Inline if YES)
* SUBMODULE_INLINELIST (Inline if YES)
* BUILTIN_SKIPLIST & THIRDPARTY_SKIPLIST (Skip if YES)
* Inline by default
In general, if you want to force inline a function or module, please consider adding
the function's file or python module to FILE_INLINELIST first.
Use the FUNC_INLINELIST only when there are other functions under the same file that
you don't want to inline.
In the future, we will consolidate FILE_INLINELIST and SUBMODULE_INLINELIST into one list
as we use the same logic (filename.startswith) to determine if a file or module is skipped.
"""
@ -102,7 +114,7 @@ BUILTIN_SKIPLIST = (
tempfile,
threading,
tokenize,
torch, # torch/* is skipped by default unless specified in FILENAME_INLINELIST or SUBMODULE_INLINELIST
torch, # torch/* is skipped by default unless specified in FILE_INLINELIST or SUBMODULE_INLINELIST
traceback,
types,
typing,
@ -145,74 +157,107 @@ def _module_dir(m: types.ModuleType):
return _strip_init_py(m.__file__)
# TODO(ybliang): Change to user *.__file__ rather than hard code string for this list.
# Force inline functions in these files, even the files is in *_SKIPLIST.
FILENAME_INLINELIST = {
torch.nn.Sequential.__init__.__code__.co_filename,
torch.set_rng_state.__code__.co_filename,
torch._inductor.test_operators.__file__,
torch.utils._content_store.__file__,
external_utils.__file__,
comptime.__file__,
polyfill.__file__,
torch.optim._functional.__file__,
torch.utils._foreach_utils.__file__,
_module_dir(torch) + "ao/quantization/pt2e/qat_utils.py",
_module_dir(torch) + "ao/quantization/quantizer/xnnpack_quantizer.py",
_module_dir(torch) + "ao/quantization/pt2e/representation/rewrite.py",
_module_dir(torch) + "ao/quantization/pt2e/utils.py",
_module_dir(torch) + "ao/quantization/pt2e/eval_utils.py",
_module_dir(torch) + "_dynamo/_trace_wrapped_higher_order_op.py",
_module_dir(torch) + "_export/constraints.py",
_module_dir(torch) + "_higher_order_ops/cond.py",
_module_dir(torch) + "_functorch/apis.py",
_module_dir(torch) + "_functorch/deprecated.py",
_module_dir(torch) + "distributed/tensor/parallel/_utils.py",
_module_dir(torch) + "distributed/tensor/parallel/style.py",
_module_dir(torch) + "distributed/tensor/parallel/_data_parallel_utils.py",
_module_dir(torch) + "distributed/_tensor/api.py",
_module_dir(torch) + "distributed/_tensor/device_mesh.py",
# TODO: Add a decoractor for easily adding functions to FUNC_INLINELIST
# after resolving all circular import issues.
FUNC_INLINELIST = {
"torch._constrain_as_size",
"torch._constrain_as_value",
}
if torch.distributed.is_available():
# Inline the checkpoint code from distributed
import torch.distributed.algorithms._checkpoint.checkpoint_wrapper
FILENAME_INLINELIST |= {
torch.distributed.algorithms._checkpoint.checkpoint_wrapper.__file__
# Force inline functions in these files or directories, even they are in *_SKIPLIST.
# We are using python module name instead of file or directory object to avoid circular dependency.
# Please keep this sorted alphabetically.
# TODO: Merge FILE_INLINELIST into SUBMODULE_INLINELIST.
FILE_INLINELIST = {
"torch._dynamo._trace_wrapped_higher_order_op",
"torch._dynamo.comptime",
"torch._dynamo.external_utils",
"torch._dynamo.polyfill",
"torch._export.db.examples",
"torch._export.wrappers",
"torch._functorch.apis",
"torch._functorch.deprecated",
"torch._higher_order_ops.cond",
"torch._inductor.test_operators",
"torch.ao.quantization.pt2e.eval_utils",
"torch.ao.quantization.pt2e.qat_utils",
"torch.ao.quantization.pt2e.representation.rewrite",
"torch.ao.quantization.pt2e.utils",
"torch.ao.quantization.quantizer.xnnpack_quantizer",
"torch.nn.modules.container",
"torch.optim._functional",
"torch.random",
"torch.utils._content_store",
"torch.utils._foreach_utils",
}
if torch.distributed.is_available():
FILE_INLINELIST |= {
"torch.distributed._tensor.api",
"torch.distributed._tensor.device_mesh",
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper",
"torch.distributed.tensor.parallel._data_parallel_utils",
"torch.distributed.tensor.parallel._utils",
"torch.distributed.tensor.parallel.style",
}
# Include optimizer code for tracing
FILENAME_INLINELIST |= {
inspect.getfile(obj)
for obj in torch.optim.__dict__.values()
if inspect.isclass(obj)
}
# TODO (zhxchen17) Make exportdb importable here.
FILENAME_INLINELIST |= set(
glob.glob(_module_dir(torch) + "_export/db/examples/*.py"),
) | {
_module_dir(torch) + "_export/wrappers.py",
FILE_INLINELIST |= {
str(obj.__module__) for obj in torch.optim.__dict__.values() if inspect.isclass(obj)
}
# TODO: consolidate SUBMODULE_INLINELIST and FILE_INLINELIST into one list
# Force inline functions under these modules, even the modules is in *_SKIPLIST.
SUBMODULE_INLINELIST = {
torch.nn,
torch.distributions,
torch.testing,
torch.ao.nn,
torch._refs,
torch._prims,
torch._decomp,
torch.utils._contextlib,
torch.utils._pytree,
torch.fx._pytree,
torch.sparse,
"torch._refs",
"torch._prims",
"torch._decomp",
"torch.ao.nn",
"torch.distributions",
"torch.fx._pytree",
"torch.nn",
"torch.sparse",
"torch.testing",
"torch.utils._contextlib",
"torch.utils._pytree",
}
if torch.distributed.is_available():
SUBMODULE_INLINELIST.add("torch.distributed._functional_collectives")
# TODO: support adding bound method into this list
@functools.lru_cache(None)
def get_func_inlinelist():
inlinelist = set()
for f in FUNC_INLINELIST:
module_name, fn_name = f.rsplit(".", 1)
m = importlib.import_module(module_name)
fn = getattr(m, fn_name)
inlinelist.add(fn.__code__)
return inlinelist
@functools.lru_cache(None)
def get_file_inlinelist():
inlinelist = set()
for f in FILE_INLINELIST:
inlinelist.add(_module_dir(torch) + f[len("torch.") :].replace(".", "/"))
return inlinelist
@functools.lru_cache(None)
def get_submodule_inlinelist():
inlinelist = set()
for m in SUBMODULE_INLINELIST:
inlinelist.add(_module_dir(torch) + m[len("torch.") :].replace(".", "/"))
return inlinelist
# skip some standard python builtin libs
SKIP_DIRS = [
"<frozen importlib",
@ -258,15 +303,15 @@ class SkipResult:
reason: Optional[str]
# TODO(ybliang): This is a temp function, we should consolidate this with check_verbose.
def _check_verbose_inner(filename, allow_torch=False):
# TODO(ybliang): This is a temp function, we should consolidate this with check_file.
def _check_file_inner(filename, allow_torch=False):
"""Should skip this file?"""
if filename is None:
return SkipResult(True, "filename is None")
if filename in FILENAME_INLINELIST:
if any(filename.startswith(d) for d in get_file_inlinelist()):
return SkipResult(
False,
"inlined according skipfiles.FILENAME_INLINELIST",
"inlined according skipfiles.FILE_INLINELIST",
)
# TODO(ybliang): the is_torch check should be consolidate with is_torch_inline_allowed
if allow_torch and is_torch(filename):
@ -285,8 +330,8 @@ def _check_verbose_inner(filename, allow_torch=False):
return SkipResult(False, "inlined by default")
def check_verbose(filename, allow_torch=False, extra_check=False):
result = _check_verbose_inner(filename, allow_torch)
def check_file(filename, allow_torch=False, extra_check=False):
result = _check_file_inner(filename, allow_torch)
if extra_check and result.skipped and is_torch_inline_allowed(filename):
return SkipResult(
False,
@ -296,8 +341,57 @@ def check_verbose(filename, allow_torch=False, extra_check=False):
return result
def check(filename, allow_torch=False, extra_check=False):
return check_verbose(filename, allow_torch, extra_check).skipped
"""
This is the main entry point to determine whether an object (function) should be inlined or skipped.
Let's illustrate the logic with an example:
@torch.compile
def f1(x, y):
......
f2(x, y)
......
def f2(x, y):
......
f3(x, y)
......
def f3(x, y):
......
There are mainly three call sites of check/check_verbose:
* The compile region entrance (like function f1), the correspoinding code is located at eval_frame.py.
* When tracing the recursively called functions (like function f2 and f3).
* Dynamo decides inline/skip everytime it encounters a new recursively function call, and the call site
is in InliningInstructionTranslator.check_inlineable of symbolic_convert.py.
* If f2 is skipped by Dynamo, when evaluating the frame of f3, Dynamo need the inline/skip check again
and the call site is in catch_errors_wrapper.catch_errors of eval_frame.py.
* For global variables and function arguments, Dynamo needs to decide if they are wrapped as SkipFilesVariable in builder.py.
"""
def check_verbose(obj, allow_torch=False, extra_check=False):
if isinstance(
obj, (UserFunctionVariable, UserMethodVariable, NestedUserFunctionVariable)
):
filename = obj.get_filename()
obj = obj.get_code()
elif isinstance(obj, types.CodeType):
filename = obj.co_filename
elif isinstance(obj, (types.FunctionType, types.MethodType)):
filename = getfile(obj)
obj = obj.__code__
else:
filename = getfile(obj)
if obj in get_func_inlinelist():
return SkipResult(
False,
"inlined according skipfiles.FUNC_INLINELIST",
)
return check_file(filename, allow_torch, extra_check)
def check(obj, allow_torch=False, extra_check=False):
return check_verbose(obj, allow_torch, extra_check).skipped
# skip common third party libs
@ -308,12 +402,7 @@ _recompile_re()
def is_torch_inline_allowed(filename):
if torch.distributed.is_available():
from torch.distributed import _functional_collectives
SUBMODULE_INLINELIST.add(_functional_collectives)
return any(filename.startswith(_module_dir(mod)) for mod in SUBMODULE_INLINELIST)
return any(filename.startswith(d) for d in get_submodule_inlinelist())
@functools.lru_cache(None)

View File

@ -2247,7 +2247,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
except NotImplementedError:
pass # closures
result = skipfiles.check_verbose(func.get_filename(), extra_check=True)
result = skipfiles.check_verbose(func, extra_check=True)
if result.skipped:
from torch._dynamo.variables.misc import (
produce_trampoline_autograd_apply,

View File

@ -60,7 +60,6 @@ from ..utils import (
clone_input,
get_fake_value,
get_static_address_type,
getfile,
global_key_name,
is_namedtuple,
is_typing,
@ -534,12 +533,12 @@ class VariableBuilder:
)
elif (
istype(value, (type, types.FunctionType))
and skipfiles.check(getfile(value), allow_torch=True)
and skipfiles.check(value, allow_torch=True)
and not inspect.getattr_static(value, "_torchdynamo_inline", False)
):
return SkipFilesVariable(
value,
skipfiles.check_verbose(getfile(value), allow_torch=True).reason,
skipfiles.check_verbose(value, allow_torch=True).reason,
source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
)