Compare commits

...

8 Commits

Author SHA1 Message Date
dc7e7a8ebd Update on "[hops] fix cond subgraph runtime asserts from polluting main graph"
The main issue addressed by this PR can be illustrated by the following example:

```
def compute(x, w):
    return torch.nn.functional.linear(x, w)

def nop(x, w):
    torch._check(x.shape[0] == 0)
    return torch.empty_like(x) 

def chunked_compute(x, w):
    return torch.cond(x.shape[0] > 0, compute, nop, (x, w))
```

Here, the runtime assertion `x.shape[0] == 0` should only apply in the false branch of the conditional. However, the current behavior incorrectly “leaks” this information into the global ShapeEnv during tracing. This happens because the runtime assert in the subgraph updates the global shape state unconditionally.

This PR fixes that by making all runtime asserts inside subgraphs conditional on the predicate of the outer torch.cond. In the example above, the traced subgraph now includes an assert equivalent to:

```
torch._check(x.shape[0] <= 0 and x.shape[0] == 0)
```

This ensures the assertion remains valid in its local context without polluting global shape information.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 jerryzh168 ezyang EikanWang wenzhe-nrv voznesenskym penguinwu Guobing-Chen zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 14:19:33 -07:00
a99e3dab46 Update on "[hops] fix cond subgraph runtime asserts from polluting main graph"
The main issue addressed by this PR can be illustrated by the following example:

```
def compute(x, w):
    return torch.nn.functional.linear(x, w)

def nop(x, w):
    torch._check(x.shape[0] == 0)
    return torch.empty_like(x) 

def chunked_compute(x, w):
    return torch.cond(x.shape[0] > 0, compute, nop, (x, w))
```

Here, the runtime assertion `x.shape[0] == 0` should only apply in the false branch of the conditional. However, the current behavior incorrectly “leaks” this information into the global ShapeEnv during tracing. This happens because the runtime assert in the subgraph updates the global shape state unconditionally.

This PR fixes that by making all runtime asserts inside subgraphs conditional on the predicate of the outer torch.cond. In the example above, the traced subgraph now includes an assert equivalent to:

```
torch._check(x.shape[0] <= 0 and x.shape[0] == 0)
```

This ensures the assertion remains valid in its local context without polluting global shape information.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 jerryzh168 ezyang EikanWang wenzhe-nrv voznesenskym penguinwu Guobing-Chen zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 13:02:49 -07:00
26f7ccf80e Update on "[hops] fix cond subgraph runtime asserts from polluting main graph"
The main issue addressed by this PR can be illustrated by the following example:

```
def compute(x, w):
    return torch.nn.functional.linear(x, w)

def nop(x, w):
    torch._check(x.shape[0] == 0)
    return torch.empty_like(x) 

def chunked_compute(x, w):
    return torch.cond(x.shape[0] > 0, compute, nop, (x, w))
```

Here, the runtime assertion `x.shape[0] == 0` should only apply in the false branch of the conditional. However, the current behavior incorrectly “leaks” this information into the global ShapeEnv during tracing. This happens because the runtime assert in the subgraph updates the global shape state unconditionally.

This PR fixes that by making all runtime asserts inside subgraphs conditional on the predicate of the outer torch.cond. In the example above, the traced subgraph now includes an assert equivalent to:

```
torch._check(x.shape[0] <= 0 and x.shape[0] == 0)
```

This ensures the assertion remains valid in its local context without polluting global shape information.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 jerryzh168 ezyang EikanWang wenzhe-nrv voznesenskym penguinwu Guobing-Chen zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 11:26:27 -07:00
a07d2993c9 Update on "[hops] fix cond subgraph runtime asserts from polluting main graph"
The main issue addressed by this PR can be illustrated by the following example:

```
def compute(x, w):
    return torch.nn.functional.linear(x, w)

def nop(x, w):
    torch._check(x.shape[0] == 0)
    return torch.empty_like(x) 

def chunked_compute(x, w):
    return torch.cond(x.shape[0] > 0, compute, nop, (x, w))
```

Here, the runtime assertion `x.shape[0] == 0` should only apply in the false branch of the conditional. However, the current behavior incorrectly “leaks” this information into the global ShapeEnv during tracing. This happens because the runtime assert in the subgraph updates the global shape state unconditionally.

