Compare commits

...

7 Commits

Author SHA1 Message Date
5fda70b854 Split up maybe_evaluate_static for more clarity
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
2024-09-07 21:01:08 -07:00
572a37a7ed Remove find_localzeros
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 439c4ef36e315428ad7f033999241fddaaf172cd
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133325
2024-09-07 20:34:38 -07:00
ffc33d516a Don't uselessly recompute axiom dict every static eval call
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 2dc16a94638726ce129532de212e5c8d282232fd
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135429
2024-09-07 20:34:22 -07:00
750d057509 Add scribe logging for maybe_evaluate_static_worker
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 212ab3893a4f3127adf9647e2d3f8db21051ea3d
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135138
2024-09-07 20:34:17 -07:00
41e8059f8c Inherit all secrets to inductor workflow
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 99ca283922614b6ec6672ea1bab9d4bb95d3fb9f
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135354
2024-09-07 20:34:17 -07:00
cbb2b07240 Deal with size oblivious before going into worker
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: b09f00fe8d79f7a47fc1e95a71d8f0ebf6cfec4d
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135137
2024-09-07 20:34:16 -07:00
b5200faeec Refactor maybe_evaluate_static into a worker function off of ShapeEnv
By refactoring this way, I can put a non-expiring LRU cache here.
Splitting also will make it easier for me to tell who is using up all
the time.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: f515ce208eccfe299626395cee23df6017e73c4f
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135107
2024-09-07 20:34:16 -07:00
3 changed files with 205 additions and 116 deletions

View File

@ -58,8 +58,7 @@ jobs:
{ config: "aot_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
{ config: "inductor_cpp_wrapper_abi_compatible", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
]}
secrets:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
secrets: inherit
linux-focal-cuda12_1-py3_10-gcc9-inductor-test:
name: cuda12.1-py3.10-gcc9-sm86
@ -69,8 +68,7 @@ jobs:
build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86
docker-image: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-inductor-build.outputs.test-matrix }}
secrets:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
secrets: inherit
linux-focal-cuda12_1-py3_12-gcc9-inductor-build:
name: cuda12.1-py3.12-gcc9-sm86
@ -86,6 +84,7 @@ jobs:
{ config: "inductor", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
{ config: "inductor", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
]}
secrets: inherit
linux-focal-cuda12_1-py3_12-gcc9-inductor-test:
name: cuda12.1-py3.12-gcc9-sm86
@ -95,6 +94,7 @@ jobs:
build-environment: linux-focal-cuda12.1-py3.12-gcc9-sm86
docker-image: ${{ needs.linux-focal-cuda12_1-py3_12-gcc9-inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-cuda12_1-py3_12-gcc9-inductor-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-cpu-py3_12-inductor-halide-build:
name: linux-jammy-cpu-py3.12-gcc11-inductor-halide
@ -108,6 +108,7 @@ jobs:
{ include: [
{ config: "inductor-halide", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" },
]}
secrets: inherit
linux-jammy-cpu-py3_12-inductor-halide-test:
name: linux-jammy-cpu-py3.12-gcc11-inductor-halide
@ -117,6 +118,7 @@ jobs:
build-environment: linux-jammy-py3.12-gcc11
docker-image: ${{ needs.linux-jammy-cpu-py3_12-inductor-halide-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-cpu-py3_12-inductor-halide-build.outputs.test-matrix }}
secrets: inherit
linux-focal-cuda12_4-py3_10-gcc9-inductor-build:
# Should be synced with the one in inductor-periodic.yml but this only runs inductor_timm
@ -134,8 +136,7 @@ jobs:
{ config: "inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
{ config: "inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
]}
secrets:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
secrets: inherit
linux-focal-cuda12_4-py3_10-gcc9-inductor-test:
name: cuda12.4-py3.10-gcc9-sm86
@ -146,8 +147,7 @@ jobs:
build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86
docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.test-matrix }}
secrets:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
secrets: inherit
linux-jammy-cpu-py3_9-gcc11-inductor-build:
name: linux-jammy-cpu-py3.9-gcc11-inductor
@ -201,8 +201,7 @@ jobs:
{ config: "cpu_inductor_freezing_avx2_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
{ config: "cpu_inductor_freezing_avx2_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
]}
secrets:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
secrets: inherit
linux-jammy-cpu-py3_9-gcc11-inductor-test:
name: linux-jammy-cpu-py3.9-gcc11-inductor
@ -212,5 +211,4 @@ jobs:
build-environment: linux-jammy-py3.9-gcc11-build
docker-image: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }}
secrets:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
secrets: inherit

