Class rename (#139490)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139490
Approved by: https://github.com/exclamaforte, https://github.com/zou3519
ghstack dependencies: #139295
This commit is contained in:
eellison
2024-11-01 12:20:56 -07:00
committed by PyTorch MergeBot
parent c95adb9c5b
commit ee2f8a50d3
9 changed files with 71 additions and 36 deletions

View File

@ -6,10 +6,10 @@ from importlib import import_module
import torch
import torch._prims_common as utils
from torch._dynamo.test_case import TestCase
from torch._dynamo.utils import preserve_rng_state
from torch._inductor import config
from torch._inductor.bisect_helper import BisectionManager
from torch._inductor.compiler_bisector import CompilerBisector
from torch._inductor.test_case import TestCase
from torch.library import _scoped_library, Library
from torch.testing._internal.inductor_utils import HAS_CUDA
@ -91,7 +91,7 @@ class TestCompilerBisector(TestCase):
return not out_compiled.isnan().any()
out = BisectionManager.do_bisect(test_fn)
out = CompilerBisector.do_bisect(test_fn)
self.assertEqual(out.backend, "aot_eager_decomp_partition")
self.assertEqual(out.subsystem, "decomposition")
self.assertEqual(out.bisect_number, 1)
@ -124,7 +124,7 @@ class TestCompilerBisector(TestCase):
return torch.allclose(out, out_c)
out = BisectionManager.do_bisect(test_fn)
out = CompilerBisector.do_bisect(test_fn)
self.assertEqual(out.backend, "inductor")
self.assertEqual(out.subsystem, "joint_graph_passes")
self.assertEqual(out.bisect_number, 4)
@ -144,7 +144,7 @@ class TestCompilerBisector(TestCase):
return torch.allclose(out, out_c)
out = BisectionManager.do_bisect(test_fn)
out = CompilerBisector.do_bisect(test_fn)
self.assertEqual(out.backend, "inductor")
self.assertEqual(out.subsystem, "inductor_fallback_random")
self.assertTrue("inductor_fallback_random" in out.debug_info)
@ -192,7 +192,7 @@ class TestCompilerBisector(TestCase):
return False
return True
out = BisectionManager.do_bisect(test_fn)
out = CompilerBisector.do_bisect(test_fn)
self.assertEqual(out.backend, "aot_eager_decomp_partition_crossref")
def test_emulate_precision_casts(self):
@ -213,7 +213,7 @@ class TestCompilerBisector(TestCase):
return torch.equal(eager_scale, compile_scale)
out = BisectionManager.do_bisect(test_fn)
out = CompilerBisector.do_bisect(test_fn)
self.assertEqual(out.backend, "inductor")
self.assertEqual(out.subsystem, "inductor_emulate_precision_casts")
@ -229,7 +229,7 @@ class TestCompilerBisector(TestCase):
return torch.allclose(torch.compile(my_func)(inp), my_func(inp))
out = BisectionManager.do_bisect(test_fn)
out = CompilerBisector.do_bisect(test_fn)
self.assertEqual(out.backend, "inductor")
self.assertEqual(out.subsystem, "lowerings")
self.assertEqual(out.bisect_number, 2)
@ -240,7 +240,7 @@ class TestCompilerBisector(TestCase):
def test_fn():
return False
out = BisectionManager.do_bisect(test_fn)
out = CompilerBisector.do_bisect(test_fn)
self.assertEqual(out.backend, "eager")
self.assertEqual(out.subsystem, None)

View File

@ -2255,9 +2255,9 @@ class _TorchCompileInductorWrapper:
)
def apply_options(self, options: _Optional[_Dict[str, _Any]]):
from torch._inductor.bisect_helper import BisectionManager
from torch._inductor.compiler_bisector import CompilerBisector
if bisect_changes := BisectionManager.get_config_change("inductor"):
if bisect_changes := CompilerBisector.get_config_change("inductor"):
options = {} if options is None else options
options = (
{**bisect_changes} if options is None else {**options, **bisect_changes} # type: ignore[dict-item]
@ -2496,9 +2496,9 @@ def compile(
if mode is None and options is None:
mode = "default"
from torch._inductor.bisect_helper import BisectionManager
from torch._inductor.compiler_bisector import CompilerBisector
if bisect_backend := BisectionManager.get_backend():
if bisect_backend := CompilerBisector.get_backend():
backend = bisect_backend
if backend == "inductor":

View File

@ -118,21 +118,27 @@ def boxed_nop(fx_g, example_inputs):
return run
def fake_crossref_boxed_nop(fx_g, example_inputs):
def fake_crossref_boxed_nop(fx_g, example_inputs, ignore_op_fn=None):
def run(args):
with torch._subclasses.CrossRefFakeMode():
with torch._subclasses.CrossRefFakeMode(ignore_op_fn):
return torch.fx.Interpreter(fx_g).boxed_run(args)
run._boxed_call = True
return run
def ignore_builtins(op: torch._ops.OpOverload) -> bool:
return op.namespace in ("aten", "prims", "prim")
def get_nop_func():
return (
boxed_nop
if not torch._functorch.config.fake_tensor_crossref
else fake_crossref_boxed_nop
)
if not torch._functorch.config.fake_tensor_crossref:
return boxed_nop
elif torch._functorch.config.fake_tensor_crossref == "all":
return fake_crossref_boxed_nop
else:
assert torch._functorch.config.fake_tensor_crossref == "custom_ops"
return functools.partial(fake_crossref_boxed_nop, ignore_op_fn=ignore_builtins)
# Useful for debugging purpose
@ -172,10 +178,10 @@ def aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs):
"aot_eager_decomp_partition backend ignoring extra kwargs %s", kwargs
)
from torch._inductor.bisect_helper import BisectionManager
from torch._inductor.compiler_bisector import CompilerBisector
config_patches = {"unlift_effect_tokens": True}
if bisect_changes := BisectionManager.get_config_change(
if bisect_changes := CompilerBisector.get_config_change(
"aot_eager_decomp_partition"
):
config_patches.update(bisect_changes)
@ -201,7 +207,15 @@ register_backend(
def aot_eager_decomp_partition_crossref(gm, fake_tensor_inputs, **kwargs):
with functorch_config.patch(fake_tensor_crossref=True):
# if the config is set, respect it, otherwise only test custom_ops.
# custom_op bad metas always manifest as an error whereas aten will only sometimes.
# by default, use the less noisy option
config_val = (
"custom_ops"
if not functorch_config.fake_tensor_crossref
else functorch_config.fake_tensor_crossref
)
with functorch_config.patch(fake_tensor_crossref=config_val):
return aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs)

View File

@ -164,6 +164,7 @@ unlift_effect_tokens = False
# Run aot eager decomp partition with CrossRefFakeMode
# options = False, "all", "custom_ops"
fake_tensor_crossref = False
# This mode specifies that we should also keep track of the real

View File

@ -1329,9 +1329,9 @@ class FxGraphCache:
"static across runs"
)
from torch._inductor.bisect_helper import BisectionManager
from torch._inductor.compiler_bisector import CompilerBisector
if BisectionManager.bisection_enabled:
if CompilerBisector.bisection_enabled:
log.debug("dont cache graph when bisect enabled")
raise BypassFxGraphCache

