mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
# Feature Inductor sometimes uses `Identity` functions to group various terms of an expression. While this is convenient in some scenarios, it can frustrate pattern matching. For example, when we're matching an indexing expression to tell if it can be represented as a block pointer, that analysis should be invariant to `Identity`'s. This PR adds a few features to achieve this invariance. - Create a new expansion mode `expr.expand(identity=True)`, which removes all `Identity` functions from the expression. - Preprocess the expression with this expansion prior to pattern matching. - Bonus: create a new test utility function called `dummy_graph()`, which creates a simple `GraphLowering`. This is useful for testing the pattern matcher, as we need to initialize `V.graph` before we can access `V.graph.sizevars`. # Test plan This PR adds a few new unit tests: - Added a unit test specifically for `expr.expand(identity=True)`. - Added a new unit test module for the block pattern matcher. Tested that we can correctly match some example patterns containing Identity ops. I originally intended to add an end to end test compiling pointwise cat, and mapping the corresponding memory accesses to block pointers. However, it looks like that will take more work, since the [relevant code path](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/triton.py#L1306) disables block pointer analysis. It might be better to defer that to a future PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146000 Approved by: https://github.com/eellison, https://github.com/jansel
198 lines
5.6 KiB
Python
198 lines
5.6 KiB
Python
# mypy: ignore-errors
|
|
|
|
import logging
|
|
import torch
|
|
import re
|
|
import unittest
|
|
import functools
|
|
import contextlib
|
|
import os
|
|
from subprocess import CalledProcessError
|
|
import sys
|
|
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch._inductor.graph import GraphLowering
|
|
from torch._inductor.compile_fx import shape_env_from_inputs
|
|
from torch._inductor.codecache import CppCodeCache
|
|
from torch._inductor.utils import get_gpu_shared_memory, is_big_gpu
|
|
from torch._inductor.utils import GPU_TYPES, get_gpu_type
|
|
from torch.utils._triton import has_triton
|
|
from torch.testing._internal.common_utils import (
|
|
LazyVal,
|
|
IS_FBCODE,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
TestCase,
|
|
IS_CI,
|
|
IS_WINDOWS,
|
|
)
|
|
|
|
log: logging.Logger = logging.getLogger(__name__)
|
|
|
|
def test_cpu():
|
|
try:
|
|
CppCodeCache.load("")
|
|
return not IS_FBCODE
|
|
except (
|
|
CalledProcessError,
|
|
OSError,
|
|
torch._inductor.exc.InvalidCxxCompiler,
|
|
torch._inductor.exc.CppCompileError,
|
|
):
|
|
return False
|
|
|
|
HAS_CPU = LazyVal(test_cpu)
|
|
|
|
HAS_TRITON = has_triton()
|
|
|
|
if HAS_TRITON:
|
|
import triton
|
|
TRITON_HAS_CPU = "cpu" in triton.backends.backends
|
|
else:
|
|
TRITON_HAS_CPU = False
|
|
|
|
|
|
HAS_CUDA = torch.cuda.is_available() and HAS_TRITON
|
|
|
|
HAS_XPU = torch.xpu.is_available() and HAS_TRITON
|
|
|
|
HAS_MPS = torch.mps.is_available()
|
|
|
|
HAS_GPU = HAS_CUDA or HAS_XPU
|
|
|
|
GPU_TYPE = get_gpu_type()
|
|
|
|
HAS_MULTIGPU = any(
|
|
getattr(torch, gpu).is_available() and getattr(torch, gpu).device_count() >= 2
|
|
for gpu in GPU_TYPES
|
|
)
|
|
|
|
def _check_has_dynamic_shape(
|
|
self: TestCase,
|
|
code,
|
|
):
|
|
for_loop_found = False
|
|
has_dynamic = False
|
|
lines = code.split("\n")
|
|
for line in lines:
|
|
if "for(" in line:
|
|
for_loop_found = True
|
|
if re.search(r";.*ks.*;", line) is not None:
|
|
has_dynamic = True
|
|
break
|
|
self.assertTrue(
|
|
has_dynamic, msg=f"Failed to find dynamic for loop variable\n{code}"
|
|
)
|
|
self.assertTrue(for_loop_found, f"Failed to find for loop\n{code}")
|
|
|
|
|
|
def skipDeviceIf(cond, msg, *, device):
|
|
if cond:
|
|
def decorate_fn(fn):
|
|
@functools.wraps(fn)
|
|
def inner(self, *args, **kwargs):
|
|
if not hasattr(self, "device"):
|
|
warn_msg = "Expect the test class to have attribute device but not found. "
|
|
if hasattr(self, "device_type"):
|
|
warn_msg += "Consider using the skip device decorators in common_device_type.py"
|
|
log.warning(warn_msg)
|
|
if self.device == device:
|
|
raise unittest.SkipTest(msg)
|
|
return fn(self, *args, **kwargs)
|
|
return inner
|
|
else:
|
|
def decorate_fn(fn):
|
|
return fn
|
|
|
|
return decorate_fn
|
|
|
|
def skip_windows_ci(name: str, file: str) -> None:
|
|
if IS_WINDOWS and IS_CI:
|
|
module = os.path.basename(file).strip(".py")
|
|
sys.stderr.write(
|
|
f"Windows CI does not have necessary dependencies for {module} tests yet\n"
|
|
)
|
|
if name == "__main__":
|
|
sys.exit(0)
|
|
raise unittest.SkipTest("requires sympy/functorch/filelock")
|
|
|
|
# TODO: Remove HAS_MPS condition when `HAS_GPU` includes HAS_MPS
|
|
requires_gpu = functools.partial(unittest.skipIf, not (HAS_GPU or HAS_MPS), "requires gpu")
|
|
requires_triton = functools.partial(unittest.skipIf, not HAS_TRITON, "requires triton")
|
|
|
|
def requires_cuda_with_enough_memory(min_mem_required):
|
|
def inner(fn):
|
|
if not torch.cuda.is_available() or torch.cuda.get_device_properties().total_memory < min_mem_required:
|
|
return unittest.skip(f"Only if the CUDA device has at least {min_mem_required / 1e9:.3f}GB memory to be safe")(fn)
|
|
else:
|
|
return fn
|
|
|
|
return inner
|
|
|
|
skipCUDAIf = functools.partial(skipDeviceIf, device="cuda")
|
|
skipXPUIf = functools.partial(skipDeviceIf, device="xpu")
|
|
skipCPUIf = functools.partial(skipDeviceIf, device="cpu")
|
|
|
|
IS_A100 = LazyVal(
|
|
lambda: HAS_CUDA
|
|
and get_gpu_shared_memory() == 166912
|
|
)
|
|
|
|
IS_H100 = LazyVal(
|
|
lambda: HAS_CUDA
|
|
and get_gpu_shared_memory() == 232448
|
|
)
|
|
|
|
IS_BIG_GPU = LazyVal(lambda: HAS_CUDA and is_big_gpu())
|
|
|
|
def dummy_graph() -> GraphLowering:
|
|
"""
|
|
Create a graph. This is useful for unit testing code which accesses
|
|
V.graph.sizevars.
|
|
"""
|
|
example_inputs = [torch.randn(10) for _ in range(2)]
|
|
gm = make_fx(torch.add, tracing_mode="fake")(*example_inputs)
|
|
shape_env = shape_env_from_inputs(example_inputs)
|
|
graph = GraphLowering(
|
|
gm,
|
|
shape_env=shape_env,
|
|
)
|
|
|
|
return graph
|
|
|
|
def maybe_skip_size_asserts(op):
|
|
"""
|
|
For certain ops, there meta and eager implementation returns differents
|
|
strides. This cause size/strides assert fail. Skip adding those
|
|
asserts for now.
|
|
"""
|
|
if (
|
|
op.aten_name
|
|
in (
|
|
"fft_hfftn",
|
|
"fft_hfft",
|
|
"fft_hfft2",
|
|
"fft_ihfftn",
|
|
"fft_fft",
|
|
"fft_fft2",
|
|
"fft_fftn",
|
|
"fft_ifft",
|
|
"fft_ifft2",
|
|
"fft_ifftn",
|
|
"fft_irfft",
|
|
"fft_irfft2",
|
|
"fft_irfftn",
|
|
"fft_ihfft",
|
|
"fft_ihfft2",
|
|
"fft_rfft",
|
|
"fft_rfft2",
|
|
"fft_rfftn",
|
|
"linalg_eig",
|
|
"linalg_eigvals",
|
|
)
|
|
and "TORCHINDUCTOR_SIZE_ASSERTS" not in os.environ
|
|
):
|
|
return torch._inductor.config.patch(size_asserts=False)
|
|
else:
|
|
return contextlib.nullcontext()
|