View File

@ -19,6 +19,7 @@ import operator
import os
import re
import sys
import json
import threading
import traceback
from collections import defaultdict
@ -57,6 +58,7 @@ from torch.fx.experimental.recording import (
)
from torch.fx.experimental.sym_node import SymNode, SymTypes
from torch._logging import trace_structured, structured
import torch._logging.scribe as scribe
# NB: The sym_* functions are used via getattr() and must be imported here.
from torch import SymBool, SymFloat, SymInt
@ -1404,6 +1406,141 @@ def safe_expand(r):
else:
return r
@lru_cache(None)
def _maybe_evaluate_static_worker(
expr: sympy.Expr,
symbol_info: Tuple[Tuple[sympy.Symbol, ValueRanges, sympy.Integer], ...],
unbacked_only: bool,
):
"""
This variant of ShapeEnv._maybe_evaluate_static has no dependence on
ShapeEnv and thus can be cached indefinitely. It does the "heavy" lifting
for static evaluation, including nontrivial reliance on Sympy simplification
that occurs when we reallocate the symbols
"""
def log_and_return(res):
def jsonify(x):
if x.is_integer:
if x in [int_oo, -int_oo]:
return str(x)
return int(x)
elif x.is_real:
return float(x)
else:
return bool(x)
try:
env = {}
for s, vr, hint in symbol_info:
e = {}
if not s.is_integer and s.is_real:
e['real'] = True
else:
assert s.is_integer
if vr is not None:
e['vr'] = [jsonify(vr.lower), jsonify(vr.upper)]
if hint is not None:
e['hint'] = jsonify(hint)
env[str(s)] = e
entry = {
'expr': str(expr),
'env': env,
'res': jsonify(res),
'pytest_current_test': os.getenv('PYTEST_CURRENT_TEST'),
'argv': sys.argv,
}
if unbacked_only:
entry['unbacked_only'] = True
from torch._dynamo.utils import get_chromium_event_logger
chromium_log = get_chromium_event_logger()
entry['simple_stack'] = chromium_log.get_stack()
except Exception:
log.exception("log_and_return failed")
else:
signpost_event(
"dynamic",
"maybe_evaluate_static_worker",
entry,
)
scribe.open_source_signpost(
subsystem="dynamic",
name="maybe_evaluate_static_worker",
parameters=lambda: json.dumps(entry),
)
return res
# Simplify making use of value range lower bound
new_shape_env = {}
new_range_env = {}
for idx, sinfo in enumerate(symbol_info):
k, vr, hint = sinfo
if isinstance(hint, SingletonInt):
# Skip var_ranges logic for SingletonInt which is only used
# for jagged layout NestedTensors today
continue
lower = vr.lower
# Don't do anything if we don't have a nontrivial lower bound
# Also don't do anything if we asked only to simplify unbacked
# SymInt
if (
lower is -int_oo or
(unbacked_only and hint is not None) or
not vr.is_int
):
new_range_env[k] = vr
continue
# The goal is to take our symbols which have various lower bounds
# and reallocate them into new symbols which are exactly positive;
# e.g., if we have s0 in [2, inf], we want to turn it into ess0 in
# [1, inf], where s0 = ess0 + 1. This gives the most information
# to sympy for subsequent simplifications.
#
# Positive means >= 1
# Positive - 1 means >= 0
# Positive + lower - 1 means >= lower
# The new symbol 's' is "too low", so when we substitute it in
# we have to increase it by offset (and conversely, the new
# variables have to have their value range bounds adjusted as
# well)
s = sympy.Symbol(f"evaluate_static_shape_{idx}", positive=True, integer=True)
# Note:
# Offset might be a fraction(e.g. aten.split.Tensor), but shapes are always integers.
# Sympy might give unexepected results when comparing an integer with a non-integer
# Therefore, we cast offset to int here.
# For example:
# shape_0 = sympy.Symbol("shape_0", positive=True, integer=True)
# expr = sympy.Eq(shape_0 - 1/3, 4)
# expr.xreplace({}) # False
offset = int(lower - 1)
new_shape_env[k] = s + offset
new_range_env[s] = SymPyValueRangeAnalysis.add(vr, -offset)
try:
new_expr = expr.xreplace(new_shape_env)
except RecursionError:
log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env)
return None
# We need to canonicalize, as after expand we may have something like `a + b = a` and
# sympy will not simplify the a. The two appeareances of the a will then make value ranges
# analysis give lose bounds
new_expr = canonicalize_bool_expr(safe_expand(new_expr))
if new_expr.is_number:
return log_and_return(new_expr)
# Check if the range can solve it statically
out = bound_sympy(new_expr, new_range_env)
if out.is_singleton():
return log_and_return(out.lower)
return new_expr if unbacked_only else None
def error():
raise AssertionError("shouldn't be hit")
@ -2431,6 +2568,7 @@ class ShapeEnv:
)
self.guards: List[ShapeGuard] = []
self.axioms: Dict[sympy.Expr, sympy.Expr] = {}
# Maps symbolic ints to their original concrete values
# Currently populated from tensors
self.var_to_val: Dict[sympy.Symbol, sympy.Integer] = {}
@ -4443,11 +4581,6 @@ class ShapeEnv:
# axioms with compute hint NYE
assert not compute_hint or not axioms
if var_to_range is None:
var_ranges = self.var_to_range
else:
var_ranges = dict(var_to_range)
expr = self.simplify(expr)
if compute_hint:
@ -4456,114 +4589,72 @@ class ShapeEnv:
expr = canonicalize_bool_expr(expr)
# Pattern matching
symbols = tuple(expr.free_symbols)
if axioms is None:
axioms = self.get_axioms(symbols, compute_hint=compute_hint)
subst = {}
for e in axioms:
if e.free_symbols.issubset(expr.free_symbols):
subst.update(dict(self.get_implications(e)))
subst = self.axioms
else:
def compute_explicit_subst():
subst = {}
for e in axioms:
if e.free_symbols.issubset(expr.free_symbols):
subst.update(dict(self.get_implications(e)))
return subst
subst = compute_explicit_subst()
expr = expr.xreplace(subst)
# TODO: compute hint might have gotten broken here
symbols = tuple(expr.free_symbols)
fs = expr.free_symbols
# Simplify making use of value range lower bound
new_shape_env = {}
new_range_env = {}
for idx, k in enumerate(symbols):
if isinstance(self.var_to_val.get(k, None), SingletonInt):
# Skip var_ranges logic for SingletonInt which is only used
# for jagged layout NestedTensors today
continue
vr = var_ranges[k]
if size_oblivious and k in self.size_like:
lower = max(2, vr.lower)
# Clamping size-oblivious to some quantity below sys.maxsize
# helps us determine that f(u0) != sys.maxsize, which is a
# test that is looking for sys.maxsize as a sentinel, but you
# don't really want to worry about it for unbacked SymInts.
# This is similar to the flavor where size oblivious omits
# 0/1, it changes semantics but in a benign way.
upper = min(2 ** 48, vr.upper)
# This is a bit dodgy: what this means is that there was a
# size-like unbacked symbol whose upper bound < 2. This
# causes... problems.
if lower <= upper:
vr = ValueRanges(lower, upper)
if not fs and (expr.is_number or expr.is_Boolean):
return expr
def adjust_vr(k, vr):
# Check if the range can solve it statically quickly
if not (size_oblivious and k in self.size_like):
return vr
lower = max(2, vr.lower)
# Clamping size-oblivious to some quantity below sys.maxsize
# helps us determine that f(u0) != sys.maxsize, which is a
# test that is looking for sys.maxsize as a sentinel, but you
# don't really want to worry about it for unbacked SymInts.
# This is similar to the flavor where size oblivious omits
# 0/1, it changes semantics but in a benign way.
upper = min(2 ** 48, vr.upper)
# This is a bit dodgy: what this means is that there was a
# size-like unbacked symbol whose upper bound < 2. This
# causes... problems. When this happens, just ignore the
# preexisting upper bound
if lower > upper:
upper = max(lower, 2 ** 48)
return ValueRanges(lower, upper)
def compute_var_ranges():
if var_to_range is None:
if size_oblivious: # micro-optimization
#var_ranges = {k: adjust_vr(k, v) for k, v in self.var_to_range.items()}
return {k: adjust_vr(k, self.var_to_range[k]) for k in fs if k in self.var_to_range}
else:
return self.var_to_range
else:
lower = vr.lower
# Don't do anything if we don't have a nontrivial lower bound
# Also don't do anything if we asked only to simplify unbacked
# SymInt
if (
lower is -int_oo or
(unbacked_only and k in self.var_to_val) or
not vr.is_int
):
new_range_env[k] = vr
continue
# The goal is to take our symbols which have various lower bounds
# and reallocate them into new symbols which are exactly positive;
# e.g., if we have s0 in [2, inf], we want to turn it into ess0 in
# [1, inf], where s0 = ess0 + 1. This gives the most information
# to sympy for subsequent simplifications.
#
# Positive means >= 1
# Positive - 1 means >= 0
# Positive + lower - 1 means >= lower
# The new symbol 's' is "too low", so when we substitute it in
# we have to increase it by offset (and conversely, the new
# variables have to have their value range bounds adjusted as
# well)
s = sympy.Symbol(f"evaluate_static_shape_{idx}", positive=True, integer=True)
return {k: adjust_vr(k, v) for k, v in var_to_range}
# Note:
# Offset might be a fraction(e.g. aten.split.Tensor), but shapes are always integers.
# Sympy might give unexepected results when comparing an integer with a non-integer
# Therefore, we cast offset to int here.
# For example:
# shape_0 = sympy.Symbol("shape_0", positive=True, integer=True)
# expr = sympy.Eq(shape_0 - 1/3, 4)
# expr.xreplace({}) # False
offset = int(lower - 1)
new_shape_env[k] = s + offset
new_range_env[s] = SymPyValueRangeAnalysis.add(vr, -offset)
var_ranges = compute_var_ranges()
try:
new_expr = expr.xreplace(new_shape_env)
except RecursionError:
log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env)
self.counter["sympy_recursion_error"] += 1
return None
# We need to canonicalize, as after expand we may have something like `a + b = a` and
# sympy will not simplify the a. The two appeareances of the a will then make value ranges
# analysis give lose bounds
new_expr = canonicalize_bool_expr(safe_expand(new_expr))
if new_expr.is_number:
return new_expr
# This is bad to do, the replacement with division leaves us with
# rationals when atom.args[0] is addition, e.g., sympy will happily
# turn (s0 + s1) // 2 into s0 / 2 + s1 / 2. Needless complication!
"""
floor_div_replace = {}
for atom in new_expr.atoms(FloorDiv):
floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1])
new_expr = safe_expand(new_expr.xreplace(floor_div_replace))
# TODO: when unbacked_only, can sometimes early return even when there
# are still free symbols
if new_expr.is_number:
return new_expr
"""
# Check if the range can solve it statically
out = bound_sympy(new_expr, new_range_env)
out = bound_sympy(expr, var_ranges)
if out.is_singleton():
return out.lower
return new_expr if unbacked_only else None
def compute_symbol_info():
return tuple(
(s, var_ranges.get(s), self.var_to_val.get(s))
for s in sorted(fs, key=lambda s: str(s)) # TODO: speed up sort?
)
symbol_info = compute_symbol_info()
return _maybe_evaluate_static_worker(expr, symbol_info, unbacked_only)
@_lru_cache
def replace(self, expr: "sympy.Expr") -> "sympy.Expr":
@ -5288,6 +5379,7 @@ class ShapeEnv:
stack = CapturedTraceback.extract(skip=1)
guard = ShapeGuard(g, stack)
self.guards.append(guard)
self.axioms.update(dict(self.get_implications(g)))
else:
# it's fine to defer simple guards here without checking,
# the _maybe_guard_rel() call above will set replacements if possible,
@ -5399,6 +5491,7 @@ class ShapeEnv:
# and the guard in question has no unbacked SymInts in front
ix = cands[-1] if cands else None
self.deferred_runtime_asserts.setdefault(ix, []).append(ra)
self.axioms.update(dict(self.get_implications(expr)))
self.num_deferred_runtime_asserts += 1
self._update_version_counter()
self._log_guard("runtime_assert", orig_expr, forcing_spec=False)

View File

@ -546,8 +546,6 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
return cls.zero # type: ignore[attr-defined]
# remove redundant args that are easily identified
args = cls._collapse_arguments(args, **assumptions)
# find local zeros
args = cls._find_localzeros(args, **assumptions)
args = frozenset(args)
if not args: