Compute bounds for the variables created during codegen (#123100)

Before we would just bail out on these bounds for all variables that did
not come from the FX graph. Now we propagate the bounds whenever we have
a rule for that op.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123100
Approved by: https://github.com/jgong5, https://github.com/peterbell10
This commit is contained in:
lezcano
2024-05-07 18:39:24 +00:00
committed by PyTorch MergeBot
parent 15a9770225
commit 320af5eaa6
7 changed files with 110 additions and 30 deletions

View File

@ -9616,7 +9616,12 @@ class CommonTemplate:
# This used to not compile due to a wrong return type of randint64_cpu
# See https://github.com/pytorch/pytorch/issues/117435
def fn(n):
return torch.randint(low=-5, high=5, size=(n,), dtype=torch.int64) % 10
return (
torch.randint(
low=-5, high=5, size=(n,), dtype=torch.int64, device=self.device
)
% 10
)
res = torch.compile(fn)(20)
self.assertTrue(torch.all((0 <= res) & (res < 10)).item())

View File

@ -27,7 +27,7 @@ import torch.fx
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch.utils import _pytree as pytree
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
from torch.utils._sympy.value_ranges import ValueRanges
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
from .. import config, metrics
from ..utils import DeferredLineBase, IndentedBuffer, sympy_dot, sympy_subs, unique
@ -269,6 +269,7 @@ class DataTypePropagation:
if node.target in (
"get_index",
"index_expr",
"randint64",
):
return torch.int64
@ -529,7 +530,7 @@ class OpOverrides:
@staticmethod
def reciprocal(x):
return ops.truediv("1", x)
return ops.truediv(ops.constant(1, torch.int32), x)
@staticmethod
def square(x):
@ -566,7 +567,11 @@ class OpOverrides:
@staticmethod
def remainder(a, b):
r = ops.mod(a, b)
return ops.where(f"(({r} != 0) & (({r} < 0) != ({b} < 0)))", ops.add(r, b), r)
cond = ops.and_(
ops.ne(r, ops.constant(0, torch.int32)),
ops.ne(ops.signbit(r), ops.signbit(b)),
)
return ops.where(cond, ops.add(r, b), r)
@staticmethod
def load_seed(name, offset):
@ -1473,24 +1478,17 @@ class Kernel(CodeGen):
# TODO: hoist this to top level
class CSEProxy:
self.name = "CSEProxy"
vr_analysis = ValueRangeAnalysis()
@staticmethod
def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
def inner(*args, **kwargs):
# TritonTemplateKernel has no current_node
buf_bounds = ValueRanges.unknown()
if (
fx_node := getattr(V.interpreter, "current_node", None)
) and fx_node.target == name:
assert isinstance(self.node_to_bounds, dict)
buf_bounds = self.node_to_bounds.get(
fx_node, ValueRanges.unknown()
)
bounds = CSEProxy._bound_variable(name, *args, **kwargs)
value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
def do_cse(v):
csevar = self.cse.generate(self.compute, v, bounds=buf_bounds)
csevar = self.cse.generate(self.compute, v, bounds=bounds)
csevar.update_on_args(name, args, kwargs)
return csevar
@ -1498,6 +1496,49 @@ class Kernel(CodeGen):
return inner
@staticmethod
def _bound_variable(name, *args, **kwargs):
"""
If the variable comes from an FX node, we forward the bound we have already computed
Else, if the variable when codegen'ing another op, we try to compute its bounds
"""
from ..select_algorithm import TritonTemplateKernel
if isinstance(V.kernel, TritonTemplateKernel):
return ValueRanges.unknown()
fx_node = V.interpreter.current_node
if fx_node.target == name:
assert isinstance(self.node_to_bounds, dict)
return self.node_to_bounds.get(fx_node, ValueRanges.unknown())
elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name):
# These create lots of inner strings. We would need to compute the bounds at the ops
# We will also likely not get much from computing VRs on these nodes
if any(
s in fx_node.target
for s in ("set_indirect", "reduction", "scan")
):
return ValueRanges.unknown()
# We assume that the inputs come from `ops.` and are not strings. If you want to generate
# intermediary strings, wrap them in CSE variables with properly initialised bounds.
# If there is no FX bound but we know how to compute one we do so
assert not kwargs
def arg_to_bound(x):
if isinstance(x, CSEVariable):
return x.bounds
elif isinstance(x, sympy.Expr):
return bound_sympy(x)
else:
return x
arg_bounds = list(map(arg_to_bound, args))
return getattr(CSEProxy.vr_analysis, name)(*arg_bounds)
else:
return ValueRanges.unknown()
@staticmethod
def indirect_indexing(
var: CSEVariable, size: sympy.Expr, check: bool = True

View File

@ -34,6 +34,7 @@ from ..scheduler import (
)
from ..utils import (
cache_on_self,
get_bounds_index_expr,
get_fused_kernel_name,
is_welford_reduction,
parallel_num_threads,
@ -841,7 +842,7 @@ class CppOverrides(OpOverrides):
@staticmethod
def constant(val, dtype):
opt_ctx: OptimizationContext = get_current_node_opt_ctx()
assert opt_ctx and opt_ctx.dtype is not None
assert opt_ctx and opt_ctx.dtype is not None, opt_ctx
dtype = opt_ctx.dtype
if dtype in DTYPE_LOWP_FP:
# Since load promotes all half-precision inputs to float, constants
@ -854,7 +855,12 @@ class CppOverrides(OpOverrides):
opt_ctx: OptimizationContext = get_current_node_opt_ctx()
assert opt_ctx and opt_ctx.dtype is not None
dtype = opt_ctx.dtype
return ops.to_dtype(cexpr(V.kernel.rename_indexing(expr)), dtype)
idx_str = cexpr(V.kernel.rename_indexing(expr))
var = V.kernel.cse.generate(
V.kernel.compute, idx_str, bounds=get_bounds_index_expr(expr)
)
return ops.to_dtype(var, dtype)
@staticmethod
def masked(mask, body, other):
@ -1451,7 +1457,10 @@ class CppVecOverrides(CppOverrides):
if stride == 0:
return CppOverrides.index_expr(expr, dtype)
elif stride is not None:
value = ops.to_dtype(cexpr(index), dtype)
idx = V.kernel.cse.generate(
V.kernel.compute, cexpr(index), bounds=get_bounds_index_expr(expr)
)
value = ops.to_dtype(idx, dtype)
if isinstance(value, OpsValue):
value = value.value
csevar = V.kernel.arange(value, stride)

View File

@ -23,7 +23,6 @@ from typing import (
Optional,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
@ -39,6 +38,7 @@ from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties
from torch._prims_common import is_integer_dtype
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
from torch.utils._sympy.value_ranges import ValueRanges
from torch.utils._triton import has_triton_package
from ..._dynamo.utils import counters
@ -58,6 +58,7 @@ from ..runtime.runtime_utils import (
from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse
from ..utils import (
cache_on_self,
get_bounds_index_expr,
get_dtype_size,
get_fused_kernel_name,
get_kernel_metadata,
@ -86,8 +87,6 @@ from .common import (
from .multi_kernel import MultiKernel
from .triton_utils import config_of, signature_of, signature_to_meta
if TYPE_CHECKING:
from torch.utils._sympy.value_ranges import ValueRanges
log = logging.getLogger(__name__)
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
@ -619,7 +618,7 @@ class TritonOverrides(OpOverrides):
elif bug == "accuracy":
return f"{x} + 1"
elif bug is None:
return ops.maximum("0", x)
return ops.maximum(ops.constant(0, torch.int32), x)
else:
raise AssertionError(
f"unrecognized config triton.inject_relu_bug_TESTING_ONLY = {bug!r}"
@ -864,11 +863,9 @@ class TritonOverrides(OpOverrides):
@staticmethod
def sign(x):
def to_int(s):
return f"{s}.to(tl.int8)"
left = to_int(ops.lt("0", x))
right = to_int(ops.lt(x, "0"))
z = ops.constant(0, torch.int32)
left = ops.to_dtype((ops.lt(z, x)), torch.int8)
right = ops.to_dtype((ops.lt(x, z)), torch.int8)
sub = ops.sub(left, right)
return f"{sub}.to({x}.dtype)"
@ -916,8 +913,9 @@ class TritonKernelOverrides(TritonOverrides):
def index_expr(cls, expr, dtype):
indexing = V.kernel.indexing(expr, block_ptr=False)
assert isinstance(indexing, IndexingOptions)
# This is called from CSEProxy.__getattr__, so we'll set the bounds there
var = V.kernel.cse.generate(V.kernel.compute, indexing.index_str)
var = V.kernel.cse.generate(
V.kernel.compute, indexing.index_str, bounds=get_bounds_index_expr(expr)
)
if dtype not in {torch.int32, torch.int64}:
var = V.kernel.cse.generate(V.kernel.compute, cls.to_dtype(var, dtype))
@ -929,10 +927,14 @@ class TritonKernelOverrides(TritonOverrides):
with V.kernel.mask_loads(mask) as new_mask:
result = body()
# Remove once CSEVariables track the dtype
if result.bounds.is_bool:
other = bool(other)
# Take dtype from result to prevent accidental promotion
other = V.kernel.cse.generate(
V.kernel.compute,
f"tl.full({result}.shape, {triton_constant(other)}, {result}.dtype)",
bounds=ValueRanges.wrap(other),
)
return ops.where(new_mask, result, other)

View File

@ -355,6 +355,9 @@ always_keep_tensor_constants = False
# assert that indirect indexing does not read / write out of bounds
assert_indirect_indexing = True
# compute CSE bounds on variables that do not appear in the FX graph
compute_all_bounds = False
# constant folding on the joint graph
joint_graph_constant_folding = True

View File

@ -49,6 +49,7 @@ from torch.autograd.profiler_util import EventList
from torch.fx.passes.shape_prop import ShapeProp
from torch.utils._sympy.functions import CeilDiv, CleanDiv, FloorDiv, ModularIndexing
from torch.utils._sympy.symbol import make_symbol, SymT
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
from . import config
from .runtime.runtime_utils import ceildiv as runtime_ceildiv
@ -539,6 +540,20 @@ def sympy_str(expr: sympy.Expr) -> str:
return str(expr)
def get_bounds_index_expr(index):
from .virtualized import V
# If this expression does not come from an FX node, we compute its bounds
if (
config.compute_all_bounds
and (fx_node := getattr(V.interpreter, "current_node", None))
and fx_node.target != "index_expr"
):
return bound_sympy(index)
else:
return ValueRanges.unknown()
def sympy_index_symbol_with_prefix(prefix: SymT, idx: int) -> sympy.Symbol:
"""
Used to generate an integer-nonnegative symbol.

View File

@ -138,7 +138,7 @@ class ValueRanges(Generic[_T]):
if not sympy_generic_le(lower, upper):
raise ValueRangeError(f"Invalid ranges [{lower}:{upper}]")
except TypeError:
raise TypeError(f"Could not compare {lower} <= {upper}")
raise TypeError(f"Could not compare {lower} <= {upper}") # noqa: TRY200
# Because this is a frozen class
object.__setattr__(self, "lower", lower)
object.__setattr__(self, "upper", upper)
@ -340,6 +340,9 @@ class SymPyValueRangeAnalysis:
@staticmethod
def constant(value, dtype):
if isinstance(value, ValueRanges):
assert value.is_singleton()
value = value.lower
# NB: value is NOT a sympy expression, it's a constant!
is_python = isinstance(value, (int, float, bool))
assert is_python or isinstance(
@ -663,7 +666,9 @@ class SymPyValueRangeAnalysis:
b = ValueRanges.wrap(b)
c = ValueRanges.wrap(c)
a = a.boolify()
assert b.is_bool == c.is_bool
# We sometimes write unknown without specifying the type correctly
# In particular, we do that when initialising the bounds for loads in bounds.py
assert b.is_bool == c.is_bool or ValueRanges.unknown() in (b, c)
if b.is_bool:
return ValueRanges(sympy.And(b.lower, c.lower), sympy.Or(b.upper, c.upper))
else: