Add compiler bisector (#131936)

This is a utility to aid the torch.compile debugging. You provide a function that returns True on success, False on failure, or do something out of process and run bisect_helper `good | bad`.

The bisector will first go through backends - `eager`, `aot_eager`, `aot_eager_decomp_partition`, `inductor` to find the first failing backend. Then, it will go through subsystems within the backend - currently limited but could be expanded - and try to find the first subsystem for which disabling fixes the problem. Once it has found the failing subsystem, it will find the number of times the subsystem is applied, and then bisect through it.

An example usage of how to hook it up for aot_eager_decomp_partition and decomposition subsystem is :

```
    from torch._inductor.bisect_helper import BisectionManager
    if op in CURRENT_DECOMPOSITION_TABLE:
        if BisectionManager.disable_subsystem("aot_eager_decomp_partition", "decomposition", lambda: repr(op)):
            return NotImplemented
```

Once it has discovered the problematic change, it will print out the associated debug info, and you can set the same limits with `TORCH_BISECT_BACKEND` `TORCH_BISECT_SUBSYSTEM` and `TORCH_BISECT_MAX`.

We could add further options as an automated way of going through a check list for checking divergence - e.g., the mode to emulate amp casts.

Fix for https://github.com/pytorch/pytorch/issues/126546

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131936
Approved by: https://github.com/ezyang
This commit is contained in:
eellison
2024-10-09 10:10:52 -07:00
committed by PyTorch MergeBot
parent cfe970260a
commit 47af7cc962
10 changed files with 685 additions and 41 deletions

View File

@ -0,0 +1,112 @@
# Owner(s): ["module: dynamo"]
import unittest
from contextlib import contextmanager
from importlib import import_module
import torch
import torch._prims_common as utils
from torch._dynamo.test_case import TestCase
from torch._inductor import config
from torch._inductor.bisect_helper import BisectionManager
from torch.testing._internal.inductor_utils import HAS_CUDA
aten = torch.ops.aten
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
f32 = torch.float32
i64 = torch.int64
i32 = torch.int32
@requires_cuda
class TestCompilerBisector(TestCase):
def test_bad_decomp(self):
mod = import_module("torch._inductor.compile_fx")
def bad_exp_decomp(self, rate=1, generator=None):
assert generator is None
torch._check(
not utils.is_complex_dtype(self.dtype)
and not utils.is_integer_dtype(self.dtype)
and not utils.is_boolean_dtype(self.dtype),
lambda: f"Exponential distribution is a continuous probability distribution. \
dtype must be a floating point but you specified {self.dtype}",
)
torch._check(
rate > 0.0,
lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}",
)
return torch.rand_like(self) * float("nan")
@contextmanager
def patch_exp_decomp():
from torch._inductor.compile_fx import select_decomp_table as old_decomp
def get_decomp():
out = old_decomp()
out = out.copy()
out[aten.exponential.default] = bad_exp_decomp
return out
torch._inductor.compile_fx.select_decomp_table = get_decomp
try:
yield
finally:
torch._inductor.compile_fx.select_decomp_table = old_decomp
def vq(x):
return (x + 3).exponential_() * 10.5
def test_fn():
torch._dynamo.reset()
with patch_exp_decomp():
vq_compiled = torch.compile(vq)
x = torch.randn(4, 400, 256).cuda()
with torch._dynamo.utils.preserve_rng_state():
out = vq(x)
out_compiled = vq_compiled(x)
return not out_compiled.isnan().any()
out = BisectionManager.do_bisect(test_fn)
self.assertEqual(out.backend, "aot_eager_decomp_partition")
self.assertEqual(out.subsystem, "decomposition")
self.assertEqual(out.bisect_number, 1)
self.assertTrue("aten.exponential" in out.debug_info)
def test_bad_lowering(self):
def test_fn():
torch._dynamo.reset()
with config.patch("triton.inject_relu_bug_TESTING_ONLY", "accuracy"):
def my_func(x):
return ((x * -1) - 0.01).relu()
inp = torch.rand([100], device="cuda")
return torch.allclose(torch.compile(my_func)(inp), my_func(inp))
out = BisectionManager.do_bisect(test_fn)
self.assertEqual(out.backend, "inductor")
self.assertEqual(out.subsystem, "lowerings")
self.assertEqual(out.bisect_number, 2)
self.assertTrue("relu" in out.debug_info)
def test_eager_backend(self):
# should indicate problem with first backend
def test_fn():
return False
out = BisectionManager.do_bisect(test_fn)
self.assertEqual(out.backend, "eager")
self.assertEqual(out.subsystem, None)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -220,6 +220,7 @@ RUN_PARALLEL_BLOCKLIST = [
"test_cuda_nvml_based_avail",
# temporarily sets a global config
"test_autograd_fallback",
"inductor/test_compiler_bisector",
] + FSDP_TEST
# Test files that should always be run serially with other test files,

