mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Compute bounds for the variables created during codegen (#123100)
Before we would just bail out on these bounds for all variables that did not come from the FX graph. Now we propagate the bounds whenever we have a rule for that op. Pull Request resolved: https://github.com/pytorch/pytorch/pull/123100 Approved by: https://github.com/jgong5, https://github.com/peterbell10
This commit is contained in:
committed by
PyTorch MergeBot
parent
15a9770225
commit
320af5eaa6
@ -9616,7 +9616,12 @@ class CommonTemplate:
|
||||
# This used to not compile due to a wrong return type of randint64_cpu
|
||||
# See https://github.com/pytorch/pytorch/issues/117435
|
||||
def fn(n):
|
||||
return torch.randint(low=-5, high=5, size=(n,), dtype=torch.int64) % 10
|
||||
return (
|
||||
torch.randint(
|
||||
low=-5, high=5, size=(n,), dtype=torch.int64, device=self.device
|
||||
)
|
||||
% 10
|
||||
)
|
||||
|
||||
res = torch.compile(fn)(20)
|
||||
self.assertTrue(torch.all((0 <= res) & (res < 10)).item())
|
||||
|
@ -27,7 +27,7 @@ import torch.fx
|
||||
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
|
||||
from torch.utils import _pytree as pytree
|
||||
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
|
||||
|
||||
from .. import config, metrics
|
||||
from ..utils import DeferredLineBase, IndentedBuffer, sympy_dot, sympy_subs, unique
|
||||
@ -269,6 +269,7 @@ class DataTypePropagation:
|
||||
if node.target in (
|
||||
"get_index",
|
||||
"index_expr",
|
||||
"randint64",
|
||||
):
|
||||
return torch.int64
|
||||
|
||||
@ -529,7 +530,7 @@ class OpOverrides:
|
||||
|
||||
@staticmethod
|
||||
def reciprocal(x):
|
||||
return ops.truediv("1", x)
|
||||
return ops.truediv(ops.constant(1, torch.int32), x)
|
||||
|
||||
@staticmethod
|
||||
def square(x):
|
||||
@ -566,7 +567,11 @@ class OpOverrides:
|
||||
@staticmethod
|
||||
def remainder(a, b):
|
||||
r = ops.mod(a, b)
|
||||
return ops.where(f"(({r} != 0) & (({r} < 0) != ({b} < 0)))", ops.add(r, b), r)
|
||||
cond = ops.and_(
|
||||
ops.ne(r, ops.constant(0, torch.int32)),
|
||||
ops.ne(ops.signbit(r), ops.signbit(b)),
|
||||
)
|
||||
return ops.where(cond, ops.add(r, b), r)
|
||||
|
||||
@staticmethod
|
||||
def load_seed(name, offset):
|
||||
@ -1473,24 +1478,17 @@ class Kernel(CodeGen):
|
||||
# TODO: hoist this to top level
|
||||
class CSEProxy:
|
||||
self.name = "CSEProxy"
|
||||
vr_analysis = ValueRangeAnalysis()
|
||||
|
||||
@staticmethod
|
||||
def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
|
||||
def inner(*args, **kwargs):
|
||||
# TritonTemplateKernel has no current_node
|
||||
buf_bounds = ValueRanges.unknown()
|
||||
if (
|
||||
fx_node := getattr(V.interpreter, "current_node", None)
|
||||
) and fx_node.target == name:
|
||||
assert isinstance(self.node_to_bounds, dict)
|
||||
buf_bounds = self.node_to_bounds.get(
|
||||
fx_node, ValueRanges.unknown()
|
||||
)
|
||||
bounds = CSEProxy._bound_variable(name, *args, **kwargs)
|
||||
|
||||
value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
|
||||
|
||||
def do_cse(v):
|
||||
csevar = self.cse.generate(self.compute, v, bounds=buf_bounds)
|
||||
csevar = self.cse.generate(self.compute, v, bounds=bounds)
|
||||
csevar.update_on_args(name, args, kwargs)
|
||||
return csevar
|
||||
|
||||
@ -1498,6 +1496,49 @@ class Kernel(CodeGen):
|
||||
|
||||
return inner
|
||||
|
||||
@staticmethod
|
||||
def _bound_variable(name, *args, **kwargs):
|
||||
"""
|
||||
If the variable comes from an FX node, we forward the bound we have already computed
|
||||
Else, if the variable when codegen'ing another op, we try to compute its bounds
|
||||
"""
|
||||
from ..select_algorithm import TritonTemplateKernel
|
||||
|
||||
if isinstance(V.kernel, TritonTemplateKernel):
|
||||
return ValueRanges.unknown()
|
||||
|
||||
fx_node = V.interpreter.current_node
|
||||
if fx_node.target == name:
|
||||
assert isinstance(self.node_to_bounds, dict)
|
||||
return self.node_to_bounds.get(fx_node, ValueRanges.unknown())
|
||||
elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name):
|
||||
# These create lots of inner strings. We would need to compute the bounds at the ops
|
||||
# We will also likely not get much from computing VRs on these nodes
|
||||
if any(
|
||||
s in fx_node.target
|
||||
for s in ("set_indirect", "reduction", "scan")
|
||||
):
|
||||
return ValueRanges.unknown()
|
||||
|
||||
# We assume that the inputs come from `ops.` and are not strings. If you want to generate
|
||||
# intermediary strings, wrap them in CSE variables with properly initialised bounds.
|
||||
|
||||
# If there is no FX bound but we know how to compute one we do so
|
||||
assert not kwargs
|
||||
|
||||
def arg_to_bound(x):
|
||||
if isinstance(x, CSEVariable):
|
||||
return x.bounds
|
||||
elif isinstance(x, sympy.Expr):
|
||||
return bound_sympy(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
arg_bounds = list(map(arg_to_bound, args))
|
||||
return getattr(CSEProxy.vr_analysis, name)(*arg_bounds)
|
||||
else:
|
||||
return ValueRanges.unknown()
|
||||
|
||||
@staticmethod
|
||||
def indirect_indexing(
|
||||
var: CSEVariable, size: sympy.Expr, check: bool = True
|
||||
|
@ -34,6 +34,7 @@ from ..scheduler import (
|
||||
)
|
||||
from ..utils import (
|
||||
cache_on_self,
|
||||
get_bounds_index_expr,
|
||||
get_fused_kernel_name,
|
||||
is_welford_reduction,
|
||||
parallel_num_threads,
|
||||
@ -841,7 +842,7 @@ class CppOverrides(OpOverrides):
|
||||
@staticmethod
|
||||
def constant(val, dtype):
|
||||
opt_ctx: OptimizationContext = get_current_node_opt_ctx()
|
||||
assert opt_ctx and opt_ctx.dtype is not None
|
||||
assert opt_ctx and opt_ctx.dtype is not None, opt_ctx
|
||||
dtype = opt_ctx.dtype
|
||||
if dtype in DTYPE_LOWP_FP:
|
||||
# Since load promotes all half-precision inputs to float, constants
|
||||
@ -854,7 +855,12 @@ class CppOverrides(OpOverrides):
|
||||
opt_ctx: OptimizationContext = get_current_node_opt_ctx()
|
||||
assert opt_ctx and opt_ctx.dtype is not None
|
||||
dtype = opt_ctx.dtype
|
||||
return ops.to_dtype(cexpr(V.kernel.rename_indexing(expr)), dtype)
|
||||
|
||||
idx_str = cexpr(V.kernel.rename_indexing(expr))
|
||||
var = V.kernel.cse.generate(
|
||||
V.kernel.compute, idx_str, bounds=get_bounds_index_expr(expr)
|
||||
)
|
||||
return ops.to_dtype(var, dtype)
|
||||
|
||||
@staticmethod
|
||||
def masked(mask, body, other):
|
||||
@ -1451,7 +1457,10 @@ class CppVecOverrides(CppOverrides):
|
||||
if stride == 0:
|
||||
return CppOverrides.index_expr(expr, dtype)
|
||||
elif stride is not None:
|
||||
value = ops.to_dtype(cexpr(index), dtype)
|
||||
idx = V.kernel.cse.generate(
|
||||
V.kernel.compute, cexpr(index), bounds=get_bounds_index_expr(expr)
|
||||
)
|
||||
value = ops.to_dtype(idx, dtype)
|
||||
if isinstance(value, OpsValue):
|
||||
value = value.value
|
||||
csevar = V.kernel.arange(value, stride)
|
||||
|
@ -23,7 +23,6 @@ from typing import (
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -39,6 +38,7 @@ from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties
|
||||
from torch._prims_common import is_integer_dtype
|
||||
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
|
||||
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
from torch.utils._triton import has_triton_package
|
||||
|
||||
from ..._dynamo.utils import counters
|
||||
@ -58,6 +58,7 @@ from ..runtime.runtime_utils import (
|
||||
from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse
|
||||
from ..utils import (
|
||||
cache_on_self,
|
||||
get_bounds_index_expr,
|
||||
get_dtype_size,
|
||||
get_fused_kernel_name,
|
||||
get_kernel_metadata,
|
||||
@ -86,8 +87,6 @@ from .common import (
|
||||
from .multi_kernel import MultiKernel
|
||||
from .triton_utils import config_of, signature_of, signature_to_meta
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
|
||||
@ -619,7 +618,7 @@ class TritonOverrides(OpOverrides):
|
||||
elif bug == "accuracy":
|
||||
return f"{x} + 1"
|
||||
elif bug is None:
|
||||
return ops.maximum("0", x)
|
||||
return ops.maximum(ops.constant(0, torch.int32), x)
|
||||
else:
|
||||
raise AssertionError(
|
||||
f"unrecognized config triton.inject_relu_bug_TESTING_ONLY = {bug!r}"
|
||||
@ -864,11 +863,9 @@ class TritonOverrides(OpOverrides):
|
||||
|
||||
@staticmethod
|
||||
def sign(x):
|
||||
def to_int(s):
|
||||
return f"{s}.to(tl.int8)"
|
||||
|
||||
left = to_int(ops.lt("0", x))
|
||||
right = to_int(ops.lt(x, "0"))
|
||||
z = ops.constant(0, torch.int32)
|
||||
left = ops.to_dtype((ops.lt(z, x)), torch.int8)
|
||||
right = ops.to_dtype((ops.lt(x, z)), torch.int8)
|
||||
sub = ops.sub(left, right)
|
||||
return f"{sub}.to({x}.dtype)"
|
||||
|
||||
@ -916,8 +913,9 @@ class TritonKernelOverrides(TritonOverrides):
|
||||
def index_expr(cls, expr, dtype):
|
||||
indexing = V.kernel.indexing(expr, block_ptr=False)
|
||||
assert isinstance(indexing, IndexingOptions)
|
||||
# This is called from CSEProxy.__getattr__, so we'll set the bounds there
|
||||
var = V.kernel.cse.generate(V.kernel.compute, indexing.index_str)
|
||||
var = V.kernel.cse.generate(
|
||||
V.kernel.compute, indexing.index_str, bounds=get_bounds_index_expr(expr)
|
||||
)
|
||||
|
||||
if dtype not in {torch.int32, torch.int64}:
|
||||
var = V.kernel.cse.generate(V.kernel.compute, cls.to_dtype(var, dtype))
|
||||
@ -929,10 +927,14 @@ class TritonKernelOverrides(TritonOverrides):
|
||||
with V.kernel.mask_loads(mask) as new_mask:
|
||||
result = body()
|
||||
|
||||
# Remove once CSEVariables track the dtype
|
||||
if result.bounds.is_bool:
|
||||
other = bool(other)
|
||||
# Take dtype from result to prevent accidental promotion
|
||||
other = V.kernel.cse.generate(
|
||||
V.kernel.compute,
|
||||
f"tl.full({result}.shape, {triton_constant(other)}, {result}.dtype)",
|
||||
bounds=ValueRanges.wrap(other),
|
||||
)
|
||||
return ops.where(new_mask, result, other)
|
||||
|
||||
|
@ -355,6 +355,9 @@ always_keep_tensor_constants = False
|
||||
# assert that indirect indexing does not read / write out of bounds
|
||||
assert_indirect_indexing = True
|
||||
|
||||
# compute CSE bounds on variables that do not appear in the FX graph
|
||||
compute_all_bounds = False
|
||||
|
||||
# constant folding on the joint graph
|
||||
joint_graph_constant_folding = True
|
||||
|
||||
|
@ -49,6 +49,7 @@ from torch.autograd.profiler_util import EventList
|
||||
from torch.fx.passes.shape_prop import ShapeProp
|
||||
from torch.utils._sympy.functions import CeilDiv, CleanDiv, FloorDiv, ModularIndexing
|
||||
from torch.utils._sympy.symbol import make_symbol, SymT
|
||||
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
|
||||
from . import config
|
||||
from .runtime.runtime_utils import ceildiv as runtime_ceildiv
|
||||
|
||||
@ -539,6 +540,20 @@ def sympy_str(expr: sympy.Expr) -> str:
|
||||
return str(expr)
|
||||
|
||||
|
||||
def get_bounds_index_expr(index):
|
||||
from .virtualized import V
|
||||
|
||||
# If this expression does not come from an FX node, we compute its bounds
|
||||
if (
|
||||
config.compute_all_bounds
|
||||
and (fx_node := getattr(V.interpreter, "current_node", None))
|
||||
and fx_node.target != "index_expr"
|
||||
):
|
||||
return bound_sympy(index)
|
||||
else:
|
||||
return ValueRanges.unknown()
|
||||
|
||||
|
||||
def sympy_index_symbol_with_prefix(prefix: SymT, idx: int) -> sympy.Symbol:
|
||||
"""
|
||||
Used to generate an integer-nonnegative symbol.
|
||||
|
@ -138,7 +138,7 @@ class ValueRanges(Generic[_T]):
|
||||
if not sympy_generic_le(lower, upper):
|
||||
raise ValueRangeError(f"Invalid ranges [{lower}:{upper}]")
|
||||
except TypeError:
|
||||
raise TypeError(f"Could not compare {lower} <= {upper}")
|
||||
raise TypeError(f"Could not compare {lower} <= {upper}") # noqa: TRY200
|
||||
# Because this is a frozen class
|
||||
object.__setattr__(self, "lower", lower)
|
||||
object.__setattr__(self, "upper", upper)
|
||||
@ -340,6 +340,9 @@ class SymPyValueRangeAnalysis:
|
||||
|
||||
@staticmethod
|
||||
def constant(value, dtype):
|
||||
if isinstance(value, ValueRanges):
|
||||
assert value.is_singleton()
|
||||
value = value.lower
|
||||
# NB: value is NOT a sympy expression, it's a constant!
|
||||
is_python = isinstance(value, (int, float, bool))
|
||||
assert is_python or isinstance(
|
||||
@ -663,7 +666,9 @@ class SymPyValueRangeAnalysis:
|
||||
b = ValueRanges.wrap(b)
|
||||
c = ValueRanges.wrap(c)
|
||||
a = a.boolify()
|
||||
assert b.is_bool == c.is_bool
|
||||
# We sometimes write unknown without specifying the type correctly
|
||||
# In particular, we do that when initialising the bounds for loads in bounds.py
|
||||
assert b.is_bool == c.is_bool or ValueRanges.unknown() in (b, c)
|
||||
if b.is_bool:
|
||||
return ValueRanges(sympy.And(b.lower, c.lower), sympy.Or(b.upper, c.upper))
|
||||
else:
|
||||
|
Reference in New Issue
Block a user