mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[2/N] Dynamo supports skip by function & removes skipfiles circular import (#110835)
Several improvements for skipfiles: * Add ```FUNC_INLINELIST``` to support function level skip/inline check. * Use ```fn.__code__``` to match function since we can't get the function object sometimes. * Use python module string name for ```FILE_INLINELIST``` and ```SUBMODULE_INLINELIST```. * Use filename to match file and python module, which can fundamentally resolved the circular import issues introduced by skipfiles. * Use ```TYPE_CHECKING``` to ensure the python module string name is correct. * Add unit tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/110835 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
a6b452dfdc
commit
986ad3bfa6
101
test/dynamo/test_allow_inline_skip.py
Normal file
101
test/dynamo/test_allow_inline_skip.py
Normal file
@ -0,0 +1,101 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import importlib
|
||||
import types
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
from torch._dynamo.skipfiles import (
|
||||
FILE_INLINELIST,
|
||||
FUNC_INLINELIST,
|
||||
SUBMODULE_INLINELIST,
|
||||
)
|
||||
from torch._dynamo.utils import istype
|
||||
|
||||
try:
|
||||
from .utils import create_dummy_module_and_function
|
||||
except ImportError:
|
||||
from utils import create_dummy_module_and_function
|
||||
|
||||
|
||||
def gen_get_func_inlinelist(dummy_func_inlinelist):
|
||||
def get_func_inlinelist():
|
||||
inlinelist = set()
|
||||
for f in dummy_func_inlinelist:
|
||||
module_name, fn_name = f.rsplit(".", 1)
|
||||
m = importlib.import_module(module_name)
|
||||
fn = getattr(m, fn_name)
|
||||
inlinelist.add(fn.__code__)
|
||||
return inlinelist
|
||||
|
||||
return get_func_inlinelist
|
||||
|
||||
|
||||
class AllowInlineSkipTests(torch._dynamo.test_case.TestCase):
|
||||
# We are using python function and module string names for these inlinelist,
|
||||
# this unit test is to make sure the functions/modules can be correctly imported
|
||||
# or loaded in case there is typo in the strings.
|
||||
def test_skipfiles_inlinelist_correctness(self):
|
||||
for m in FILE_INLINELIST.union(SUBMODULE_INLINELIST):
|
||||
self.assertTrue(isinstance(importlib.import_module(m), types.ModuleType))
|
||||
for f in FUNC_INLINELIST:
|
||||
module_name, fn_name = f.rsplit(".", 1)
|
||||
m = importlib.import_module(module_name)
|
||||
self.assertTrue(isinstance(getattr(m, fn_name), types.FunctionType))
|
||||
|
||||
def test_func_inlinelist_torch_function(self):
|
||||
def fn(x):
|
||||
if istype(x, torch.Tensor):
|
||||
return x + 1
|
||||
else:
|
||||
return x - 1
|
||||
|
||||
func_inlinelist = torch._dynamo.skipfiles.FUNC_INLINELIST.copy()
|
||||
func_inlinelist.add("torch._dynamo.utils.istype")
|
||||
|
||||
self.assertTrue(
|
||||
"torch._dynamo.utils" not in torch._dynamo.skipfiles.FILE_INLINELIST
|
||||
)
|
||||
self.assertTrue(
|
||||
"torch._dynamo" not in torch._dynamo.skipfiles.SUBMODULE_INLINELIST
|
||||
)
|
||||
|
||||
with unittest.mock.patch(
|
||||
"torch._dynamo.skipfiles.get_func_inlinelist",
|
||||
gen_get_func_inlinelist(func_inlinelist),
|
||||
):
|
||||
x = torch.rand(3)
|
||||
opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
|
||||
ref = fn(x)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_func_inlinelist_third_party_function(self):
|
||||
mod, func = create_dummy_module_and_function()
|
||||
|
||||
def fn(x):
|
||||
return func(x)
|
||||
|
||||
func_inlinelist = torch._dynamo.skipfiles.FUNC_INLINELIST.copy()
|
||||
func_inlinelist.add(f"{mod.__name__}.{func.__name__}")
|
||||
|
||||
with unittest.mock.patch(
|
||||
"torch._dynamo.skipfiles.get_func_inlinelist",
|
||||
gen_get_func_inlinelist(func_inlinelist),
|
||||
), unittest.mock.patch(
|
||||
"torch._dynamo.skipfiles.SKIP_DIRS",
|
||||
torch._dynamo.skipfiles.SKIP_DIRS.copy(),
|
||||
):
|
||||
# First adding the module to SKIP_DIRS so that it will be skipped.
|
||||
torch._dynamo.skipfiles.add(mod.__name__)
|
||||
x = torch.rand(3)
|
||||
opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
|
||||
ref = fn(x)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
@ -1,4 +1,8 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
@ -20,3 +24,27 @@ def outer_func(func):
|
||||
return torch.sin(a + 1), inner_func()
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
# Create a dummy python module and function to test skipfiles rules.
|
||||
module_code = """
|
||||
def add(x):
|
||||
return x + 1
|
||||
"""
|
||||
|
||||
|
||||
def add(x):
|
||||
return x + 1
|
||||
|
||||
|
||||
def create_dummy_module_and_function():
|
||||
module = types.ModuleType("dummy_module")
|
||||
module.__spec__ = importlib.machinery.ModuleSpec(
|
||||
"dummy_module", None, origin=os.path.abspath(__file__)
|
||||
)
|
||||
exec(module_code, module.__dict__)
|
||||
sys.modules["dummy_module"] = module
|
||||
# Need to override the original function since its __code__.co_filename is not a regular python file name,
|
||||
# and the skipfiles rules use filename when checking SKIP_DIRS.
|
||||
module.add = add
|
||||
return module, module.add
|
||||
|
@ -182,7 +182,7 @@ class OptimizedModule(torch.nn.Module):
|
||||
def _initialize(self):
|
||||
# Do this stuff in constructor to lower overhead slightly
|
||||
if isinstance(self._orig_mod.forward, types.MethodType) and skipfiles.check(
|
||||
inspect.getsourcefile(self._orig_mod.forward)
|
||||
self._orig_mod.forward
|
||||
):
|
||||
# This may be a torch.nn.* instance in skipfiles.py which
|
||||
# won't trigger a frame evaluation workaround to add an extra
|
||||
@ -362,7 +362,7 @@ class _TorchDynamoContext:
|
||||
except TypeError:
|
||||
filename = None
|
||||
if (
|
||||
(filename is None or skipfiles.check(filename))
|
||||
(filename is None or skipfiles.check(fn))
|
||||
and (
|
||||
getattr(fn, "__name__", "") not in ["_call_impl", "_wrapped_call_impl"]
|
||||
)
|
||||
@ -519,7 +519,7 @@ def catch_errors_wrapper(callback, hooks: Hooks):
|
||||
if (
|
||||
# TODO: the first condition is not covered by any test
|
||||
frame.f_lasti >= first_real_inst_idx(frame.f_code)
|
||||
or skipfiles.check(frame.f_code.co_filename)
|
||||
or skipfiles.check(frame.f_code)
|
||||
or config.disable
|
||||
):
|
||||
log.debug("skipping %s %s", frame.f_code.co_name, frame.f_code.co_filename)
|
||||
@ -1218,7 +1218,7 @@ def export(
|
||||
if (
|
||||
(shape_env := getattr(fake_mode, "shape_env", None)) is not None
|
||||
and (dim_constraints := shape_env.dim_constraints) is not None
|
||||
and not skipfiles.check(inspect.getsourcefile(call_to_inspect))
|
||||
and not skipfiles.check(call_to_inspect)
|
||||
):
|
||||
dim_constraints.solve()
|
||||
dim_constraints.remove_redundant_dynamic_results()
|
||||
|
@ -8,7 +8,6 @@ import copyreg
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
import glob
|
||||
import importlib
|
||||
import inspect
|
||||
import linecache
|
||||
@ -35,8 +34,14 @@ import torch
|
||||
import torch._inductor.test_operators
|
||||
import torch.distributed
|
||||
import torch.utils._content_store
|
||||
from .utils import getfile
|
||||
|
||||
from .variables.functions import (
|
||||
NestedUserFunctionVariable,
|
||||
UserFunctionVariable,
|
||||
UserMethodVariable,
|
||||
)
|
||||
|
||||
from . import comptime, external_utils, polyfill
|
||||
|
||||
"""
|
||||
A note on skipfiles:
|
||||
@ -59,10 +64,10 @@ Dynamo skip/inline rules & priorities are defined as follows:
|
||||
* BUILTIN_SKIPLIST contains builtin python modules, such as abc, collections, etc.
|
||||
* THIRDPARTY_SKIPLIST contains common third party libraries, such as numpy, pandas, etc.
|
||||
* Functions in these two SKIPLISTs are always skipped, except when they are explicitly
|
||||
put into the two INLINELIST: FILENAME_INLINELIST and SUBMODULE_INLINELIST.
|
||||
put into the three INLINELIST: FUNC_INLINELIST, FILE_INLINELIST and SUBMODULE_INLINELIST.
|
||||
* PyTorch(torch) is in the BUILTIN_SKIPLIST by default, but there are many cases
|
||||
where we want inline the functions under torch namespace. We should add them
|
||||
into FILENAME_INLINELIST or SUBMODULE_INLINELIST to make dynamo inline those functions.
|
||||
into one of the three *_INLINELIST to make dynamo inline those functions.
|
||||
* If you call functions under skipped modules/files, Dynamo will wrap these functions
|
||||
as SkipFilesVariable. There are a few functions(e.g, collections.OrderedDict) that
|
||||
we have special handling at SkipFilesVariable.call_function.
|
||||
@ -70,11 +75,18 @@ Dynamo skip/inline rules & priorities are defined as follows:
|
||||
Overall: *_INLINELIST has precedence over *_SKIPLIST has precedence over DEFAULT (inline)
|
||||
|
||||
To figure out what the behavior is, check the following list in order:
|
||||
* FILENAME_INLINELIST (Inline if YES)
|
||||
* FUNC_INLINELIST (Inline if YES)
|
||||
* FILE_INLINELIST (Inline if YES)
|
||||
* SUBMODULE_INLINELIST (Inline if YES)
|
||||
* BUILTIN_SKIPLIST & THIRDPARTY_SKIPLIST (Skip if YES)
|
||||
* Inline by default
|
||||
|
||||
In general, if you want to force inline a function or module, please consider adding
|
||||
the function's file or python module to FILE_INLINELIST first.
|
||||
Use the FUNC_INLINELIST only when there are other functions under the same file that
|
||||
you don't want to inline.
|
||||
In the future, we will consolidate FILE_INLINELIST and SUBMODULE_INLINELIST into one list
|
||||
as we use the same logic (filename.startswith) to determine if a file or module is skipped.
|
||||
"""
|
||||
|
||||
|
||||
@ -102,7 +114,7 @@ BUILTIN_SKIPLIST = (
|
||||
tempfile,
|
||||
threading,
|
||||
tokenize,
|
||||
torch, # torch/* is skipped by default unless specified in FILENAME_INLINELIST or SUBMODULE_INLINELIST
|
||||
torch, # torch/* is skipped by default unless specified in FILE_INLINELIST or SUBMODULE_INLINELIST
|
||||
traceback,
|
||||
types,
|
||||
typing,
|
||||
@ -145,74 +157,107 @@ def _module_dir(m: types.ModuleType):
|
||||
return _strip_init_py(m.__file__)
|
||||
|
||||
|
||||
# TODO(ybliang): Change to user *.__file__ rather than hard code string for this list.
|
||||
# Force inline functions in these files, even the files is in *_SKIPLIST.
|
||||
FILENAME_INLINELIST = {
|
||||
torch.nn.Sequential.__init__.__code__.co_filename,
|
||||
torch.set_rng_state.__code__.co_filename,
|
||||
torch._inductor.test_operators.__file__,
|
||||
torch.utils._content_store.__file__,
|
||||
external_utils.__file__,
|
||||
comptime.__file__,
|
||||
polyfill.__file__,
|
||||
torch.optim._functional.__file__,
|
||||
torch.utils._foreach_utils.__file__,
|
||||
_module_dir(torch) + "ao/quantization/pt2e/qat_utils.py",
|
||||
_module_dir(torch) + "ao/quantization/quantizer/xnnpack_quantizer.py",
|
||||
_module_dir(torch) + "ao/quantization/pt2e/representation/rewrite.py",
|
||||
_module_dir(torch) + "ao/quantization/pt2e/utils.py",
|
||||
_module_dir(torch) + "ao/quantization/pt2e/eval_utils.py",
|
||||
_module_dir(torch) + "_dynamo/_trace_wrapped_higher_order_op.py",
|
||||
_module_dir(torch) + "_export/constraints.py",
|
||||
_module_dir(torch) + "_higher_order_ops/cond.py",
|
||||
_module_dir(torch) + "_functorch/apis.py",
|
||||
_module_dir(torch) + "_functorch/deprecated.py",
|
||||
_module_dir(torch) + "distributed/tensor/parallel/_utils.py",
|
||||
_module_dir(torch) + "distributed/tensor/parallel/style.py",
|
||||
_module_dir(torch) + "distributed/tensor/parallel/_data_parallel_utils.py",
|
||||
_module_dir(torch) + "distributed/_tensor/api.py",
|
||||
_module_dir(torch) + "distributed/_tensor/device_mesh.py",
|
||||
# TODO: Add a decoractor for easily adding functions to FUNC_INLINELIST
|
||||
# after resolving all circular import issues.
|
||||
FUNC_INLINELIST = {
|
||||
"torch._constrain_as_size",
|
||||
"torch._constrain_as_value",
|
||||
}
|
||||
|
||||
if torch.distributed.is_available():
|
||||
# Inline the checkpoint code from distributed
|
||||
import torch.distributed.algorithms._checkpoint.checkpoint_wrapper
|
||||
|
||||
FILENAME_INLINELIST |= {
|
||||
torch.distributed.algorithms._checkpoint.checkpoint_wrapper.__file__
|
||||
# Force inline functions in these files or directories, even they are in *_SKIPLIST.
|
||||
# We are using python module name instead of file or directory object to avoid circular dependency.
|
||||
# Please keep this sorted alphabetically.
|
||||
# TODO: Merge FILE_INLINELIST into SUBMODULE_INLINELIST.
|
||||
FILE_INLINELIST = {
|
||||
"torch._dynamo._trace_wrapped_higher_order_op",
|
||||
"torch._dynamo.comptime",
|
||||
"torch._dynamo.external_utils",
|
||||
"torch._dynamo.polyfill",
|
||||
"torch._export.db.examples",
|
||||
"torch._export.wrappers",
|
||||
"torch._functorch.apis",
|
||||
"torch._functorch.deprecated",
|
||||
"torch._higher_order_ops.cond",
|
||||
"torch._inductor.test_operators",
|
||||
"torch.ao.quantization.pt2e.eval_utils",
|
||||
"torch.ao.quantization.pt2e.qat_utils",
|
||||
"torch.ao.quantization.pt2e.representation.rewrite",
|
||||
"torch.ao.quantization.pt2e.utils",
|
||||
"torch.ao.quantization.quantizer.xnnpack_quantizer",
|
||||
"torch.nn.modules.container",
|
||||
"torch.optim._functional",
|
||||
"torch.random",
|
||||
"torch.utils._content_store",
|
||||
"torch.utils._foreach_utils",
|
||||
}
|
||||
|
||||
|
||||
if torch.distributed.is_available():
|
||||
FILE_INLINELIST |= {
|
||||
"torch.distributed._tensor.api",
|
||||
"torch.distributed._tensor.device_mesh",
|
||||
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper",
|
||||
"torch.distributed.tensor.parallel._data_parallel_utils",
|
||||
"torch.distributed.tensor.parallel._utils",
|
||||
"torch.distributed.tensor.parallel.style",
|
||||
}
|
||||
|
||||
# Include optimizer code for tracing
|
||||
FILENAME_INLINELIST |= {
|
||||
inspect.getfile(obj)
|
||||
for obj in torch.optim.__dict__.values()
|
||||
if inspect.isclass(obj)
|
||||
}
|
||||
|
||||
# TODO (zhxchen17) Make exportdb importable here.
|
||||
FILENAME_INLINELIST |= set(
|
||||
glob.glob(_module_dir(torch) + "_export/db/examples/*.py"),
|
||||
) | {
|
||||
_module_dir(torch) + "_export/wrappers.py",
|
||||
FILE_INLINELIST |= {
|
||||
str(obj.__module__) for obj in torch.optim.__dict__.values() if inspect.isclass(obj)
|
||||
}
|
||||
|
||||
|
||||
# TODO: consolidate SUBMODULE_INLINELIST and FILE_INLINELIST into one list
|
||||
# Force inline functions under these modules, even the modules is in *_SKIPLIST.
|
||||
SUBMODULE_INLINELIST = {
|
||||
torch.nn,
|
||||
torch.distributions,
|
||||
torch.testing,
|
||||
torch.ao.nn,
|
||||
torch._refs,
|
||||
torch._prims,
|
||||
torch._decomp,
|
||||
torch.utils._contextlib,
|
||||
torch.utils._pytree,
|
||||
torch.fx._pytree,
|
||||
torch.sparse,
|
||||
"torch._refs",
|
||||
"torch._prims",
|
||||
"torch._decomp",
|
||||
"torch.ao.nn",
|
||||
"torch.distributions",
|
||||
"torch.fx._pytree",
|
||||
"torch.nn",
|
||||
"torch.sparse",
|
||||
"torch.testing",
|
||||
"torch.utils._contextlib",
|
||||
"torch.utils._pytree",
|
||||
}
|
||||
|
||||
|
||||
if torch.distributed.is_available():
|
||||
SUBMODULE_INLINELIST.add("torch.distributed._functional_collectives")
|
||||
|
||||
|
||||
# TODO: support adding bound method into this list
|
||||
@functools.lru_cache(None)
|
||||
def get_func_inlinelist():
|
||||
inlinelist = set()
|
||||
for f in FUNC_INLINELIST:
|
||||
module_name, fn_name = f.rsplit(".", 1)
|
||||
m = importlib.import_module(module_name)
|
||||
fn = getattr(m, fn_name)
|
||||
inlinelist.add(fn.__code__)
|
||||
return inlinelist
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_file_inlinelist():
|
||||
inlinelist = set()
|
||||
for f in FILE_INLINELIST:
|
||||
inlinelist.add(_module_dir(torch) + f[len("torch.") :].replace(".", "/"))
|
||||
return inlinelist
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_submodule_inlinelist():
|
||||
inlinelist = set()
|
||||
for m in SUBMODULE_INLINELIST:
|
||||
inlinelist.add(_module_dir(torch) + m[len("torch.") :].replace(".", "/"))
|
||||
return inlinelist
|
||||
|
||||
|
||||
# skip some standard python builtin libs
|
||||
SKIP_DIRS = [
|
||||
"<frozen importlib",
|
||||
@ -258,15 +303,15 @@ class SkipResult:
|
||||
reason: Optional[str]
|
||||
|
||||
|
||||
# TODO(ybliang): This is a temp function, we should consolidate this with check_verbose.
|
||||
def _check_verbose_inner(filename, allow_torch=False):
|
||||
# TODO(ybliang): This is a temp function, we should consolidate this with check_file.
|
||||
def _check_file_inner(filename, allow_torch=False):
|
||||
"""Should skip this file?"""
|
||||
if filename is None:
|
||||
return SkipResult(True, "filename is None")
|
||||
if filename in FILENAME_INLINELIST:
|
||||
if any(filename.startswith(d) for d in get_file_inlinelist()):
|
||||
return SkipResult(
|
||||
False,
|
||||
"inlined according skipfiles.FILENAME_INLINELIST",
|
||||
"inlined according skipfiles.FILE_INLINELIST",
|
||||
)
|
||||
# TODO(ybliang): the is_torch check should be consolidate with is_torch_inline_allowed
|
||||
if allow_torch and is_torch(filename):
|
||||
@ -285,8 +330,8 @@ def _check_verbose_inner(filename, allow_torch=False):
|
||||
return SkipResult(False, "inlined by default")
|
||||
|
||||
|
||||
def check_verbose(filename, allow_torch=False, extra_check=False):
|
||||
result = _check_verbose_inner(filename, allow_torch)
|
||||
def check_file(filename, allow_torch=False, extra_check=False):
|
||||
result = _check_file_inner(filename, allow_torch)
|
||||
if extra_check and result.skipped and is_torch_inline_allowed(filename):
|
||||
return SkipResult(
|
||||
False,
|
||||
@ -296,8 +341,57 @@ def check_verbose(filename, allow_torch=False, extra_check=False):
|
||||
return result
|
||||
|
||||
|
||||
def check(filename, allow_torch=False, extra_check=False):
|
||||
return check_verbose(filename, allow_torch, extra_check).skipped
|
||||
"""
|
||||
This is the main entry point to determine whether an object (function) should be inlined or skipped.
|
||||
Let's illustrate the logic with an example:
|
||||
@torch.compile
|
||||
def f1(x, y):
|
||||
......
|
||||
f2(x, y)
|
||||
......
|
||||
|
||||
def f2(x, y):
|
||||
......
|
||||
f3(x, y)
|
||||
......
|
||||
|
||||
def f3(x, y):
|
||||
......
|
||||
|
||||
There are mainly three call sites of check/check_verbose:
|
||||
* The compile region entrance (like function f1), the correspoinding code is located at eval_frame.py.
|
||||
* When tracing the recursively called functions (like function f2 and f3).
|
||||
* Dynamo decides inline/skip everytime it encounters a new recursively function call, and the call site
|
||||
is in InliningInstructionTranslator.check_inlineable of symbolic_convert.py.
|
||||
* If f2 is skipped by Dynamo, when evaluating the frame of f3, Dynamo need the inline/skip check again
|
||||
and the call site is in catch_errors_wrapper.catch_errors of eval_frame.py.
|
||||
* For global variables and function arguments, Dynamo needs to decide if they are wrapped as SkipFilesVariable in builder.py.
|
||||
"""
|
||||
|
||||
|
||||
def check_verbose(obj, allow_torch=False, extra_check=False):
|
||||
if isinstance(
|
||||
obj, (UserFunctionVariable, UserMethodVariable, NestedUserFunctionVariable)
|
||||
):
|
||||
filename = obj.get_filename()
|
||||
obj = obj.get_code()
|
||||
elif isinstance(obj, types.CodeType):
|
||||
filename = obj.co_filename
|
||||
elif isinstance(obj, (types.FunctionType, types.MethodType)):
|
||||
filename = getfile(obj)
|
||||
obj = obj.__code__
|
||||
else:
|
||||
filename = getfile(obj)
|
||||
if obj in get_func_inlinelist():
|
||||
return SkipResult(
|
||||
False,
|
||||
"inlined according skipfiles.FUNC_INLINELIST",
|
||||
)
|
||||
return check_file(filename, allow_torch, extra_check)
|
||||
|
||||
|
||||
def check(obj, allow_torch=False, extra_check=False):
|
||||
return check_verbose(obj, allow_torch, extra_check).skipped
|
||||
|
||||
|
||||
# skip common third party libs
|
||||
@ -308,12 +402,7 @@ _recompile_re()
|
||||
|
||||
|
||||
def is_torch_inline_allowed(filename):
|
||||
if torch.distributed.is_available():
|
||||
from torch.distributed import _functional_collectives
|
||||
|
||||
SUBMODULE_INLINELIST.add(_functional_collectives)
|
||||
|
||||
return any(filename.startswith(_module_dir(mod)) for mod in SUBMODULE_INLINELIST)
|
||||
return any(filename.startswith(d) for d in get_submodule_inlinelist())
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
|
@ -2247,7 +2247,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
except NotImplementedError:
|
||||
pass # closures
|
||||
|
||||
result = skipfiles.check_verbose(func.get_filename(), extra_check=True)
|
||||
result = skipfiles.check_verbose(func, extra_check=True)
|
||||
if result.skipped:
|
||||
from torch._dynamo.variables.misc import (
|
||||
produce_trampoline_autograd_apply,
|
||||
|
@ -60,7 +60,6 @@ from ..utils import (
|
||||
clone_input,
|
||||
get_fake_value,
|
||||
get_static_address_type,
|
||||
getfile,
|
||||
global_key_name,
|
||||
is_namedtuple,
|
||||
is_typing,
|
||||
@ -534,12 +533,12 @@ class VariableBuilder:
|
||||
)
|
||||
elif (
|
||||
istype(value, (type, types.FunctionType))
|
||||
and skipfiles.check(getfile(value), allow_torch=True)
|
||||
and skipfiles.check(value, allow_torch=True)
|
||||
and not inspect.getattr_static(value, "_torchdynamo_inline", False)
|
||||
):
|
||||
return SkipFilesVariable(
|
||||
value,
|
||||
skipfiles.check_verbose(getfile(value), allow_torch=True).reason,
|
||||
skipfiles.check_verbose(value, allow_torch=True).reason,
|
||||
source=self.source,
|
||||
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
||||
)
|
||||
|
Reference in New Issue
Block a user