Compare commits

...

11 Commits

Author SHA1 Message Date
5262f8243c Update on "Avoid DDE in narrow with unbacked start"
Slice knows how to handle unbacked start, we do not need to offset start before calling slice, we can leave it for slice.
The only edge case is when start<0 and start+length ==0 in that case slice and narrow would deviate, 
for that case we shall pass dim_size instead of start+length 

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 jerryzh168 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-29 12:05:19 -07:00
2b190740d5 Update on "Avoid DDE in narrow with unbacked start"
Slice knows how to handle unbacked start, we do not need to offset start before calling slice, we can leave it for slice.
The only edge case is when start<0 and start+length ==0 in that case slice and narrow would deviate, 
for that case we shall pass dim_size instead of start+length 

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 jerryzh168 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-29 09:04:27 -07:00
e0c56a147d Update on "Avoid DDE in narrow with unbacked start"
Slice knows how to handle unbacked start, we do not need to offset start before calling slice, we can leave it for slice.
The only edge case is when start<0 and start+length ==0 in that case slice and narrow would deviate, 
for that case we shall pass dim_size instead of start+length 

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 jerryzh168 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-29 08:47:32 -07:00
ef83b7fd8b Update on "Avoid DDE in narrow with unbacked start"
Slice knows how to handle unbacked start, we do not need to offset start before calling slice, we can leave it for slice.
The only edge case is when start<0 and start+length ==0 in that case slice and narrow would deviate, 
for that case we shall pass dim_size instead of start+length 

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 jerryzh168 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-28 20:12:30 -07:00
fc9afa627f Update on "Avoid DDE in narrow with unbacked start"
Slice knows how to handle unbacked start, we do not need to offset start before calling slice, we can leave it for slice.
The only edge case is when start<0 and start+length ==0 in that case slice and narrow would deviate, 
for that case we shall pass dim_size instead of start+length 

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 jerryzh168 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-28 18:19:50 -07:00
7cf5d059fa Update on "Avoid DDE in narrow with unbacked start"
Slice knows how to handle unbacked start, we do not need to offset start before calling slice, we can leave it for slice.
The only edge case is when start<0 and start+length ==0 in that case slice and narrow would deviate, 
for that case we shall pass dim_size instead of start+length 

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 jerryzh168 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-28 18:17:20 -07:00
de032cd203 Update on "Avoid DDE in narrow with unbacked start"
Slice knows how to handle unbacked start, we do not need to offset start before calling slice, we can leave it for slice
but we have to write the check carefully.  

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 jerryzh168 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-28 18:10:00 -07:00
88387fc1ff Update on "Avoid DDE in narrow with unbacked start"
Slice knows how to handle unbacked start, we do not need to offset start before calling slice, we can leave it for slice
but we have to write the check carefully.  

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 jerryzh168 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-28 18:05:50 -07:00
36a046fd47 Update on "Avoid DDE in narrow with unbacked start"
Slice knows how to handle unbacked start, we do not need to offset start before calling slice, we can leave it for slice
but we have to write the check carefully.  

[ghstack-poisoned]
2025-10-28 17:48:10 -07:00
a9776fa0cc Update on "Avoid DDE in narrow with unbacked start"
Slice knows how to handle unbacked start, we do not need to offset start before calling slice, we can leave it for slice
but we have to write the check carefully.  

[ghstack-poisoned]
2025-10-27 23:29:51 -07:00
b225bcaa97 Avoid DDE in narrow with unbacked start
[ghstack-poisoned]
2025-10-27 19:02:55 -07:00
8 changed files with 180 additions and 18 deletions

View File