View File

@ -2446,6 +2446,12 @@ def compile(
)
if mode is None and options is None:
mode = "default"
from torch._inductor.bisect_helper import BisectionManager
if bisect_backend := BisectionManager.get_backend():
backend = bisect_backend
if backend == "inductor":
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
else:

View File

@ -0,0 +1,470 @@
import collections
import dataclasses
import functools
import os
import shutil
import sys
from typing import Callable, Dict, List, Optional, Tuple
from torch._inductor.runtime.cache_dir_utils import cache_dir
# Set the subdirectory name
SUBDIR_NAME = "bisect"
# Dictionary of backend -> subsystems
BACKENDS: Dict[str, List[str]] = {
# run dynamo without aot_autograd
"eager": [],
# run dynamo with aot_autograd, but no partitioner or decomps
"aot_eager": [],
# run dynamo with aot autograd, decompositions and partitioner
"aot_eager_decomp_partition": [
"decomposition" # number of decompositions we apply in tracing
], # TODO - add cse ?
"inductor": [
"post_grad_passes", # passes applied individually on forward, and backward in inductor
"lowerings", # lowering aten operators to inductor
], # TODO - add more - fusions, amp numeric mode ?
}
subsystem_call_counter: Dict[str, int] = collections.Counter()
call_counter_debug_info: Dict[int, str] = {}
def reset_counters() -> None:
subsystem_call_counter.clear()
call_counter_debug_info.clear()
@functools.lru_cache(None)
def get_env_val(env_str: str) -> Optional[str]:
return os.environ.get(env_str, None)
@dataclasses.dataclass
class BisectionResult:
"""
backend: torch.compile backend responsible for failure
subsystem: optional, registered component identified for failure
bisect_number: optional, number of times the subsystem needed to be applied to trigger failure
debug_info: associated info of the triggering bisect application of subsystem
"""
backend: str
subsystem: Optional[str] = None
bisect_number: Optional[int] = None
debug_info: Optional[str] = None
class BisectionManager:
bisection_enabled: bool = False
@classmethod
def get_dir(cls) -> str:
return f"{cache_dir()}/{SUBDIR_NAME}"
@classmethod
def write_lines_to_file(cls, file_path: str, lines: List[str]) -> None:
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w") as file:
file.writelines(lines)
@classmethod
def read_lines_from_file(cls, file_path: str) -> List[str]:
if os.path.exists(file_path):
with open(file_path) as file:
return file.readlines()
return []
@classmethod
def update_run_state(
cls, backend_name: str, subsystem_name: str, run_state: str
) -> None:
file_path = os.path.join(
cls.get_dir(), backend_name, f"{subsystem_name}_run_state.txt"
)
cls.write_lines_to_file(file_path, [run_state])
@classmethod
def update_bisect_status(cls, backend_name: str, subsystem_name: str) -> None:
file_path = os.path.join(cls.get_dir(), "bisect_status.txt")
lines = [f"backend={backend_name}\n", f"subsystem={subsystem_name}\n"]
cls.write_lines_to_file(file_path, lines)
@classmethod
def update_bisect_range(
cls, backend_name: str, subsystem_name: str, low: int, high: int
) -> None:
file_path = os.path.join(
cls.get_dir(), backend_name, f"{subsystem_name}_bisect_range.txt"
)
lines = [f"low={low}\n", f"high={high}\n"]
cls.write_lines_to_file(file_path, lines)
@classmethod
def get_backend(cls) -> Optional[str]:
"""
Returns the active backend, if any
"""
if val := get_env_val("TORCH_BISECT_BACKEND"):
return val
file_path = os.path.join(cls.get_dir(), "bisect_status.txt")
lines = cls.read_lines_from_file(file_path)
for line in lines:
if line.startswith("backend="):
return line.strip().split("=")[1]
return None
@classmethod
def get_subsystem(cls) -> Optional[str]:
"""
Returns the active subsystem, if any
"""
if val := get_env_val("TORCH_BISECT_SUBSYSTEM"):
return val
file_path = os.path.join(cls.get_dir(), "bisect_status.txt")
lines = cls.read_lines_from_file(file_path)
for line in lines:
if line.startswith("subsystem="):
return line.strip().split("=")[1]
return None
@classmethod
def get_run_state(cls, backend_name: str, subsystem_name: str) -> Optional[str]:
"""
Returns the current stage of bisecting, if Any
"""
file_path = os.path.join(
cls.get_dir(), backend_name, f"{subsystem_name}_run_state.txt"
)
lines = cls.read_lines_from_file(file_path)
if lines:
out = lines[0].strip()
assert out in ("test_disable", "find_max_bounds", "bisect")
return out
return None
@classmethod
def get_bisect_range(
cls, backend_name: str, subsystem_name: str
) -> Tuple[int, int]:
file_path = os.path.join(
cls.get_dir(), backend_name, f"{subsystem_name}_bisect_range.txt"
)
lines = cls.read_lines_from_file(file_path)
low = None
high = None
for line in reversed(lines):
if line.startswith("low="):
low = int(line.strip().split("=")[1])
elif line.startswith("high="):
high = int(line.strip().split("=")[1])
if low is not None and high is not None:
break
if low is None or high is None:
raise RuntimeError(
f"Trying to get bisect range when it is not set: subsystem {subsystem_name}"
)
return low, high
@classmethod
def delete_bisect_status(cls) -> None:
if os.path.exists(cls.get_dir()):
shutil.rmtree(cls.get_dir())
print("Bisection status deleted.")
else:
print("No bisection status found.")
@classmethod
def get_system_counter(cls, name: str, increment: bool = True) -> int:
global subsystem_call_counter
curr = subsystem_call_counter[name]
if increment:
subsystem_call_counter[name] += 1
return curr
@classmethod
def disable_subsystem(
cls,
backend: str,
subsystem: str,
debug_info: Optional[Callable[[], str]] = None,
) -> bool:
if not cls.bisection_enabled:
return False
if cls.get_backend() != backend:
return False
if cls.get_subsystem() != subsystem:
return False
if val := get_env_val("TORCH_BISECT_MAX"):
counter = cls.get_system_counter(subsystem, increment=True)
return counter > int(val)
run_state = cls.get_run_state(backend, subsystem)
if run_state == "test_disable":
# First run, disable completely
return True
elif run_state == "find_max_bounds":
# Second run, update bisection range and return True to enable the subsystem
cls.update_bisect_range(
backend,
subsystem,
0,
cls.get_system_counter(subsystem, increment=True),
)
return False
else:
assert run_state == "bisect"
# If the environment variable is not set, use the bisection range midpoint
low, high = cls.get_bisect_range(backend, subsystem)
# if high - low <= 2:
midpoint = (low + high) // 2
call_counter = cls.get_system_counter(subsystem)
if (
call_counter >= low
and call_counter <= high
and (low - high) <= 2
and debug_info is not None
):
call_counter_debug_info[call_counter] = debug_info()
return call_counter > midpoint
@classmethod
def advance_subsystem(cls, curr_backend: str, curr_subsystem: str) -> Optional[str]:
"""
Tries to move to the next subsystem within the current system.
"""
print(f"Disabling {curr_subsystem} did not fix the issue.")
current_subsystems = BACKENDS[curr_backend]
current_subsystem_index = current_subsystems.index(curr_subsystem)
if current_subsystem_index < len(current_subsystems) - 1:
curr_subsystem = current_subsystems[current_subsystem_index + 1]
cls.update_bisect_status(curr_backend, curr_subsystem)
cls.update_run_state(curr_backend, curr_subsystem, "test_disable")
print(f"Moving to the next subsystem: {curr_backend} - {curr_subsystem}")
return curr_subsystem
else:
print(
f"All subsystems in {curr_backend} have been checked. The issue is not in this system."
)
return None
@classmethod
def advance_backend(cls, curr_backend: str) -> Optional[str]:
"""
Tries Move to the next backend.
"""
current_system_index = list(BACKENDS.keys()).index(curr_backend)
if current_system_index < len(BACKENDS) - 1:
curr_backend = list(BACKENDS.keys())[current_system_index + 1]
cls.update_bisect_status(curr_backend, "")
print(f"Moving to the next system: {curr_backend}")
return curr_backend
else:
return None
@classmethod
def perform_bisection(
cls,
curr_backend: str,
curr_subsystem: str,
fn: Callable[[], bool],
cli_interface: bool = True,
) -> bool:
"""
Perform the bisection process for the current system and subsystem. Returns True if the issue is found, False otherwise.
"""
while True:
run_state = cls.get_run_state(curr_backend, curr_subsystem)
reset_counters()
if run_state == "test_disable":
if not fn():
next_subsystem = cls.advance_subsystem(curr_backend, curr_subsystem)
if not next_subsystem:
return False
curr_subsystem = next_subsystem
else:
# breakpoint()
print(
f"Disabling {curr_subsystem} fixed the issue. Starting bisect by getting upper bound."
)
cls.update_run_state(
curr_backend, curr_subsystem, "find_max_bounds"
)
elif run_state == "find_max_bounds":
if fn():
raise RuntimeError(
f"Function succeeded with 'find_max_bounds' status for {curr_backend} - {curr_subsystem}."
)
else:
_, high = cls.get_bisect_range(curr_backend, curr_subsystem)
print(f"Upper bound of {high} found for {curr_backend}.")
cls.update_run_state(curr_backend, curr_subsystem, "bisect")
elif run_state == "bisect":
low, high = cls.get_bisect_range(curr_backend, curr_subsystem)
midpoint = (low + high) // 2
print(
f"Bisecting {curr_backend} - {curr_subsystem} (Range: [{low}, {high}], Midpoint: {midpoint})"
)
if fn():
cls.update_bisect_range(
curr_backend, curr_subsystem, midpoint + 1, high
)
else:
cls.update_bisect_range(curr_backend, curr_subsystem, low, midpoint)
low, high = cls.get_bisect_range(curr_backend, curr_subsystem)
if low == high:
print(
f"Binary search completed for {curr_backend} - {curr_subsystem}. The bisect number is {low}. "
f"Debug info: {call_counter_debug_info.get(low, 'not found')}"
)
return True
else:
raise RuntimeError(f"Unexpected run_state {run_state}")
if cli_interface:
sys.exit(0)
@classmethod
def initialize_system(cls) -> None:
curr_backend = next(iter(BACKENDS.keys()))
curr_subsystem = ""
cls.update_bisect_status(curr_backend, curr_subsystem)
print(f"Starting bisection process with system: {curr_backend}")
@classmethod
def do_bisect(
cls, fn: Callable[[], bool], cli_interface: bool = False
) -> Optional[BisectionResult]:
if not cli_interface:
bisection_enabled_orig = cls.bisection_enabled
cls.delete_bisect_status()
cls.bisection_enabled = True
# TODO - cli interface, and in-process different directories
class DisableBisect:
def __del__(self) -> None:
cls.bisection_enabled = bisection_enabled_orig
cls.delete_bisect_status()
cleanup = DisableBisect()
curr_backend = cls.get_backend()
curr_subsystem = cls.get_subsystem()
if not curr_backend:
cls.initialize_system()
curr_backend = cls.get_backend()
curr_subsystem = cls.get_subsystem()
while True:
assert curr_backend is not None
reset_counters()
if curr_subsystem:
result = cls.perform_bisection(
curr_backend, curr_subsystem, fn, cli_interface=cli_interface
)
if result:
curr_subsystem = cls.get_subsystem()
assert curr_subsystem is not None
low, _ = cls.get_bisect_range(curr_backend, curr_subsystem)
return BisectionResult(
curr_backend,
curr_subsystem,
low,
call_counter_debug_info.get(low, None),
)
next_subsystem = cls.advance_subsystem(curr_backend, curr_subsystem)
if not next_subsystem:
print(
f"The issue is in the {curr_backend} system, but could not identify subsystem."
)
assert curr_backend is not None
return BisectionResult(curr_backend)
curr_subsystem = next_subsystem
else:
if fn():
next_backend = cls.advance_backend(curr_backend)
if not next_backend:
print("All systems have been checked.")
return None
curr_backend = next_backend
else:
current_subsystems = BACKENDS[curr_backend]
if current_subsystems:
curr_subsystem = current_subsystems[0]
cls.update_bisect_status(curr_backend, curr_subsystem)
cls.update_run_state(
curr_backend, curr_subsystem, "test_disable"
)
print(
f"The issue is in the {curr_backend} system. Moving to the first subsystem: {curr_subsystem}"
)
else:
print(f"The issue is in the {curr_backend} system.")
return BisectionResult(curr_backend)
if cli_interface:
sys.exit(0)
def command_line_usage() -> None:
if len(sys.argv) < 2:
print("Usage: python bisect_update.py <start|end|good|bad>")
sys.exit(1)
bisection_manager = BisectionManager()
command = sys.argv[1]
if command == "end":
bisection_manager.delete_bisect_status()
sys.exit(0)
if command == "start":
bisection_manager.delete_bisect_status()
bisection_manager.initialize_system()
sys.exit(0)
if command not in ["good", "bad"]:
print("Invalid command. Must be 'good', 'bad', 'start', or 'end'.")
sys.exit(1)
def test_function() -> bool:
return command == "good"
if not bisection_manager.get_backend():
raise ValueError("Must call start prior to good or bad")
bisection_manager.do_bisect(test_function, cli_interface=True)
def get_is_bisection_enabled() -> bool:
return (
BisectionManager.get_subsystem() is not None
or BisectionManager.get_backend() is not None
)
BisectionManager.bisection_enabled = get_is_bisection_enabled()
if __name__ == "__main__":
command_line_usage()

