mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Class rename (#139490)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139490 Approved by: https://github.com/exclamaforte, https://github.com/zou3519 ghstack dependencies: #139295
This commit is contained in:
committed by
PyTorch MergeBot
parent
c95adb9c5b
commit
ee2f8a50d3
@ -6,10 +6,10 @@ from importlib import import_module
|
||||
|
||||
import torch
|
||||
import torch._prims_common as utils
|
||||
from torch._dynamo.test_case import TestCase
|
||||
from torch._dynamo.utils import preserve_rng_state
|
||||
from torch._inductor import config
|
||||
from torch._inductor.bisect_helper import BisectionManager
|
||||
from torch._inductor.compiler_bisector import CompilerBisector
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch.library import _scoped_library, Library
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
|
||||
@ -91,7 +91,7 @@ class TestCompilerBisector(TestCase):
|
||||
|
||||
return not out_compiled.isnan().any()
|
||||
|
||||
out = BisectionManager.do_bisect(test_fn)
|
||||
out = CompilerBisector.do_bisect(test_fn)
|
||||
self.assertEqual(out.backend, "aot_eager_decomp_partition")
|
||||
self.assertEqual(out.subsystem, "decomposition")
|
||||
self.assertEqual(out.bisect_number, 1)
|
||||
@ -124,7 +124,7 @@ class TestCompilerBisector(TestCase):
|
||||
|
||||
return torch.allclose(out, out_c)
|
||||
|
||||
out = BisectionManager.do_bisect(test_fn)
|
||||
out = CompilerBisector.do_bisect(test_fn)
|
||||
self.assertEqual(out.backend, "inductor")
|
||||
self.assertEqual(out.subsystem, "joint_graph_passes")
|
||||
self.assertEqual(out.bisect_number, 4)
|
||||
@ -144,7 +144,7 @@ class TestCompilerBisector(TestCase):
|
||||
|
||||
return torch.allclose(out, out_c)
|
||||
|
||||
out = BisectionManager.do_bisect(test_fn)
|
||||
out = CompilerBisector.do_bisect(test_fn)
|
||||
self.assertEqual(out.backend, "inductor")
|
||||
self.assertEqual(out.subsystem, "inductor_fallback_random")
|
||||
self.assertTrue("inductor_fallback_random" in out.debug_info)
|
||||
@ -192,7 +192,7 @@ class TestCompilerBisector(TestCase):
|
||||
return False
|
||||
return True
|
||||
|
||||
out = BisectionManager.do_bisect(test_fn)
|
||||
out = CompilerBisector.do_bisect(test_fn)
|
||||
self.assertEqual(out.backend, "aot_eager_decomp_partition_crossref")
|
||||
|
||||
def test_emulate_precision_casts(self):
|
||||
@ -213,7 +213,7 @@ class TestCompilerBisector(TestCase):
|
||||
|
||||
return torch.equal(eager_scale, compile_scale)
|
||||
|
||||
out = BisectionManager.do_bisect(test_fn)
|
||||
out = CompilerBisector.do_bisect(test_fn)
|
||||
self.assertEqual(out.backend, "inductor")
|
||||
self.assertEqual(out.subsystem, "inductor_emulate_precision_casts")
|
||||
|
||||
@ -229,7 +229,7 @@ class TestCompilerBisector(TestCase):
|
||||
|
||||
return torch.allclose(torch.compile(my_func)(inp), my_func(inp))
|
||||
|
||||
out = BisectionManager.do_bisect(test_fn)
|
||||
out = CompilerBisector.do_bisect(test_fn)
|
||||
self.assertEqual(out.backend, "inductor")
|
||||
self.assertEqual(out.subsystem, "lowerings")
|
||||
self.assertEqual(out.bisect_number, 2)
|
||||
@ -240,7 +240,7 @@ class TestCompilerBisector(TestCase):
|
||||
def test_fn():
|
||||
return False
|
||||
|
||||
out = BisectionManager.do_bisect(test_fn)
|
||||
out = CompilerBisector.do_bisect(test_fn)
|
||||
self.assertEqual(out.backend, "eager")
|
||||
self.assertEqual(out.subsystem, None)
|
||||
|
||||
|
@ -2255,9 +2255,9 @@ class _TorchCompileInductorWrapper:
|
||||
)
|
||||
|
||||
def apply_options(self, options: _Optional[_Dict[str, _Any]]):
|
||||
from torch._inductor.bisect_helper import BisectionManager
|
||||
from torch._inductor.compiler_bisector import CompilerBisector
|
||||
|
||||
if bisect_changes := BisectionManager.get_config_change("inductor"):
|
||||
if bisect_changes := CompilerBisector.get_config_change("inductor"):
|
||||
options = {} if options is None else options
|
||||
options = (
|
||||
{**bisect_changes} if options is None else {**options, **bisect_changes} # type: ignore[dict-item]
|
||||
@ -2496,9 +2496,9 @@ def compile(
|
||||
if mode is None and options is None:
|
||||
mode = "default"
|
||||
|
||||
from torch._inductor.bisect_helper import BisectionManager
|
||||
from torch._inductor.compiler_bisector import CompilerBisector
|
||||
|
||||
if bisect_backend := BisectionManager.get_backend():
|
||||
if bisect_backend := CompilerBisector.get_backend():
|
||||
backend = bisect_backend
|
||||
|
||||
if backend == "inductor":
|
||||
|
@ -118,21 +118,27 @@ def boxed_nop(fx_g, example_inputs):
|
||||
return run
|
||||
|
||||
|
||||
def fake_crossref_boxed_nop(fx_g, example_inputs):
|
||||
def fake_crossref_boxed_nop(fx_g, example_inputs, ignore_op_fn=None):
|
||||
def run(args):
|
||||
with torch._subclasses.CrossRefFakeMode():
|
||||
with torch._subclasses.CrossRefFakeMode(ignore_op_fn):
|
||||
return torch.fx.Interpreter(fx_g).boxed_run(args)
|
||||
|
||||
run._boxed_call = True
|
||||
return run
|
||||
|
||||
|
||||
def ignore_builtins(op: torch._ops.OpOverload) -> bool:
|
||||
return op.namespace in ("aten", "prims", "prim")
|
||||
|
||||
|
||||
def get_nop_func():
|
||||
return (
|
||||
boxed_nop
|
||||
if not torch._functorch.config.fake_tensor_crossref
|
||||
else fake_crossref_boxed_nop
|
||||
)
|
||||
if not torch._functorch.config.fake_tensor_crossref:
|
||||
return boxed_nop
|
||||
elif torch._functorch.config.fake_tensor_crossref == "all":
|
||||
return fake_crossref_boxed_nop
|
||||
else:
|
||||
assert torch._functorch.config.fake_tensor_crossref == "custom_ops"
|
||||
return functools.partial(fake_crossref_boxed_nop, ignore_op_fn=ignore_builtins)
|
||||
|
||||
|
||||
# Useful for debugging purpose
|
||||
@ -172,10 +178,10 @@ def aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs):
|
||||
"aot_eager_decomp_partition backend ignoring extra kwargs %s", kwargs
|
||||
)
|
||||
|
||||
from torch._inductor.bisect_helper import BisectionManager
|
||||
from torch._inductor.compiler_bisector import CompilerBisector
|
||||
|
||||
config_patches = {"unlift_effect_tokens": True}
|
||||
if bisect_changes := BisectionManager.get_config_change(
|
||||
if bisect_changes := CompilerBisector.get_config_change(
|
||||
"aot_eager_decomp_partition"
|
||||
):
|
||||
config_patches.update(bisect_changes)
|
||||
@ -201,7 +207,15 @@ register_backend(
|
||||
|
||||
|
||||
def aot_eager_decomp_partition_crossref(gm, fake_tensor_inputs, **kwargs):
|
||||
with functorch_config.patch(fake_tensor_crossref=True):
|
||||
# if the config is set, respect it, otherwise only test custom_ops.
|
||||
# custom_op bad metas always manifest as an error whereas aten will only sometimes.
|
||||
# by default, use the less noisy option
|
||||
config_val = (
|
||||
"custom_ops"
|
||||
if not functorch_config.fake_tensor_crossref
|
||||
else functorch_config.fake_tensor_crossref
|
||||
)
|
||||
with functorch_config.patch(fake_tensor_crossref=config_val):
|
||||
return aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs)
|
||||
|
||||
|
||||
|
@ -164,6 +164,7 @@ unlift_effect_tokens = False
|
||||
|
||||
|
||||
# Run aot eager decomp partition with CrossRefFakeMode
|
||||
# options = False, "all", "custom_ops"
|
||||
fake_tensor_crossref = False
|
||||
|
||||
# This mode specifies that we should also keep track of the real
|
||||
|
@ -1329,9 +1329,9 @@ class FxGraphCache:
|
||||
"static across runs"
|
||||
)
|
||||
|
||||
from torch._inductor.bisect_helper import BisectionManager
|
||||
from torch._inductor.compiler_bisector import CompilerBisector
|
||||
|
||||
if BisectionManager.bisection_enabled:
|
||||
if CompilerBisector.bisection_enabled:
|
||||
log.debug("dont cache graph when bisect enabled")
|
||||
raise BypassFxGraphCache
|
||||
|
||||
|
@ -95,7 +95,23 @@ class BisectionResult:
|
||||
debug_info: Optional[str] = None
|
||||
|
||||
|
||||
class BisectionManager:
|
||||
class CompilerBisector:
|
||||
"""
|
||||
This class iteratively runs torch.compile backends (eager, aot_eager, inductor) to find the
|
||||
first backend that can repro an issue.
|
||||
|
||||
Once it discovers the offending backend it will iteratively disable subsystems within the backend.
|
||||
For subsystems which are applied repeatedly, such as the number of post grad passes or number
|
||||
of lowering of nodes to inductor ir, it will bisect to find the offending application.
|
||||
|
||||
The idiomatic way to run it is with `do_bisect`. You can also use it by setting the env flags
|
||||
`TORCH_BISECT_BACKEND`, `TORCH_BISECT_SUBSYSTEM` and `TORCH_BISECT_MAX`.
|
||||
|
||||
It also supports a CLI interface, although this is less well tested.
|
||||
|
||||
You must run python compiler_bisector.py [start | good | bad | end]
|
||||
"""
|
||||
|
||||
bisection_enabled: bool = False
|
||||
|
||||
@classmethod
|
||||
@ -461,6 +477,10 @@ class BisectionManager:
|
||||
def do_bisect(
|
||||
cls, fn: Callable[[], bool], cli_interface: bool = False
|
||||
) -> Optional[BisectionResult]:
|
||||
"""
|
||||
Run fn repeatedly attempting to bisect torch.compile. fn should return True on success and False on failure.
|
||||
"""
|
||||
|
||||
if not cli_interface:
|
||||
bisection_enabled_orig = cls.bisection_enabled
|
||||
cls.delete_bisect_status()
|
||||
@ -557,7 +577,7 @@ def command_line_usage() -> None:
|
||||
print("Usage: python bisect_update.py <start|end|good|bad>")
|
||||
sys.exit(1)
|
||||
|
||||
bisection_manager = BisectionManager()
|
||||
bisection_manager = CompilerBisector()
|
||||
command = sys.argv[1]
|
||||
|
||||
if command == "end":
|
||||
@ -584,12 +604,12 @@ def command_line_usage() -> None:
|
||||
|
||||
def get_is_bisection_enabled() -> bool:
|
||||
return (
|
||||
BisectionManager.get_subsystem() is not None
|
||||
or BisectionManager.get_backend() is not None
|
||||
CompilerBisector.get_subsystem() is not None
|
||||
or CompilerBisector.get_backend() is not None
|
||||
)
|
||||
|
||||
|
||||
BisectionManager.bisection_enabled = get_is_bisection_enabled()
|
||||
CompilerBisector.bisection_enabled = get_is_bisection_enabled()
|
||||
|
||||
if __name__ == "__main__":
|
||||
command_line_usage()
|
@ -1346,7 +1346,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
def debug(msg: str) -> None:
|
||||
log.debug("lowering %s %s", LazyString(n.format_node), msg)
|
||||
|
||||
from torch._inductor.bisect_helper import BisectionManager
|
||||
from torch._inductor.compiler_bisector import CompilerBisector
|
||||
|
||||
buffer_watermark = len(self.buffers)
|
||||
operation_watermark = len(self.operations)
|
||||
@ -1366,7 +1366,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
and n.target is not operator.getitem
|
||||
and (
|
||||
fallback_node_due_to_unsupported_type(n)
|
||||
or BisectionManager.disable_subsystem(
|
||||
or CompilerBisector.disable_subsystem(
|
||||
"inductor", "lowerings", lambda: repr(n)
|
||||
)
|
||||
)
|
||||
|
@ -2227,10 +2227,10 @@ def maybe_handle_decomp(
|
||||
args: Tuple[object, ...],
|
||||
kwargs: Dict[str, object],
|
||||
) -> object:
|
||||
from torch._inductor.bisect_helper import BisectionManager
|
||||
from torch._inductor.compiler_bisector import CompilerBisector
|
||||
|
||||
if op in CURRENT_DECOMPOSITION_TABLE:
|
||||
if BisectionManager.disable_subsystem(
|
||||
if CompilerBisector.disable_subsystem(
|
||||
"aot_eager_decomp_partition", "decomposition", lambda: repr(op)
|
||||
):
|
||||
return NotImplemented
|
||||
|
@ -76,9 +76,9 @@ class GraphTransformObserver:
|
||||
return False
|
||||
|
||||
debug_info = lambda: self.passname # noqa: E731
|
||||
from torch._inductor.bisect_helper import BisectionManager
|
||||
from torch._inductor.compiler_bisector import CompilerBisector
|
||||
|
||||
return BisectionManager.disable_subsystem(
|
||||
return CompilerBisector.disable_subsystem(
|
||||
"inductor", self.subsystem, debug_info
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user