mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157639 Approved by: https://github.com/yewentao256, https://github.com/jansel ghstack dependencies: #157638
527 lines
20 KiB
Python
527 lines
20 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import dataclasses
|
|
import importlib
|
|
import inspect
|
|
import math
|
|
import types
|
|
import unittest
|
|
import warnings
|
|
from typing import Any
|
|
|
|
import torch
|
|
import torch._dynamo.config as config
|
|
import torch._dynamo.test_case
|
|
import torch._functorch.deprecated as deprecated_func
|
|
from torch._dynamo.trace_rules import (
|
|
LEGACY_MOD_INLINELIST,
|
|
load_object,
|
|
lookup_inner,
|
|
manual_torch_name_rule_map,
|
|
MOD_INLINELIST,
|
|
torch_c_binding_in_graph_functions,
|
|
torch_non_c_binding_in_graph_functions,
|
|
)
|
|
from torch._dynamo.utils import hashable, is_safe_constant, istype
|
|
from torch._dynamo.variables import (
|
|
SkipFunctionVariable,
|
|
TorchInGraphFunctionVariable,
|
|
UserFunctionVariable,
|
|
)
|
|
from torch.testing._internal.common_utils import skipIfWindows
|
|
|
|
|
|
try:
|
|
from .utils import create_dummy_module_and_function
|
|
except ImportError:
|
|
from utils import create_dummy_module_and_function
|
|
|
|
|
|
ignored_c_binding_in_graph_function_names = {
|
|
# Ignored because they have manual rules defined at `trace_rules.manual_torch_name_rule_map`.
|
|
"torch._nested_tensor_from_mask",
|
|
"torch._nested_from_padded",
|
|
"torch.sparse_compressed_tensor",
|
|
"torch.sparse_bsc_tensor",
|
|
"torch.sparse_bsr_tensor",
|
|
"torch.sparse_coo_tensor",
|
|
"torch.sparse_csc_tensor",
|
|
"torch.sparse_csr_tensor",
|
|
"torch.cuda._get_device_properties",
|
|
# Ignored and go through rules defined at `trace_rules.check`.
|
|
"torch._functionalize_are_all_mutations_under_no_grad_or_inference_mode",
|
|
"torch._cslt_sparse_mm_search",
|
|
"torch._C._abort",
|
|
"torch._C._mps_is_on_macos_or_newer",
|
|
"torch._C._swap_tensor_impl",
|
|
"torch._C._unsafe_reset_storage",
|
|
"torch._dynamo.eval_frame.reset_code",
|
|
"torch._C.autocast_decrement_nesting",
|
|
"torch._C.autocast_increment_nesting",
|
|
"torch._C.clear_autocast_cache",
|
|
"torch._C.set_anomaly_enabled",
|
|
"torch._C.set_autocast_cache_enabled",
|
|
"torch._C.set_autocast_cpu_dtype",
|
|
"torch._C.set_autocast_cpu_enabled",
|
|
"torch._C.set_autocast_enabled",
|
|
"torch._C.set_autocast_gpu_dtype",
|
|
"torch._C.set_autocast_ipu_dtype",
|
|
"torch._C.set_autocast_ipu_enabled",
|
|
"torch._C.set_autocast_xla_dtype",
|
|
"torch._C.set_autocast_xla_enabled",
|
|
"torch.resize_as_",
|
|
"torch.resize_as_sparse_",
|
|
"torch._C._data_address",
|
|
"torch._C._is_cow_tensor",
|
|
"torch._lazy_clone",
|
|
"torch._test_parallel_materialize",
|
|
"torch._C._storage_address",
|
|
"torch._C._pickle_save",
|
|
"torch._validate_sparse_compressed_tensor_args",
|
|
"torch._validate_sparse_csr_tensor_args",
|
|
"torch._validate_sparse_bsr_tensor_args",
|
|
"torch._validate_sparse_csc_tensor_args",
|
|
"torch._validate_sparse_coo_tensor_args",
|
|
"torch._validate_sparse_bsc_tensor_args",
|
|
"torch._validate_compressed_sparse_indices",
|
|
}
|
|
if torch._C._llvm_enabled():
|
|
ignored_c_binding_in_graph_function_names |= {
|
|
"torch._C._te.set_llvm_aot_workflow",
|
|
"torch._C._te.set_llvm_target_cpu",
|
|
"torch._C._te.set_llvm_target_attrs",
|
|
"torch._C._te.set_llvm_target_triple",
|
|
}
|
|
|
|
|
|
# Helper function to dump the torch name rule map generated based on
|
|
# the heuristic defined in gen_allowed_objs_and_ids.
|
|
def dump_allowed_torch_name_rule_map() -> None:
|
|
m = gen_allowed_objs_and_ids(record=True, c_binding_only=False).name_rule_map
|
|
for k, v in m.items():
|
|
print(f'"{k}": {v.__name__},')
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class AllowedObjects:
|
|
"""
|
|
Track the objects, object id - name pairs, and name - dynamo wrapping rule pairs
|
|
from the heuristic defined in `gen_allowed_objs_and_ids`.
|
|
"""
|
|
|
|
object_ids: dict[int, str]
|
|
c_binding_in_graph_functions: set[Any]
|
|
non_c_binding_in_graph_functions: set[Any]
|
|
name_rule_map: dict[str, Any]
|
|
|
|
|
|
def gen_allowed_objs_and_ids(record=False, c_binding_only=True) -> AllowedObjects:
|
|
"""
|
|
Walk torch.* and get the ids of all the stuff in it
|
|
"""
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="torch.distributed")
|
|
torch_object_ids = {}
|
|
c_binding_in_graph_functions = set()
|
|
non_c_binding_in_graph_functions = set()
|
|
torch_name_rule_map = {}
|
|
|
|
# In some platforms, these functions were loaded as classes instead of functions.
|
|
# To mitigate these weird cases, we need this special check.
|
|
def is_special_functions(obj):
|
|
return hashable(obj) and obj in {
|
|
torch._C._cuda_isCurrentStreamCapturing,
|
|
torch._C._graph_pool_handle,
|
|
}
|
|
|
|
# Add obj to c_binding_in_graph_functions set or non_c_binding_in_graph_functions set
|
|
# if it's a torch function or method.
|
|
# This is used to generate the in graph function list based on heuristic.
|
|
def heuristic_record_if_in_graph_function(obj, module, name):
|
|
try:
|
|
if hasattr(obj, "__wrapped__"):
|
|
obj = obj.__wrapped__
|
|
except Exception:
|
|
pass
|
|
if isinstance(
|
|
obj,
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
types.MethodDescriptorType,
|
|
types.WrapperDescriptorType,
|
|
),
|
|
) or is_special_functions(obj):
|
|
torch_name_rule_map[f"{module.__name__}.{name}"] = (
|
|
TorchInGraphFunctionVariable
|
|
)
|
|
if c_binding_only:
|
|
if not hasattr(obj, "__code__"):
|
|
c_binding_in_graph_functions.add(obj)
|
|
else:
|
|
if hasattr(obj, "__code__"):
|
|
non_c_binding_in_graph_functions.add(obj)
|
|
else:
|
|
c_binding_in_graph_functions.add(obj)
|
|
|
|
def _is_allowed_module_prefix(obj):
|
|
allowed_modules = ("torch", "math")
|
|
# torch.nn.modules.rnn is disallowed because these modules internally
|
|
# flatten their parameters. This flattening process will call
|
|
# Tensor.set_ with a Storage, and Storages cannot be traced with
|
|
# AOTAutograd; so we need to graph-break. To ensure this, we inline
|
|
# these functions, rather than keep them opaque-ly in the graph.
|
|
disallowed_modules = [
|
|
"torch.optim.",
|
|
"torch.nn.modules.rnn.",
|
|
"torch._dynamo.",
|
|
"torch._C._dynamo.",
|
|
"torch._inductor.",
|
|
"torch._C.inductor.",
|
|
"torch.fx.",
|
|
"torch._C._autograd",
|
|
"torch._C._cudart",
|
|
"torch._C._distributed_autograd",
|
|
"torch._C._distributed_c10d",
|
|
"torch._C._distributed_rpc",
|
|
"torch._C._functorch",
|
|
"torch._C._monitor",
|
|
"torch._C._nvtx",
|
|
"torch._C._lazy",
|
|
"torch._C._profiler",
|
|
"torch.__config__",
|
|
"torch._custom_op",
|
|
"torch._decomp",
|
|
"torch._dispatch",
|
|
"torch._export",
|
|
"torch._functorch.make_functional",
|
|
"torch._functorch.compile_utils",
|
|
"torch._functorch.partitioners",
|
|
"torch._functorch.aot_autograd",
|
|
"torch._functorch.compilers",
|
|
"torch._functorch.fx_minifier",
|
|
"torch.autograd.profiler_util",
|
|
"torch.autograd.profiler",
|
|
"torch._jit_internal",
|
|
"torch._library",
|
|
"torch._lobpcg",
|
|
"torch._logging",
|
|
"torch._meta_registrations",
|
|
"torch._namedtensor_internals",
|
|
"torch._numpy",
|
|
"torch._sources",
|
|
"torch._subclasses",
|
|
"torch._tensor",
|
|
"torch._tensor_str",
|
|
"torch._utils",
|
|
"torch._utils_internal",
|
|
"torch._vmap_internals",
|
|
"torch.compiler",
|
|
"torch.distributed",
|
|
"torch.export",
|
|
"torch.hub",
|
|
"torch.jit",
|
|
"torch.library",
|
|
"torch.masked.maskedtensor",
|
|
"torch.nn.init",
|
|
"torch.nn.modules.module",
|
|
"torch.nn.parallel",
|
|
"torch.nn.utils",
|
|
"torch.multiprocessing",
|
|
"torch.onnx",
|
|
"torch.overrides",
|
|
"torch.package",
|
|
"torch.profiler",
|
|
"torch.serialization",
|
|
"torch.storage",
|
|
"torch.utils",
|
|
"torch.distributed.",
|
|
]
|
|
|
|
allowed_modules_dot = tuple([x + "." for x in allowed_modules])
|
|
module = inspect.getmodule(obj)
|
|
if module is None:
|
|
return False
|
|
|
|
mod_name = module.__name__
|
|
|
|
if any(mod_name.startswith(m) for m in disallowed_modules):
|
|
return False
|
|
|
|
return mod_name in allowed_modules or mod_name.startswith(allowed_modules_dot)
|
|
|
|
def _find_torch_objects(module):
|
|
if any(
|
|
module.__name__.startswith(mod_name)
|
|
for mod_name in config.allowed_functions_module_string_ignorelist
|
|
):
|
|
return
|
|
torch_object_ids[id(module)] = module.__name__
|
|
for name, obj in list(module.__dict__.items()):
|
|
if id(obj) not in torch_object_ids:
|
|
# Dynamo allows all builtins into the graph and does not attempt
|
|
# to introspect into them. We don't want to allow instances of
|
|
# HigherOrderOperator into the graph all the time (Dynamo needs
|
|
# to introspect the body functions of these HigherOrderOperator
|
|
# first, decide they are safe, and then allow them into the graph).
|
|
# So we exclude HigherOrderOperator from being a builtin.
|
|
import torch._ops
|
|
|
|
if isinstance(obj, torch._ops.HigherOrderOperator):
|
|
continue
|
|
|
|
# We want to trace through `grad` and `vmap`
|
|
if obj in (
|
|
torch.func.grad,
|
|
deprecated_func.grad,
|
|
torch.func.vmap,
|
|
deprecated_func.vmap,
|
|
torch.nn.functional.triplet_margin_with_distance_loss,
|
|
torch.cond,
|
|
):
|
|
continue
|
|
|
|
if isinstance(obj, types.ModuleType):
|
|
if obj.__name__.startswith("torch.") and _is_allowed_module_prefix(
|
|
obj
|
|
):
|
|
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
|
|
_find_torch_objects(obj)
|
|
elif _is_allowed_module_prefix(obj):
|
|
if record:
|
|
heuristic_record_if_in_graph_function(obj, module, name)
|
|
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
|
|
elif inspect.getmodule(obj) is None and not is_safe_constant(obj):
|
|
if record:
|
|
heuristic_record_if_in_graph_function(obj, module, name)
|
|
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
|
|
|
|
_find_torch_objects(torch)
|
|
_find_torch_objects(math)
|
|
|
|
return AllowedObjects(
|
|
torch_object_ids,
|
|
c_binding_in_graph_functions,
|
|
non_c_binding_in_graph_functions,
|
|
torch_name_rule_map,
|
|
)
|
|
|
|
|
|
class TraceRuleTests(torch._dynamo.test_case.TestCase):
|
|
def _check_set_equality(self, generated, used, rule_map, ignored_set):
|
|
x = generated - used
|
|
y = used - generated
|
|
msg1 = (
|
|
f"New torch objects: {x} "
|
|
f"were not added to `trace_rules.{rule_map}` or `test_trace_rules.{ignored_set}`. "
|
|
"Refer the instruction in `torch/_dynamo/trace_rules.py` for more details."
|
|
)
|
|
msg2 = (
|
|
f"Existing torch objects: {y} were removed. "
|
|
f"Please remove them from `trace_rules.{rule_map}` or `test_trace_rules.{ignored_set}`. "
|
|
"Refer the instruction in `torch/_dynamo/trace_rules.py` for more details."
|
|
)
|
|
self.assertTrue(len(x) == 0, msg1)
|
|
self.assertTrue(len(y) == 0, msg2)
|
|
|
|
# We are using python function and module string names for these inlinelist,
|
|
# this unit test is to make sure the functions/modules can be correctly imported
|
|
# or loaded in case there is typo in the strings.
|
|
def test_skipfiles_inlinelist(self):
|
|
for m in LEGACY_MOD_INLINELIST.union(MOD_INLINELIST):
|
|
try:
|
|
mod = importlib.import_module(m)
|
|
except ImportError:
|
|
continue
|
|
else:
|
|
self.assertTrue(
|
|
isinstance(mod, types.ModuleType),
|
|
f"{m} from trace_rules.MOD_INLINELIST/LEGACY_MOD_INLINELIST "
|
|
"is not a python module, please check and correct it.",
|
|
)
|
|
|
|
@unittest.skip(
|
|
"This test keeps getting broken and our disable infra is not handling well. see #120627"
|
|
)
|
|
def test_torch_name_rule_map_updated(self):
|
|
# Generate the allowed objects based on heuristic defined in `allowed_functions.py`,
|
|
objs = gen_allowed_objs_and_ids(record=True, c_binding_only=True)
|
|
# Test C binding in graph functions are updated in torch_name_rule_map.
|
|
generated = objs.c_binding_in_graph_functions
|
|
used = set()
|
|
for x in (
|
|
set(torch_c_binding_in_graph_functions.keys())
|
|
| ignored_c_binding_in_graph_function_names
|
|
):
|
|
obj = load_object(x)
|
|
if obj is not None:
|
|
used.add(obj)
|
|
self._check_set_equality(
|
|
generated,
|
|
used,
|
|
"torch_c_binding_in_graph_functions",
|
|
"ignored_c_binding_in_graph_function_names",
|
|
)
|
|
# For non C binding in graph functions, we only test if they can be loaded successfully.
|
|
for f in torch_non_c_binding_in_graph_functions:
|
|
self.assertTrue(
|
|
isinstance(
|
|
load_object(f),
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
types.MethodDescriptorType,
|
|
types.WrapperDescriptorType,
|
|
),
|
|
)
|
|
)
|
|
|
|
def test_force_inline_torch_function(self):
|
|
# `torch._dynamo.utils.istype` is skipped by default
|
|
def fn(x):
|
|
if istype(x, torch.Tensor):
|
|
return x + 1
|
|
else:
|
|
return x - 1
|
|
|
|
_manual_torch_name_rule_map = manual_torch_name_rule_map.copy()
|
|
# Force inline `torch._dynamo.utils.istype` by setting trace rule.
|
|
_manual_torch_name_rule_map["torch._dynamo.utils.istype"] = UserFunctionVariable
|
|
|
|
_torch_name_rule_map = [
|
|
_manual_torch_name_rule_map,
|
|
torch_c_binding_in_graph_functions,
|
|
torch_non_c_binding_in_graph_functions,
|
|
]
|
|
|
|
self.assertTrue(
|
|
"torch._dynamo" not in torch._dynamo.trace_rules.LEGACY_MOD_INLINELIST
|
|
)
|
|
self.assertTrue("torch._dynamo" not in torch._dynamo.trace_rules.MOD_INLINELIST)
|
|
|
|
with (
|
|
unittest.mock.patch(
|
|
"torch._dynamo.trace_rules.torch_name_rule_map",
|
|
_torch_name_rule_map,
|
|
),
|
|
unittest.mock.patch(
|
|
"torch._dynamo.trace_rules.get_torch_obj_rule_map",
|
|
torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, # bypass functools.lru_cache
|
|
),
|
|
):
|
|
x = torch.rand(3)
|
|
opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_force_inline_custom_function(self):
|
|
mod, func = create_dummy_module_and_function()
|
|
|
|
def fn(x):
|
|
return func(x)
|
|
|
|
_manual_torch_name_rule_map = manual_torch_name_rule_map.copy()
|
|
# Force inline `mod.func` by setting trace rule.
|
|
_manual_torch_name_rule_map[f"{mod.__name__}.{func.__name__}"] = (
|
|
UserFunctionVariable
|
|
)
|
|
|
|
_torch_name_rule_map = [
|
|
_manual_torch_name_rule_map,
|
|
torch_c_binding_in_graph_functions,
|
|
torch_non_c_binding_in_graph_functions,
|
|
]
|
|
|
|
with (
|
|
unittest.mock.patch(
|
|
"torch._dynamo.trace_rules.torch_name_rule_map",
|
|
_torch_name_rule_map,
|
|
),
|
|
unittest.mock.patch(
|
|
"torch._dynamo.trace_rules.get_torch_obj_rule_map",
|
|
torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__,
|
|
),
|
|
):
|
|
# First adding the module to SKIP_DIRS so that it will be skipped by default.
|
|
torch._dynamo.trace_rules.add(mod.__name__)
|
|
x = torch.rand(3)
|
|
opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_no_special_handlers_for_torch_non_c_bindings(self):
|
|
handlers = TorchInGraphFunctionVariable._get_handlers()
|
|
# These handlers are manually audited to be safe
|
|
safe_handlers = (
|
|
"handle_tracing_state_functions", # No global state (constant)
|
|
"handle_radians", # No global state (constant)
|
|
"handle_is_tensor", # No global state
|
|
"handle_torch_compile", # No global state, constant
|
|
"handle_ntuple", # No global state
|
|
"handle_is_grad_enabled", # Safely implemented
|
|
"handle_use_deterministic_algorithms", # Guarded variable
|
|
"handle_are_deterministic_algorithms_enabled", # Guarded constant
|
|
"handle_device_interface_stream", # No global state
|
|
"handle_cudnn_is_acceptable", # No global state
|
|
"handle_assert", # No global state (constant)
|
|
"handle_nested_tensor", # No global state
|
|
)
|
|
for fn in handlers:
|
|
if isinstance(fn, staticmethod) or inspect.ismethod(fn):
|
|
fn_name = f"{fn.__module__}#{fn.__name__}"
|
|
else:
|
|
fn_name = f"{fn.__module__}.{fn.__name__}"
|
|
if handlers[fn].__name__ in safe_handlers:
|
|
continue
|
|
self.assertFalse(
|
|
fn_name in torch_non_c_binding_in_graph_functions,
|
|
(
|
|
f"torch function {fn_name} has a special handler {handlers[fn].__name__}.\n"
|
|
"We expected all functions in `torch_non_c_binding_in_graph_functions` to be safe to cache.\n"
|
|
"Functions with special handlers may not be safe to cache, since they can close over global state.\n"
|
|
"If your handler/function is safe to cache, please add it to the list of safe handlers above.\n"
|
|
"Otherwise, add it to `manual_torch_name_rule_map` instead."
|
|
),
|
|
)
|
|
|
|
def test_almost_impossible_missing_name(self):
|
|
class weird: # noqa: UP004
|
|
def __getattribute__(self, name):
|
|
if name == "__name__":
|
|
raise AttributeError("test")
|
|
|
|
w = weird()
|
|
o = set()
|
|
with self.assertRaises(AttributeError):
|
|
w.__name__
|
|
self.assertEqual(lookup_inner(w, name=None, reasons=o), SkipFunctionVariable)
|
|
|
|
|
|
class TestModuleSurviveSkipFiles(torch._dynamo.test_case.TestCase):
|
|
@unittest.skipIf(
|
|
not torch.distributed.is_available(),
|
|
"need to import MLP module from distributed",
|
|
)
|
|
@skipIfWindows(
|
|
msg="AssertionError: False is not true : MLP did not survive skip files"
|
|
)
|
|
def test_module_survive_skip_files(self):
|
|
from torch.testing._internal.common_fsdp import MLP
|
|
|
|
model = MLP(3)
|
|
inp = torch.randn((2, 3))
|
|
frame_count_before = torch._dynamo.convert_frame.FRAME_COUNTER
|
|
model.compile(backend="eager")
|
|
model(inp)
|
|
frame_count_after = torch._dynamo.convert_frame.FRAME_COUNTER
|
|
self.assertTrue(
|
|
frame_count_after > frame_count_before, "MLP did not survive skip files"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|