View File

@ -1285,6 +1285,12 @@ class FxGraphCache:
"Freezing may introduce constants that aren't static across runs"
)
from torch._inductor.bisect_helper import BisectionManager
if BisectionManager.bisection_enabled:
log.debug("dont cache graph when bisect enabled")
raise BypassFxGraphCache
# The treatment of guards in the caching implementation requires that
# we have a shape env.
if FxGraphCache._get_shape_env() is None:

View File

@ -4,7 +4,7 @@ import itertools
import logging
import operator
from collections import Counter, defaultdict
from typing import Any, Dict, List, Optional, Set
from typing import Any, Callable, Dict, List, Optional, Set
import torch
import torch._inductor as inductor
@ -65,6 +65,19 @@ pass_patterns = [
]
def apply_pass(pass_fn: Callable[[], object], name: Optional[str] = None) -> None:
# TODO - we should just make this part of GraphTransformObserver
from torch._inductor.bisect_helper import BisectionManager
debug_info: Optional[Callable[[], str]] = None
if name is not None:
debug_info = lambda: name # noqa: E731
if BisectionManager.disable_subsystem("inductor", "post_grad_passes", debug_info):
return
pass_fn()
def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
"""
Passes that run on after grad. This is called once on the forwards
@ -80,23 +93,28 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
gm.graph.eliminate_dead_code()
if is_inference and config.reorder_for_locality:
reorder_for_locality(gm.graph)
apply_pass(lambda: reorder_for_locality(gm.graph), "reorder_for_locality")
fake_tensor_updater = FakeTensorUpdater(gm.graph)
if config.post_grad_custom_pre_pass is not None:
if post_grad_custom_pre_pass := config.post_grad_custom_pre_pass:
with GraphTransformObserver(
gm, "post_grad_custom_pre_pass", config.trace.log_url_for_graph_xform
):
config.post_grad_custom_pre_pass(gm.graph)
apply_pass(
lambda: post_grad_custom_pre_pass(gm.graph), "post_grad_custom_pre_pass"
)
if config.pattern_matcher:
lazy_init()
optimus_scuba_log["before_recompile_post_grad"] = upload_graph(gm.graph)
group_batch_fusion_passes(gm.graph, pre_grad=False)
remove_noop_ops(gm.graph)
for patterns in pass_patterns:
patterns.apply(gm.graph) # type: ignore[arg-type]
apply_pass(
lambda: group_batch_fusion_passes(gm.graph, pre_grad=False),
"group_batch_fusion_passes",
)
apply_pass(lambda: remove_noop_ops(gm.graph), "remove_noop_ops")
for i, patterns in enumerate(pass_patterns):
apply_pass(lambda: patterns.apply(gm.graph), f"pass_pattern_{i}") # type: ignore[arg-type]
for pass_name in config.post_grad_fusion_options:
# skip all patterns for group batch fusions
if pass_name in POST_GRAD_FUSIONS:
@ -105,7 +123,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
inductor_before_change = save_inductor_dict(
[pattern_matcher_pass.pass_name]
)
pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type]
apply_pass(lambda: pattern_matcher_pass.apply(gm.graph), pass_name) # type: ignore[arg-type]
if not is_same_dict(counters["inductor"], inductor_before_change):
optimus_scuba_log[
f"{pattern_matcher_pass.pass_name}_post_grad"
@ -117,30 +135,40 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
micro_pipeline_tp_pass(gm.graph)
if config._fuse_ddp_communication:
fuse_ddp_communication(
apply_pass(
lambda: fuse_ddp_communication(
gm.graph,
config._fuse_ddp_communication_passes,
config._fuse_ddp_bucket_size,
),
"fuse_ddp_communication",
)
if config.post_grad_custom_post_pass is not None:
if post_grad_custom_post_pass := config.post_grad_custom_post_pass:
with GraphTransformObserver(
gm, "post_grad_custom_post_pass", config.trace.log_url_for_graph_xform
):
config.post_grad_custom_post_pass(gm.graph)
apply_pass(
lambda: post_grad_custom_post_pass(gm.graph),
"post_grad_custom_post_pass",
)
stable_topological_sort(gm.graph)
apply_pass(lambda: stable_topological_sort(gm.graph), "stable_sort")
move_constructors_to_gpu(gm.graph)
apply_pass(lambda: move_constructors_to_gpu(gm.graph), "move_constructors_to_cuda")
fake_tensor_updater.incremental_update()
# Keep these last, since they introduces mutation. Look at
# ./fx_passes/README.md for a discussion of mutation invariants.
reinplace_inplaceable_ops(gm.graph)
decompose_auto_functionalized(gm.graph)
apply_pass(lambda: reinplace_inplaceable_ops(gm.graph), "reinplace_inplaceable_ops")
apply_pass(
lambda: decompose_auto_functionalized(gm.graph), "decompose_auto_functionalized"
)
comms.reinplace_fsdp_all_gather(gm.graph)
apply_pass(
lambda: comms.reinplace_fsdp_all_gather(gm.graph), "reinplace_fsdp_all_gather"
)
gm.recompile()
optimus_scuba_log["after_recompile_post_grad"] = upload_graph(gm.graph)

View File

@ -1304,6 +1304,8 @@ class GraphLowering(torch.fx.Interpreter):
def debug(msg: str) -> None:
log.debug("lowering %s %s", LazyString(n.format_node), msg)
from torch._inductor.bisect_helper import BisectionManager
buffer_watermark = len(self.buffers)
operation_watermark = len(self.operations)
@ -1320,7 +1322,12 @@ class GraphLowering(torch.fx.Interpreter):
if (
n.op == "call_function"
and n.target is not operator.getitem
and fallback_node_due_to_unsupported_type(n)
and (
fallback_node_due_to_unsupported_type(n)
or BisectionManager.disable_subsystem(
"inductor", "lowerings", lambda: repr(n)
)
)
):
debug("fallback_handler")
result = fallback_handler(n.target, add_to_fallback_set=False)(

View File

@ -0,0 +1,23 @@
import getpass
import os
import re
import tempfile
# Factoring out to file without torch dependencies
def cache_dir() -> str:
cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR")
if cache_dir is None:
os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir = default_cache_dir()
os.makedirs(cache_dir, exist_ok=True)
return cache_dir
def default_cache_dir() -> str:
sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser())
return os.path.join(
tempfile.gettempdir(),
"torchinductor_" + sanitized_username,
)

View File

@ -3,13 +3,13 @@ from __future__ import annotations
import contextlib
import functools
import getpass
import operator
import os
import re
import tempfile
import torch
from torch._inductor.runtime.cache_dir_utils import ( # noqa: F401
cache_dir,
default_cache_dir,
)
def conditional_product(*args):
@ -86,22 +86,6 @@ def get_max_y_grid():
return 65535
def cache_dir() -> str:
cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR")
if cache_dir is None:
os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir = default_cache_dir()
os.makedirs(cache_dir, exist_ok=True)
return cache_dir
def default_cache_dir():
sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser())
return os.path.join(
tempfile.gettempdir(),
"torchinductor_" + sanitized_username,
)
try:
import colorama

View File

@ -2202,7 +2202,14 @@ def maybe_handle_decomp(
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> object:
from torch._inductor.bisect_helper import BisectionManager
if op in CURRENT_DECOMPOSITION_TABLE:
if BisectionManager.disable_subsystem(
"aot_eager_decomp_partition", "decomposition", lambda: repr(op)
):
return NotImplemented
with proxy_mode:
proxy_mode.decomp_layers += 1
out = CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs)