mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add compiler bisector (#131936)
This is a utility to aid the torch.compile debugging. You provide a function that returns True on success, False on failure, or do something out of process and run bisect_helper `good | bad`.
The bisector will first go through backends - `eager`, `aot_eager`, `aot_eager_decomp_partition`, `inductor` to find the first failing backend. Then, it will go through subsystems within the backend - currently limited but could be expanded - and try to find the first subsystem for which disabling fixes the problem. Once it has found the failing subsystem, it will find the number of times the subsystem is applied, and then bisect through it.
An example usage of how to hook it up for aot_eager_decomp_partition and decomposition subsystem is :
```
from torch._inductor.bisect_helper import BisectionManager
if op in CURRENT_DECOMPOSITION_TABLE:
if BisectionManager.disable_subsystem("aot_eager_decomp_partition", "decomposition", lambda: repr(op)):
return NotImplemented
```
Once it has discovered the problematic change, it will print out the associated debug info, and you can set the same limits with `TORCH_BISECT_BACKEND` `TORCH_BISECT_SUBSYSTEM` and `TORCH_BISECT_MAX`.
We could add further options as an automated way of going through a check list for checking divergence - e.g., the mode to emulate amp casts.
Fix for https://github.com/pytorch/pytorch/issues/126546
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131936
Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
cfe970260a
commit
47af7cc962
112
test/dynamo/test_compiler_bisector.py
Normal file
112
test/dynamo/test_compiler_bisector.py
Normal file
@ -0,0 +1,112 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
from importlib import import_module
|
||||
|
||||
import torch
|
||||
import torch._prims_common as utils
|
||||
from torch._dynamo.test_case import TestCase
|
||||
from torch._inductor import config
|
||||
from torch._inductor.bisect_helper import BisectionManager
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
|
||||
|
||||
f32 = torch.float32
|
||||
i64 = torch.int64
|
||||
i32 = torch.int32
|
||||
|
||||
|
||||
@requires_cuda
|
||||
class TestCompilerBisector(TestCase):
|
||||
def test_bad_decomp(self):
|
||||
mod = import_module("torch._inductor.compile_fx")
|
||||
|
||||
def bad_exp_decomp(self, rate=1, generator=None):
|
||||
assert generator is None
|
||||
torch._check(
|
||||
not utils.is_complex_dtype(self.dtype)
|
||||
and not utils.is_integer_dtype(self.dtype)
|
||||
and not utils.is_boolean_dtype(self.dtype),
|
||||
lambda: f"Exponential distribution is a continuous probability distribution. \
|
||||
dtype must be a floating point but you specified {self.dtype}",
|
||||
)
|
||||
torch._check(
|
||||
rate > 0.0,
|
||||
lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}",
|
||||
)
|
||||
return torch.rand_like(self) * float("nan")
|
||||
|
||||
@contextmanager
|
||||
def patch_exp_decomp():
|
||||
from torch._inductor.compile_fx import select_decomp_table as old_decomp
|
||||
|
||||
def get_decomp():
|
||||
out = old_decomp()
|
||||
out = out.copy()
|
||||
out[aten.exponential.default] = bad_exp_decomp
|
||||
return out
|
||||
|
||||
torch._inductor.compile_fx.select_decomp_table = get_decomp
|
||||
try:
|
||||
yield
|
||||
|
||||
finally:
|
||||
torch._inductor.compile_fx.select_decomp_table = old_decomp
|
||||
|
||||
def vq(x):
|
||||
return (x + 3).exponential_() * 10.5
|
||||
|
||||
def test_fn():
|
||||
torch._dynamo.reset()
|
||||
with patch_exp_decomp():
|
||||
vq_compiled = torch.compile(vq)
|
||||
x = torch.randn(4, 400, 256).cuda()
|
||||
with torch._dynamo.utils.preserve_rng_state():
|
||||
out = vq(x)
|
||||
out_compiled = vq_compiled(x)
|
||||
|
||||
return not out_compiled.isnan().any()
|
||||
|
||||
out = BisectionManager.do_bisect(test_fn)
|
||||
self.assertEqual(out.backend, "aot_eager_decomp_partition")
|
||||
self.assertEqual(out.subsystem, "decomposition")
|
||||
self.assertEqual(out.bisect_number, 1)
|
||||
self.assertTrue("aten.exponential" in out.debug_info)
|
||||
|
||||
def test_bad_lowering(self):
|
||||
def test_fn():
|
||||
torch._dynamo.reset()
|
||||
with config.patch("triton.inject_relu_bug_TESTING_ONLY", "accuracy"):
|
||||
|
||||
def my_func(x):
|
||||
return ((x * -1) - 0.01).relu()
|
||||
|
||||
inp = torch.rand([100], device="cuda")
|
||||
|
||||
return torch.allclose(torch.compile(my_func)(inp), my_func(inp))
|
||||
|
||||
out = BisectionManager.do_bisect(test_fn)
|
||||
self.assertEqual(out.backend, "inductor")
|
||||
self.assertEqual(out.subsystem, "lowerings")
|
||||
self.assertEqual(out.bisect_number, 2)
|
||||
self.assertTrue("relu" in out.debug_info)
|
||||
|
||||
def test_eager_backend(self):
|
||||
# should indicate problem with first backend
|
||||
def test_fn():
|
||||
return False
|
||||
|
||||
out = BisectionManager.do_bisect(test_fn)
|
||||
self.assertEqual(out.backend, "eager")
|
||||
self.assertEqual(out.subsystem, None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
@ -220,6 +220,7 @@ RUN_PARALLEL_BLOCKLIST = [
|
||||
"test_cuda_nvml_based_avail",
|
||||
# temporarily sets a global config
|
||||
"test_autograd_fallback",
|
||||
"inductor/test_compiler_bisector",
|
||||
] + FSDP_TEST
|
||||
|
||||
# Test files that should always be run serially with other test files,
|
||||
|
||||
@ -2446,6 +2446,12 @@ def compile(
|
||||
)
|
||||
if mode is None and options is None:
|
||||
mode = "default"
|
||||
|
||||
from torch._inductor.bisect_helper import BisectionManager
|
||||
|
||||
if bisect_backend := BisectionManager.get_backend():
|
||||
backend = bisect_backend
|
||||
|
||||
if backend == "inductor":
|
||||
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
|
||||
else:
|
||||
|
||||
470
torch/_inductor/bisect_helper.py
Normal file
470
torch/_inductor/bisect_helper.py
Normal file
@ -0,0 +1,470 @@
|
||||
import collections
|
||||
import dataclasses
|
||||
import functools
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from torch._inductor.runtime.cache_dir_utils import cache_dir
|
||||
|
||||
|
||||
# Set the subdirectory name
|
||||
SUBDIR_NAME = "bisect"
|
||||
|
||||
# Dictionary of backend -> subsystems
|
||||
BACKENDS: Dict[str, List[str]] = {
|
||||
# run dynamo without aot_autograd
|
||||
"eager": [],
|
||||
# run dynamo with aot_autograd, but no partitioner or decomps
|
||||
"aot_eager": [],
|
||||
# run dynamo with aot autograd, decompositions and partitioner
|
||||
"aot_eager_decomp_partition": [
|
||||
"decomposition" # number of decompositions we apply in tracing
|
||||
], # TODO - add cse ?
|
||||
"inductor": [
|
||||
"post_grad_passes", # passes applied individually on forward, and backward in inductor
|
||||
"lowerings", # lowering aten operators to inductor
|
||||
], # TODO - add more - fusions, amp numeric mode ?
|
||||
}
|
||||
|
||||
subsystem_call_counter: Dict[str, int] = collections.Counter()
|
||||
call_counter_debug_info: Dict[int, str] = {}
|
||||
|
||||
|
||||
def reset_counters() -> None:
|
||||
subsystem_call_counter.clear()
|
||||
call_counter_debug_info.clear()
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_env_val(env_str: str) -> Optional[str]:
|
||||
return os.environ.get(env_str, None)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BisectionResult:
|
||||
"""
|
||||
backend: torch.compile backend responsible for failure
|
||||
subsystem: optional, registered component identified for failure
|
||||
bisect_number: optional, number of times the subsystem needed to be applied to trigger failure
|
||||
debug_info: associated info of the triggering bisect application of subsystem
|
||||
"""
|
||||
|
||||
backend: str
|
||||
subsystem: Optional[str] = None
|
||||
bisect_number: Optional[int] = None
|
||||
debug_info: Optional[str] = None
|
||||
|
||||
|
||||
class BisectionManager:
|
||||
bisection_enabled: bool = False
|
||||
|
||||
@classmethod
|
||||
def get_dir(cls) -> str:
|
||||
return f"{cache_dir()}/{SUBDIR_NAME}"
|
||||
|
||||
@classmethod
|
||||
def write_lines_to_file(cls, file_path: str, lines: List[str]) -> None:
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
with open(file_path, "w") as file:
|
||||
file.writelines(lines)
|
||||
|
||||
@classmethod
|
||||
def read_lines_from_file(cls, file_path: str) -> List[str]:
|
||||
if os.path.exists(file_path):
|
||||
with open(file_path) as file:
|
||||
return file.readlines()
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def update_run_state(
|
||||
cls, backend_name: str, subsystem_name: str, run_state: str
|
||||
) -> None:
|
||||
file_path = os.path.join(
|
||||
cls.get_dir(), backend_name, f"{subsystem_name}_run_state.txt"
|
||||
)
|
||||
cls.write_lines_to_file(file_path, [run_state])
|
||||
|
||||
@classmethod
|
||||
def update_bisect_status(cls, backend_name: str, subsystem_name: str) -> None:
|
||||
file_path = os.path.join(cls.get_dir(), "bisect_status.txt")
|
||||
lines = [f"backend={backend_name}\n", f"subsystem={subsystem_name}\n"]
|
||||
cls.write_lines_to_file(file_path, lines)
|
||||
|
||||
@classmethod
|
||||
def update_bisect_range(
|
||||
cls, backend_name: str, subsystem_name: str, low: int, high: int
|
||||
) -> None:
|
||||
file_path = os.path.join(
|
||||
cls.get_dir(), backend_name, f"{subsystem_name}_bisect_range.txt"
|
||||
)
|
||||
lines = [f"low={low}\n", f"high={high}\n"]
|
||||
cls.write_lines_to_file(file_path, lines)
|
||||
|
||||
@classmethod
|
||||
def get_backend(cls) -> Optional[str]:
|
||||
"""
|
||||
Returns the active backend, if any
|
||||
"""
|
||||
if val := get_env_val("TORCH_BISECT_BACKEND"):
|
||||
return val
|
||||
|
||||
file_path = os.path.join(cls.get_dir(), "bisect_status.txt")
|
||||
lines = cls.read_lines_from_file(file_path)
|
||||
for line in lines:
|
||||
if line.startswith("backend="):
|
||||
return line.strip().split("=")[1]
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_subsystem(cls) -> Optional[str]:
|
||||
"""
|
||||
Returns the active subsystem, if any
|
||||
"""
|
||||
|
||||
if val := get_env_val("TORCH_BISECT_SUBSYSTEM"):
|
||||
return val
|
||||
|
||||
file_path = os.path.join(cls.get_dir(), "bisect_status.txt")
|
||||
lines = cls.read_lines_from_file(file_path)
|
||||
for line in lines:
|
||||
if line.startswith("subsystem="):
|
||||
return line.strip().split("=")[1]
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_run_state(cls, backend_name: str, subsystem_name: str) -> Optional[str]:
|
||||
"""
|
||||
Returns the current stage of bisecting, if Any
|
||||
"""
|
||||
|
||||
file_path = os.path.join(
|
||||
cls.get_dir(), backend_name, f"{subsystem_name}_run_state.txt"
|
||||
)
|
||||
lines = cls.read_lines_from_file(file_path)
|
||||
if lines:
|
||||
out = lines[0].strip()
|
||||
assert out in ("test_disable", "find_max_bounds", "bisect")
|
||||
return out
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_bisect_range(
|
||||
cls, backend_name: str, subsystem_name: str
|
||||
) -> Tuple[int, int]:
|
||||
file_path = os.path.join(
|
||||
cls.get_dir(), backend_name, f"{subsystem_name}_bisect_range.txt"
|
||||
)
|
||||
lines = cls.read_lines_from_file(file_path)
|
||||
low = None
|
||||
high = None
|
||||
for line in reversed(lines):
|
||||
if line.startswith("low="):
|
||||
low = int(line.strip().split("=")[1])
|
||||
elif line.startswith("high="):
|
||||
high = int(line.strip().split("=")[1])
|
||||
|
||||
if low is not None and high is not None:
|
||||
break
|
||||
|
||||
if low is None or high is None:
|
||||
raise RuntimeError(
|
||||
f"Trying to get bisect range when it is not set: subsystem {subsystem_name}"
|
||||
)
|
||||
|
||||
return low, high
|
||||
|
||||
@classmethod
|
||||
def delete_bisect_status(cls) -> None:
|
||||
if os.path.exists(cls.get_dir()):
|
||||
shutil.rmtree(cls.get_dir())
|
||||
print("Bisection status deleted.")
|
||||
else:
|
||||
print("No bisection status found.")
|
||||
|
||||
@classmethod
|
||||
def get_system_counter(cls, name: str, increment: bool = True) -> int:
|
||||
global subsystem_call_counter
|
||||
curr = subsystem_call_counter[name]
|
||||
if increment:
|
||||
subsystem_call_counter[name] += 1
|
||||
return curr
|
||||
|
||||
@classmethod
|
||||
def disable_subsystem(
|
||||
cls,
|
||||
backend: str,
|
||||
subsystem: str,
|
||||
debug_info: Optional[Callable[[], str]] = None,
|
||||
) -> bool:
|
||||
if not cls.bisection_enabled:
|
||||
return False
|
||||
|
||||
if cls.get_backend() != backend:
|
||||
return False
|
||||
|
||||
if cls.get_subsystem() != subsystem:
|
||||
return False
|
||||
|
||||
if val := get_env_val("TORCH_BISECT_MAX"):
|
||||
counter = cls.get_system_counter(subsystem, increment=True)
|
||||
return counter > int(val)
|
||||
|
||||
run_state = cls.get_run_state(backend, subsystem)
|
||||
if run_state == "test_disable":
|
||||
# First run, disable completely
|
||||
return True
|
||||
elif run_state == "find_max_bounds":
|
||||
# Second run, update bisection range and return True to enable the subsystem
|
||||
cls.update_bisect_range(
|
||||
backend,
|
||||
subsystem,
|
||||
0,
|
||||
cls.get_system_counter(subsystem, increment=True),
|
||||
)
|
||||
return False
|
||||
else:
|
||||
assert run_state == "bisect"
|
||||
# If the environment variable is not set, use the bisection range midpoint
|
||||
low, high = cls.get_bisect_range(backend, subsystem)
|
||||
# if high - low <= 2:
|
||||
midpoint = (low + high) // 2
|
||||
call_counter = cls.get_system_counter(subsystem)
|
||||
|
||||
if (
|
||||
call_counter >= low
|
||||
and call_counter <= high
|
||||
and (low - high) <= 2
|
||||
and debug_info is not None
|
||||
):
|
||||
call_counter_debug_info[call_counter] = debug_info()
|
||||
|
||||
return call_counter > midpoint
|
||||
|
||||
@classmethod
|
||||
def advance_subsystem(cls, curr_backend: str, curr_subsystem: str) -> Optional[str]:
|
||||
"""
|
||||
Tries to move to the next subsystem within the current system.
|
||||
"""
|
||||
print(f"Disabling {curr_subsystem} did not fix the issue.")
|
||||
|
||||
current_subsystems = BACKENDS[curr_backend]
|
||||
current_subsystem_index = current_subsystems.index(curr_subsystem)
|
||||
|
||||
if current_subsystem_index < len(current_subsystems) - 1:
|
||||
curr_subsystem = current_subsystems[current_subsystem_index + 1]
|
||||
cls.update_bisect_status(curr_backend, curr_subsystem)
|
||||
cls.update_run_state(curr_backend, curr_subsystem, "test_disable")
|
||||
print(f"Moving to the next subsystem: {curr_backend} - {curr_subsystem}")
|
||||
return curr_subsystem
|
||||
else:
|
||||
print(
|
||||
f"All subsystems in {curr_backend} have been checked. The issue is not in this system."
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def advance_backend(cls, curr_backend: str) -> Optional[str]:
|
||||
"""
|
||||
Tries Move to the next backend.
|
||||
"""
|
||||
current_system_index = list(BACKENDS.keys()).index(curr_backend)
|
||||
|
||||
if current_system_index < len(BACKENDS) - 1:
|
||||
curr_backend = list(BACKENDS.keys())[current_system_index + 1]
|
||||
cls.update_bisect_status(curr_backend, "")
|
||||
print(f"Moving to the next system: {curr_backend}")
|
||||
return curr_backend
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def perform_bisection(
|
||||
cls,
|
||||
curr_backend: str,
|
||||
curr_subsystem: str,
|
||||
fn: Callable[[], bool],
|
||||
cli_interface: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
Perform the bisection process for the current system and subsystem. Returns True if the issue is found, False otherwise.
|
||||
"""
|
||||
while True:
|
||||
run_state = cls.get_run_state(curr_backend, curr_subsystem)
|
||||
reset_counters()
|
||||
if run_state == "test_disable":
|
||||
if not fn():
|
||||
next_subsystem = cls.advance_subsystem(curr_backend, curr_subsystem)
|
||||
if not next_subsystem:
|
||||
return False
|
||||
curr_subsystem = next_subsystem
|
||||
else:
|
||||
# breakpoint()
|
||||
print(
|
||||
f"Disabling {curr_subsystem} fixed the issue. Starting bisect by getting upper bound."
|
||||
)
|
||||
cls.update_run_state(
|
||||
curr_backend, curr_subsystem, "find_max_bounds"
|
||||
)
|
||||
elif run_state == "find_max_bounds":
|
||||
if fn():
|
||||
raise RuntimeError(
|
||||
f"Function succeeded with 'find_max_bounds' status for {curr_backend} - {curr_subsystem}."
|
||||
)
|
||||
else:
|
||||
_, high = cls.get_bisect_range(curr_backend, curr_subsystem)
|
||||
print(f"Upper bound of {high} found for {curr_backend}.")
|
||||
cls.update_run_state(curr_backend, curr_subsystem, "bisect")
|
||||
elif run_state == "bisect":
|
||||
low, high = cls.get_bisect_range(curr_backend, curr_subsystem)
|
||||
midpoint = (low + high) // 2
|
||||
print(
|
||||
f"Bisecting {curr_backend} - {curr_subsystem} (Range: [{low}, {high}], Midpoint: {midpoint})"
|
||||
)
|
||||
if fn():
|
||||
cls.update_bisect_range(
|
||||
curr_backend, curr_subsystem, midpoint + 1, high
|
||||
)
|
||||
else:
|
||||
cls.update_bisect_range(curr_backend, curr_subsystem, low, midpoint)
|
||||
low, high = cls.get_bisect_range(curr_backend, curr_subsystem)
|
||||
if low == high:
|
||||
print(
|
||||
f"Binary search completed for {curr_backend} - {curr_subsystem}. The bisect number is {low}. "
|
||||
f"Debug info: {call_counter_debug_info.get(low, 'not found')}"
|
||||
)
|
||||
return True
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected run_state {run_state}")
|
||||
|
||||
if cli_interface:
|
||||
sys.exit(0)
|
||||
|
||||
@classmethod
|
||||
def initialize_system(cls) -> None:
|
||||
curr_backend = next(iter(BACKENDS.keys()))
|
||||
curr_subsystem = ""
|
||||
cls.update_bisect_status(curr_backend, curr_subsystem)
|
||||
print(f"Starting bisection process with system: {curr_backend}")
|
||||
|
||||
@classmethod
|
||||
def do_bisect(
|
||||
cls, fn: Callable[[], bool], cli_interface: bool = False
|
||||
) -> Optional[BisectionResult]:
|
||||
if not cli_interface:
|
||||
bisection_enabled_orig = cls.bisection_enabled
|
||||
cls.delete_bisect_status()
|
||||
cls.bisection_enabled = True
|
||||
|
||||
# TODO - cli interface, and in-process different directories
|
||||
class DisableBisect:
|
||||
def __del__(self) -> None:
|
||||
cls.bisection_enabled = bisection_enabled_orig
|
||||
cls.delete_bisect_status()
|
||||
|
||||
cleanup = DisableBisect()
|
||||
|
||||
curr_backend = cls.get_backend()
|
||||
curr_subsystem = cls.get_subsystem()
|
||||
|
||||
if not curr_backend:
|
||||
cls.initialize_system()
|
||||
curr_backend = cls.get_backend()
|
||||
curr_subsystem = cls.get_subsystem()
|
||||
|
||||
while True:
|
||||
assert curr_backend is not None
|
||||
reset_counters()
|
||||
if curr_subsystem:
|
||||
result = cls.perform_bisection(
|
||||
curr_backend, curr_subsystem, fn, cli_interface=cli_interface
|
||||
)
|
||||
if result:
|
||||
curr_subsystem = cls.get_subsystem()
|
||||
assert curr_subsystem is not None
|
||||
low, _ = cls.get_bisect_range(curr_backend, curr_subsystem)
|
||||
return BisectionResult(
|
||||
curr_backend,
|
||||
curr_subsystem,
|
||||
low,
|
||||
call_counter_debug_info.get(low, None),
|
||||
)
|
||||
|
||||
next_subsystem = cls.advance_subsystem(curr_backend, curr_subsystem)
|
||||
if not next_subsystem:
|
||||
print(
|
||||
f"The issue is in the {curr_backend} system, but could not identify subsystem."
|
||||
)
|
||||
assert curr_backend is not None
|
||||
return BisectionResult(curr_backend)
|
||||
|
||||
curr_subsystem = next_subsystem
|
||||
else:
|
||||
if fn():
|
||||
next_backend = cls.advance_backend(curr_backend)
|
||||
if not next_backend:
|
||||
print("All systems have been checked.")
|
||||
return None
|
||||
|
||||
curr_backend = next_backend
|
||||
else:
|
||||
current_subsystems = BACKENDS[curr_backend]
|
||||
if current_subsystems:
|
||||
curr_subsystem = current_subsystems[0]
|
||||
cls.update_bisect_status(curr_backend, curr_subsystem)
|
||||
cls.update_run_state(
|
||||
curr_backend, curr_subsystem, "test_disable"
|
||||
)
|
||||
print(
|
||||
f"The issue is in the {curr_backend} system. Moving to the first subsystem: {curr_subsystem}"
|
||||
)
|
||||
else:
|
||||
print(f"The issue is in the {curr_backend} system.")
|
||||
return BisectionResult(curr_backend)
|
||||
|
||||
if cli_interface:
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def command_line_usage() -> None:
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python bisect_update.py <start|end|good|bad>")
|
||||
sys.exit(1)
|
||||
|
||||
bisection_manager = BisectionManager()
|
||||
command = sys.argv[1]
|
||||
|
||||
if command == "end":
|
||||
bisection_manager.delete_bisect_status()
|
||||
sys.exit(0)
|
||||
|
||||
if command == "start":
|
||||
bisection_manager.delete_bisect_status()
|
||||
bisection_manager.initialize_system()
|
||||
sys.exit(0)
|
||||
|
||||
if command not in ["good", "bad"]:
|
||||
print("Invalid command. Must be 'good', 'bad', 'start', or 'end'.")
|
||||
sys.exit(1)
|
||||
|
||||
def test_function() -> bool:
|
||||
return command == "good"
|
||||
|
||||
if not bisection_manager.get_backend():
|
||||
raise ValueError("Must call start prior to good or bad")
|
||||
|
||||
bisection_manager.do_bisect(test_function, cli_interface=True)
|
||||
|
||||
|
||||
def get_is_bisection_enabled() -> bool:
|
||||
return (
|
||||
BisectionManager.get_subsystem() is not None
|
||||
or BisectionManager.get_backend() is not None
|
||||
)
|
||||
|
||||
|
||||
BisectionManager.bisection_enabled = get_is_bisection_enabled()
|
||||
|
||||
if __name__ == "__main__":
|
||||
command_line_usage()
|
||||
@ -1285,6 +1285,12 @@ class FxGraphCache:
|
||||
"Freezing may introduce constants that aren't static across runs"
|
||||
)
|
||||
|
||||
from torch._inductor.bisect_helper import BisectionManager
|
||||
|
||||
if BisectionManager.bisection_enabled:
|
||||
log.debug("dont cache graph when bisect enabled")
|
||||
raise BypassFxGraphCache
|
||||
|
||||
# The treatment of guards in the caching implementation requires that
|
||||
# we have a shape env.
|
||||
if FxGraphCache._get_shape_env() is None:
|
||||
|
||||
@ -4,7 +4,7 @@ import itertools
|
||||
import logging
|
||||
import operator
|
||||
from collections import Counter, defaultdict
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
import torch
|
||||
import torch._inductor as inductor
|
||||
@ -65,6 +65,19 @@ pass_patterns = [
|
||||
]
|
||||
|
||||
|
||||
def apply_pass(pass_fn: Callable[[], object], name: Optional[str] = None) -> None:
|
||||
# TODO - we should just make this part of GraphTransformObserver
|
||||
from torch._inductor.bisect_helper import BisectionManager
|
||||
|
||||
debug_info: Optional[Callable[[], str]] = None
|
||||
if name is not None:
|
||||
debug_info = lambda: name # noqa: E731
|
||||
|
||||
if BisectionManager.disable_subsystem("inductor", "post_grad_passes", debug_info):
|
||||
return
|
||||
pass_fn()
|
||||
|
||||
|
||||
def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
||||
"""
|
||||
Passes that run on after grad. This is called once on the forwards
|
||||
@ -80,23 +93,28 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
||||
if is_inference and config.reorder_for_locality:
|
||||
reorder_for_locality(gm.graph)
|
||||
apply_pass(lambda: reorder_for_locality(gm.graph), "reorder_for_locality")
|
||||
|
||||
fake_tensor_updater = FakeTensorUpdater(gm.graph)
|
||||
|
||||
if config.post_grad_custom_pre_pass is not None:
|
||||
if post_grad_custom_pre_pass := config.post_grad_custom_pre_pass:
|
||||
with GraphTransformObserver(
|
||||
gm, "post_grad_custom_pre_pass", config.trace.log_url_for_graph_xform
|
||||
):
|
||||
config.post_grad_custom_pre_pass(gm.graph)
|
||||
apply_pass(
|
||||
lambda: post_grad_custom_pre_pass(gm.graph), "post_grad_custom_pre_pass"
|
||||
)
|
||||
|
||||
if config.pattern_matcher:
|
||||
lazy_init()
|
||||
optimus_scuba_log["before_recompile_post_grad"] = upload_graph(gm.graph)
|
||||
group_batch_fusion_passes(gm.graph, pre_grad=False)
|
||||
remove_noop_ops(gm.graph)
|
||||
for patterns in pass_patterns:
|
||||
patterns.apply(gm.graph) # type: ignore[arg-type]
|
||||
apply_pass(
|
||||
lambda: group_batch_fusion_passes(gm.graph, pre_grad=False),
|
||||
"group_batch_fusion_passes",
|
||||
)
|
||||
apply_pass(lambda: remove_noop_ops(gm.graph), "remove_noop_ops")
|
||||
for i, patterns in enumerate(pass_patterns):
|
||||
apply_pass(lambda: patterns.apply(gm.graph), f"pass_pattern_{i}") # type: ignore[arg-type]
|
||||
for pass_name in config.post_grad_fusion_options:
|
||||
# skip all patterns for group batch fusions
|
||||
if pass_name in POST_GRAD_FUSIONS:
|
||||
@ -105,7 +123,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
||||
inductor_before_change = save_inductor_dict(
|
||||
[pattern_matcher_pass.pass_name]
|
||||
)
|
||||
pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type]
|
||||
apply_pass(lambda: pattern_matcher_pass.apply(gm.graph), pass_name) # type: ignore[arg-type]
|
||||
if not is_same_dict(counters["inductor"], inductor_before_change):
|
||||
optimus_scuba_log[
|
||||
f"{pattern_matcher_pass.pass_name}_post_grad"
|
||||
@ -117,30 +135,40 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
||||
micro_pipeline_tp_pass(gm.graph)
|
||||
|
||||
if config._fuse_ddp_communication:
|
||||
fuse_ddp_communication(
|
||||
apply_pass(
|
||||
lambda: fuse_ddp_communication(
|
||||
gm.graph,
|
||||
config._fuse_ddp_communication_passes,
|
||||
config._fuse_ddp_bucket_size,
|
||||
),
|
||||
"fuse_ddp_communication",
|
||||
)
|
||||
|
||||
if config.post_grad_custom_post_pass is not None:
|
||||
if post_grad_custom_post_pass := config.post_grad_custom_post_pass:
|
||||
with GraphTransformObserver(
|
||||
gm, "post_grad_custom_post_pass", config.trace.log_url_for_graph_xform
|
||||
):
|
||||
config.post_grad_custom_post_pass(gm.graph)
|
||||
apply_pass(
|
||||
lambda: post_grad_custom_post_pass(gm.graph),
|
||||
"post_grad_custom_post_pass",
|
||||
)
|
||||
|
||||
stable_topological_sort(gm.graph)
|
||||
apply_pass(lambda: stable_topological_sort(gm.graph), "stable_sort")
|
||||
|
||||
move_constructors_to_gpu(gm.graph)
|
||||
apply_pass(lambda: move_constructors_to_gpu(gm.graph), "move_constructors_to_cuda")
|
||||
|
||||
fake_tensor_updater.incremental_update()
|
||||
|
||||
# Keep these last, since they introduces mutation. Look at
|
||||
# ./fx_passes/README.md for a discussion of mutation invariants.
|
||||
reinplace_inplaceable_ops(gm.graph)
|
||||
decompose_auto_functionalized(gm.graph)
|
||||
apply_pass(lambda: reinplace_inplaceable_ops(gm.graph), "reinplace_inplaceable_ops")
|
||||
apply_pass(
|
||||
lambda: decompose_auto_functionalized(gm.graph), "decompose_auto_functionalized"
|
||||
)
|
||||
|
||||
comms.reinplace_fsdp_all_gather(gm.graph)
|
||||
apply_pass(
|
||||
lambda: comms.reinplace_fsdp_all_gather(gm.graph), "reinplace_fsdp_all_gather"
|
||||
)
|
||||
|
||||
gm.recompile()
|
||||
optimus_scuba_log["after_recompile_post_grad"] = upload_graph(gm.graph)
|
||||
|
||||
@ -1304,6 +1304,8 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
def debug(msg: str) -> None:
|
||||
log.debug("lowering %s %s", LazyString(n.format_node), msg)
|
||||
|
||||
from torch._inductor.bisect_helper import BisectionManager
|
||||
|
||||
buffer_watermark = len(self.buffers)
|
||||
operation_watermark = len(self.operations)
|
||||
|
||||
@ -1320,7 +1322,12 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
if (
|
||||
n.op == "call_function"
|
||||
and n.target is not operator.getitem
|
||||
and fallback_node_due_to_unsupported_type(n)
|
||||
and (
|
||||
fallback_node_due_to_unsupported_type(n)
|
||||
or BisectionManager.disable_subsystem(
|
||||
"inductor", "lowerings", lambda: repr(n)
|
||||
)
|
||||
)
|
||||
):
|
||||
debug("fallback_handler")
|
||||
result = fallback_handler(n.target, add_to_fallback_set=False)(
|
||||
|
||||
23
torch/_inductor/runtime/cache_dir_utils.py
Normal file
23
torch/_inductor/runtime/cache_dir_utils.py
Normal file
@ -0,0 +1,23 @@
|
||||
import getpass
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
|
||||
|
||||
# Factoring out to file without torch dependencies
|
||||
|
||||
|
||||
def cache_dir() -> str:
|
||||
cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR")
|
||||
if cache_dir is None:
|
||||
os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir = default_cache_dir()
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
return cache_dir
|
||||
|
||||
|
||||
def default_cache_dir() -> str:
|
||||
sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser())
|
||||
return os.path.join(
|
||||
tempfile.gettempdir(),
|
||||
"torchinductor_" + sanitized_username,
|
||||
)
|
||||
@ -3,13 +3,13 @@ from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import getpass
|
||||
import operator
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
from torch._inductor.runtime.cache_dir_utils import ( # noqa: F401
|
||||
cache_dir,
|
||||
default_cache_dir,
|
||||
)
|
||||
|
||||
|
||||
def conditional_product(*args):
|
||||
@ -86,22 +86,6 @@ def get_max_y_grid():
|
||||
return 65535
|
||||
|
||||
|
||||
def cache_dir() -> str:
|
||||
cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR")
|
||||
if cache_dir is None:
|
||||
os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir = default_cache_dir()
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
return cache_dir
|
||||
|
||||
|
||||
def default_cache_dir():
|
||||
sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser())
|
||||
return os.path.join(
|
||||
tempfile.gettempdir(),
|
||||
"torchinductor_" + sanitized_username,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
import colorama
|
||||
|
||||
|
||||
@ -2202,7 +2202,14 @@ def maybe_handle_decomp(
|
||||
args: Tuple[object, ...],
|
||||
kwargs: Dict[str, object],
|
||||
) -> object:
|
||||
from torch._inductor.bisect_helper import BisectionManager
|
||||
|
||||
if op in CURRENT_DECOMPOSITION_TABLE:
|
||||
if BisectionManager.disable_subsystem(
|
||||
"aot_eager_decomp_partition", "decomposition", lambda: repr(op)
|
||||
):
|
||||
return NotImplemented
|
||||
|
||||
with proxy_mode:
|
||||
proxy_mode.decomp_layers += 1
|
||||
out = CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user