@ -1,5 +1,6 @@
#include <ATen/core/ATen_fwd.h>
#include <c10/core/ScalarType.h>
#include <c10/core/SymInt.h>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
@ -1710,11 +1711,14 @@ Tensor narrow_symint(
"], but got ",
start,
")")
if (start < 0) {
start = start + cur_size;
}
// Bounds check without converting start:
// - If start < 0: need (start + cur_size) + length <= cur_size, i.e., start +
// length <= 0
// - If start >= 0: need start + length <= cur_size
auto end = start + length;
TORCH_SYM_CHECK(
start.sym_le(cur_size - length),
(start.sym_lt(0).sym_and((end).sym_le(0)))
.sym_or(start.sym_ge(0).sym_and((end).sym_le(cur_size))),
"start (",
start,
") + length (",
@ -1722,7 +1726,31 @@ Tensor narrow_symint(
") exceeds dimension size (",
cur_size,
").");
return at::slice_symint(self, dim, start, start + length, 1);
if (TORCH_GUARD_OR_FALSE(start.sym_ge(0).sym_or(end.sym_ne(0)))) {
return at::slice_symint(self, dim, start, end, 1);
} else if (TORCH_GUARD_OR_FALSE(start.sym_lt(0))) {
// Avoid the complex symbolic expressions path for non-unbacked.
return at::slice_symint(self, dim, start + cur_size, end + cur_size, 1);
} else {
// Cannot statically determine the condition due to unbacked.
// This is an interesting situation; when start is negative and
// start + length == 0, slice and narrow do different things.
// i.e., x.narrow(0, -2, 2) != x[-2:0]; in that case, we want to
// pass curr_size instead of 0. Otherwise, they would do the same thing.
// This says at runtime: if start < 0 and end == 0, then pass curr_size
// instead of 0.
auto use_different = start.sym_lt(0).sym_and(end.sym_eq(0)).toSymInt();
auto result =
at::slice_symint(self, dim, start, end + use_different * cur_size, 1);
// Ensure slice allocated unbacked size is specialized to length.
SymInt new_size = result.sym_sizes()[dim];
TORCH_SYM_CHECK(new_size.sym_eq(length), "")
return result;
}
}
// This overload exists purely for XLA, because they wanted to pass in

View File

@ -1,4 +1,5 @@
#include <c10/core/SymBool.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymNodeImpl.h>
namespace c10 {
@ -111,4 +112,17 @@ bool SymBool::has_hint() const {
return toSymNodeImpl()->has_hint();
}
SymInt SymBool::toSymInt() const {
// If concrete bool, return concrete SymInt
if (auto ma = maybe_as_bool()) {
return SymInt(*ma ? 1 : 0);
}
// Symbolic case: use sym_ite to convert bool to int (0 or 1)
auto node = toSymNodeImpl();
auto one_node = node->wrap_int(1);
auto zero_node = node->wrap_int(0);
return SymInt(node->sym_ite(one_node, zero_node));
}
} // namespace c10

View File