This PR fixes that by making all runtime asserts inside subgraphs conditional on the predicate of the outer torch.cond. In the example above, the traced subgraph now includes an assert equivalent to:

```
torch._check(x.shape[0] <= 0 and x.shape[0] == 0)
```

This ensures the assertion remains valid in its local context without polluting global shape information.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 jerryzh168 ezyang EikanWang wenzhe-nrv voznesenskym penguinwu Guobing-Chen zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-26 23:55:50 -07:00
0c51b1a3c6 Update on "fix cond subgraph runtime asserts from polluting main graph"
The main issue addressed by this PR can be illustrated by the following example:

```
def compute(x, w):
    return torch.nn.functional.linear(x, w)

def nop(x, w):
    torch._check(x.shape[0] == 0)
    return torch.empty_like(x) 

def chunked_compute(x, w):
    return torch.cond(x.shape[0] > 0, compute, nop, (x, w))
```

Here, the runtime assertion x.shape[0] == 0 should only apply in the false branch of the conditional. However, the current behavior incorrectly “leaks” this information into the global ShapeEnv during tracing. This happens because the runtime assert in the subgraph updates the global shape state unconditionally.

This PR fixes that by making all runtime asserts inside subgraphs conditional on the predicate of the outer torch.cond. In the example above, the traced subgraph now includes an assert equivalent to:

```
torch._check(x.shape[0] <= 0 and x.shape[0] == 0)
```

This ensures the assertion remains valid in its local context without polluting global shape information.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 jerryzh168 ezyang EikanWang wenzhe-nrv voznesenskym penguinwu Guobing-Chen zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-26 23:45:56 -07:00
75d1f9c045 Update on "fix cond subgraph runtime asserts from polluting main graph"
The main issue addressed by this PR can be illustrated by the following example:

```
def compute(x, w):
    return torch.nn.functional.linear(x, w)

def nop(x, w):
    torch._check(x.shape[0] == 0)
    return torch.empty_like(x) 

def chunked_compute(x, w):
    return torch.cond(x.shape[0] > 0, compute, nop, (x, w))
```

Here, the runtime assertion x.shape[0] == 0 should only apply in the false branch of the conditional. However, the current behavior incorrectly “leaks” this information into the global ShapeEnv during tracing. This happens because the runtime assert in the subgraph updates the global shape state unconditionally.

This PR fixes that by making all runtime asserts inside subgraphs conditional on the predicate of the outer torch.cond. In the example above, the traced subgraph now includes an assert equivalent to:

```
torch._check(x.shape[0] <= 0 and x.shape[0] == 0)
```

This ensures the assertion remains valid in its local context without polluting global shape information.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 jerryzh168 ezyang EikanWang wenzhe-nrv voznesenskym penguinwu Guobing-Chen zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-26 23:42:24 -07:00
74ae9c8054 Update on "fix cond subgraph runtime asserts from polluting main graph"
The main issue addressed by this PR can be illustrated by the following example:

```
def compute(x, w):
    return torch.nn.functional.linear(x, w)

def nop(x, w):
    torch._check(x.shape[0] == 0)
    return torch.empty_like(x) 

def chunked_compute(x, w):
    return torch.cond(x.shape[0] > 0, compute, nop, (x, w))
```

Here, the runtime assertion x.shape[0] == 0 should only apply in the false branch of the conditional. However, the current behavior incorrectly “leaks” this information into the global ShapeEnv during tracing. This happens because the runtime assert in the subgraph updates the global shape state unconditionally.

This PR fixes that by making all runtime asserts inside subgraphs conditional on the predicate of the outer torch.cond. In the example above, the traced subgraph now includes an assert equivalent to:

```
torch._check(x.shape[0] <= 0 and x.shape[0] == 0)
```

This ensures the assertion remains valid in its local context without polluting global shape information.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 jerryzh168 ezyang EikanWang wenzhe-nrv voznesenskym penguinwu Guobing-Chen zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-26 23:37:35 -07:00
94db5919f0 fix cond subgraph runtime asserts from polluting main graph
[ghstack-poisoned]
2025-10-26 23:25:20 -07:00
8 changed files with 199 additions and 31 deletions

View File

