Revert "Replace _dynamo.config with an object instead of module (#96455)"

This reverts commit 3864207c2a71a3ba8dc13bcf9582a726a10292cd.

Reverted https://github.com/pytorch/pytorch/pull/96455 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/96455#issuecomment-1576162237))
This commit is contained in:
PyTorch MergeBot
2023-06-05 07:06:14 +00:00
parent 258d398eec
commit f79d2b45fb
10 changed files with 408 additions and 536 deletions

View File

@ -1,12 +1,8 @@
# Owner(s): ["module: inductor"]
import copy
import math
import pickle
import unittest
from dataclasses import dataclass, field
import torch
from torch._config_utils import ConfigMixin
from torch._dynamo.test_case import run_tests, TestCase
@ -228,46 +224,6 @@ class TestInductorConfig(TestCase):
lambda: torch.compile(fn, backend="eager", mode="nope")(inp),
)
def test_config_mixin(self):
@dataclass
class Nest1(ConfigMixin):
a: int = 5
@dataclass
class Nest2(ConfigMixin):
n1: Nest1 = field(default_factory=Nest1)
@dataclass
class Nest3(ConfigMixin):
n2: Nest2 = field(default_factory=Nest2)
@dataclass
class Nest4(ConfigMixin):
n3: Nest3 = field(default_factory=Nest3)
c = Nest4()
self.assertEqual(c.to_dict(), {"n3.n2.n1.a": 5})
c.n3.n2.n1.a = 6
self.assertEqual(c.codegen_config("c"), "c.n3.n2.n1.a = 6")
state_dict = c.to_dict()
state_dict["n3.n2.n1.a"] = 7
self.assertEqual(c.n3.n2.n1.a, 7)
@c.patch("n3.n2.n1.a", 8)
def test_patch():
self.assertEqual(c.n3.n2.n1.a, 8)
test_patch()
# assert outside didnt change config value
self.assertEqual(c.n3.n2.n1.a, 7)
def test_config_pickle_ignore(self):
config = copy.deepcopy(torch._inductor.config)
config.trace.upload_tar = torch.ops.aten.add # something unpickable
pickle.dumps(config) # not throw
if __name__ == "__main__":
run_tests()

View File

@ -1,271 +0,0 @@
import contextlib
import copy
import dataclasses
import functools
import inspect
import pickle
import types
import unittest
from types import ModuleType
from typing import Any, Dict, Set
import torch
class ContextDecorator(contextlib.ContextDecorator):
"""
Same as contextlib.ContextDecorator, but with support for
`unittest.TestCase`
"""
def __call__(self, func):
if isinstance(func, type) and issubclass(func, unittest.TestCase):
class _TestCase(func): # type: ignore[misc, valid-type]
@classmethod
def setUpClass(cls):
self.__enter__() # type: ignore[attr-defined]
try:
super().setUpClass()
except Exception:
self.__exit__(None, None, None) # type: ignore[attr-defined]
raise
@classmethod
def tearDownClass(cls):
try:
super().tearDownClass()
finally:
self.__exit__(None, None, None) # type: ignore[attr-defined]
_TestCase.__name__ = func.__name__ # type:ignore[attr-defined]
_TestCase.__qualname__ = func.__qualname__
_TestCase.__module__ = func.__module__
return _TestCase
return super().__call__(func)
def _dataclass_obj_to_flat_dict(dc):
fields = getattr(type(dc), "__dataclass_fields__", {})
result = {}
for name, field in fields.items():
if not field.metadata.get("skip_pickle", False):
value = getattr(dc, name)
if dataclasses.is_dataclass(value):
for k2, v2 in _dataclass_obj_to_flat_dict(value).items():
result[f"{name}.{k2}"] = v2
else:
result[name] = value
return result
def _codegen_changes_of_dataclass_obj(dc, name):
values = _dataclass_obj_to_flat_dict(dc)
defaults = _dataclass_obj_to_flat_dict(type(dc)())
result = []
for k, v in values.items():
if defaults[k] != v:
result.append(f"{name}.{k} = {v!r}")
return "\n".join(result)
class ConfigMixin:
"""Mixin class shared between dataclasses that meant to represent a config.
Usage:
@dataclass
class SomeConfig(ConfigMixin):
a: int
b: int
...
c: SomeOtherNestedConfig
Note: c the nested config should also inherit ConfigMixin.
ie.
@dataclass
class SomeOtherNestedConfig(ConfigMixin):
d: ...
This mixin will:
1. Make the subclass pickable by allowing one to mark non-picklable
field with {'skip_pickle': True} metadata.
2. `save_config` which returns the config as bytes, and
`load_config` what replaces fields of an instance with the content
of serialized string. Note: these are legacy methods, it's better
to use pickle directly.
3. .to_dict will create a flat dict:
in the SomeConfig above, it will return a dictionary with keys
'a', 'b', 'c.d'
4. .codegen_config will create a string of python code with
modifications of this config compared to the default values.
"""
def __getstate__(self):
start = {}
for name, field in self._fields().items():
if not field.metadata.get("skip_pickle", False):
start[name] = getattr(self, name)
return start
def __setstate__(self, state):
self.__init__() # type: ignore[misc]
self.__dict__.update(state)
def save_config(self):
return pickle.dumps(self, protocol=2)
def load_config(self, content):
state = pickle.loads(content)
self.__dict__.update(state.__dict__)
return self
def _update_single(self, key, val):
pieces = key.split(".")
current = self
for token in pieces[:-1]:
current = getattr(current, token)
setattr(current, pieces[-1], val)
def _get_single(self, key):
pieces = key.split(".")
current = self
for token in pieces:
current = getattr(current, token)
return current
def update(self, content_dict):
for k, v in content_dict.items():
self._update_single(k, v)
@classmethod
def _fields(cls):
return getattr(cls, "__dataclass_fields__", {})
def __setattr__(self, key, val):
if (
not inspect.isclass(val)
and key not in type(self).__dict__
and key not in self._fields()
):
raise AttributeError(
f"Trying to set attribute {key} that is not part of this config {type(self).__name__}"
)
super().__setattr__(key, val)
def to_dict(self):
flatdict = _dataclass_obj_to_flat_dict(self)
return BoundDict(flatdict, self)
@classmethod
def is_fbcode(cls):
return not hasattr(torch.version, "git_version")
def patch(self, arg1=None, arg2=None, **kwargs):
"""
Decorator and/or context manager to make temporary changes to a config.
As a decorator:
@config.patch("name", val)
@config.patch(name1=val1, name2=val2):
@config.patch({"name1": val1, "name2", val2})
def foo(...):
...
As a context manager:
with config.patch("name", val):
...
"""
if arg1 is not None:
if arg2 is not None:
# patch("key", True) syntax
changes = {arg1: arg2}
else:
# patch({"key": True}) syntax
changes = arg1
assert not kwargs
else:
# patch(key=True) syntax
changes = kwargs
assert arg2 is None
assert isinstance(changes, dict), f"expected `dict` got {type(changes)}"
prior: Dict[str, Any] = {}
config = self
class ConfigPatch(ContextDecorator):
def __enter__(self):
assert not prior
for key in changes.keys():
# KeyError on invalid entry
prior[key] = config._get_single(key)
config.update(changes)
def __exit__(self, exc_type, exc_val, exc_tb):
config.update(prior)
prior.clear()
return ConfigPatch()
def codegen_config(self, name=None):
"""Convert config to Python statements that replicate current config.
This does NOT include config settings that are at default values.
"""
lines = []
if name is None:
name = self.__name__ # type: ignore[attr-defined]
return _codegen_changes_of_dataclass_obj(self, name)
class BoundDict(dict):
def __init__(self, orig, config):
super().__init__(orig)
self._config = config
def __setitem__(self, key, val):
self._config._update_single(key, val)
super().__setitem__(key, val)
def make_config_dataclass(name, config_module):
fields = []
module_name = ".".join(config_module.__name__.split(".")[:-1])
ignored_fields: Set[str] = getattr(config_module, "_save_config_ignore", set())
for fname, default_value in config_module.__dict__.items():
if callable(default_value) or isinstance(default_value, ModuleType):
# Module level functions and imported modules are
# usually not part of config.
continue
if fname.startswith("__"):
continue
annotation = config_module.__annotations__.get(fname)
assert (
annotation is not None
), f"Please specify type annotation for {fname} in {config_module.__name__}"
should_skip = fname in ignored_fields
field = dataclasses.field(
default_factory=functools.partial(copy.copy, default_value),
metadata={"skip_pickle": should_skip},
)
fields.append((fname, annotation, field))
fields.append(("__name__", str, dataclasses.field(default=config_module.__name__)))
cls = dataclasses.make_dataclass(
name, fields, bases=(ConfigMixin, types.ModuleType)
)
cls.__dataclass_fields__["__name__"].default = config_module.__name__ # type: ignore[attr-defined]
# NOTE: this is to make pickle work. In Python 3.12 make_dataclass
# will take a module argument that it would set __module__ field inside.
cls.__module__ = module_name
return cls
def install_config_module(classname, module):
orig_name = module.__name__
module.__class__ = make_config_dataclass(classname, module)
module.__init__() # call constructor by hand
module.__name__ = orig_name

