mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Add inductor backend to device interface; make minifier_tests more device agnostic (#151314)"
This reverts commit 77bc959fe122bfd131e339ca36cab445a1860806. Reverted https://github.com/pytorch/pytorch/pull/151314 on behalf of https://github.com/atalman due to sorry change is faling internally ([comment](https://github.com/pytorch/pytorch/pull/151314#issuecomment-3229774015))
This commit is contained in:
@ -5,7 +5,7 @@ from unittest.mock import patch
|
||||
import torch._dynamo.config as dynamo_config
|
||||
import torch._inductor.config as inductor_config
|
||||
from torch._dynamo.test_minifier_common import MinifierTestBase
|
||||
from torch._inductor.codegen.common import get_wrapper_codegen_for_device
|
||||
from torch._inductor import config
|
||||
from torch.export import load as export_load
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_JETSON,
|
||||
@ -13,11 +13,7 @@ from torch.testing._internal.common_utils import (
|
||||
skipIfXpu,
|
||||
TEST_WITH_ASAN,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import (
|
||||
backend_for_device,
|
||||
GPU_TYPE,
|
||||
try_patch_inductor_backend_config,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE
|
||||
from torch.testing._internal.triton_utils import requires_gpu
|
||||
|
||||
|
||||
@ -38,43 +34,27 @@ inner(torch.randn(20, 20).to("{device}"))
|
||||
"""
|
||||
self._run_full_test(run_code, "aot", expected_error, isolate=False)
|
||||
|
||||
@unittest.skipIf(
|
||||
backend_for_device("cpu") != "cpp", "Specifically testing C++ codegen"
|
||||
)
|
||||
@unittest.skipIf(IS_JETSON, "Fails on Jetson")
|
||||
@try_patch_inductor_backend_config(
|
||||
"cpu", "inject_relu_bug_TESTING_ONLY", "compile_error"
|
||||
)
|
||||
def test_after_aot_cpp_compile_error(self):
|
||||
@inductor_config.patch("cpp.inject_relu_bug_TESTING_ONLY", "compile_error")
|
||||
def test_after_aot_cpu_compile_error(self):
|
||||
self._test_after_aot("cpu", "CppCompileError")
|
||||
|
||||
@unittest.skipIf(IS_JETSON, "Fails on Jetson")
|
||||
@try_patch_inductor_backend_config(
|
||||
"cpu", "inject_relu_bug_TESTING_ONLY", "accuracy"
|
||||
)
|
||||
@inductor_config.patch("cpp.inject_relu_bug_TESTING_ONLY", "accuracy")
|
||||
def test_after_aot_cpu_accuracy_error(self):
|
||||
self._test_after_aot("cpu", "AccuracyError")
|
||||
|
||||
@requires_gpu
|
||||
@unittest.skipIf(
|
||||
backend_for_device(GPU_TYPE) != "triton", "Specifically testing Triton codegen"
|
||||
)
|
||||
@try_patch_inductor_backend_config(
|
||||
GPU_TYPE, "inject_relu_bug_TESTING_ONLY", "compile_error"
|
||||
)
|
||||
def test_after_aot_triton_compile_error(self):
|
||||
@inductor_config.patch("triton.inject_relu_bug_TESTING_ONLY", "compile_error")
|
||||
def test_after_aot_gpu_compile_error(self):
|
||||
self._test_after_aot(GPU_TYPE, "SyntaxError")
|
||||
|
||||
@requires_gpu
|
||||
@try_patch_inductor_backend_config(
|
||||
GPU_TYPE, "inject_relu_bug_TESTING_ONLY", "accuracy"
|
||||
)
|
||||
@inductor_config.patch("triton.inject_relu_bug_TESTING_ONLY", "accuracy")
|
||||
def test_after_aot_gpu_accuracy_error(self):
|
||||
self._test_after_aot(GPU_TYPE, "AccuracyError")
|
||||
|
||||
@try_patch_inductor_backend_config(
|
||||
"cpu", "inject_relu_bug_TESTING_ONLY", "accuracy"
|
||||
)
|
||||
@inductor_config.patch("cpp.inject_relu_bug_TESTING_ONLY", "accuracy")
|
||||
def test_constant_in_graph(self):
|
||||
run_code = """\
|
||||
@torch.compile()
|
||||
@ -86,7 +66,7 @@ inner(torch.randn(2))
|
||||
self._run_full_test(run_code, "aot", "AccuracyError", isolate=False)
|
||||
|
||||
@requires_gpu
|
||||
@patch.object(inductor_config, "joint_graph_constant_folding", False)
|
||||
@patch.object(config, "joint_graph_constant_folding", False)
|
||||
def test_rmse_improves_over_atol(self):
|
||||
# From https://twitter.com/itsclivetime/status/1651135821045719041?s=20
|
||||
run_code = """
|
||||
@ -115,12 +95,8 @@ inner(torch.tensor(655 * 100, dtype=torch.half, device='GPU_TYPE'))
|
||||
# 655 * 100 precision, and so we report no problem
|
||||
self._run_full_test(run_code, "aot", None, isolate=False)
|
||||
|
||||
@try_patch_inductor_backend_config(
|
||||
"cpu", "inject_relu_bug_TESTING_ONLY", "accuracy"
|
||||
)
|
||||
@try_patch_inductor_backend_config(
|
||||
"cpu", "inject_log1p_bug_TESTING_ONLY", "accuracy"
|
||||
)
|
||||
@inductor_config.patch("cpp.inject_relu_bug_TESTING_ONLY", "accuracy")
|
||||
@inductor_config.patch("cpp.inject_log1p_bug_TESTING_ONLY", "accuracy")
|
||||
def test_accuracy_vs_strict_accuracy(self):
|
||||
run_code = """
|
||||
@torch.compile()
|
||||
@ -174,9 +150,7 @@ class Repro(torch.nn.Module):
|
||||
return (relu,)""",
|
||||
)
|
||||
|
||||
@try_patch_inductor_backend_config(
|
||||
"cpu", "inject_relu_bug_TESTING_ONLY", "accuracy"
|
||||
)
|
||||
@inductor_config.patch("cpp.inject_relu_bug_TESTING_ONLY", "accuracy")
|
||||
def test_offload_to_disk(self):
|
||||
# Just a smoketest, this doesn't actually test that memory
|
||||
# usage went down. Test case is carefully constructed to hit
|
||||
@ -205,8 +179,6 @@ inner(torch.randn(20, 20))
|
||||
# NB: The program is intentionally quite simple, just enough to
|
||||
# trigger one minification step, no more (dedicated minifier tests
|
||||
# should exercise minifier only)
|
||||
if get_wrapper_codegen_for_device(device, cpp_wrapper=True) is None:
|
||||
raise unittest.SkipTest(f"Device {device} does not support c++ wrapper")
|
||||
run_code = f"""\
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -239,8 +211,6 @@ with torch.no_grad():
|
||||
# NB: The program is intentionally quite simple, just enough to
|
||||
# trigger one minification step, no more (dedicated minifier tests
|
||||
# should exercise minifier only)
|
||||
if get_wrapper_codegen_for_device(device, cpp_wrapper=True) is None:
|
||||
raise unittest.SkipTest(f"Device {device} does not support c++ wrapper")
|
||||
|
||||
# It tests that the minifier can handle unflattened inputs and kwargs
|
||||
run_code = f"""\
|
||||
@ -289,73 +259,53 @@ def forward(self, linear):
|
||||
return pytree.tree_unflatten((relu,), self._out_spec)""",
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
backend_for_device("cpu") != "cpp", "Specifically testing C++ codegen"
|
||||
)
|
||||
@unittest.skipIf(IS_JETSON, "Fails on Jetson")
|
||||
@try_patch_inductor_backend_config(
|
||||
"cpu",
|
||||
"inject_relu_bug_TESTING_ONLY",
|
||||
@inductor_config.patch(
|
||||
"cpp.inject_relu_bug_TESTING_ONLY",
|
||||
"compile_error",
|
||||
)
|
||||
def test_aoti_cpp_compile_error(self):
|
||||
def test_aoti_cpu_compile_error(self):
|
||||
res = self._test_aoti("cpu", "CppCompileError")
|
||||
self._aoti_check_relu_repro(res)
|
||||
|
||||
@unittest.skipIf(
|
||||
backend_for_device("cpu") != "cpp", "Specifically testing C++ codegen"
|
||||
)
|
||||
@unittest.skipIf(IS_JETSON, "Fails on Jetson")
|
||||
@try_patch_inductor_backend_config(
|
||||
"cpu",
|
||||
"inject_relu_bug_TESTING_ONLY",
|
||||
@inductor_config.patch(
|
||||
"cpp.inject_relu_bug_TESTING_ONLY",
|
||||
"compile_error",
|
||||
)
|
||||
def test_aoti_cpp_compile_error_unflatten(self):
|
||||
def test_aoti_cpu_compile_error_unflatten(self):
|
||||
res = self._test_aoti_unflattened_inputs("cpu", "CppCompileError")
|
||||
self._aoti_check_relu_repro(res)
|
||||
|
||||
@requires_gpu
|
||||
@unittest.skipIf(
|
||||
backend_for_device(GPU_TYPE) != "triton", "Specifically testing Triton codegen"
|
||||
)
|
||||
@skipIfXpu(msg="AOTI for XPU not enabled yet")
|
||||
@try_patch_inductor_backend_config(
|
||||
GPU_TYPE,
|
||||
"inject_relu_bug_TESTING_ONLY",
|
||||
@inductor_config.patch(
|
||||
"triton.inject_relu_bug_TESTING_ONLY",
|
||||
"compile_error",
|
||||
)
|
||||
def test_aoti_triton_compile_error(self):
|
||||
def test_aoti_gpu_compile_error(self):
|
||||
res = self._test_aoti(GPU_TYPE, "SyntaxError")
|
||||
self._aoti_check_relu_repro(res)
|
||||
|
||||
@requires_gpu
|
||||
@unittest.skipIf(
|
||||
backend_for_device(GPU_TYPE) != "triton", "Specifically testing Triton codegen"
|
||||
)
|
||||
@skipIfXpu(msg="AOTI for XPU not enabled yet")
|
||||
@try_patch_inductor_backend_config(
|
||||
GPU_TYPE,
|
||||
"inject_relu_bug_TESTING_ONLY",
|
||||
@inductor_config.patch(
|
||||
"triton.inject_relu_bug_TESTING_ONLY",
|
||||
"compile_error",
|
||||
)
|
||||
def test_aoti_triton_compile_error_unflatten(self):
|
||||
def test_aoti_gpu_compile_error_unflatten(self):
|
||||
res = self._test_aoti_unflattened_inputs(GPU_TYPE, "SyntaxError")
|
||||
self._aoti_check_relu_repro(res)
|
||||
|
||||
@unittest.skipIf(IS_JETSON, "Fails on Jetson")
|
||||
@try_patch_inductor_backend_config(
|
||||
"cpu", "inject_relu_bug_TESTING_ONLY", "accuracy"
|
||||
)
|
||||
@inductor_config.patch("cpp.inject_relu_bug_TESTING_ONLY", "accuracy")
|
||||
def test_aoti_cpu_accuracy_error(self):
|
||||
res = self._test_aoti("cpu", "AccuracyError")
|
||||
self._aoti_check_relu_repro(res)
|
||||
|
||||
@requires_gpu
|
||||
@skipIfXpu(msg="AOTI for XPU not enabled yet")
|
||||
@try_patch_inductor_backend_config(
|
||||
GPU_TYPE, "inject_relu_bug_TESTING_ONLY", "accuracy"
|
||||
)
|
||||
@inductor_config.patch("triton.inject_relu_bug_TESTING_ONLY", "accuracy")
|
||||
def test_aoti_gpu_accuracy_error(self):
|
||||
res = self._test_aoti(GPU_TYPE, "AccuracyError")
|
||||
self._aoti_check_relu_repro(res)
|
||||
|
@ -148,10 +148,6 @@ class DeviceInterface:
|
||||
def memory_allocated(device: torch.types.Device = None) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def inductor_backend() -> Optional[str]:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def is_triton_capable(device: torch.types.Device = None) -> bool:
|
||||
"""
|
||||
@ -268,10 +264,6 @@ class CudaInterface(DeviceInterface):
|
||||
else:
|
||||
return torch.cuda.get_device_properties(device).gcnArchName.split(":", 1)[0]
|
||||
|
||||
@staticmethod
|
||||
def inductor_backend() -> Optional[str]:
|
||||
return torch._inductor.config.cuda_backend
|
||||
|
||||
@staticmethod
|
||||
def is_triton_capable(device: torch.types.Device = None) -> bool:
|
||||
return (
|
||||
@ -365,10 +357,6 @@ class MtiaInterface(DeviceInterface):
|
||||
cc = torch.mtia.get_device_capability(device)
|
||||
return cc
|
||||
|
||||
@staticmethod
|
||||
def inductor_backend() -> Optional[str]:
|
||||
return "triton"
|
||||
|
||||
@staticmethod
|
||||
def is_triton_capable(device: torch.types.Device = None) -> bool:
|
||||
return True
|
||||
@ -452,10 +440,6 @@ class XpuInterface(DeviceInterface):
|
||||
def is_bf16_supported(including_emulation: bool = False) -> bool:
|
||||
return torch.xpu.is_bf16_supported()
|
||||
|
||||
@staticmethod
|
||||
def inductor_backend() -> Optional[str]:
|
||||
return "triton"
|
||||
|
||||
@staticmethod
|
||||
def is_triton_capable(device: torch.types.Device = None) -> bool:
|
||||
return True
|
||||
@ -529,10 +513,6 @@ class CpuInterface(DeviceInterface):
|
||||
if "cpu" not in triton.backends.backends:
|
||||
raise RuntimeError("triton not built with the 'cpu' backend")
|
||||
|
||||
@staticmethod
|
||||
def inductor_backend() -> Optional[str]:
|
||||
return torch._inductor.config.cpu_backend
|
||||
|
||||
|
||||
class MpsInterface(DeviceInterface):
|
||||
@staticmethod
|
||||
@ -574,10 +554,6 @@ class MpsInterface(DeviceInterface):
|
||||
def current_device() -> int:
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def inductor_backend() -> Optional[str]:
|
||||
return "mps"
|
||||
|
||||
|
||||
device_interfaces: dict[str, type[DeviceInterface]] = {}
|
||||
_device_initialized = False
|
||||
|
@ -100,6 +100,14 @@ class MinifierTestBase(torch._dynamo.test_case.TestCase):
|
||||
print(f"test_minifier_common tmpdir kept at: {cls.DEBUG_DIR}")
|
||||
cls._exit_stack.close() # type: ignore[attr-defined]
|
||||
|
||||
def _gen_codegen_fn_patch_code(self, device: str, bug_type: str) -> str:
|
||||
assert bug_type in ("compile_error", "runtime_error", "accuracy")
|
||||
return f"""\
|
||||
{torch._dynamo.config.codegen_config()}
|
||||
{torch._inductor.config.codegen_config()}
|
||||
torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_TESTING_ONLY = {bug_type!r}
|
||||
"""
|
||||
|
||||
def _maybe_subprocess_run(
|
||||
self, args: Sequence[Any], *, isolate: bool, cwd: Optional[str] = None
|
||||
) -> subprocess.CompletedProcess[bytes]:
|
||||
|
@ -3099,14 +3099,7 @@ def same(
|
||||
and math.isnan(res_error)
|
||||
# Some unit test for the accuracy minifier relies on
|
||||
# returning false in this case.
|
||||
and not any(
|
||||
(
|
||||
torch._inductor.config.cpp.inject_relu_bug_TESTING_ONLY,
|
||||
torch._inductor.config.cpp.inject_log1p_bug_TESTING_ONLY,
|
||||
torch._inductor.config.triton.inject_relu_bug_TESTING_ONLY,
|
||||
torch._inductor.config.triton.inject_log1p_bug_TESTING_ONLY,
|
||||
)
|
||||
)
|
||||
and not torch._inductor.config.cpp.inject_relu_bug_TESTING_ONLY
|
||||
):
|
||||
passes_test = True
|
||||
if not passes_test:
|
||||
|
@ -1296,21 +1296,7 @@ class TritonOverrides(OpOverrides):
|
||||
@staticmethod
|
||||
@maybe_upcast_float32()
|
||||
def log1p(x):
|
||||
bug = config.triton.inject_log1p_bug_TESTING_ONLY
|
||||
if bug == "compile_error":
|
||||
return "compile error!"
|
||||
elif bug == "runtime_error":
|
||||
# NB: this only triggers runtime error as long as input
|
||||
# is not all zero
|
||||
return f'triton_helpers.device_assert_then({x} == 0, "injected assert fail", {x})'
|
||||
elif bug == "accuracy":
|
||||
return f"{x} + 1"
|
||||
elif bug is None:
|
||||
return f"libdevice.log1p({x})"
|
||||
else:
|
||||
raise AssertionError(
|
||||
f"unrecognized config triton.inject_log1p_bug_TESTING_ONLY = {bug!r}"
|
||||
)
|
||||
return f"libdevice.log1p({x})"
|
||||
|
||||
@staticmethod
|
||||
@maybe_upcast_float32()
|
||||
|
@ -1383,7 +1383,6 @@ class triton:
|
||||
# extraction and minification functionality.
|
||||
# Valid values: "compile_error", "runtime_error", "accuracy"
|
||||
inject_relu_bug_TESTING_ONLY: Optional[str] = None
|
||||
inject_log1p_bug_TESTING_ONLY: Optional[str] = None
|
||||
|
||||
# Whether to upcast float16 / bfloat16 to float32 in triton codegen (Experimental)
|
||||
codegen_upcast_to_fp32 = True
|
||||
|
@ -109,7 +109,7 @@ def get_gpu_type() -> str:
|
||||
return gpu_type
|
||||
|
||||
|
||||
from torch._dynamo.device_interface import DeviceInterface, get_interface_for_device
|
||||
from torch._dynamo.device_interface import get_interface_for_device
|
||||
from torch._dynamo.utils import detect_fake_mode
|
||||
from torch.autograd import DeviceType
|
||||
from torch.autograd.profiler_util import EventList
|
||||
@ -3154,16 +3154,15 @@ def register_op_requires_libdevice_fp64(name: str) -> None:
|
||||
|
||||
|
||||
def get_current_backend() -> str:
|
||||
"""Get the codegen backend for the current graph, or throw."""
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
device: torch.device = V.graph.get_current_device_or_throw()
|
||||
device_interface: type[DeviceInterface] = get_interface_for_device(device.type)
|
||||
|
||||
device_inductor_backend: Optional[str] = device_interface.inductor_backend()
|
||||
if device_inductor_backend is None:
|
||||
raise ValueError(f"Couldn't get an Inductor backend for device {device.type}")
|
||||
return device_inductor_backend
|
||||
device_str = V.graph.get_current_device_or_throw().type
|
||||
if device_str == "cpu":
|
||||
return config.cpu_backend
|
||||
elif device_str == "mps":
|
||||
return "mps"
|
||||
else:
|
||||
return config.cuda_backend
|
||||
|
||||
|
||||
def upcast_compute_type(dtype: torch.dtype) -> torch.dtype:
|
||||
|
@ -9,10 +9,8 @@ import contextlib
|
||||
import os
|
||||
from subprocess import CalledProcessError
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._dynamo.device_interface import get_interface_for_device
|
||||
from torch._inductor.graph import GraphLowering
|
||||
from torch._inductor.compile_fx import shape_env_from_inputs
|
||||
from torch._inductor.codecache import CppCodeCache
|
||||
@ -37,6 +35,8 @@ from torch.testing._internal.common_device_type import (
|
||||
from torch.testing._internal.common_utils import (
|
||||
LazyVal,
|
||||
IS_FBCODE,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
IS_CI,
|
||||
IS_WINDOWS,
|
||||
@ -347,58 +347,3 @@ def patch_inductor_backend(
|
||||
original_custom_pass,
|
||||
original_custom_backend_config
|
||||
)
|
||||
|
||||
def backend_for_device(device: str) -> Optional[str]:
|
||||
""" Get the Inductor codegen backend used for the device ``device``. """
|
||||
if dev_int := get_interface_for_device(device):
|
||||
return dev_int.inductor_backend()
|
||||
return None
|
||||
|
||||
def try_patch_inductor_backend_config(device: str, key: str,
|
||||
value: Any) -> contextlib.ContextDecorator:
|
||||
"""
|
||||
Try to patch the backend-specific Inductor options, for the codegen backend
|
||||
corresponding to the given ``device``. If that config can't be found to
|
||||
patch, skip the test.
|
||||
|
||||
Will patch the member of the global ``config.$BACKEND``, if it exists. If
|
||||
the given device also specifies a custom config module, will also try to
|
||||
patch its ``$BACKEND`` member if it exists.
|
||||
|
||||
"""
|
||||
device_backend = backend_for_device(device)
|
||||
|
||||
if device_backend is None:
|
||||
return unittest.skip(
|
||||
f"Can't patch Inductor config {key} for device {device}")
|
||||
|
||||
config_modules = [torch._inductor.config]
|
||||
if custom_config_module := get_custom_backend_config_for_device(device):
|
||||
config_modules.append(custom_config_module)
|
||||
|
||||
contexts: list[contextlib.ContextDecorator] = []
|
||||
|
||||
for mod in config_modules:
|
||||
if (
|
||||
hasattr(mod, f"{device_backend}")
|
||||
and hasattr(mod, f"{device_backend}.{key}")
|
||||
):
|
||||
contexts.append(mod.patch(f"{device_backend}.{key}", value))
|
||||
|
||||
if len(contexts) == 0:
|
||||
return unittest.skip(
|
||||
f"Can't patch Inductor config {key} for device {device}")
|
||||
|
||||
class ContextStack(contextlib.ContextDecorator):
|
||||
def __init__(self, contexts: list[contextlib.ContextDecorator]) -> None:
|
||||
self.contexts: list[contextlib.ContextDecorator] = contexts
|
||||
|
||||
def __enter__(self) -> None:
|
||||
for cd in self.contexts:
|
||||
cd.__enter__()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore[no-untyped-def]
|
||||
for cd in self.contexts:
|
||||
cd.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
return ContextStack(contexts)
|
||||
|
Reference in New Issue
Block a user