Revert "Add inductor backend to device interface; make minifier_tests more device agnostic (#151314)"

This reverts commit 77bc959fe122bfd131e339ca36cab445a1860806.

Reverted https://github.com/pytorch/pytorch/pull/151314 on behalf of https://github.com/atalman due to sorry change is faling internally ([comment](https://github.com/pytorch/pytorch/pull/151314#issuecomment-3229774015))
This commit is contained in:
PyTorch MergeBot
2025-08-27 21:21:19 +00:00
parent 38ed57d446
commit 014b98dd09
8 changed files with 47 additions and 191 deletions

View File

@ -5,7 +5,7 @@ from unittest.mock import patch
import torch._dynamo.config as dynamo_config
import torch._inductor.config as inductor_config
from torch._dynamo.test_minifier_common import MinifierTestBase
from torch._inductor.codegen.common import get_wrapper_codegen_for_device
from torch._inductor import config
from torch.export import load as export_load
from torch.testing._internal.common_utils import (
IS_JETSON,
@ -13,11 +13,7 @@ from torch.testing._internal.common_utils import (
skipIfXpu,
TEST_WITH_ASAN,
)
from torch.testing._internal.inductor_utils import (
backend_for_device,
GPU_TYPE,
try_patch_inductor_backend_config,
)
from torch.testing._internal.inductor_utils import GPU_TYPE
from torch.testing._internal.triton_utils import requires_gpu
@ -38,43 +34,27 @@ inner(torch.randn(20, 20).to("{device}"))
"""
self._run_full_test(run_code, "aot", expected_error, isolate=False)
@unittest.skipIf(
backend_for_device("cpu") != "cpp", "Specifically testing C++ codegen"
)
@unittest.skipIf(IS_JETSON, "Fails on Jetson")
@try_patch_inductor_backend_config(
"cpu", "inject_relu_bug_TESTING_ONLY", "compile_error"
)
def test_after_aot_cpp_compile_error(self):
@inductor_config.patch("cpp.inject_relu_bug_TESTING_ONLY", "compile_error")
def test_after_aot_cpu_compile_error(self):
self._test_after_aot("cpu", "CppCompileError")
@unittest.skipIf(IS_JETSON, "Fails on Jetson")
@try_patch_inductor_backend_config(
"cpu", "inject_relu_bug_TESTING_ONLY", "accuracy"
)
@inductor_config.patch("cpp.inject_relu_bug_TESTING_ONLY", "accuracy")
def test_after_aot_cpu_accuracy_error(self):
self._test_after_aot("cpu", "AccuracyError")
@requires_gpu
@unittest.skipIf(
backend_for_device(GPU_TYPE) != "triton", "Specifically testing Triton codegen"
)
@try_patch_inductor_backend_config(
GPU_TYPE, "inject_relu_bug_TESTING_ONLY", "compile_error"
)
def test_after_aot_triton_compile_error(self):
@inductor_config.patch("triton.inject_relu_bug_TESTING_ONLY", "compile_error")
def test_after_aot_gpu_compile_error(self):
self._test_after_aot(GPU_TYPE, "SyntaxError")
@requires_gpu
@try_patch_inductor_backend_config(
GPU_TYPE, "inject_relu_bug_TESTING_ONLY", "accuracy"
)
@inductor_config.patch("triton.inject_relu_bug_TESTING_ONLY", "accuracy")
def test_after_aot_gpu_accuracy_error(self):
self._test_after_aot(GPU_TYPE, "AccuracyError")
@try_patch_inductor_backend_config(
"cpu", "inject_relu_bug_TESTING_ONLY", "accuracy"
)
@inductor_config.patch("cpp.inject_relu_bug_TESTING_ONLY", "accuracy")
def test_constant_in_graph(self):
run_code = """\
@torch.compile()
@ -86,7 +66,7 @@ inner(torch.randn(2))
self._run_full_test(run_code, "aot", "AccuracyError", isolate=False)
@requires_gpu
@patch.object(inductor_config, "joint_graph_constant_folding", False)
@patch.object(config, "joint_graph_constant_folding", False)
def test_rmse_improves_over_atol(self):
# From https://twitter.com/itsclivetime/status/1651135821045719041?s=20
run_code = """
@ -115,12 +95,8 @@ inner(torch.tensor(655 * 100, dtype=torch.half, device='GPU_TYPE'))
# 655 * 100 precision, and so we report no problem
self._run_full_test(run_code, "aot", None, isolate=False)
@try_patch_inductor_backend_config(
"cpu", "inject_relu_bug_TESTING_ONLY", "accuracy"
)
@try_patch_inductor_backend_config(
"cpu", "inject_log1p_bug_TESTING_ONLY", "accuracy"
)
@inductor_config.patch("cpp.inject_relu_bug_TESTING_ONLY", "accuracy")
@inductor_config.patch("cpp.inject_log1p_bug_TESTING_ONLY", "accuracy")
def test_accuracy_vs_strict_accuracy(self):
run_code = """
@torch.compile()
@ -174,9 +150,7 @@ class Repro(torch.nn.Module):
return (relu,)""",
)
@try_patch_inductor_backend_config(
"cpu", "inject_relu_bug_TESTING_ONLY", "accuracy"
)
@inductor_config.patch("cpp.inject_relu_bug_TESTING_ONLY", "accuracy")
def test_offload_to_disk(self):
# Just a smoketest, this doesn't actually test that memory
# usage went down. Test case is carefully constructed to hit
@ -205,8 +179,6 @@ inner(torch.randn(20, 20))
# NB: The program is intentionally quite simple, just enough to
# trigger one minification step, no more (dedicated minifier tests
# should exercise minifier only)
if get_wrapper_codegen_for_device(device, cpp_wrapper=True) is None:
raise unittest.SkipTest(f"Device {device} does not support c++ wrapper")
run_code = f"""\
class Model(torch.nn.Module):
def __init__(self):
@ -239,8 +211,6 @@ with torch.no_grad():
# NB: The program is intentionally quite simple, just enough to
# trigger one minification step, no more (dedicated minifier tests
# should exercise minifier only)
if get_wrapper_codegen_for_device(device, cpp_wrapper=True) is None:
raise unittest.SkipTest(f"Device {device} does not support c++ wrapper")
# It tests that the minifier can handle unflattened inputs and kwargs
run_code = f"""\
@ -289,73 +259,53 @@ def forward(self, linear):
return pytree.tree_unflatten((relu,), self._out_spec)""",
)
@unittest.skipIf(
backend_for_device("cpu") != "cpp", "Specifically testing C++ codegen"
)
@unittest.skipIf(IS_JETSON, "Fails on Jetson")
@try_patch_inductor_backend_config(
"cpu",
"inject_relu_bug_TESTING_ONLY",
@inductor_config.patch(
"cpp.inject_relu_bug_TESTING_ONLY",
"compile_error",
)
def test_aoti_cpp_compile_error(self):
def test_aoti_cpu_compile_error(self):
res = self._test_aoti("cpu", "CppCompileError")
self._aoti_check_relu_repro(res)
@unittest.skipIf(
backend_for_device("cpu") != "cpp", "Specifically testing C++ codegen"
)
@unittest.skipIf(IS_JETSON, "Fails on Jetson")
@try_patch_inductor_backend_config(
"cpu",
"inject_relu_bug_TESTING_ONLY",
@inductor_config.patch(
"cpp.inject_relu_bug_TESTING_ONLY",
"compile_error",
)
def test_aoti_cpp_compile_error_unflatten(self):
def test_aoti_cpu_compile_error_unflatten(self):
res = self._test_aoti_unflattened_inputs("cpu", "CppCompileError")
self._aoti_check_relu_repro(res)
@requires_gpu
@unittest.skipIf(
backend_for_device(GPU_TYPE) != "triton", "Specifically testing Triton codegen"
)
@skipIfXpu(msg="AOTI for XPU not enabled yet")
@try_patch_inductor_backend_config(
GPU_TYPE,
"inject_relu_bug_TESTING_ONLY",
@inductor_config.patch(
"triton.inject_relu_bug_TESTING_ONLY",
"compile_error",
)
def test_aoti_triton_compile_error(self):
def test_aoti_gpu_compile_error(self):
res = self._test_aoti(GPU_TYPE, "SyntaxError")
self._aoti_check_relu_repro(res)
@requires_gpu
@unittest.skipIf(
backend_for_device(GPU_TYPE) != "triton", "Specifically testing Triton codegen"
)
@skipIfXpu(msg="AOTI for XPU not enabled yet")
@try_patch_inductor_backend_config(
GPU_TYPE,
"inject_relu_bug_TESTING_ONLY",
@inductor_config.patch(
"triton.inject_relu_bug_TESTING_ONLY",
"compile_error",
)
def test_aoti_triton_compile_error_unflatten(self):
def test_aoti_gpu_compile_error_unflatten(self):
res = self._test_aoti_unflattened_inputs(GPU_TYPE, "SyntaxError")
self._aoti_check_relu_repro(res)
@unittest.skipIf(IS_JETSON, "Fails on Jetson")
@try_patch_inductor_backend_config(
"cpu", "inject_relu_bug_TESTING_ONLY", "accuracy"
)
@inductor_config.patch("cpp.inject_relu_bug_TESTING_ONLY", "accuracy")
def test_aoti_cpu_accuracy_error(self):
res = self._test_aoti("cpu", "AccuracyError")
self._aoti_check_relu_repro(res)
@requires_gpu
@skipIfXpu(msg="AOTI for XPU not enabled yet")
@try_patch_inductor_backend_config(
GPU_TYPE, "inject_relu_bug_TESTING_ONLY", "accuracy"
)
@inductor_config.patch("triton.inject_relu_bug_TESTING_ONLY", "accuracy")
def test_aoti_gpu_accuracy_error(self):
res = self._test_aoti(GPU_TYPE, "AccuracyError")
self._aoti_check_relu_repro(res)

View File

@ -148,10 +148,6 @@ class DeviceInterface:
def memory_allocated(device: torch.types.Device = None) -> int:
raise NotImplementedError
@staticmethod
def inductor_backend() -> Optional[str]:
return None
@staticmethod
def is_triton_capable(device: torch.types.Device = None) -> bool:
"""
@ -268,10 +264,6 @@ class CudaInterface(DeviceInterface):
else:
return torch.cuda.get_device_properties(device).gcnArchName.split(":", 1)[0]
@staticmethod
def inductor_backend() -> Optional[str]:
return torch._inductor.config.cuda_backend
@staticmethod
def is_triton_capable(device: torch.types.Device = None) -> bool:
return (
@ -365,10 +357,6 @@ class MtiaInterface(DeviceInterface):
cc = torch.mtia.get_device_capability(device)
return cc
@staticmethod
def inductor_backend() -> Optional[str]:
return "triton"
@staticmethod
def is_triton_capable(device: torch.types.Device = None) -> bool:
return True
@ -452,10 +440,6 @@ class XpuInterface(DeviceInterface):
def is_bf16_supported(including_emulation: bool = False) -> bool:
return torch.xpu.is_bf16_supported()
@staticmethod
def inductor_backend() -> Optional[str]:
return "triton"
@staticmethod
def is_triton_capable(device: torch.types.Device = None) -> bool:
return True
@ -529,10 +513,6 @@ class CpuInterface(DeviceInterface):
if "cpu" not in triton.backends.backends:
raise RuntimeError("triton not built with the 'cpu' backend")
@staticmethod
def inductor_backend() -> Optional[str]:
return torch._inductor.config.cpu_backend
class MpsInterface(DeviceInterface):
@staticmethod
@ -574,10 +554,6 @@ class MpsInterface(DeviceInterface):
def current_device() -> int:
return 0
@staticmethod
def inductor_backend() -> Optional[str]:
return "mps"
device_interfaces: dict[str, type[DeviceInterface]] = {}
_device_initialized = False

View File

@ -100,6 +100,14 @@ class MinifierTestBase(torch._dynamo.test_case.TestCase):
print(f"test_minifier_common tmpdir kept at: {cls.DEBUG_DIR}")
cls._exit_stack.close() # type: ignore[attr-defined]
def _gen_codegen_fn_patch_code(self, device: str, bug_type: str) -> str:
assert bug_type in ("compile_error", "runtime_error", "accuracy")
return f"""\
{torch._dynamo.config.codegen_config()}
{torch._inductor.config.codegen_config()}
torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_TESTING_ONLY = {bug_type!r}
"""
def _maybe_subprocess_run(
self, args: Sequence[Any], *, isolate: bool, cwd: Optional[str] = None
) -> subprocess.CompletedProcess[bytes]:

View File

@ -3099,14 +3099,7 @@ def same(
and math.isnan(res_error)
# Some unit test for the accuracy minifier relies on
# returning false in this case.
and not any(
(
torch._inductor.config.cpp.inject_relu_bug_TESTING_ONLY,
torch._inductor.config.cpp.inject_log1p_bug_TESTING_ONLY,
torch._inductor.config.triton.inject_relu_bug_TESTING_ONLY,
torch._inductor.config.triton.inject_log1p_bug_TESTING_ONLY,
)
)
and not torch._inductor.config.cpp.inject_relu_bug_TESTING_ONLY
):
passes_test = True
if not passes_test:

View File

@ -1296,21 +1296,7 @@ class TritonOverrides(OpOverrides):
@staticmethod
@maybe_upcast_float32()
def log1p(x):
bug = config.triton.inject_log1p_bug_TESTING_ONLY
if bug == "compile_error":
return "compile error!"
elif bug == "runtime_error":
# NB: this only triggers runtime error as long as input
# is not all zero
return f'triton_helpers.device_assert_then({x} == 0, "injected assert fail", {x})'
elif bug == "accuracy":
return f"{x} + 1"
elif bug is None:
return f"libdevice.log1p({x})"
else:
raise AssertionError(
f"unrecognized config triton.inject_log1p_bug_TESTING_ONLY = {bug!r}"
)
return f"libdevice.log1p({x})"
@staticmethod
@maybe_upcast_float32()

View File

@ -1383,7 +1383,6 @@ class triton:
# extraction and minification functionality.
# Valid values: "compile_error", "runtime_error", "accuracy"
inject_relu_bug_TESTING_ONLY: Optional[str] = None
inject_log1p_bug_TESTING_ONLY: Optional[str] = None
# Whether to upcast float16 / bfloat16 to float32 in triton codegen (Experimental)
codegen_upcast_to_fp32 = True

View File

@ -109,7 +109,7 @@ def get_gpu_type() -> str:
return gpu_type
from torch._dynamo.device_interface import DeviceInterface, get_interface_for_device
from torch._dynamo.device_interface import get_interface_for_device
from torch._dynamo.utils import detect_fake_mode
from torch.autograd import DeviceType
from torch.autograd.profiler_util import EventList
@ -3154,16 +3154,15 @@ def register_op_requires_libdevice_fp64(name: str) -> None:
def get_current_backend() -> str:
"""Get the codegen backend for the current graph, or throw."""
from torch._inductor.virtualized import V
device: torch.device = V.graph.get_current_device_or_throw()
device_interface: type[DeviceInterface] = get_interface_for_device(device.type)
device_inductor_backend: Optional[str] = device_interface.inductor_backend()
if device_inductor_backend is None:
raise ValueError(f"Couldn't get an Inductor backend for device {device.type}")
return device_inductor_backend
device_str = V.graph.get_current_device_or_throw().type
if device_str == "cpu":
return config.cpu_backend
elif device_str == "mps":
return "mps"
else:
return config.cuda_backend
def upcast_compute_type(dtype: torch.dtype) -> torch.dtype:

View File

@ -9,10 +9,8 @@ import contextlib
import os
from subprocess import CalledProcessError
import sys
from typing import Any, Optional
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
from torch.fx.experimental.proxy_tensor import make_fx
from torch._dynamo.device_interface import get_interface_for_device
from torch._inductor.graph import GraphLowering
from torch._inductor.compile_fx import shape_env_from_inputs
from torch._inductor.codecache import CppCodeCache
@ -37,6 +35,8 @@ from torch.testing._internal.common_device_type import (
from torch.testing._internal.common_utils import (
LazyVal,
IS_FBCODE,
)
from torch.testing._internal.common_utils import (
TestCase,
IS_CI,
IS_WINDOWS,
@ -347,58 +347,3 @@ def patch_inductor_backend(
original_custom_pass,
original_custom_backend_config
)
def backend_for_device(device: str) -> Optional[str]:
""" Get the Inductor codegen backend used for the device ``device``. """
if dev_int := get_interface_for_device(device):
return dev_int.inductor_backend()
return None
def try_patch_inductor_backend_config(device: str, key: str,
value: Any) -> contextlib.ContextDecorator:
"""
Try to patch the backend-specific Inductor options, for the codegen backend
corresponding to the given ``device``. If that config can't be found to
patch, skip the test.
Will patch the member of the global ``config.$BACKEND``, if it exists. If
the given device also specifies a custom config module, will also try to
patch its ``$BACKEND`` member if it exists.
"""
device_backend = backend_for_device(device)
if device_backend is None:
return unittest.skip(
f"Can't patch Inductor config {key} for device {device}")
config_modules = [torch._inductor.config]
if custom_config_module := get_custom_backend_config_for_device(device):
config_modules.append(custom_config_module)
contexts: list[contextlib.ContextDecorator] = []
for mod in config_modules:
if (
hasattr(mod, f"{device_backend}")
and hasattr(mod, f"{device_backend}.{key}")
):
contexts.append(mod.patch(f"{device_backend}.{key}", value))
if len(contexts) == 0:
return unittest.skip(
f"Can't patch Inductor config {key} for device {device}")
class ContextStack(contextlib.ContextDecorator):
def __init__(self, contexts: list[contextlib.ContextDecorator]) -> None:
self.contexts: list[contextlib.ContextDecorator] = contexts
def __enter__(self) -> None:
for cd in self.contexts:
cd.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore[no-untyped-def]
for cd in self.contexts:
cd.__exit__(exc_type, exc_val, exc_tb)
return ContextStack(contexts)