View File

@ -1,10 +1,3 @@
# Class type of config
# torch._dynamo.config is really an object of this type instead of a module.
from . import config
DynamoConfig = type(config)
DynamoConfig.__doc__ = "Dataclass that holds configs related to dynamo."
from . import allowed_functions, convert_frame, eval_frame, resume_execution
from .backends.registry import list_backends, register_backend
from .convert_frame import replay
@ -33,8 +26,6 @@ from .utils import (
__all__ = [
"allow_in_graph",
"assume_constant_result",
"config",
"DynamoConfig",
"disallow_in_graph",
"forbid_in_graph",
"graph_break",

View File

@ -2,12 +2,10 @@ import inspect
import os
import re
import sys
import tempfile
from os.path import abspath, dirname
from types import ModuleType
from typing import Any, Callable, Dict, List, Optional, Set, Type
import torch
from torch._config_utils import ConfigMixin
from . import external_utils
@ -18,32 +16,30 @@ from . import external_utils
# see this design doc for more detailed info
# Design doc: https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit#
# the name of a file to write the logs to
log_file_name: Optional[str] = None
log_file_name = None
# Verbose will print full stack traces on warnings and errors
verbose: bool = os.environ.get("TORCHDYNAMO_VERBOSE", "0") == "1"
verbose = os.environ.get("TORCHDYNAMO_VERBOSE", "0") == "1"
# verify the correctness of optimized backend
verify_correctness: bool = False
verify_correctness = False
# need this many ops to create an FX graph
minimum_call_count: int = 1
minimum_call_count = 1
# turn on/off DCE pass
dead_code_elimination: bool = True
dead_code_elimination = True
# disable (for a function) when cache reaches this size
cache_size_limit: int = 64
cache_size_limit = 64
# whether or not to specialize on int inputs. This only has an effect with
# dynamic_shapes; when dynamic_shapes is False, we ALWAYS specialize on int
# inputs
specialize_int: bool = False
specialize_int = False
# Assume these functions return constants
# Key is the function itself and value is the constant return value of
# that function to substitute.
constant_functions: Dict[Callable, Any] = {
constant_functions = {
torch.jit.is_scripting: False,
torch.jit.is_tracing: False,
torch._C._get_tracing_state: None,
@ -54,29 +50,29 @@ constant_functions: Dict[Callable, Any] = {
}
# don't specialize on shapes and strides and put shape ops in graph
dynamic_shapes: bool = os.environ.get("TORCHDYNAMO_DYNAMIC_SHAPES") == "1"
dynamic_shapes = os.environ.get("TORCHDYNAMO_DYNAMIC_SHAPES") == "1"
# This is a temporarily flag, which changes the behavior of dynamic_shapes=True.
# When assume_static_by_default is True, we only allocate symbols for shapes marked dynamic via mark_dynamic.
# NOTE - this flag can be removed once we can run dynamic_shapes=False w/ the mark_dynamic API
# see [Note - on the state of mark_dynamic]
assume_static_by_default: bool = True
assume_static_by_default = True
# This flag changes how dynamic_shapes=True works, and is meant to be used in conjunction
# with assume_static_by_default=True.
# With this flag enabled, we always compile a frame as fully static for the first time, and, if we fail
# any guards due to wobbles in shape, we recompile with *all* the wobbled shapes as being marked dynamic.
automatic_dynamic_shapes: bool = True
automatic_dynamic_shapes = True
# Typically, if you mark_dynamic a dimension, we will error if the dimension
# actually ended up getting specialized. This knob changes the behavior so
# that we don't error at all. This is helpful for our CI where I'm using a
# heuristic to mark batch dimensions as dynamic and the heuristic may get it
# wrong.
allow_ignore_mark_dynamic: bool = False
allow_ignore_mark_dynamic = False
# Set this to False to assume nn.Modules() contents are immutable (similar assumption as freezing)
guard_nn_modules: bool = False
guard_nn_modules = False
# This feature doesn't really work. We offer this flag for experimental
# purposes / if you want to help us build out support.
@ -94,38 +90,37 @@ guard_nn_modules: bool = False
# We do NOT currently support __torch_dispatch__. The implementation is
# currently buggy, the main show stopper for nontrivial use is
# https://github.com/pytorch/torchdynamo/issues/1952
traceable_tensor_subclasses: Set[Type] = set()
traceable_tensor_subclasses = set()
# Suppress errors in torch._dynamo.optimize, instead forcing a fallback to eager.
# This is a good way to get your model to work one way or another, but you may
# lose optimization opportunities this way. Devs, if your benchmark model is failing
# this way, you should figure out why instead of suppressing it.
suppress_errors: bool = bool(os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", False))
suppress_errors = bool(os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", False))
# Record and write an execution record of the current frame to a file
# if an exception is encountered
replay_record_enabled: bool = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
replay_record_enabled = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
# Rewrite assert statement in python with torch._assert
rewrite_assert_with_torch_assert: bool = True
rewrite_assert_with_torch_assert = True
# Show a warning on every graph break
print_graph_breaks: bool = False
print_graph_breaks = False
# Show a warning for every specialization
print_specializations: bool = False
print_specializations = False
# Simplify guards, summarizing static and dynamic constraints on dimensions.
# NOTE: This only has an effect when dynamic_shapes=True.
summarize_dim_constraints: bool = False
summarize_dim_constraints = False
# Disable dynamo
disable: bool = os.environ.get("TORCH_COMPILE_DISABLE", False)
disable = os.environ.get("TORCH_COMPILE_DISABLE", False)
# If a PyTorch module is in this allowlist, torchdynamo will be allowed
# to inline objects from it or its children.
skipfiles_inline_module_allowlist: Set[ModuleType] = {
skipfiles_inline_module_allowlist = {
torch.nn,
torch.distributions,
torch.testing,
@ -140,7 +135,7 @@ skipfiles_inline_module_allowlist: Set[ModuleType] = {
# the `allowed_functions.is_allowed` function will not consider it
# when creating a list of PyTorch functions that will appear in
# FX IR.
allowed_functions_module_string_ignorelist: Set[str] = {
allowed_functions_module_string_ignorelist = {
"torch.distributions",
"torch.testing",
"torch._refs",
@ -148,18 +143,17 @@ allowed_functions_module_string_ignorelist: Set[str] = {
"torch._decomp",
}
# Debug Flag to try minifier at different stages. Possible values are {None, "aot", "dynamo"}
# None - Minifier is switched off
# dynamo - Runs minifier on the TorchDynamo produced graphs, if compilation fails
# aot - Runs minifier on the Aot Autograd produced graphs, if compilation fails
repro_after: Optional[str] = os.environ.get("TORCHDYNAMO_REPRO_AFTER", None)
repro_after = os.environ.get("TORCHDYNAMO_REPRO_AFTER", None)
# Compiler compilation debug info
# 1: Dumps the original graph out to repro.py if compilation fails
# 2: Dumps a minifier_launcher.py if compilation fails.
# 3: Always dumps a minifier_launcher.py. Good for segfaults.
# 4: Dumps a minifier_launcher.py if the accuracy fails.
repro_level: int = int(os.environ.get("TORCHDYNAMO_REPRO_LEVEL", 2))
repro_level = int(os.environ.get("TORCHDYNAMO_REPRO_LEVEL", 2))
# By default, we try to detect accuracy failure by running both forward
# and backward of a torchdynamo produced graph (if you are using repro_after
@ -169,93 +163,95 @@ repro_level: int = int(os.environ.get("TORCHDYNAMO_REPRO_LEVEL", 2))
# backwards step
# TODO: Detect this situation automatically so the user doesn't need
# to manually configure this
repro_forward_only: bool = os.environ.get("TORCHDYNAMO_REPRO_FORWARD_ONLY") == "1"
repro_forward_only = os.environ.get("TORCHDYNAMO_REPRO_FORWARD_ONLY") == "1"
# The tolerance we should use when testing if a compiled graph
# has diverged so that we should treat it as an accuracy failure
repro_tolerance: float = 1e-3
repro_tolerance = 1e-3
# If True, when testing if two models are the same, we will test them against
# a third fp64 reference and only report a problem if the RMSE relative to the
# fp64 is greater. However, this will use more memory; you may disable this
# if memory usage is too high.
same_two_models_use_fp64: bool = True
same_two_models_use_fp64 = True
# Not all backends support scalars. Some calls on torch.Tensor (like .item()) return a scalar type.
# When this flag is set to False, we introduce a graph break instead of capturing.
# This requires dynamic_shapes to be True.
capture_scalar_outputs: bool = False
capture_scalar_outputs = False
# Not all backends support operators that have dynamic output shape (e.g.,
# nonzero, unique). When this flag is set to False, we introduce a graph
# break instead of capturing. This requires dynamic_shapes to be True.
# If you set this to True, you probably also want capture_scalar_outputs
# (these are separated for historical reasons).
capture_dynamic_output_shape_ops: bool = False
capture_dynamic_output_shape_ops = False
# Should almost always be true in prod. This relaxes the requirement that cond's true_fn and
# false_fn produces code with identical guards.
enforce_cond_guards_match: bool = True
enforce_cond_guards_match = True
# Automatically split model graph into pieces to match DDP bucket sizes
# to allow DDP comm/compute overlap. Disable to allow DDP models to
# run without graph-breaks, but also without comm/compute overlap.
# set torch._dynamo.config.log_level to INFO or DEBUG for more info
# about optimize_ddp behavior.
optimize_ddp: bool = True
optimize_ddp = True
# Whether to skip guarding on FSDP-managed modules
skip_fsdp_guards: bool = True
skip_fsdp_guards = True
# Make dynamo skip guarding on hooks on nn modules
# Note: unsafe: if your model actually has hooks and you remove them, or doesn't and you add them,
# dynamo will not notice and will execute whichever version you first compiled.
skip_nnmodule_hook_guards: bool = True
skip_nnmodule_hook_guards = True
# If True, raises exception if TorchDynamo is called with a context manager
raise_on_ctx_manager_usage: bool = True
raise_on_ctx_manager_usage = True
# If True, raise when aot autograd is unsafe to use
raise_on_unsafe_aot_autograd: bool = False
raise_on_unsafe_aot_autograd = False
# Throw an error if backend changes without reset
raise_on_backend_change: bool = False
raise_on_backend_change = False
# If true, error with a better message if we symbolically trace over a
# dynamo-optimized function. If false, silently suppress dynamo.
error_on_nested_fx_trace: bool = True
error_on_nested_fx_trace = True
# Disables graph breaking on rnn. YMMV with backends.
allow_rnn: bool = False
allow_rnn = False
# If true, error if we try to compile a function that has
# been seen before.
error_on_recompile: bool = False
error_on_recompile = False
# reports why guards fail. Useful to identify the guards failing frequently and
# causing recompilations.
report_guard_failures: bool = os.environ.get("TORCHDYNAMO_REPORT_GUARD_FAILURES") == "1"
report_guard_failures = os.environ.get("TORCHDYNAMO_REPORT_GUARD_FAILURES") == "1"
# root folder of the project
base_dir: str = dirname(dirname(dirname(abspath(__file__))))
base_dir = dirname(dirname(dirname(abspath(__file__))))
# trace through numpy ndarray as tensor and try to translate numpy function to torch function.
numpy_ndarray_as_tensor: bool = False
numpy_ndarray_as_tensor = False
DEBUG_DIR_VAR_NAME: str = "TORCH_COMPILE_DEBUG_DIR"
def is_fbcode():
return not hasattr(torch.version, "git_version")
debug_dir_root: str = ""
DEBUG_DIR_VAR_NAME = "TORCH_COMPILE_DEBUG_DIR"
if DEBUG_DIR_VAR_NAME in os.environ:
debug_dir_root = os.path.join(os.environ[DEBUG_DIR_VAR_NAME], "torch_compile_debug")
elif ConfigMixin.is_fbcode():
elif is_fbcode():
debug_dir_root = os.path.join(tempfile.gettempdir(), "torch_compile_debug")
else:
debug_dir_root = os.path.join(os.getcwd(), "torch_compile_debug")
_save_config_ignore: Set[str] = {
_save_config_ignore = {
"repro_after",
"repro_level",
# workaround: "cannot pickle PyCapsule"
@ -264,9 +260,9 @@ _save_config_ignore: Set[str] = {
"skipfiles_inline_module_allowlist",
}
capture_autograd_function: bool = True
capture_autograd_function = True
_autograd_backward_strict_mode_banned_ops: List[str] = [
_autograd_backward_strict_mode_banned_ops = [
"stride",
"requires_grad",
"storage_offset",
@ -278,6 +274,7 @@ _autograd_backward_strict_mode_banned_ops.extend(
[name for name, _ in inspect.getmembers(torch.Tensor) if re.match(r"^is_.*", name)]
)
from torch._config_utils import install_config_module
install_config_module("DynamoConfig", sys.modules[__name__])
from .config_utils import install_config_module
install_config_module(sys.modules[__name__])

View File

@ -1,10 +1,225 @@
import contextlib
import pickle
import unittest
from types import FunctionType, ModuleType
from typing import Any, Dict, Set
from unittest import mock
# Types saved/loaded in configs
CONFIG_TYPES = (int, float, bool, type(None), str, list, set, tuple, dict)
def install_config_module(module):
"""
Converts a module-level config into a `ConfigModule()`
"""
class ConfigModuleInstance(ConfigModule):
_bypass_keys = set()
def visit(source, dest, prefix):
"""Walk the module structure and move everything to module._config"""
for key, value in list(source.__dict__.items()):
if key.startswith("__") or isinstance(value, (ModuleType, FunctionType)):
continue
name = f"{prefix}{key}"
if isinstance(value, CONFIG_TYPES):
config[name] = value
default[name] = value
if dest is module:
delattr(module, key)
elif isinstance(value, type):
assert value.__module__ == module.__name__
# a subconfig with `class Blah:` syntax
proxy = SubConfigProxy(module, f"{name}.")
visit(value, proxy, f"{name}.")
setattr(dest, key, proxy)
else:
raise AssertionError(f"Unhandled config {key}={value} ({type(value)})")
config = dict()
default = dict()
visit(module, module, "")
module._config = config
module._default = default
module._allowed_keys = set(config.keys())
module.__class__ = ConfigModuleInstance
class ConfigModule(ModuleType):
# The default values of the configuration settings. This can be used to
# determine if the config has been changed or not.
_default: Dict[str, Any]
# The actual configuration settings. E.g., torch._dynamo.config.debug
# would live as "debug" in the key, and torch._inductor.config.triton.cudagraphs
# maps as "triton.cudagraphs"
_config: Dict[str, Any]
_allowed_keys: Set[str]
_bypass_keys: Set[str]
def __init__(self):
raise NotImplementedError(
f"use {__name__}.install_config_module(sys.modules[__name__])"
)
def __setattr__(self, name, value):
if name in self._bypass_keys:
super().__setattr__(name, value)
elif name not in self._allowed_keys:
raise AttributeError(f"{self.__name__}.{name} does not exist")
else:
self._config[name] = value
def __getattr__(self, name):
try:
return self._config[name]
except KeyError:
# make hasattr() work properly
raise AttributeError(f"{self.__name__}.{name} does not exist")
def __delattr__(self, name):
# must support delete because unittest.mock.patch deletes
# then recreate things
del self._config[name]
def save_config(self):
"""Convert config to a pickled blob"""
config = dict(self._config)
for key in config.get("_save_config_ignore", ()):
config.pop(key)
return pickle.dumps(config, protocol=2)
def codegen_config(self):
"""Convert config to Python statements that replicate current config.
This does NOT include config settings that are at default values.
"""
lines = []
mod = self.__name__
for k, v in self._config.items():
if k in self._config.get("_save_config_ignore", ()):
continue
if v == self._default[k]:
continue
lines.append(f"{mod}.{k} = {v!r}")
return "\n".join(lines)
def load_config(self, data):
"""Restore from a prior call to save_config()"""
self.to_dict().update(pickle.loads(data))
def to_dict(self):
return self._config
def patch(self, arg1=None, arg2=None, **kwargs):
"""
Decorator and/or context manager to make temporary changes to a config.
As a decorator:
@config.patch("name", val)
@config.patch(name1=val1, name2=val2):
@config.patch({"name1": val1, "name2", val2})
def foo(...):
...
As a context manager:
with config.patch("name", val):
...
"""
if arg1 is not None:
if arg2 is not None:
# patch("key", True) syntax
changes = {arg1: arg2}
else:
# patch({"key": True}) syntax
changes = arg1
assert not kwargs
else:
# patch(key=True) syntax
changes = kwargs
assert arg2 is None
assert isinstance(changes, dict), f"expected `dict` got {type(changes)}"
prior = {}
config = self
class ConfigPatch(ContextDecorator):
def __enter__(self):
assert not prior
for key in changes.keys():
# KeyError on invalid entry
prior[key] = config._config[key]
config._config.update(changes)
def __exit__(self, exc_type, exc_val, exc_tb):
config._config.update(prior)
prior.clear()
return ConfigPatch()
class ContextDecorator(contextlib.ContextDecorator):
"""
Same as contextlib.ContextDecorator, but with support for
`unittest.TestCase`
"""
def __call__(self, func):
if isinstance(func, type) and issubclass(func, unittest.TestCase):
class _TestCase(func):
@classmethod
def setUpClass(cls):
self.__enter__()
try:
super().setUpClass()
except Exception:
self.__exit__(None, None, None)
raise
@classmethod
def tearDownClass(cls):
try:
super().tearDownClass()
finally:
self.__exit__(None, None, None)
_TestCase.__name__ = func.__name__
_TestCase.__qualname__ = func.__qualname__
_TestCase.__module__ = func.__module__
return _TestCase
return super().__call__(func)
class SubConfigProxy:
"""
Shim to redirect to main config.
`config.triton.cudagraphs` maps to _config["triton.cudagraphs"]
"""
def __init__(self, config, prefix):
# `super().__setattr__` to bypass custom `__setattr__`
super().__setattr__("_config", config)
super().__setattr__("_prefix", prefix)
def __setattr__(self, name, value):
return self._config.__setattr__(self._prefix + name, value)
def __getattr__(self, name):
return self._config.__getattr__(self._prefix + name)
def __delattr__(self, name):
return self._config.__delattr__(self._prefix + name)
def patch_object(obj, name, value):
"""
Workaround `mock.patch.object` issue with ConfigModule
"""
if isinstance(obj, ConfigMixin):
if isinstance(obj, ConfigModule):
return obj.patch(name, value)
return mock.patch.object(obj, name, value)

View File

@ -1,4 +1,3 @@
import copy
import dataclasses
import io
import logging
@ -98,8 +97,8 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
# NB: Can't use save_config because that will omit some fields,
# but we must save and reset ALL fields
dynamo_config = copy.deepcopy(torch._dynamo.config)
inductor_config = copy.deepcopy(torch._inductor.config)
dynamo_config = torch._dynamo.config._config.copy()
inductor_config = torch._inductor.config._config.copy()
try:
stderr = io.StringIO()
log_handler = logging.StreamHandler(stderr)
@ -123,8 +122,8 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
# around
torch._dynamo.reset()
finally:
torch._dynamo.config.__dict__.update(dynamo_config.__dict__)
torch._inductor.config.__dict__.update(inductor_config.__dict__)
object.__setattr__(torch._dynamo.config, "_config", dynamo_config)
object.__setattr__(torch._inductor.config, "_config", inductor_config)
# TODO: return a more appropriate data structure here
return subprocess.CompletedProcess(

View File

@ -3,7 +3,3 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from . import config
FuncTorchConfig = type(config)
__all__ = ['config', 'FunctorchConfig']

View File

@ -9,30 +9,32 @@ Global flags for aot autograd
"""
import os
import sys
from typing import Union
# Converts torch rng ops to their functional philox rng equivalents. Note that
# we functionalize only CUDA rng ops today.
functionalize_rng_ops: bool = False
functionalize_rng_ops = False
# can be useful for debugging if we are incorrectly creating meta fake tensors
fake_tensor_allow_meta: Union[str, bool] = os.environ.get("FAKE_ALLOW_META", True)
fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", True)
# Enables optional asserts in hotpath code to check for errors. If
# you are seeing weird accuracy problems, try turning this on.
# This is currently off by default as it will harm tracing time,
# but it is on by default for aot_eager.
debug_assert: bool = False
debug_assert = False
debug_partitioner: Union[str, bool] = os.environ.get("AOT_PARTITIONER_DEBUG", False)
debug_partitioner = os.environ.get("AOT_PARTITIONER_DEBUG", False)
static_weight_shapes: bool = True
static_weight_shapes = True
# Applies CSE to the graph before partitioning
cse: bool = True
cse = True
# Restricts the amount of computation AOTAutograd can do.
max_dist_from_bw: int = 3
max_dist_from_bw = 3
from torch._config_utils import install_config_module
install_config_module('FunctorchConfig', sys.modules[__name__])
from .._dynamo.config_utils import install_config_module
# adds patch, save_config, invalid config checks, etc
install_config_module(sys.modules[__name__])

View File

@ -1,10 +1,8 @@
from typing import Any, Dict, List, Optional
import torch.fx
from . import config
__all__ = ["compile", "list_mode_options", "list_options", "cudagraph_mark_step_begin"]
InductorConfig = type(config)
def compile(

View File

@ -1,157 +1,149 @@
import dataclasses
import os
import sys
from typing import Any, Dict, Optional, Tuple
import torch
from torch._config_utils import ConfigMixin
# add some debug printouts
debug: bool = False
debug = False
# Whether to disable a progress bar for autotuning
disable_progress: bool = True
disable_progress = True
# Whether to enable printing the source code for each future
verbose_progress: bool = False
verbose_progress = False
# use cpp wrapper instead of python wrapper
cpp_wrapper: bool = False
cpp_wrapper = False
# dead code elimination
dce: bool = False
dce = False
# assume weight tensors are fixed size
static_weight_shapes: bool = True
static_weight_shapes = True
# put correctness assertions in generated code
size_asserts: bool = os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS", "1") == "1"
size_asserts = os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS", "1") == "1"
# enable loop reordering based on input orders
pick_loop_orders: bool = True
pick_loop_orders = True
# generate inplace computations
inplace_buffers: bool = True
inplace_buffers = True
# allow reusing buffers for more efficient memory use
allow_buffer_reuse: bool = True
allow_buffer_reuse = True
# codegen benchmark harness
benchmark_harness: bool = True
benchmark_harness = True
# fuse pointwise into templates
epilogue_fusion: bool = True
epilogue_fusion = True
# do epilogue fusions before other fusions
epilogue_fusion_first: bool = False
epilogue_fusion_first = False
# enable pattern match+replace optimizations
pattern_matcher: bool = True
pattern_matcher = True
# Optimize away split cat patterns (Experimental)
split_cat_fx_passes: bool = True
split_cat_fx_passes = True
# enable reordering pass
reordering: bool = True
reordering = True
# enable slow autotuning passes to select algorithms
max_autotune: bool = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1"
max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1"
# enable slow autotuning passes to select pointwise/reductions algorithms
max_autotune_pointwise: bool = (
os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE") == "1"
)
max_autotune_pointwise = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE") == "1"
# enable slow autotuning passes to select gemm algorithms
max_autotune_gemm: bool = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_GEMM") == "1"
max_autotune_gemm = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_GEMM") == "1"
# enable searching global and local cache regardless of `max_autotune`
search_autotune_cache: bool = (
os.environ.get("TORCHINDUCTOR_SEARCH_AUTOTUNE_CACHE") == "1"
)
search_autotune_cache = os.environ.get("TORCHINDUCTOR_SEARCH_AUTOTUNE_CACHE") == "1"
# We will disable creating subprocess for autotuning if this is False
autotune_in_subproc: bool = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1"
autotune_in_subproc = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1"
coordinate_descent_tuning: bool = (
coordinate_descent_tuning = (
os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_TUNING") == "1"
)
layout_optimization: bool = (
os.environ.get("TORCHINDUCTOR_LAYOUT_OPTIMIZATION", "1") == "1"
)
layout_optimization = os.environ.get("TORCHINDUCTOR_LAYOUT_OPTIMIZATION", "1") == "1"
# Whether to keep the output strides the same as eager after layout optimization.
keep_output_stride: bool = (
os.environ.get("TORCHINDUCTOR_KEEP_OUTPUT_STRIDE", "1") == "1"
)
keep_output_stride = os.environ.get("TORCHINDUCTOR_KEEP_OUTPUT_STRIDE", "1") == "1"
# Enabling this will let compiler print warning messages if a generated triton
# kernel has inputs with mixed layouts. This is helpful for perf debugging
# since kernel with mixed layout inputs may run much slower then one whose inputs
# have uniform layouts.
warn_mix_layout: bool = os.environ.get("TORCHINDUCTOR_WARN_MIX_LAYOUT") == "1"
warn_mix_layout = os.environ.get("TORCHINDUCTOR_WARN_MIX_LAYOUT") == "1"
# control store vs recompute heuristic
# For fanouts, rematerialization can lead to exponential blowup. So, have
# smaller threshold
realize_reads_threshold: int = 4
realize_bytes_threshold: int = 2000
realize_reads_threshold = 4
realize_bytes_threshold = 2000
# Threshold to prevent excessive accumulation of ops in one buffer during lowering
realize_acc_reads_threshold: int = 8
realize_acc_reads_threshold = 8
# fallback to eager for random/dropout, this is slow but useful for debugging
fallback_random: bool = False
fallback_random = False
# automatically create fallbacks when encountering an unhandled op
implicit_fallbacks: bool = True
implicit_fallbacks = True
# fuse even in cases without common reads
aggressive_fusion: bool = False
aggressive_fusion = False
# how many nodes to allow into a single fusion
max_fusion_size: int = 64
max_fusion_size = 64
# replace small reductions with pointwise, disable with `= 1`
unroll_reductions_threshold: int = 8
unroll_reductions_threshold = 8
# Add extra comments to output code (causes compile cache misses)
comment_origin: bool = False
comment_origin = False
# Convert 1x1 convs into matmuls
conv_1x1_as_mm: bool = False
conv_1x1_as_mm = False
# Enable split reductions for better utilization when the dimension
# being reduced over is large (by splitting it)
split_reductions: bool = True
split_reductions = True
benchmark_kernel: bool = os.environ.get("TORCHINDUCTOR_BENCHMARK_KERNEL", "0") == "1"
benchmark_kernel = os.environ.get("TORCHINDUCTOR_BENCHMARK_KERNEL", "0") == "1"
# Enable constant and index_expr folding
constant_and_index_propagation: bool = True
constant_and_index_propagation = True
# Enable indirect_indexing asserts for decompositions and lowerings
debug_index_asserts: bool = False
debug_index_asserts = False
def is_fbcode():
return not hasattr(torch.version, "git_version")
# warnings intended for PyTorch developers, disable for point releases
is_nightly_or_source: bool = "dev" in torch.__version__ or "git" in torch.__version__
developer_warnings: bool = ConfigMixin.is_fbcode() or is_nightly_or_source
is_nightly_or_source = "dev" in torch.__version__ or "git" in torch.__version__
developer_warnings = is_fbcode() or is_nightly_or_source
def decide_compile_threads():
"""
Here are the precedence to decide compile_threads
1. User can override it by TORCHINDUCTOR_COMPILE_THREADS.
One may want to disable async compiling by
1. User can override it by TORCHINDUCTOR_COMPILE_THREADS. One may want to disable async compiling by
setting this to 1 to make pdb happy.
2. Set to 1 if it's win32 platform or it's a fbcode build
3. decide by the number of CPU cores
"""
if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ:
return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"])
elif sys.platform == "win32" or ConfigMixin.is_fbcode():
elif sys.platform == "win32" or is_fbcode():
return 1
else:
return min(
@ -162,63 +154,64 @@ def decide_compile_threads():
)
compile_threads: int = decide_compile_threads()
compile_threads = decide_compile_threads()
# gemm autotuning global cache dir
global_cache_dir: Optional[str] = "fb/cache" if ConfigMixin.is_fbcode() else None
if is_fbcode():
global_cache_dir = "fb/cache"
else:
global_cache_dir = None
# If kernel is fused, the name is generated from the origin node op names
# for larger kernels limit this
kernel_name_max_ops: int = 10
kernel_name_max_ops = 10
# Pad input tensors of matmul/bmm/addmm to leverage Tensor Cores in NVIDIA GPUs
shape_padding: bool = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "1") == "1"
shape_padding = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "1") == "1"
# Fx-based linear/matmul/bmm + permute/transpose vertical fusion
permute_fusion: bool = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1"
permute_fusion = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1"
# Mark the wrapper call in PyTorch profiler
profiler_mark_wrapper_call: bool = False
profiler_mark_wrapper_call = False
# Generate hook calls to torch._inductor.hooks.run_intermediate_hooks for
# every intermediate for which we can correlate it with an intermediate
# from the original FX graph
generate_intermediate_hooks: bool = False
generate_intermediate_hooks = False
# Populate traceback field on IRNode; good for debugging why origin_node is
# not populated, or finding out where an IRNode was constructed
debug_ir_traceback: bool = False
debug_ir_traceback = False
# used for debugging to make sure config is properly set
_raise_error_for_testing: bool = False
_raise_error_for_testing = False
_profile_var: str = os.environ.get("TORCHINDUCTOR_PROFILE", "")
profile_bandwidth: bool = _profile_var != ""
profile_bandwidth_regex: str = "" if _profile_var == "1" else _profile_var
_profile_var = os.environ.get("TORCHINDUCTOR_PROFILE", "")
profile_bandwidth = _profile_var != ""
profile_bandwidth_regex = "" if _profile_var == "1" else _profile_var
disable_cpp_codegen: bool = ConfigMixin.is_fbcode()
disable_cpp_codegen = is_fbcode()
# config specific to codegen/cpp.pp
@dataclasses.dataclass
class CppConfig(ConfigMixin):
# config specific to codegen/cpp.py
class cpp:
# set to torch.get_num_threads()
threads: int = -1
threads = -1
# Do not generate loops when the condition doesn't hold, like:
# for(long i0=4096; i0<4096; i0+=1)
no_redundant_loops: bool = True
no_redundant_loops = True
# Assume number of threads is dynamic, don't specialize thread number.
# Kernels don't recompile on thread number changes with this flag on.
# For single-threaded workload, turning it on would incur a slight
# performance degradation.
dynamic_threads: bool = False
dynamic_threads = False
simdlen: Optional[int] = None
min_chunk_size: int = 4096
cxx: Tuple[Optional[str]] = (
simdlen = None
min_chunk_size = 4096
cxx = (
None, # download gcc12 from conda-forge if conda is installed
# "g++-12",
# "g++-11",
@ -228,160 +221,156 @@ class CppConfig(ConfigMixin):
# "g++.par",
)
# Allow kernel performance profiling via PyTorch profiler
enable_kernel_profile: bool = False
enable_kernel_profile = False
# enable weight prepacking to get a better performance; may lead to large memory footprint
weight_prepack: bool = True
weight_prepack = True
# Inject a bug into our relu implementation; useful for testing our repro
# extraction and minification functionality.
# Valid values: "compile_error", "runtime_error", "accuracy"
inject_relu_bug_TESTING_ONLY: Optional[str] = None
inject_log1p_bug_TESTING_ONLY: Optional[str] = None
inject_relu_bug_TESTING_ONLY = None
inject_log1p_bug_TESTING_ONLY = None
# If None, autodetect whether or not AVX512/AVX2 can be used. Otherwise,
# force usage as specified, without testing.
vec_isa_ok: Optional[str] = None
vec_isa_ok = None
# similar to config.triton.descriptive_names
descriptive_names: str = "original_aten"
descriptive_names = "original_aten"
# how many nodes to allow into a single horizontal fusion
max_horizontal_fusion_size: int = 16
max_horizontal_fusion_size = 16
# config specific to codegen/triton.py
@dataclasses.dataclass
class TritonConfig(ConfigMixin):
class triton:
# Use cudagraphs on output code
cudagraphs: bool = False
cudagraphs = False
# Use cudagraph trees for memory pooling if `cudagraphs` is True
cudagraph_trees: bool = not ConfigMixin.is_fbcode()
cudagraph_trees = not is_fbcode()
# assertions not on the fast path, steady state
slow_path_cudagraph_asserts: bool = True
slow_path_cudagraph_asserts = True
# TODO - need to debug why this prevents cleanup
cudagraph_trees_history_recording: bool = False
cudagraph_trees_history_recording = False
# assertions on the fast path
fast_path_cudagraph_asserts: bool = False
fast_path_cudagraph_asserts = False
# skip warmup for cudagraph trees
skip_cudagraph_warmup: bool = False
skip_cudagraph_warmup = False
# Synchronize before and after every compiled graph.
debug_sync_graph: bool = False
debug_sync_graph = False
# Synchronize after every kernel launch, to help pinpoint bugs
debug_sync_kernel: bool = False
debug_sync_kernel = False
# Always load full blocks (rather than broadcasting inside the block)
dense_indexing: bool = False
dense_indexing = False
# limit tiling dimensions
max_tiles: int = 2
max_tiles = 2
# use triton.autotune for pointwise ops with complex layouts
# this should only be disabled for debugging/testing
autotune_pointwise: bool = True
autotune_pointwise = True
# should we stop a fusion to allow better tiling?
tiling_prevents_pointwise_fusion: bool = True
tiling_prevents_reduction_fusion: bool = True
tiling_prevents_pointwise_fusion = True
tiling_prevents_reduction_fusion = True
# assert that indirect indexing does not read / write out of bounds
assert_indirect_indexing: bool = True
assert_indirect_indexing = True
# should we give different names to kernels
# Note: This is orthogonal to descriptive_names - this is deciding whether
# our triton kernel names should all be `triton_` (to maximize caching) or
# whether they should be unique.
unique_kernel_names: bool = (
os.environ.get("TORCHINDUCTOR_UNIQUE_KERNEL_NAMES") == "1"
)
unique_kernel_names = os.environ.get("TORCHINDUCTOR_UNIQUE_KERNEL_NAMES") == "1"
# should we put op names in kernel names
# False: No special names (just triton__1, triton__2, etc.)
# "torch": Maps to the fx op in the Dynamo graph (module name, method name, etc.)
# "original_aten": Maps to the highest-level aten op (i.e. pre-decompositions)
# "inductor_node": Maps to the node name in the FX graph passed to Inductor
descriptive_names: str = "original_aten"
descriptive_names = "original_aten"
# use alternate codegen for smaller reductions
persistent_reductions: bool = (
persistent_reductions = (
os.environ.get("TORCHINDUCTOR_PERSISTENT_REDUCTIONS", "1") == "1"
)
# hint to Triton when arguments are divisible by 16
divisible_by_16: bool = True
divisible_by_16 = True
# theses are not enforced, but they are used by asserts in triton_heuristics.py
# NOTE: mobilevit_s in timm_models required X to be set to the higher value 2048
max_block: Dict[str, int] = dataclasses.field(
default_factory=lambda: {"X": 2048, "Y": 1024, "Z": 1024}
)
max_block = {"X": 2048, "Y": 1024, "Z": 1024}
# Store the generated cubin files for cpp wrapper code to load
store_cubin: bool = False
store_cubin = False
# the max number of spills we allow for the configs we benchmark.
# Setting this to 0 means we skip a config if it spills even a single
# register.
# Settting it to a larger value allows a config spilling a small amount
# of registers being benchmarked.
spill_threshold: int = 0
spill_threshold = 0
# Inject a bug into our relu implementation; useful for testing our repro
# extraction and minification functionality.
# Valid values: "compile_error", "runtime_error", "accuracy"
inject_relu_bug_TESTING_ONLY: Optional[str] = None
inject_relu_bug_TESTING_ONLY = None
# create a directory containing lots of debug information
@dataclasses.dataclass
class TraceConfig(ConfigMixin):
class trace:
# master switch for all debugging flags below
enabled: bool = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
enabled = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
# Save python logger call >=logging.DEBUG
debug_log: bool = False
debug_log = False
# Save python logger call >=logging.INFO
info_log: bool = False
info_log = False
# Save input FX graph (post decomps, pre optimization)
fx_graph: bool = True
fx_graph = True
# Save FX graph after transformations
fx_graph_transformed: bool = True
fx_graph_transformed = True
# Save TorchInductor IR before fusion pass
ir_pre_fusion: bool = True
ir_pre_fusion = True
# Save TorchInductor IR after fusion pass
ir_post_fusion: bool = True
ir_post_fusion = True
# Copy generated code to trace dir
output_code: bool = True
output_code = True
# SVG figure showing post-fusion graph
graph_diagram: bool = False
graph_diagram = False
# Store cProfile (see snakeviz to view)
compile_profile: bool = False
compile_profile = False
# Upload the .tar.gz file
# Needs to be overriden based on specific environment needs
# skip_pickle will make this field skipped by pickle
upload_tar: Any = dataclasses.field(default=None, metadata={"skip_pickle": True})
upload_tar = None
cpp: CppConfig = CppConfig()
triton: TritonConfig = TritonConfig()
trace: TraceConfig = TraceConfig()
_save_config_ignore = {
# workaround: "Can't pickle <function ...>"
"trace.upload_tar",
}
from torch._config_utils import install_config_module
install_config_module("InductorConfig", sys.modules[__name__])
from .._dynamo.config_utils import install_config_module
# adds patch, save_config, etc
install_config_module(sys.modules[__name__])