mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "Replace _dynamo.config with an object instead of module (#96455)"
This reverts commit 3864207c2a71a3ba8dc13bcf9582a726a10292cd. Reverted https://github.com/pytorch/pytorch/pull/96455 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/96455#issuecomment-1576162237))
This commit is contained in:
@ -1,12 +1,8 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import copy
|
||||
import math
|
||||
import pickle
|
||||
import unittest
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
from torch._config_utils import ConfigMixin
|
||||
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
|
||||
@ -228,46 +224,6 @@ class TestInductorConfig(TestCase):
|
||||
lambda: torch.compile(fn, backend="eager", mode="nope")(inp),
|
||||
)
|
||||
|
||||
def test_config_mixin(self):
|
||||
@dataclass
|
||||
class Nest1(ConfigMixin):
|
||||
a: int = 5
|
||||
|
||||
@dataclass
|
||||
class Nest2(ConfigMixin):
|
||||
n1: Nest1 = field(default_factory=Nest1)
|
||||
|
||||
@dataclass
|
||||
class Nest3(ConfigMixin):
|
||||
n2: Nest2 = field(default_factory=Nest2)
|
||||
|
||||
@dataclass
|
||||
class Nest4(ConfigMixin):
|
||||
n3: Nest3 = field(default_factory=Nest3)
|
||||
|
||||
c = Nest4()
|
||||
self.assertEqual(c.to_dict(), {"n3.n2.n1.a": 5})
|
||||
|
||||
c.n3.n2.n1.a = 6
|
||||
self.assertEqual(c.codegen_config("c"), "c.n3.n2.n1.a = 6")
|
||||
|
||||
state_dict = c.to_dict()
|
||||
state_dict["n3.n2.n1.a"] = 7
|
||||
self.assertEqual(c.n3.n2.n1.a, 7)
|
||||
|
||||
@c.patch("n3.n2.n1.a", 8)
|
||||
def test_patch():
|
||||
self.assertEqual(c.n3.n2.n1.a, 8)
|
||||
|
||||
test_patch()
|
||||
# assert outside didnt change config value
|
||||
self.assertEqual(c.n3.n2.n1.a, 7)
|
||||
|
||||
def test_config_pickle_ignore(self):
|
||||
config = copy.deepcopy(torch._inductor.config)
|
||||
config.trace.upload_tar = torch.ops.aten.add # something unpickable
|
||||
pickle.dumps(config) # not throw
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -1,271 +0,0 @@
|
||||
import contextlib
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
import inspect
|
||||
import pickle
|
||||
import types
|
||||
import unittest
|
||||
from types import ModuleType
|
||||
from typing import Any, Dict, Set
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class ContextDecorator(contextlib.ContextDecorator):
|
||||
"""
|
||||
Same as contextlib.ContextDecorator, but with support for
|
||||
`unittest.TestCase`
|
||||
"""
|
||||
|
||||
def __call__(self, func):
|
||||
if isinstance(func, type) and issubclass(func, unittest.TestCase):
|
||||
|
||||
class _TestCase(func): # type: ignore[misc, valid-type]
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
self.__enter__() # type: ignore[attr-defined]
|
||||
try:
|
||||
super().setUpClass()
|
||||
except Exception:
|
||||
self.__exit__(None, None, None) # type: ignore[attr-defined]
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
super().tearDownClass()
|
||||
finally:
|
||||
self.__exit__(None, None, None) # type: ignore[attr-defined]
|
||||
|
||||
_TestCase.__name__ = func.__name__ # type:ignore[attr-defined]
|
||||
_TestCase.__qualname__ = func.__qualname__
|
||||
_TestCase.__module__ = func.__module__
|
||||
return _TestCase
|
||||
|
||||
return super().__call__(func)
|
||||
|
||||
|
||||
def _dataclass_obj_to_flat_dict(dc):
|
||||
fields = getattr(type(dc), "__dataclass_fields__", {})
|
||||
result = {}
|
||||
for name, field in fields.items():
|
||||
if not field.metadata.get("skip_pickle", False):
|
||||
value = getattr(dc, name)
|
||||
if dataclasses.is_dataclass(value):
|
||||
for k2, v2 in _dataclass_obj_to_flat_dict(value).items():
|
||||
result[f"{name}.{k2}"] = v2
|
||||
else:
|
||||
result[name] = value
|
||||
return result
|
||||
|
||||
|
||||
def _codegen_changes_of_dataclass_obj(dc, name):
|
||||
values = _dataclass_obj_to_flat_dict(dc)
|
||||
defaults = _dataclass_obj_to_flat_dict(type(dc)())
|
||||
|
||||
result = []
|
||||
for k, v in values.items():
|
||||
if defaults[k] != v:
|
||||
result.append(f"{name}.{k} = {v!r}")
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
class ConfigMixin:
|
||||
"""Mixin class shared between dataclasses that meant to represent a config.
|
||||
|
||||
Usage:
|
||||
@dataclass
|
||||
class SomeConfig(ConfigMixin):
|
||||
a: int
|
||||
b: int
|
||||
...
|
||||
c: SomeOtherNestedConfig
|
||||
|
||||
Note: c the nested config should also inherit ConfigMixin.
|
||||
ie.
|
||||
|
||||
@dataclass
|
||||
class SomeOtherNestedConfig(ConfigMixin):
|
||||
d: ...
|
||||
|
||||
This mixin will:
|
||||
1. Make the subclass pickable by allowing one to mark non-picklable
|
||||
field with {'skip_pickle': True} metadata.
|
||||
2. `save_config` which returns the config as bytes, and
|
||||
`load_config` what replaces fields of an instance with the content
|
||||
of serialized string. Note: these are legacy methods, it's better
|
||||
to use pickle directly.
|
||||
3. .to_dict will create a flat dict:
|
||||
in the SomeConfig above, it will return a dictionary with keys
|
||||
'a', 'b', 'c.d'
|
||||
4. .codegen_config will create a string of python code with
|
||||
modifications of this config compared to the default values.
|
||||
"""
|
||||
|
||||
def __getstate__(self):
|
||||
start = {}
|
||||
for name, field in self._fields().items():
|
||||
if not field.metadata.get("skip_pickle", False):
|
||||
start[name] = getattr(self, name)
|
||||
return start
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__init__() # type: ignore[misc]
|
||||
self.__dict__.update(state)
|
||||
|
||||
def save_config(self):
|
||||
return pickle.dumps(self, protocol=2)
|
||||
|
||||
def load_config(self, content):
|
||||
state = pickle.loads(content)
|
||||
self.__dict__.update(state.__dict__)
|
||||
return self
|
||||
|
||||
def _update_single(self, key, val):
|
||||
pieces = key.split(".")
|
||||
current = self
|
||||
for token in pieces[:-1]:
|
||||
current = getattr(current, token)
|
||||
setattr(current, pieces[-1], val)
|
||||
|
||||
def _get_single(self, key):
|
||||
pieces = key.split(".")
|
||||
current = self
|
||||
for token in pieces:
|
||||
current = getattr(current, token)
|
||||
return current
|
||||
|
||||
def update(self, content_dict):
|
||||
for k, v in content_dict.items():
|
||||
self._update_single(k, v)
|
||||
|
||||
@classmethod
|
||||
def _fields(cls):
|
||||
return getattr(cls, "__dataclass_fields__", {})
|
||||
|
||||
def __setattr__(self, key, val):
|
||||
if (
|
||||
not inspect.isclass(val)
|
||||
and key not in type(self).__dict__
|
||||
and key not in self._fields()
|
||||
):
|
||||
raise AttributeError(
|
||||
f"Trying to set attribute {key} that is not part of this config {type(self).__name__}"
|
||||
)
|
||||
super().__setattr__(key, val)
|
||||
|
||||
def to_dict(self):
|
||||
flatdict = _dataclass_obj_to_flat_dict(self)
|
||||
return BoundDict(flatdict, self)
|
||||
|
||||
@classmethod
|
||||
def is_fbcode(cls):
|
||||
return not hasattr(torch.version, "git_version")
|
||||
|
||||
def patch(self, arg1=None, arg2=None, **kwargs):
|
||||
"""
|
||||
Decorator and/or context manager to make temporary changes to a config.
|
||||
|
||||
As a decorator:
|
||||
|
||||
@config.patch("name", val)
|
||||
@config.patch(name1=val1, name2=val2):
|
||||
@config.patch({"name1": val1, "name2", val2})
|
||||
def foo(...):
|
||||
...
|
||||
|
||||
As a context manager:
|
||||
|
||||
with config.patch("name", val):
|
||||
...
|
||||
"""
|
||||
if arg1 is not None:
|
||||
if arg2 is not None:
|
||||
# patch("key", True) syntax
|
||||
changes = {arg1: arg2}
|
||||
else:
|
||||
# patch({"key": True}) syntax
|
||||
changes = arg1
|
||||
assert not kwargs
|
||||
else:
|
||||
# patch(key=True) syntax
|
||||
changes = kwargs
|
||||
assert arg2 is None
|
||||
assert isinstance(changes, dict), f"expected `dict` got {type(changes)}"
|
||||
prior: Dict[str, Any] = {}
|
||||
config = self
|
||||
|
||||
class ConfigPatch(ContextDecorator):
|
||||
def __enter__(self):
|
||||
assert not prior
|
||||
for key in changes.keys():
|
||||
# KeyError on invalid entry
|
||||
prior[key] = config._get_single(key)
|
||||
config.update(changes)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
config.update(prior)
|
||||
prior.clear()
|
||||
|
||||
return ConfigPatch()
|
||||
|
||||
def codegen_config(self, name=None):
|
||||
"""Convert config to Python statements that replicate current config.
|
||||
This does NOT include config settings that are at default values.
|
||||
"""
|
||||
lines = []
|
||||
if name is None:
|
||||
name = self.__name__ # type: ignore[attr-defined]
|
||||
return _codegen_changes_of_dataclass_obj(self, name)
|
||||
|
||||
|
||||
class BoundDict(dict):
|
||||
def __init__(self, orig, config):
|
||||
super().__init__(orig)
|
||||
self._config = config
|
||||
|
||||
def __setitem__(self, key, val):
|
||||
self._config._update_single(key, val)
|
||||
super().__setitem__(key, val)
|
||||
|
||||
|
||||
def make_config_dataclass(name, config_module):
|
||||
fields = []
|
||||
module_name = ".".join(config_module.__name__.split(".")[:-1])
|
||||
|
||||
ignored_fields: Set[str] = getattr(config_module, "_save_config_ignore", set())
|
||||
for fname, default_value in config_module.__dict__.items():
|
||||
if callable(default_value) or isinstance(default_value, ModuleType):
|
||||
# Module level functions and imported modules are
|
||||
# usually not part of config.
|
||||
continue
|
||||
if fname.startswith("__"):
|
||||
continue
|
||||
annotation = config_module.__annotations__.get(fname)
|
||||
assert (
|
||||
annotation is not None
|
||||
), f"Please specify type annotation for {fname} in {config_module.__name__}"
|
||||
should_skip = fname in ignored_fields
|
||||
field = dataclasses.field(
|
||||
default_factory=functools.partial(copy.copy, default_value),
|
||||
metadata={"skip_pickle": should_skip},
|
||||
)
|
||||
fields.append((fname, annotation, field))
|
||||
fields.append(("__name__", str, dataclasses.field(default=config_module.__name__)))
|
||||
cls = dataclasses.make_dataclass(
|
||||
name, fields, bases=(ConfigMixin, types.ModuleType)
|
||||
)
|
||||
cls.__dataclass_fields__["__name__"].default = config_module.__name__ # type: ignore[attr-defined]
|
||||
|
||||
# NOTE: this is to make pickle work. In Python 3.12 make_dataclass
|
||||
# will take a module argument that it would set __module__ field inside.
|
||||
cls.__module__ = module_name
|
||||
return cls
|
||||
|
||||
|
||||
def install_config_module(classname, module):
|
||||
orig_name = module.__name__
|
||||
module.__class__ = make_config_dataclass(classname, module)
|
||||
module.__init__() # call constructor by hand
|
||||
module.__name__ = orig_name
|
@ -1,10 +1,3 @@
|
||||
# Class type of config
|
||||
# torch._dynamo.config is really an object of this type instead of a module.
|
||||
from . import config
|
||||
|
||||
DynamoConfig = type(config)
|
||||
DynamoConfig.__doc__ = "Dataclass that holds configs related to dynamo."
|
||||
|
||||
from . import allowed_functions, convert_frame, eval_frame, resume_execution
|
||||
from .backends.registry import list_backends, register_backend
|
||||
from .convert_frame import replay
|
||||
@ -33,8 +26,6 @@ from .utils import (
|
||||
__all__ = [
|
||||
"allow_in_graph",
|
||||
"assume_constant_result",
|
||||
"config",
|
||||
"DynamoConfig",
|
||||
"disallow_in_graph",
|
||||
"forbid_in_graph",
|
||||
"graph_break",
|
||||
|
@ -2,12 +2,10 @@ import inspect
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
from os.path import abspath, dirname
|
||||
from types import ModuleType
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Type
|
||||
|
||||
import torch
|
||||
from torch._config_utils import ConfigMixin
|
||||
from . import external_utils
|
||||
|
||||
|
||||
@ -18,32 +16,30 @@ from . import external_utils
|
||||
# see this design doc for more detailed info
|
||||
# Design doc: https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit#
|
||||
# the name of a file to write the logs to
|
||||
log_file_name: Optional[str] = None
|
||||
log_file_name = None
|
||||
|
||||
# Verbose will print full stack traces on warnings and errors
|
||||
verbose: bool = os.environ.get("TORCHDYNAMO_VERBOSE", "0") == "1"
|
||||
verbose = os.environ.get("TORCHDYNAMO_VERBOSE", "0") == "1"
|
||||
|
||||
# verify the correctness of optimized backend
|
||||
verify_correctness: bool = False
|
||||
verify_correctness = False
|
||||
|
||||
# need this many ops to create an FX graph
|
||||
minimum_call_count: int = 1
|
||||
minimum_call_count = 1
|
||||
|
||||
# turn on/off DCE pass
|
||||
dead_code_elimination: bool = True
|
||||
dead_code_elimination = True
|
||||
|
||||
# disable (for a function) when cache reaches this size
|
||||
cache_size_limit: int = 64
|
||||
cache_size_limit = 64
|
||||
|
||||
# whether or not to specialize on int inputs. This only has an effect with
|
||||
# dynamic_shapes; when dynamic_shapes is False, we ALWAYS specialize on int
|
||||
# inputs
|
||||
specialize_int: bool = False
|
||||
specialize_int = False
|
||||
|
||||
# Assume these functions return constants
|
||||
# Key is the function itself and value is the constant return value of
|
||||
# that function to substitute.
|
||||
constant_functions: Dict[Callable, Any] = {
|
||||
constant_functions = {
|
||||
torch.jit.is_scripting: False,
|
||||
torch.jit.is_tracing: False,
|
||||
torch._C._get_tracing_state: None,
|
||||
@ -54,29 +50,29 @@ constant_functions: Dict[Callable, Any] = {
|
||||
}
|
||||
|
||||
# don't specialize on shapes and strides and put shape ops in graph
|
||||
dynamic_shapes: bool = os.environ.get("TORCHDYNAMO_DYNAMIC_SHAPES") == "1"
|
||||
dynamic_shapes = os.environ.get("TORCHDYNAMO_DYNAMIC_SHAPES") == "1"
|
||||
|
||||
# This is a temporarily flag, which changes the behavior of dynamic_shapes=True.
|
||||
# When assume_static_by_default is True, we only allocate symbols for shapes marked dynamic via mark_dynamic.
|
||||
# NOTE - this flag can be removed once we can run dynamic_shapes=False w/ the mark_dynamic API
|
||||
# see [Note - on the state of mark_dynamic]
|
||||
assume_static_by_default: bool = True
|
||||
assume_static_by_default = True
|
||||
|
||||
# This flag changes how dynamic_shapes=True works, and is meant to be used in conjunction
|
||||
# with assume_static_by_default=True.
|
||||
# With this flag enabled, we always compile a frame as fully static for the first time, and, if we fail
|
||||
# any guards due to wobbles in shape, we recompile with *all* the wobbled shapes as being marked dynamic.
|
||||
automatic_dynamic_shapes: bool = True
|
||||
automatic_dynamic_shapes = True
|
||||
|
||||
# Typically, if you mark_dynamic a dimension, we will error if the dimension
|
||||
# actually ended up getting specialized. This knob changes the behavior so
|
||||
# that we don't error at all. This is helpful for our CI where I'm using a
|
||||
# heuristic to mark batch dimensions as dynamic and the heuristic may get it
|
||||
# wrong.
|
||||
allow_ignore_mark_dynamic: bool = False
|
||||
allow_ignore_mark_dynamic = False
|
||||
|
||||
# Set this to False to assume nn.Modules() contents are immutable (similar assumption as freezing)
|
||||
guard_nn_modules: bool = False
|
||||
guard_nn_modules = False
|
||||
|
||||
# This feature doesn't really work. We offer this flag for experimental
|
||||
# purposes / if you want to help us build out support.
|
||||
@ -94,38 +90,37 @@ guard_nn_modules: bool = False
|
||||
# We do NOT currently support __torch_dispatch__. The implementation is
|
||||
# currently buggy, the main show stopper for nontrivial use is
|
||||
# https://github.com/pytorch/torchdynamo/issues/1952
|
||||
traceable_tensor_subclasses: Set[Type] = set()
|
||||
traceable_tensor_subclasses = set()
|
||||
|
||||
# Suppress errors in torch._dynamo.optimize, instead forcing a fallback to eager.
|
||||
# This is a good way to get your model to work one way or another, but you may
|
||||
# lose optimization opportunities this way. Devs, if your benchmark model is failing
|
||||
# this way, you should figure out why instead of suppressing it.
|
||||
suppress_errors: bool = bool(os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", False))
|
||||
suppress_errors = bool(os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", False))
|
||||
|
||||
# Record and write an execution record of the current frame to a file
|
||||
# if an exception is encountered
|
||||
replay_record_enabled: bool = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
|
||||
replay_record_enabled = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
|
||||
|
||||
# Rewrite assert statement in python with torch._assert
|
||||
rewrite_assert_with_torch_assert: bool = True
|
||||
rewrite_assert_with_torch_assert = True
|
||||
|
||||
# Show a warning on every graph break
|
||||
print_graph_breaks: bool = False
|
||||
|
||||
print_graph_breaks = False
|
||||
|
||||
# Show a warning for every specialization
|
||||
print_specializations: bool = False
|
||||
print_specializations = False
|
||||
|
||||
# Simplify guards, summarizing static and dynamic constraints on dimensions.
|
||||
# NOTE: This only has an effect when dynamic_shapes=True.
|
||||
summarize_dim_constraints: bool = False
|
||||
summarize_dim_constraints = False
|
||||
|
||||
# Disable dynamo
|
||||
disable: bool = os.environ.get("TORCH_COMPILE_DISABLE", False)
|
||||
disable = os.environ.get("TORCH_COMPILE_DISABLE", False)
|
||||
|
||||
# If a PyTorch module is in this allowlist, torchdynamo will be allowed
|
||||
# to inline objects from it or its children.
|
||||
skipfiles_inline_module_allowlist: Set[ModuleType] = {
|
||||
skipfiles_inline_module_allowlist = {
|
||||
torch.nn,
|
||||
torch.distributions,
|
||||
torch.testing,
|
||||
@ -140,7 +135,7 @@ skipfiles_inline_module_allowlist: Set[ModuleType] = {
|
||||
# the `allowed_functions.is_allowed` function will not consider it
|
||||
# when creating a list of PyTorch functions that will appear in
|
||||
# FX IR.
|
||||
allowed_functions_module_string_ignorelist: Set[str] = {
|
||||
allowed_functions_module_string_ignorelist = {
|
||||
"torch.distributions",
|
||||
"torch.testing",
|
||||
"torch._refs",
|
||||
@ -148,18 +143,17 @@ allowed_functions_module_string_ignorelist: Set[str] = {
|
||||
"torch._decomp",
|
||||
}
|
||||
|
||||
|
||||
# Debug Flag to try minifier at different stages. Possible values are {None, "aot", "dynamo"}
|
||||
# None - Minifier is switched off
|
||||
# dynamo - Runs minifier on the TorchDynamo produced graphs, if compilation fails
|
||||
# aot - Runs minifier on the Aot Autograd produced graphs, if compilation fails
|
||||
repro_after: Optional[str] = os.environ.get("TORCHDYNAMO_REPRO_AFTER", None)
|
||||
repro_after = os.environ.get("TORCHDYNAMO_REPRO_AFTER", None)
|
||||
# Compiler compilation debug info
|
||||
# 1: Dumps the original graph out to repro.py if compilation fails
|
||||
# 2: Dumps a minifier_launcher.py if compilation fails.
|
||||
# 3: Always dumps a minifier_launcher.py. Good for segfaults.
|
||||
# 4: Dumps a minifier_launcher.py if the accuracy fails.
|
||||
repro_level: int = int(os.environ.get("TORCHDYNAMO_REPRO_LEVEL", 2))
|
||||
repro_level = int(os.environ.get("TORCHDYNAMO_REPRO_LEVEL", 2))
|
||||
|
||||
# By default, we try to detect accuracy failure by running both forward
|
||||
# and backward of a torchdynamo produced graph (if you are using repro_after
|
||||
@ -169,93 +163,95 @@ repro_level: int = int(os.environ.get("TORCHDYNAMO_REPRO_LEVEL", 2))
|
||||
# backwards step
|
||||
# TODO: Detect this situation automatically so the user doesn't need
|
||||
# to manually configure this
|
||||
repro_forward_only: bool = os.environ.get("TORCHDYNAMO_REPRO_FORWARD_ONLY") == "1"
|
||||
repro_forward_only = os.environ.get("TORCHDYNAMO_REPRO_FORWARD_ONLY") == "1"
|
||||
|
||||
# The tolerance we should use when testing if a compiled graph
|
||||
# has diverged so that we should treat it as an accuracy failure
|
||||
repro_tolerance: float = 1e-3
|
||||
repro_tolerance = 1e-3
|
||||
|
||||
# If True, when testing if two models are the same, we will test them against
|
||||
# a third fp64 reference and only report a problem if the RMSE relative to the
|
||||
# fp64 is greater. However, this will use more memory; you may disable this
|
||||
# if memory usage is too high.
|
||||
same_two_models_use_fp64: bool = True
|
||||
same_two_models_use_fp64 = True
|
||||
|
||||
# Not all backends support scalars. Some calls on torch.Tensor (like .item()) return a scalar type.
|
||||
# When this flag is set to False, we introduce a graph break instead of capturing.
|
||||
# This requires dynamic_shapes to be True.
|
||||
capture_scalar_outputs: bool = False
|
||||
capture_scalar_outputs = False
|
||||
|
||||
# Not all backends support operators that have dynamic output shape (e.g.,
|
||||
# nonzero, unique). When this flag is set to False, we introduce a graph
|
||||
# break instead of capturing. This requires dynamic_shapes to be True.
|
||||
# If you set this to True, you probably also want capture_scalar_outputs
|
||||
# (these are separated for historical reasons).
|
||||
capture_dynamic_output_shape_ops: bool = False
|
||||
capture_dynamic_output_shape_ops = False
|
||||
|
||||
# Should almost always be true in prod. This relaxes the requirement that cond's true_fn and
|
||||
# false_fn produces code with identical guards.
|
||||
enforce_cond_guards_match: bool = True
|
||||
enforce_cond_guards_match = True
|
||||
|
||||
# Automatically split model graph into pieces to match DDP bucket sizes
|
||||
# to allow DDP comm/compute overlap. Disable to allow DDP models to
|
||||
# run without graph-breaks, but also without comm/compute overlap.
|
||||
# set torch._dynamo.config.log_level to INFO or DEBUG for more info
|
||||
# about optimize_ddp behavior.
|
||||
optimize_ddp: bool = True
|
||||
optimize_ddp = True
|
||||
|
||||
# Whether to skip guarding on FSDP-managed modules
|
||||
skip_fsdp_guards: bool = True
|
||||
skip_fsdp_guards = True
|
||||
|
||||
# Make dynamo skip guarding on hooks on nn modules
|
||||
# Note: unsafe: if your model actually has hooks and you remove them, or doesn't and you add them,
|
||||
# dynamo will not notice and will execute whichever version you first compiled.
|
||||
skip_nnmodule_hook_guards: bool = True
|
||||
skip_nnmodule_hook_guards = True
|
||||
|
||||
# If True, raises exception if TorchDynamo is called with a context manager
|
||||
raise_on_ctx_manager_usage: bool = True
|
||||
raise_on_ctx_manager_usage = True
|
||||
|
||||
# If True, raise when aot autograd is unsafe to use
|
||||
raise_on_unsafe_aot_autograd: bool = False
|
||||
raise_on_unsafe_aot_autograd = False
|
||||
|
||||
# Throw an error if backend changes without reset
|
||||
raise_on_backend_change: bool = False
|
||||
raise_on_backend_change = False
|
||||
|
||||
# If true, error with a better message if we symbolically trace over a
|
||||
# dynamo-optimized function. If false, silently suppress dynamo.
|
||||
error_on_nested_fx_trace: bool = True
|
||||
error_on_nested_fx_trace = True
|
||||
|
||||
# Disables graph breaking on rnn. YMMV with backends.
|
||||
allow_rnn: bool = False
|
||||
allow_rnn = False
|
||||
|
||||
# If true, error if we try to compile a function that has
|
||||
# been seen before.
|
||||
error_on_recompile: bool = False
|
||||
error_on_recompile = False
|
||||
|
||||
# reports why guards fail. Useful to identify the guards failing frequently and
|
||||
# causing recompilations.
|
||||
report_guard_failures: bool = os.environ.get("TORCHDYNAMO_REPORT_GUARD_FAILURES") == "1"
|
||||
report_guard_failures = os.environ.get("TORCHDYNAMO_REPORT_GUARD_FAILURES") == "1"
|
||||
|
||||
# root folder of the project
|
||||
base_dir: str = dirname(dirname(dirname(abspath(__file__))))
|
||||
base_dir = dirname(dirname(dirname(abspath(__file__))))
|
||||
|
||||
# trace through numpy ndarray as tensor and try to translate numpy function to torch function.
|
||||
numpy_ndarray_as_tensor: bool = False
|
||||
numpy_ndarray_as_tensor = False
|
||||
|
||||
|
||||
DEBUG_DIR_VAR_NAME: str = "TORCH_COMPILE_DEBUG_DIR"
|
||||
def is_fbcode():
|
||||
return not hasattr(torch.version, "git_version")
|
||||
|
||||
|
||||
debug_dir_root: str = ""
|
||||
DEBUG_DIR_VAR_NAME = "TORCH_COMPILE_DEBUG_DIR"
|
||||
|
||||
if DEBUG_DIR_VAR_NAME in os.environ:
|
||||
debug_dir_root = os.path.join(os.environ[DEBUG_DIR_VAR_NAME], "torch_compile_debug")
|
||||
elif ConfigMixin.is_fbcode():
|
||||
elif is_fbcode():
|
||||
debug_dir_root = os.path.join(tempfile.gettempdir(), "torch_compile_debug")
|
||||
else:
|
||||
debug_dir_root = os.path.join(os.getcwd(), "torch_compile_debug")
|
||||
|
||||
|
||||
_save_config_ignore: Set[str] = {
|
||||
_save_config_ignore = {
|
||||
"repro_after",
|
||||
"repro_level",
|
||||
# workaround: "cannot pickle PyCapsule"
|
||||
@ -264,9 +260,9 @@ _save_config_ignore: Set[str] = {
|
||||
"skipfiles_inline_module_allowlist",
|
||||
}
|
||||
|
||||
capture_autograd_function: bool = True
|
||||
capture_autograd_function = True
|
||||
|
||||
_autograd_backward_strict_mode_banned_ops: List[str] = [
|
||||
_autograd_backward_strict_mode_banned_ops = [
|
||||
"stride",
|
||||
"requires_grad",
|
||||
"storage_offset",
|
||||
@ -278,6 +274,7 @@ _autograd_backward_strict_mode_banned_ops.extend(
|
||||
[name for name, _ in inspect.getmembers(torch.Tensor) if re.match(r"^is_.*", name)]
|
||||
)
|
||||
|
||||
from torch._config_utils import install_config_module
|
||||
|
||||
install_config_module("DynamoConfig", sys.modules[__name__])
|
||||
from .config_utils import install_config_module
|
||||
|
||||
install_config_module(sys.modules[__name__])
|
||||
|
@ -1,10 +1,225 @@
|
||||
import contextlib
|
||||
|
||||
import pickle
|
||||
import unittest
|
||||
from types import FunctionType, ModuleType
|
||||
from typing import Any, Dict, Set
|
||||
from unittest import mock
|
||||
|
||||
# Types saved/loaded in configs
|
||||
CONFIG_TYPES = (int, float, bool, type(None), str, list, set, tuple, dict)
|
||||
|
||||
|
||||
def install_config_module(module):
|
||||
"""
|
||||
Converts a module-level config into a `ConfigModule()`
|
||||
"""
|
||||
|
||||
class ConfigModuleInstance(ConfigModule):
|
||||
_bypass_keys = set()
|
||||
|
||||
def visit(source, dest, prefix):
|
||||
"""Walk the module structure and move everything to module._config"""
|
||||
for key, value in list(source.__dict__.items()):
|
||||
if key.startswith("__") or isinstance(value, (ModuleType, FunctionType)):
|
||||
continue
|
||||
|
||||
name = f"{prefix}{key}"
|
||||
if isinstance(value, CONFIG_TYPES):
|
||||
config[name] = value
|
||||
default[name] = value
|
||||
if dest is module:
|
||||
delattr(module, key)
|
||||
elif isinstance(value, type):
|
||||
assert value.__module__ == module.__name__
|
||||
# a subconfig with `class Blah:` syntax
|
||||
proxy = SubConfigProxy(module, f"{name}.")
|
||||
visit(value, proxy, f"{name}.")
|
||||
setattr(dest, key, proxy)
|
||||
else:
|
||||
raise AssertionError(f"Unhandled config {key}={value} ({type(value)})")
|
||||
|
||||
config = dict()
|
||||
default = dict()
|
||||
visit(module, module, "")
|
||||
module._config = config
|
||||
module._default = default
|
||||
module._allowed_keys = set(config.keys())
|
||||
module.__class__ = ConfigModuleInstance
|
||||
|
||||
|
||||
class ConfigModule(ModuleType):
|
||||
# The default values of the configuration settings. This can be used to
|
||||
# determine if the config has been changed or not.
|
||||
_default: Dict[str, Any]
|
||||
# The actual configuration settings. E.g., torch._dynamo.config.debug
|
||||
# would live as "debug" in the key, and torch._inductor.config.triton.cudagraphs
|
||||
# maps as "triton.cudagraphs"
|
||||
_config: Dict[str, Any]
|
||||
_allowed_keys: Set[str]
|
||||
_bypass_keys: Set[str]
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError(
|
||||
f"use {__name__}.install_config_module(sys.modules[__name__])"
|
||||
)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name in self._bypass_keys:
|
||||
super().__setattr__(name, value)
|
||||
elif name not in self._allowed_keys:
|
||||
raise AttributeError(f"{self.__name__}.{name} does not exist")
|
||||
else:
|
||||
self._config[name] = value
|
||||
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
return self._config[name]
|
||||
except KeyError:
|
||||
# make hasattr() work properly
|
||||
raise AttributeError(f"{self.__name__}.{name} does not exist")
|
||||
|
||||
def __delattr__(self, name):
|
||||
# must support delete because unittest.mock.patch deletes
|
||||
# then recreate things
|
||||
del self._config[name]
|
||||
|
||||
def save_config(self):
|
||||
"""Convert config to a pickled blob"""
|
||||
config = dict(self._config)
|
||||
for key in config.get("_save_config_ignore", ()):
|
||||
config.pop(key)
|
||||
return pickle.dumps(config, protocol=2)
|
||||
|
||||
def codegen_config(self):
|
||||
"""Convert config to Python statements that replicate current config.
|
||||
This does NOT include config settings that are at default values.
|
||||
"""
|
||||
lines = []
|
||||
mod = self.__name__
|
||||
for k, v in self._config.items():
|
||||
if k in self._config.get("_save_config_ignore", ()):
|
||||
continue
|
||||
if v == self._default[k]:
|
||||
continue
|
||||
lines.append(f"{mod}.{k} = {v!r}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def load_config(self, data):
|
||||
"""Restore from a prior call to save_config()"""
|
||||
self.to_dict().update(pickle.loads(data))
|
||||
|
||||
def to_dict(self):
|
||||
return self._config
|
||||
|
||||
def patch(self, arg1=None, arg2=None, **kwargs):
|
||||
"""
|
||||
Decorator and/or context manager to make temporary changes to a config.
|
||||
|
||||
As a decorator:
|
||||
|
||||
@config.patch("name", val)
|
||||
@config.patch(name1=val1, name2=val2):
|
||||
@config.patch({"name1": val1, "name2", val2})
|
||||
def foo(...):
|
||||
...
|
||||
|
||||
As a context manager:
|
||||
|
||||
with config.patch("name", val):
|
||||
...
|
||||
"""
|
||||
if arg1 is not None:
|
||||
if arg2 is not None:
|
||||
# patch("key", True) syntax
|
||||
changes = {arg1: arg2}
|
||||
else:
|
||||
# patch({"key": True}) syntax
|
||||
changes = arg1
|
||||
assert not kwargs
|
||||
else:
|
||||
# patch(key=True) syntax
|
||||
changes = kwargs
|
||||
assert arg2 is None
|
||||
assert isinstance(changes, dict), f"expected `dict` got {type(changes)}"
|
||||
prior = {}
|
||||
config = self
|
||||
|
||||
class ConfigPatch(ContextDecorator):
|
||||
def __enter__(self):
|
||||
assert not prior
|
||||
for key in changes.keys():
|
||||
# KeyError on invalid entry
|
||||
prior[key] = config._config[key]
|
||||
config._config.update(changes)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
config._config.update(prior)
|
||||
prior.clear()
|
||||
|
||||
return ConfigPatch()
|
||||
|
||||
|
||||
class ContextDecorator(contextlib.ContextDecorator):
|
||||
"""
|
||||
Same as contextlib.ContextDecorator, but with support for
|
||||
`unittest.TestCase`
|
||||
"""
|
||||
|
||||
def __call__(self, func):
|
||||
if isinstance(func, type) and issubclass(func, unittest.TestCase):
|
||||
|
||||
class _TestCase(func):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
self.__enter__()
|
||||
try:
|
||||
super().setUpClass()
|
||||
except Exception:
|
||||
self.__exit__(None, None, None)
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
super().tearDownClass()
|
||||
finally:
|
||||
self.__exit__(None, None, None)
|
||||
|
||||
_TestCase.__name__ = func.__name__
|
||||
_TestCase.__qualname__ = func.__qualname__
|
||||
_TestCase.__module__ = func.__module__
|
||||
|
||||
return _TestCase
|
||||
|
||||
return super().__call__(func)
|
||||
|
||||
|
||||
class SubConfigProxy:
|
||||
"""
|
||||
Shim to redirect to main config.
|
||||
`config.triton.cudagraphs` maps to _config["triton.cudagraphs"]
|
||||
"""
|
||||
|
||||
def __init__(self, config, prefix):
|
||||
# `super().__setattr__` to bypass custom `__setattr__`
|
||||
super().__setattr__("_config", config)
|
||||
super().__setattr__("_prefix", prefix)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
return self._config.__setattr__(self._prefix + name, value)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return self._config.__getattr__(self._prefix + name)
|
||||
|
||||
def __delattr__(self, name):
|
||||
return self._config.__delattr__(self._prefix + name)
|
||||
|
||||
|
||||
def patch_object(obj, name, value):
|
||||
"""
|
||||
Workaround `mock.patch.object` issue with ConfigModule
|
||||
"""
|
||||
if isinstance(obj, ConfigMixin):
|
||||
if isinstance(obj, ConfigModule):
|
||||
return obj.patch(name, value)
|
||||
return mock.patch.object(obj, name, value)
|
||||
|
@ -1,4 +1,3 @@
|
||||
import copy
|
||||
import dataclasses
|
||||
import io
|
||||
import logging
|
||||
@ -98,8 +97,8 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
|
||||
|
||||
# NB: Can't use save_config because that will omit some fields,
|
||||
# but we must save and reset ALL fields
|
||||
dynamo_config = copy.deepcopy(torch._dynamo.config)
|
||||
inductor_config = copy.deepcopy(torch._inductor.config)
|
||||
dynamo_config = torch._dynamo.config._config.copy()
|
||||
inductor_config = torch._inductor.config._config.copy()
|
||||
try:
|
||||
stderr = io.StringIO()
|
||||
log_handler = logging.StreamHandler(stderr)
|
||||
@ -123,8 +122,8 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
|
||||
# around
|
||||
torch._dynamo.reset()
|
||||
finally:
|
||||
torch._dynamo.config.__dict__.update(dynamo_config.__dict__)
|
||||
torch._inductor.config.__dict__.update(inductor_config.__dict__)
|
||||
object.__setattr__(torch._dynamo.config, "_config", dynamo_config)
|
||||
object.__setattr__(torch._inductor.config, "_config", inductor_config)
|
||||
|
||||
# TODO: return a more appropriate data structure here
|
||||
return subprocess.CompletedProcess(
|
||||
|
@ -3,7 +3,3 @@
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from . import config
|
||||
|
||||
FuncTorchConfig = type(config)
|
||||
__all__ = ['config', 'FunctorchConfig']
|
||||
|
@ -9,30 +9,32 @@ Global flags for aot autograd
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
from typing import Union
|
||||
|
||||
# Converts torch rng ops to their functional philox rng equivalents. Note that
|
||||
# we functionalize only CUDA rng ops today.
|
||||
functionalize_rng_ops: bool = False
|
||||
functionalize_rng_ops = False
|
||||
|
||||
# can be useful for debugging if we are incorrectly creating meta fake tensors
|
||||
fake_tensor_allow_meta: Union[str, bool] = os.environ.get("FAKE_ALLOW_META", True)
|
||||
fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", True)
|
||||
|
||||
# Enables optional asserts in hotpath code to check for errors. If
|
||||
# you are seeing weird accuracy problems, try turning this on.
|
||||
# This is currently off by default as it will harm tracing time,
|
||||
# but it is on by default for aot_eager.
|
||||
debug_assert: bool = False
|
||||
debug_assert = False
|
||||
|
||||
debug_partitioner: Union[str, bool] = os.environ.get("AOT_PARTITIONER_DEBUG", False)
|
||||
debug_partitioner = os.environ.get("AOT_PARTITIONER_DEBUG", False)
|
||||
|
||||
static_weight_shapes: bool = True
|
||||
static_weight_shapes = True
|
||||
|
||||
# Applies CSE to the graph before partitioning
|
||||
cse: bool = True
|
||||
cse = True
|
||||
|
||||
# Restricts the amount of computation AOTAutograd can do.
|
||||
max_dist_from_bw: int = 3
|
||||
max_dist_from_bw = 3
|
||||
|
||||
from torch._config_utils import install_config_module
|
||||
install_config_module('FunctorchConfig', sys.modules[__name__])
|
||||
|
||||
from .._dynamo.config_utils import install_config_module
|
||||
|
||||
# adds patch, save_config, invalid config checks, etc
|
||||
install_config_module(sys.modules[__name__])
|
||||
|
@ -1,10 +1,8 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch.fx
|
||||
from . import config
|
||||
|
||||
__all__ = ["compile", "list_mode_options", "list_options", "cudagraph_mark_step_begin"]
|
||||
InductorConfig = type(config)
|
||||
|
||||
|
||||
def compile(
|
||||
|
@ -1,157 +1,149 @@
|
||||
import dataclasses
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch._config_utils import ConfigMixin
|
||||
|
||||
# add some debug printouts
|
||||
debug: bool = False
|
||||
debug = False
|
||||
|
||||
# Whether to disable a progress bar for autotuning
|
||||
disable_progress: bool = True
|
||||
disable_progress = True
|
||||
|
||||
# Whether to enable printing the source code for each future
|
||||
verbose_progress: bool = False
|
||||
verbose_progress = False
|
||||
|
||||
# use cpp wrapper instead of python wrapper
|
||||
cpp_wrapper: bool = False
|
||||
cpp_wrapper = False
|
||||
|
||||
# dead code elimination
|
||||
dce: bool = False
|
||||
dce = False
|
||||
|
||||
# assume weight tensors are fixed size
|
||||
static_weight_shapes: bool = True
|
||||
static_weight_shapes = True
|
||||
|
||||
# put correctness assertions in generated code
|
||||
size_asserts: bool = os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS", "1") == "1"
|
||||
size_asserts = os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS", "1") == "1"
|
||||
|
||||
# enable loop reordering based on input orders
|
||||
pick_loop_orders: bool = True
|
||||
pick_loop_orders = True
|
||||
|
||||
# generate inplace computations
|
||||
inplace_buffers: bool = True
|
||||
inplace_buffers = True
|
||||
|
||||
# allow reusing buffers for more efficient memory use
|
||||
allow_buffer_reuse: bool = True
|
||||
allow_buffer_reuse = True
|
||||
|
||||
# codegen benchmark harness
|
||||
benchmark_harness: bool = True
|
||||
benchmark_harness = True
|
||||
|
||||
# fuse pointwise into templates
|
||||
epilogue_fusion: bool = True
|
||||
epilogue_fusion = True
|
||||
|
||||
# do epilogue fusions before other fusions
|
||||
epilogue_fusion_first: bool = False
|
||||
epilogue_fusion_first = False
|
||||
|
||||
# enable pattern match+replace optimizations
|
||||
pattern_matcher: bool = True
|
||||
pattern_matcher = True
|
||||
|
||||
# Optimize away split cat patterns (Experimental)
|
||||
split_cat_fx_passes: bool = True
|
||||
split_cat_fx_passes = True
|
||||
|
||||
# enable reordering pass
|
||||
reordering: bool = True
|
||||
reordering = True
|
||||
|
||||
# enable slow autotuning passes to select algorithms
|
||||
max_autotune: bool = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1"
|
||||
max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1"
|
||||
|
||||
# enable slow autotuning passes to select pointwise/reductions algorithms
|
||||
max_autotune_pointwise: bool = (
|
||||
os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE") == "1"
|
||||
)
|
||||
max_autotune_pointwise = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE") == "1"
|
||||
|
||||
# enable slow autotuning passes to select gemm algorithms
|
||||
max_autotune_gemm: bool = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_GEMM") == "1"
|
||||
max_autotune_gemm = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_GEMM") == "1"
|
||||
|
||||
# enable searching global and local cache regardless of `max_autotune`
|
||||
search_autotune_cache: bool = (
|
||||
os.environ.get("TORCHINDUCTOR_SEARCH_AUTOTUNE_CACHE") == "1"
|
||||
)
|
||||
search_autotune_cache = os.environ.get("TORCHINDUCTOR_SEARCH_AUTOTUNE_CACHE") == "1"
|
||||
|
||||
# We will disable creating subprocess for autotuning if this is False
|
||||
autotune_in_subproc: bool = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1"
|
||||
autotune_in_subproc = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1"
|
||||
|
||||
coordinate_descent_tuning: bool = (
|
||||
coordinate_descent_tuning = (
|
||||
os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_TUNING") == "1"
|
||||
)
|
||||
|
||||
layout_optimization: bool = (
|
||||
os.environ.get("TORCHINDUCTOR_LAYOUT_OPTIMIZATION", "1") == "1"
|
||||
)
|
||||
layout_optimization = os.environ.get("TORCHINDUCTOR_LAYOUT_OPTIMIZATION", "1") == "1"
|
||||
|
||||
# Whether to keep the output strides the same as eager after layout optimization.
|
||||
keep_output_stride: bool = (
|
||||
os.environ.get("TORCHINDUCTOR_KEEP_OUTPUT_STRIDE", "1") == "1"
|
||||
)
|
||||
keep_output_stride = os.environ.get("TORCHINDUCTOR_KEEP_OUTPUT_STRIDE", "1") == "1"
|
||||
|
||||
# Enabling this will let compiler print warning messages if a generated triton
|
||||
# kernel has inputs with mixed layouts. This is helpful for perf debugging
|
||||
# since kernel with mixed layout inputs may run much slower then one whose inputs
|
||||
# have uniform layouts.
|
||||
warn_mix_layout: bool = os.environ.get("TORCHINDUCTOR_WARN_MIX_LAYOUT") == "1"
|
||||
warn_mix_layout = os.environ.get("TORCHINDUCTOR_WARN_MIX_LAYOUT") == "1"
|
||||
|
||||
# control store vs recompute heuristic
|
||||
# For fanouts, rematerialization can lead to exponential blowup. So, have
|
||||
# smaller threshold
|
||||
realize_reads_threshold: int = 4
|
||||
realize_bytes_threshold: int = 2000
|
||||
realize_reads_threshold = 4
|
||||
realize_bytes_threshold = 2000
|
||||
|
||||
# Threshold to prevent excessive accumulation of ops in one buffer during lowering
|
||||
realize_acc_reads_threshold: int = 8
|
||||
realize_acc_reads_threshold = 8
|
||||
|
||||
# fallback to eager for random/dropout, this is slow but useful for debugging
|
||||
fallback_random: bool = False
|
||||
fallback_random = False
|
||||
|
||||
# automatically create fallbacks when encountering an unhandled op
|
||||
implicit_fallbacks: bool = True
|
||||
implicit_fallbacks = True
|
||||
|
||||
# fuse even in cases without common reads
|
||||
aggressive_fusion: bool = False
|
||||
aggressive_fusion = False
|
||||
|
||||
# how many nodes to allow into a single fusion
|
||||
max_fusion_size: int = 64
|
||||
max_fusion_size = 64
|
||||
|
||||
# replace small reductions with pointwise, disable with `= 1`
|
||||
unroll_reductions_threshold: int = 8
|
||||
unroll_reductions_threshold = 8
|
||||
|
||||
# Add extra comments to output code (causes compile cache misses)
|
||||
comment_origin: bool = False
|
||||
comment_origin = False
|
||||
|
||||
# Convert 1x1 convs into matmuls
|
||||
conv_1x1_as_mm: bool = False
|
||||
conv_1x1_as_mm = False
|
||||
|
||||
# Enable split reductions for better utilization when the dimension
|
||||
# being reduced over is large (by splitting it)
|
||||
split_reductions: bool = True
|
||||
split_reductions = True
|
||||
|
||||
benchmark_kernel: bool = os.environ.get("TORCHINDUCTOR_BENCHMARK_KERNEL", "0") == "1"
|
||||
benchmark_kernel = os.environ.get("TORCHINDUCTOR_BENCHMARK_KERNEL", "0") == "1"
|
||||
|
||||
# Enable constant and index_expr folding
|
||||
constant_and_index_propagation: bool = True
|
||||
constant_and_index_propagation = True
|
||||
|
||||
# Enable indirect_indexing asserts for decompositions and lowerings
|
||||
debug_index_asserts: bool = False
|
||||
debug_index_asserts = False
|
||||
|
||||
|
||||
def is_fbcode():
|
||||
return not hasattr(torch.version, "git_version")
|
||||
|
||||
|
||||
# warnings intended for PyTorch developers, disable for point releases
|
||||
is_nightly_or_source: bool = "dev" in torch.__version__ or "git" in torch.__version__
|
||||
developer_warnings: bool = ConfigMixin.is_fbcode() or is_nightly_or_source
|
||||
is_nightly_or_source = "dev" in torch.__version__ or "git" in torch.__version__
|
||||
developer_warnings = is_fbcode() or is_nightly_or_source
|
||||
|
||||
|
||||
def decide_compile_threads():
|
||||
"""
|
||||
Here are the precedence to decide compile_threads
|
||||
1. User can override it by TORCHINDUCTOR_COMPILE_THREADS.
|
||||
One may want to disable async compiling by
|
||||
1. User can override it by TORCHINDUCTOR_COMPILE_THREADS. One may want to disable async compiling by
|
||||
setting this to 1 to make pdb happy.
|
||||
2. Set to 1 if it's win32 platform or it's a fbcode build
|
||||
3. decide by the number of CPU cores
|
||||
"""
|
||||
if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ:
|
||||
return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"])
|
||||
elif sys.platform == "win32" or ConfigMixin.is_fbcode():
|
||||
elif sys.platform == "win32" or is_fbcode():
|
||||
return 1
|
||||
else:
|
||||
return min(
|
||||
@ -162,63 +154,64 @@ def decide_compile_threads():
|
||||
)
|
||||
|
||||
|
||||
compile_threads: int = decide_compile_threads()
|
||||
compile_threads = decide_compile_threads()
|
||||
|
||||
# gemm autotuning global cache dir
|
||||
|
||||
global_cache_dir: Optional[str] = "fb/cache" if ConfigMixin.is_fbcode() else None
|
||||
if is_fbcode():
|
||||
global_cache_dir = "fb/cache"
|
||||
else:
|
||||
global_cache_dir = None
|
||||
|
||||
# If kernel is fused, the name is generated from the origin node op names
|
||||
# for larger kernels limit this
|
||||
kernel_name_max_ops: int = 10
|
||||
kernel_name_max_ops = 10
|
||||
|
||||
# Pad input tensors of matmul/bmm/addmm to leverage Tensor Cores in NVIDIA GPUs
|
||||
shape_padding: bool = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "1") == "1"
|
||||
shape_padding = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "1") == "1"
|
||||
|
||||
# Fx-based linear/matmul/bmm + permute/transpose vertical fusion
|
||||
permute_fusion: bool = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1"
|
||||
permute_fusion = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1"
|
||||
|
||||
# Mark the wrapper call in PyTorch profiler
|
||||
profiler_mark_wrapper_call: bool = False
|
||||
profiler_mark_wrapper_call = False
|
||||
|
||||
# Generate hook calls to torch._inductor.hooks.run_intermediate_hooks for
|
||||
# every intermediate for which we can correlate it with an intermediate
|
||||
# from the original FX graph
|
||||
generate_intermediate_hooks: bool = False
|
||||
generate_intermediate_hooks = False
|
||||
|
||||
# Populate traceback field on IRNode; good for debugging why origin_node is
|
||||
# not populated, or finding out where an IRNode was constructed
|
||||
debug_ir_traceback: bool = False
|
||||
debug_ir_traceback = False
|
||||
|
||||
# used for debugging to make sure config is properly set
|
||||
_raise_error_for_testing: bool = False
|
||||
_raise_error_for_testing = False
|
||||
|
||||
_profile_var: str = os.environ.get("TORCHINDUCTOR_PROFILE", "")
|
||||
profile_bandwidth: bool = _profile_var != ""
|
||||
profile_bandwidth_regex: str = "" if _profile_var == "1" else _profile_var
|
||||
_profile_var = os.environ.get("TORCHINDUCTOR_PROFILE", "")
|
||||
profile_bandwidth = _profile_var != ""
|
||||
profile_bandwidth_regex = "" if _profile_var == "1" else _profile_var
|
||||
|
||||
disable_cpp_codegen: bool = ConfigMixin.is_fbcode()
|
||||
disable_cpp_codegen = is_fbcode()
|
||||
|
||||
|
||||
# config specific to codegen/cpp.pp
|
||||
@dataclasses.dataclass
|
||||
class CppConfig(ConfigMixin):
|
||||
# config specific to codegen/cpp.py
|
||||
class cpp:
|
||||
# set to torch.get_num_threads()
|
||||
threads: int = -1
|
||||
threads = -1
|
||||
|
||||
# Do not generate loops when the condition doesn't hold, like:
|
||||
# for(long i0=4096; i0<4096; i0+=1)
|
||||
no_redundant_loops: bool = True
|
||||
no_redundant_loops = True
|
||||
|
||||
# Assume number of threads is dynamic, don't specialize thread number.
|
||||
# Kernels don't recompile on thread number changes with this flag on.
|
||||
# For single-threaded workload, turning it on would incur a slight
|
||||
# performance degradation.
|
||||
dynamic_threads: bool = False
|
||||
dynamic_threads = False
|
||||
|
||||
simdlen: Optional[int] = None
|
||||
min_chunk_size: int = 4096
|
||||
cxx: Tuple[Optional[str]] = (
|
||||
simdlen = None
|
||||
min_chunk_size = 4096
|
||||
cxx = (
|
||||
None, # download gcc12 from conda-forge if conda is installed
|
||||
# "g++-12",
|
||||
# "g++-11",
|
||||
@ -228,160 +221,156 @@ class CppConfig(ConfigMixin):
|
||||
# "g++.par",
|
||||
)
|
||||
# Allow kernel performance profiling via PyTorch profiler
|
||||
enable_kernel_profile: bool = False
|
||||
enable_kernel_profile = False
|
||||
|
||||
# enable weight prepacking to get a better performance; may lead to large memory footprint
|
||||
weight_prepack: bool = True
|
||||
weight_prepack = True
|
||||
|
||||
# Inject a bug into our relu implementation; useful for testing our repro
|
||||
# extraction and minification functionality.
|
||||
# Valid values: "compile_error", "runtime_error", "accuracy"
|
||||
inject_relu_bug_TESTING_ONLY: Optional[str] = None
|
||||
inject_log1p_bug_TESTING_ONLY: Optional[str] = None
|
||||
inject_relu_bug_TESTING_ONLY = None
|
||||
inject_log1p_bug_TESTING_ONLY = None
|
||||
|
||||
# If None, autodetect whether or not AVX512/AVX2 can be used. Otherwise,
|
||||
# force usage as specified, without testing.
|
||||
vec_isa_ok: Optional[str] = None
|
||||
vec_isa_ok = None
|
||||
|
||||
# similar to config.triton.descriptive_names
|
||||
descriptive_names: str = "original_aten"
|
||||
descriptive_names = "original_aten"
|
||||
|
||||
# how many nodes to allow into a single horizontal fusion
|
||||
max_horizontal_fusion_size: int = 16
|
||||
max_horizontal_fusion_size = 16
|
||||
|
||||
|
||||
# config specific to codegen/triton.py
|
||||
@dataclasses.dataclass
|
||||
class TritonConfig(ConfigMixin):
|
||||
class triton:
|
||||
# Use cudagraphs on output code
|
||||
cudagraphs: bool = False
|
||||
cudagraphs = False
|
||||
|
||||
# Use cudagraph trees for memory pooling if `cudagraphs` is True
|
||||
cudagraph_trees: bool = not ConfigMixin.is_fbcode()
|
||||
cudagraph_trees = not is_fbcode()
|
||||
|
||||
# assertions not on the fast path, steady state
|
||||
slow_path_cudagraph_asserts: bool = True
|
||||
slow_path_cudagraph_asserts = True
|
||||
|
||||
# TODO - need to debug why this prevents cleanup
|
||||
cudagraph_trees_history_recording: bool = False
|
||||
cudagraph_trees_history_recording = False
|
||||
|
||||
# assertions on the fast path
|
||||
fast_path_cudagraph_asserts: bool = False
|
||||
fast_path_cudagraph_asserts = False
|
||||
|
||||
# skip warmup for cudagraph trees
|
||||
skip_cudagraph_warmup: bool = False
|
||||
skip_cudagraph_warmup = False
|
||||
|
||||
# Synchronize before and after every compiled graph.
|
||||
debug_sync_graph: bool = False
|
||||
debug_sync_graph = False
|
||||
|
||||
# Synchronize after every kernel launch, to help pinpoint bugs
|
||||
debug_sync_kernel: bool = False
|
||||
debug_sync_kernel = False
|
||||
|
||||
# Always load full blocks (rather than broadcasting inside the block)
|
||||
dense_indexing: bool = False
|
||||
dense_indexing = False
|
||||
|
||||
# limit tiling dimensions
|
||||
max_tiles: int = 2
|
||||
max_tiles = 2
|
||||
|
||||
# use triton.autotune for pointwise ops with complex layouts
|
||||
# this should only be disabled for debugging/testing
|
||||
autotune_pointwise: bool = True
|
||||
autotune_pointwise = True
|
||||
|
||||
# should we stop a fusion to allow better tiling?
|
||||
tiling_prevents_pointwise_fusion: bool = True
|
||||
tiling_prevents_reduction_fusion: bool = True
|
||||
tiling_prevents_pointwise_fusion = True
|
||||
tiling_prevents_reduction_fusion = True
|
||||
|
||||
# assert that indirect indexing does not read / write out of bounds
|
||||
assert_indirect_indexing: bool = True
|
||||
assert_indirect_indexing = True
|
||||
|
||||
# should we give different names to kernels
|
||||
# Note: This is orthogonal to descriptive_names - this is deciding whether
|
||||
# our triton kernel names should all be `triton_` (to maximize caching) or
|
||||
# whether they should be unique.
|
||||
unique_kernel_names: bool = (
|
||||
os.environ.get("TORCHINDUCTOR_UNIQUE_KERNEL_NAMES") == "1"
|
||||
)
|
||||
unique_kernel_names = os.environ.get("TORCHINDUCTOR_UNIQUE_KERNEL_NAMES") == "1"
|
||||
|
||||
# should we put op names in kernel names
|
||||
# False: No special names (just triton__1, triton__2, etc.)
|
||||
# "torch": Maps to the fx op in the Dynamo graph (module name, method name, etc.)
|
||||
# "original_aten": Maps to the highest-level aten op (i.e. pre-decompositions)
|
||||
# "inductor_node": Maps to the node name in the FX graph passed to Inductor
|
||||
descriptive_names: str = "original_aten"
|
||||
descriptive_names = "original_aten"
|
||||
|
||||
# use alternate codegen for smaller reductions
|
||||
persistent_reductions: bool = (
|
||||
persistent_reductions = (
|
||||
os.environ.get("TORCHINDUCTOR_PERSISTENT_REDUCTIONS", "1") == "1"
|
||||
)
|
||||
|
||||
# hint to Triton when arguments are divisible by 16
|
||||
divisible_by_16: bool = True
|
||||
divisible_by_16 = True
|
||||
|
||||
# theses are not enforced, but they are used by asserts in triton_heuristics.py
|
||||
# NOTE: mobilevit_s in timm_models required X to be set to the higher value 2048
|
||||
max_block: Dict[str, int] = dataclasses.field(
|
||||
default_factory=lambda: {"X": 2048, "Y": 1024, "Z": 1024}
|
||||
)
|
||||
max_block = {"X": 2048, "Y": 1024, "Z": 1024}
|
||||
|
||||
# Store the generated cubin files for cpp wrapper code to load
|
||||
store_cubin: bool = False
|
||||
store_cubin = False
|
||||
|
||||
# the max number of spills we allow for the configs we benchmark.
|
||||
# Setting this to 0 means we skip a config if it spills even a single
|
||||
# register.
|
||||
# Settting it to a larger value allows a config spilling a small amount
|
||||
# of registers being benchmarked.
|
||||
spill_threshold: int = 0
|
||||
spill_threshold = 0
|
||||
|
||||
# Inject a bug into our relu implementation; useful for testing our repro
|
||||
# extraction and minification functionality.
|
||||
# Valid values: "compile_error", "runtime_error", "accuracy"
|
||||
inject_relu_bug_TESTING_ONLY: Optional[str] = None
|
||||
inject_relu_bug_TESTING_ONLY = None
|
||||
|
||||
|
||||
# create a directory containing lots of debug information
|
||||
@dataclasses.dataclass
|
||||
class TraceConfig(ConfigMixin):
|
||||
class trace:
|
||||
# master switch for all debugging flags below
|
||||
enabled: bool = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
|
||||
enabled = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
|
||||
|
||||
# Save python logger call >=logging.DEBUG
|
||||
debug_log: bool = False
|
||||
debug_log = False
|
||||
|
||||
# Save python logger call >=logging.INFO
|
||||
info_log: bool = False
|
||||
info_log = False
|
||||
|
||||
# Save input FX graph (post decomps, pre optimization)
|
||||
fx_graph: bool = True
|
||||
fx_graph = True
|
||||
|
||||
# Save FX graph after transformations
|
||||
fx_graph_transformed: bool = True
|
||||
fx_graph_transformed = True
|
||||
|
||||
# Save TorchInductor IR before fusion pass
|
||||
ir_pre_fusion: bool = True
|
||||
ir_pre_fusion = True
|
||||
|
||||
# Save TorchInductor IR after fusion pass
|
||||
ir_post_fusion: bool = True
|
||||
ir_post_fusion = True
|
||||
|
||||
# Copy generated code to trace dir
|
||||
output_code: bool = True
|
||||
output_code = True
|
||||
|
||||
# SVG figure showing post-fusion graph
|
||||
graph_diagram: bool = False
|
||||
graph_diagram = False
|
||||
|
||||
# Store cProfile (see snakeviz to view)
|
||||
compile_profile: bool = False
|
||||
compile_profile = False
|
||||
|
||||
# Upload the .tar.gz file
|
||||
# Needs to be overriden based on specific environment needs
|
||||
# skip_pickle will make this field skipped by pickle
|
||||
upload_tar: Any = dataclasses.field(default=None, metadata={"skip_pickle": True})
|
||||
upload_tar = None
|
||||
|
||||
|
||||
cpp: CppConfig = CppConfig()
|
||||
triton: TritonConfig = TritonConfig()
|
||||
trace: TraceConfig = TraceConfig()
|
||||
_save_config_ignore = {
|
||||
# workaround: "Can't pickle <function ...>"
|
||||
"trace.upload_tar",
|
||||
}
|
||||
|
||||
from torch._config_utils import install_config_module
|
||||
|
||||
install_config_module("InductorConfig", sys.modules[__name__])
|
||||
from .._dynamo.config_utils import install_config_module
|
||||
|
||||
# adds patch, save_config, etc
|
||||
install_config_module(sys.modules[__name__])
|
||||
|
Reference in New Issue
Block a user