@ -13454,6 +13454,47 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase):
y = torch.tensor(5)
f(x, y)
def test_cond_ra_pollution(self):
def compute(x, w):
return torch.nn.functional.linear(x, w)
def nop(x, w):
torch._check(x.shape[0] == 0)
return torch.empty_like(x)
def chunked_compute(x, w):
return torch.cond(x.shape[0] > 0, compute, nop, (x, w))
x, w = (
torch.randn(4, 16, requires_grad=True),
torch.randn(16, 16, requires_grad=True),
)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(16, 16)
def forward(self, x):
return chunked_compute(x, self.linear.weight)
torch._dynamo.decorators.mark_unbacked(x, 0)
orig_mod = Model()
mod = torch._dynamo.functional_export._dynamo_graph_capture_for_export(
orig_mod
)(x)
torch.export._trace._restore_state_dict(orig_mod, mod)
# Previously, this would cause an error because torch._check(x.shape[0] == 0)
# would propagate a runtime assertion from the subgraph (nop branch) into
# the main graph, leading to an incorrect assertion error. The sequence:
# 1) Trace through the nop subgraph.
# 2) Add a runtime assert that u0 == 0 which erroneously would update the
# global shape environment.
# 3) When generating runtime asserts for the main graph, the shape
# environment incorrectly asserts u0 == 0, causing a false assertion.
mod(x)
def test_full_graph_capture_scalar_outputs(self):
@torch.compile(fullgraph=True)
def foo(a):

View File