@ -12,6 +12,8 @@
namespace c10 {
class SymInt;
class C10_API SymBool {
public:
/*implicit*/ SymBool(bool b) : data_(b) {}
@ -80,6 +82,10 @@ class C10_API SymBool {
return toSymNodeImplUnowned()->constant_bool();
}
// Convert SymBool to SymInt (0 or 1)
// This is the C++ equivalent of Python's cast_symbool_to_symint_guardless
SymInt toSymInt() const;
bool is_heap_allocated() const {
return ptr_;
}

View File

@ -6102,26 +6102,19 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
retry_export(
cf_implicitsize(),
(torch.tensor(2), torch.randn(10)),
fixes=[
# Could not guard on data-dependent expression u0 < 0
"torch._check(i >= 0)",
],
fixes=[],
)
class cf_stacklist(torch.nn.Module):
def forward(self, xs, y, fixes):
i = y.item()
eval(fixes)
# instead of xs[i]
return torch.stack(xs, 0).narrow(0, i, 1).squeeze()
retry_export(
cf_stacklist(),
([torch.ones(5) * i for i in range(10)], torch.tensor(2)),
fixes=[
# Could not guard on data-dependent expression u0 < 0
"torch._check(i >= 0)",
],
fixes=[],
)
class cf_tensorsplit(torch.nn.Module):
@ -6175,7 +6168,12 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
class cf_stacklist(torch.nn.Module):
def forward(self, xs, y):
# y.item() is not a local, so we can't suggest a fix
return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze()
if y.item() < 0:
return (
torch.stack(xs, 0).narrow(0, y.item() + xs.size(), 1).squeeze()
)
else:
return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze()
with self.assertRaisesRegex(
error_type,
@ -6205,7 +6203,18 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
def forward(self, xs, y):
box = Box(y.item())
# box.content is not a local, so we can't suggest a fix
return torch.stack(xs, 0).narrow(0, box.content, 1).squeeze()
if box.content < 0:
return (
torch.stack(xs, 0)
.narrow(0, box.content + xs.size(), 1)
.squeeze()
)
else:
return (
torch.stack(xs, 0)
.narrow(0, box.content + xs.size(), 1)
.squeeze()
)
with self.assertRaisesRegex(
error_type,

View File

@ -4329,6 +4329,57 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
self.assertEqual(compiled(a, b), func(a, b))
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_narrow_unbacked_start(self):
def func(x, start, length):
# unbacked start
u0 = start.item()
return torch.narrow(x, 0, u0, length)
compiled_func = torch.compile(func, fullgraph=True, backend="inductor")
x = torch.tensor([1, 2, 3, 4, 5, 6])
# Test cases: (start, length)
test_cases = [
# Negative starts
(-2, 2), # Start from second-to-last element
(-1, 1), # Start from last element
(-3, 3), # Start from third-to-last element
(-6, 2), # Start from beginning (negative)
(-4, 1), # Start from fourth-to-last element
# Positive starts
(0, 2), # Start from beginning
(1, 3), # Start from second element
(2, 2), # Start from third element
(4, 2), # Start near end
# Edge cases
(0, 6), # Full tensor
(0, 1), # Single element from start
(5, 1), # Single element from end
]
for start_val, length in test_cases:
with self.subTest(start=start_val, length=length):
start = torch.tensor([start_val])
# Test with compiled function
result_compiled = compiled_func(x, start, length)
# Test with eager function (expected behavior)
result_eager = func(x, start, length)
# Compare results
self.assertEqual(result_compiled, result_eager)
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._inductor.config.patch("cpp_wrapper", True)
def test_narrow_unbacked_start_cpp_wrapper(self):
"""Test narrow with unbacked start with cpp_wrapper"""
self.test_narrow_unbacked_start()
instantiate_parametrized_tests(TestUnbacked)

View File

@ -2041,7 +2041,8 @@ class PythonWrapperCodegen(CodeGen):
neg = self.codegen_sizevar(
sympy.Max(0, sympy.Min(x + node.size, node.size))
)
return f"{pos} if {x} >= 0 else {neg}"
x_cond = self.codegen_sizevar(x)
return f"{pos} if {x_cond} >= 0 else {neg}"
def codegen_with_step(start_var, end_var, step):
if step == 1:

View File

@ -547,6 +547,7 @@ def rebind_unbacked(
assert shape_env is not None
for raw_u0, path in bindings.items():
u1 = pytree.key_get(result, path)
# Sometimes, things were previously unbacked bindings become constants.
# There are two situations this can happen.
#
@ -602,7 +603,23 @@ def rebind_unbacked(
if u1.node.hint is not None:
continue
raw_u1 = u1.node.expr
# unbacked symbols bindings might be replaced to other backed or
# unbacked replacements.
#
# Example:
# u = x.item()
# torch._check(u == 5)
#
# The safest approach is to retrieve raw_u1 from u1.node._expr
# and perform the rebinding on the original unbacked symbol,
# even if its no longer directly referenced.
#
# In other words, we should always rebind the original symbol
# before any replacements are applied.
# u0 -> u0 == s1
raw_u1 = u1.node._expr
# TODO Do we still need this logic below?
# Simplify SymBool binding
if (
isinstance(raw_u1, sympy.Piecewise)

View File

@ -306,6 +306,24 @@ class PythonPrinter(ExprPrinter):
raise TypeError("ndigits must be an instance of sympy.Integer")
return f"round({self._print(number)}, {ndigits})"
def _print_Piecewise(self, expr: sympy.Expr) -> str:
# Convert Piecewise(expr_cond_pairs) to nested ternary expressions
# Piecewise((e1, c1), (e2, c2), ..., (eN, cN))
# becomes: e1 if c1 else (e2 if c2 else (... else eN))
result = None
for expr_i, cond_i in reversed(expr.args):
expr_str = self._print(expr_i)
if cond_i == True: # noqa: E712
# This is the default case
result = expr_str
else:
cond_str = self._print(cond_i)
if result is None:
result = expr_str
else:
result = f"({expr_str} if {cond_str} else {result})"
return result if result else "0"
class CppPrinter(ExprPrinter):
def _print_Integer(self, expr: sympy.Expr) -> str:
@ -327,6 +345,24 @@ class CppPrinter(ExprPrinter):
)
return f"{c} ? {p} : {q}"
def _print_Piecewise(self, expr: sympy.Expr) -> str:
# Convert Piecewise(expr_cond_pairs) to nested ternary operators
# Piecewise((e1, c1), (e2, c2), ..., (eN, cN))
# becomes: c1 ? e1 : (c2 ? e2 : (... : eN))
result = None
for expr_i, cond_i in reversed(expr.args):
expr_str = self.parenthesize(expr_i, PRECEDENCE["Atom"] - 0.5)
if cond_i == True: # noqa: E712
# This is the default case
result = expr_str
else:
cond_str = self.parenthesize(cond_i, PRECEDENCE["Atom"] - 0.5)
if result is None:
result = expr_str
else:
result = f"{cond_str} ? {expr_str} : {result}"
return f"({result})" if result else "0"
def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
x, div, mod = expr.args
x = self.doprint(x)