View File

@ -95,7 +95,23 @@ class BisectionResult:
debug_info: Optional[str] = None
class BisectionManager:
class CompilerBisector:
"""
This class iteratively runs torch.compile backends (eager, aot_eager, inductor) to find the
first backend that can repro an issue.
Once it discovers the offending backend it will iteratively disable subsystems within the backend.
For subsystems which are applied repeatedly, such as the number of post grad passes or number
of lowering of nodes to inductor ir, it will bisect to find the offending application.
The idiomatic way to run it is with `do_bisect`. You can also use it by setting the env flags
`TORCH_BISECT_BACKEND`, `TORCH_BISECT_SUBSYSTEM` and `TORCH_BISECT_MAX`.
It also supports a CLI interface, although this is less well tested.
You must run python compiler_bisector.py [start | good | bad | end]
"""
bisection_enabled: bool = False
@classmethod
@ -461,6 +477,10 @@ class BisectionManager:
def do_bisect(
cls, fn: Callable[[], bool], cli_interface: bool = False
) -> Optional[BisectionResult]:
"""
Run fn repeatedly attempting to bisect torch.compile. fn should return True on success and False on failure.
"""
if not cli_interface:
bisection_enabled_orig = cls.bisection_enabled
cls.delete_bisect_status()
@ -557,7 +577,7 @@ def command_line_usage() -> None:
print("Usage: python bisect_update.py <start|end|good|bad>")
sys.exit(1)
bisection_manager = BisectionManager()
bisection_manager = CompilerBisector()
command = sys.argv[1]
if command == "end":
@ -584,12 +604,12 @@ def command_line_usage() -> None:
def get_is_bisection_enabled() -> bool:
return (
BisectionManager.get_subsystem() is not None
or BisectionManager.get_backend() is not None
CompilerBisector.get_subsystem() is not None
or CompilerBisector.get_backend() is not None
)
BisectionManager.bisection_enabled = get_is_bisection_enabled()
CompilerBisector.bisection_enabled = get_is_bisection_enabled()
if __name__ == "__main__":
command_line_usage()

View File

@ -1346,7 +1346,7 @@ class GraphLowering(torch.fx.Interpreter):
def debug(msg: str) -> None:
log.debug("lowering %s %s", LazyString(n.format_node), msg)
from torch._inductor.bisect_helper import BisectionManager
from torch._inductor.compiler_bisector import CompilerBisector
buffer_watermark = len(self.buffers)
operation_watermark = len(self.operations)
@ -1366,7 +1366,7 @@ class GraphLowering(torch.fx.Interpreter):
and n.target is not operator.getitem
and (
fallback_node_due_to_unsupported_type(n)
or BisectionManager.disable_subsystem(
or CompilerBisector.disable_subsystem(
"inductor", "lowerings", lambda: repr(n)
)
)

View File

@ -2227,10 +2227,10 @@ def maybe_handle_decomp(
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> object:
from torch._inductor.bisect_helper import BisectionManager
from torch._inductor.compiler_bisector import CompilerBisector
if op in CURRENT_DECOMPOSITION_TABLE:
if BisectionManager.disable_subsystem(
if CompilerBisector.disable_subsystem(
"aot_eager_decomp_partition", "decomposition", lambda: repr(op)
):
return NotImplemented

View File

@ -76,9 +76,9 @@ class GraphTransformObserver:
return False
debug_info = lambda: self.passname # noqa: E731
from torch._inductor.bisect_helper import BisectionManager
from torch._inductor.compiler_bisector import CompilerBisector
return BisectionManager.disable_subsystem(
return CompilerBisector.disable_subsystem(
"inductor", self.subsystem, debug_info
)