@ -1353,23 +1353,30 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
# NB: 0 is predicate
ix = 1 if branch else 2
# TODO: Support kwargs
(
(ret_val, ret_spec),
ret_graph,
ret_lifted_freevars,
) = speculate_subgraph(
tx,
args[ix],
operands_seq,
{},
"cond",
source_target=self.value,
should_flatten_outputs=True,
# TODO - removing consts from control flow ops need more work
remove_consts_from_outputs=False,
supports_input_mutation=self.supports_input_mutation,
supports_aliasing=self.supports_aliasing,
)
ra_context = contextlib.nullcontext()
if hasattr(args[0], "sym_num"):
pred = args[0].sym_num.node.expr
prelude = pred if branch else ~pred
ra_context = tx.output.shape_env.patch_ra_prelude(prelude)
with ra_context:
(
(ret_val, ret_spec),
ret_graph,
ret_lifted_freevars,
) = speculate_subgraph(
tx,
args[ix],
operands_seq,
{},
"cond",
source_target=self.value,
should_flatten_outputs=True,
# TODO - removing consts from control flow ops need more work
remove_consts_from_outputs=False,
supports_input_mutation=self.supports_input_mutation,
supports_aliasing=self.supports_aliasing,
)
if not only_consist_of(ret_val, (TensorVariable, ConstantVariable)):
unimplemented(

View File

@ -334,14 +334,37 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
from torch._higher_order_ops.utils import _has_gen_schema
if _has_gen_schema(self):
schema = self.gen_schema(*args, **kwargs)
if any(arg.is_write for arg in schema.arguments):
raise RuntimeError(
f"The {self.name()} HigherOrderOperator does not currently support training "
"with in-place input or buffer mutations "
"If you require this feature, please submit an issue to PyTorch. "
"Alternatively, consider creating your own custom autograd.Function. "
)
try:
schema = self.gen_schema(*args, **kwargs)
if any(arg.is_write for arg in schema.arguments):
raise RuntimeError(
f"The {self.name()} HigherOrderOperator does not currently support training "
"with in-place input or buffer mutations "
"If you require this feature, please submit an issue to PyTorch. "
"Alternatively, consider creating your own custom autograd.Function. "
)
except RuntimeError as e:
if "Expected cond to be True, but got False" in str(e):
# Although we attempt to detect in-place input or buffer mutations,
# the current approach in CondOp::gen_schema is not fully reliable.
# Specifically, we invoke materialize_as_graph on both the true and false
# subgraphs with the provided inputs at runtime (not compile time).
# This can lead to unintended side effects: for example, consider the following code:
#
# def nop(x, w):
# torch._check(x.shape[0] == 0)
#
# torch.cond(x.shape[0] > 0, compute, nop, (x, w))
#
# If, at runtime, x.shape[0] > 0, the assertion in nop will be triggered,
# even though that branch is not actually taken. As a result, strictly enforcing
# a hard failure based on this check would incorrectly penalize valid programs
# due to the unsoundness of our detection mechanism. Therefore, rather than
# failing outright, we conservatively proceed under the assumption that there
# are no in-place input or buffer mutations.
pass
else:
raise
return fn(*args, **kwargs)

View File

@ -84,6 +84,7 @@ from torch.utils._sympy.functions import (
IsNonOverlappingAndDenseIndicator,
Max,
Mod,
OrderedAnd,
PythonMod,
TruncToInt,
)
@ -3765,6 +3766,8 @@ class ShapeEnv:
self.guards: list[ShapeGuard] = []
self.axioms: dict[sympy.Expr, sympy.Expr] = {}
self.ra_prelude: Optional[sympy.Expr] = None
# A set of ids that have already been allocated. This is used
# for when we allocate symbol ids using the hash of the source
# names to ensure we don't have collisions via linear probing
@ -6282,6 +6285,15 @@ class ShapeEnv:
self, e: SympyBoolean
) -> tuple[tuple[SympyBoolean, sympy.logic.boolalg.BooleanAtom], ...]:
"""Given a expression, it returns a list of predicates that follow from it"""
if isinstance(e, OrderedAnd):
# Because SymPy's default And does not preserve operand order,
# we introduced OrderedAnd to maintain order. As a result, we
# cannot make additional global logical assumptions about the
# conjunction as a whole, since the semantics of OrderedAnd are
# intentionally more restrictive.
return tuple()
equiv: dict[SympyBoolean, sympy.logic.boolalg.BooleanAtom] = {}
def add_expr(expr: SympyBoolean) -> None:
@ -7787,6 +7799,9 @@ class ShapeEnv:
"""
expr = orig_expr
if self.ra_prelude is not None:
expr = OrderedAnd(self.ra_prelude, expr)
# TODO: split conjunctions and evaluate them separately
static_expr = self._maybe_evaluate_static(expr)
@ -7939,6 +7954,25 @@ class ShapeEnv:
"constrain_symbol_range %s [%s, %s]", s, new_vr.lower, new_vr.upper
)
@contextmanager
def patch_ra_prelude(self, prelude: sympy.Expr) -> Iterator[None]:
"""
Context manager that ensures all runtime asserts generated while this context manager
is active include a prelude expression. This is mainly used in torch.cond to guarantee
that runtime asserts in subgraphs are guarded by the original cond predicate, preventing
them from leaking into the main graph.
"""
prev = self.ra_prelude
if prev is not None:
prelude = OrderedAnd(prev, prelude)
self.ra_prelude = prelude
try:
yield
finally:
self.ra_prelude = prev
def _is_int(expr: object) -> bool:
return isinstance(expr, SymInt) and expr.node.expr.is_number

View File

@ -178,7 +178,24 @@ def insert_deferred_runtime_asserts(
assert isinstance(node.target, str)
target = getattr(fake_args[0], node.target)
fake_args = fake_args[1:]
node.meta[val_key] = target(*fake_args) # type: ignore[operator]
# The OrderedAnd function in torch.utils._sympy.functions combines
# `not` and `any` operations to generate runtime assertions correctly
# in the code. For these specific operations, we avoid evaluating the
# function directly, as doing so could unnecessarily trigger
# data-dependent errors. For example, if there's a runtime
# assertion `u0 <= 0`, evaluating the meta of not(u0 <= 0) would
# cause us to guard on the inner expression and potentially raise a
# data-dependent error. Therefore, we choose not to compute the meta
# in these cases, since it's not essential.
calculate_meta = (
node.target != operator.not_
and node.target != any
and not any(hasattr(a, "target") and a.target == any for a in node.args) # type: ignore[union-attr]
)
if calculate_meta:
node.meta[val_key] = target(*fake_args) # type: ignore[operator]
except NotImplementedError:
# This can happen when attempting to reify a symbol with an unsupported call_function node,
# e.g. with NestedTensors + sym_size.int via match_symbol().
@ -195,11 +212,12 @@ def insert_deferred_runtime_asserts(
Analysis = PythonReferenceAnalysis if export else OptimizedPythonReferenceAnalysis
def _sympy_interp(expr_to_proxy, expr):
def _sympy_interp(expr_to_proxy, expr, graph):
# sympy_interp() with hash consing
from sympy import Integer, Number, Symbol
from sympy.logic.boolalg import BooleanAtom
from torch.utils._sympy.functions import OrderedAnd
from torch.utils._sympy.interp import _run_sympy_handler, sympy_interp
# hash cons
@ -209,10 +227,31 @@ def insert_deferred_runtime_asserts(
if isinstance(expr, (Integer, Number, Symbol, BooleanAtom)):
return sympy_interp(Analysis, expr_to_proxy, expr)
if isinstance(expr, OrderedAnd):
predicate = (_sympy_interp(expr_to_proxy, expr.args[0], graph)).node
runtime_assert = _sympy_interp(expr_to_proxy, expr.args[1], graph).node
not_predicate = fx.Proxy(
graph.call_function(operator.not_, (predicate,)), tracer=tracer
).node
return fx.Proxy(
graph.call_function(
any,
(
[
not_predicate,
runtime_assert,
],
),
),
tracer=tracer,
)
# hash cons on arguments, run expr handler
expr_to_proxy[expr] = _run_sympy_handler(
Analysis,
[_sympy_interp(expr_to_proxy, arg) for arg in expr.args],
[_sympy_interp(expr_to_proxy, arg, graph) for arg in expr.args],
expr,
)
return expr_to_proxy[expr]
@ -258,7 +297,7 @@ def insert_deferred_runtime_asserts(
# Convert the sympy expression into a sequence of FX
# nodes
with _set_node_metadata_hook(gm, _node_metadata_hook):
res = _sympy_interp(expr_to_proxy, ra.expr).node
res = _sympy_interp(expr_to_proxy, ra.expr, graph).node
graph.call_function(
torch.ops.aten._assert_scalar.default,
@ -410,6 +449,7 @@ def insert_deferred_runtime_asserts(
expr_to_proxy,
# pyrefly: ignore # unbound-name
sym_expr,
graph,
) # type: ignore[arg-type]
# won't try DCE-ing tensor compute here
hash_node = expr_to_proxy[sym_expr].node # type: ignore[arg-type]
@ -627,7 +667,9 @@ def insert_deferred_runtime_asserts(
),
):
if (min_val := convert(vr.lower)) is not None:
ge = _sympy_interp(expr_to_proxy, i0 >= min_val).node
ge = _sympy_interp(
expr_to_proxy, i0 >= min_val, graph
).node
graph.call_function(
torch.ops.aten._assert_scalar.default,
(
@ -637,7 +679,9 @@ def insert_deferred_runtime_asserts(
)
added_asserts.add(i0 >= min_val)
if (max_val := convert(vr.upper)) is not None:
le = _sympy_interp(expr_to_proxy, i0 <= max_val).node
le = _sympy_interp(
expr_to_proxy, i0 <= max_val, graph
).node
graph.call_function(
torch.ops.aten._assert_scalar.default,
(

View File

@ -1461,3 +1461,16 @@ def make_opaque_bitwise_fn(name, real_op_name):
BitwiseFn_bitwise_and = make_opaque_bitwise_fn("bitwise_and", "and_")
BitwiseFn_bitwise_or = make_opaque_bitwise_fn("bitwise_or", "or_")
from sympy.logic.boolalg import BooleanFunction
class OrderedAnd(BooleanFunction):
@classmethod
def eval(cls, *args):
# Returning None tells SymPy not to simplify further
return None
def _sympystr(self, printer):
return " and ".join(printer._print(a) for a in self.args)

View File

@ -34,6 +34,7 @@ from .functions import (
Mod,
ModularIndexing,
OpaqueUnaryFn_log2,
OrderedAnd,
PowByNatural,
PythonMod,
RoundDecimal,
@ -108,6 +109,7 @@ def handlers():
OpaqueUnaryFn_log2: "log2",
BitwiseFn_bitwise_and: "bitwise_and",
BitwiseFn_bitwise_or: "bitwise_or",
OrderedAnd: "ordered_and",
}
# TODO: This is kind of pointless, we shouldn't be generating sympy.sin
# for these functions, they should be Opaque instead

View File

@ -513,6 +513,10 @@ class SymPyValueRangeAnalysis:
def and_(a, b):
return ValueRanges.coordinatewise_increasing_map(a, b, sympy.And)
@staticmethod
def ordered_and(a, b):
return ValueRanges.unknown()
@staticmethod
def _bool_to_int(x):
if x.is_singleton():