Support C++ statically_known_true (#151346)

Differential Revision: [D73040543](https://our.internmc.facebook.com/intern/diff/D73040543/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151346
Approved by: https://github.com/laithsakka
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2025-04-17 15:13:47 -07:00
committed by PyTorch MergeBot
parent 8895c290f4
commit eb1f85a2a0
8 changed files with 71 additions and 3 deletions

View File

@ -222,8 +222,8 @@ inline Tensor applySlice(
? (*self_sizes)[dim]
: self.sym_size(dim);
if (!disable_slice_optimization &&
TORCH_GUARD_SIZE_OBLIVIOUS(start.sym_eq(0)) &&
TORCH_GUARD_SIZE_OBLIVIOUS(length.sym_eq(stop)) && step == 1) {
TORCH_STATICALLY_KNOWN_TRUE(start.sym_eq(0)) &&
TORCH_STATICALLY_KNOWN_TRUE(length.sym_eq(stop)) && step == 1) {
return self;
}
}

View File

@ -80,6 +80,14 @@ bool SymBool::guard_or_false(const char* file, int64_t line) const {
return a->guard_or_false(file, line);
}
bool SymBool::statically_known_true(const char* file, int64_t line) const {
if (auto ma = maybe_as_bool()) {
return *ma;
}
SymNode a = toSymNodeImpl();
return a->statically_known_true(file, line);
}
bool SymBool::guard_or_true(const char* file, int64_t line) const {
if (auto ma = maybe_as_bool()) {
return *ma;

View File

@ -1,3 +1,4 @@
#pragma once
#include <c10/core/SymNodeImpl.h>
@ -62,6 +63,7 @@ class C10_API SymBool {
bool guard_bool(const char* file, int64_t line) const;
bool expect_true(const char* file, int64_t line) const;
bool guard_size_oblivious(const char* file, int64_t line) const;
bool statically_known_true(const char* file, int64_t line) const;
bool guard_or_false(const char* file, int64_t line) const;
bool guard_or_true(const char* file, int64_t line) const;
@ -129,6 +131,20 @@ inline bool guard_or_false(
return b.guard_or_false(file, line);
}
inline bool statically_known_true(
bool b,
const char* file [[maybe_unused]],
int64_t line [[maybe_unused]]) {
return b;
}
inline bool statically_known_true(
const c10::SymBool& b,
const char* file,
int64_t line) {
return b.statically_known_true(file, line);
}
inline bool guard_or_true(
bool b,
const char* file [[maybe_unused]],
@ -146,6 +162,9 @@ inline bool guard_or_true(
#define TORCH_GUARD_SIZE_OBLIVIOUS(cond) \
c10::guard_size_oblivious((cond), __FILE__, __LINE__)
#define TORCH_STATICALLY_KNOWN_TRUE(cond) \
c10::statically_known_true((cond), __FILE__, __LINE__)
#define TORCH_GUARD_OR_FALSE(cond) \
c10::guard_or_false((cond), __FILE__, __LINE__)

View File

@ -1,3 +1,4 @@
#pragma once
#include <c10/macros/Export.h>
@ -191,6 +192,11 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
// with a better implementation!
return guard_bool(file, line);
}
virtual bool statically_known_true(const char* file, int64_t line) {
// No improvement for unbacked SymBools by default, replace this
// with a better implementation!
return guard_bool(file, line);
}
virtual bool guard_or_true(const char* file, int64_t line) {
// No improvement for unbacked SymBools by default, replace this
// with a better implementation!

View File

@ -696,6 +696,30 @@ graph():
ep = export(f, args, strict=False)
self.assertEqual(ep.module()(*args), f(*args))
@testing.expectedFailureCppSerDes # Cpp serder seems to fail parsing complicated guards
def test_export_statically_known_true(self):
class Foo(torch.nn.Module):
def forward(self, x, y):
shape = y.shape[0] ** 2 - 3 * y.shape[0]
end = shape
return x[:, :end]
dynamic_shapes = (
(torch.export.Dim.DYNAMIC, torch.export.Dim.DYNAMIC),
(torch.export.Dim.DYNAMIC, torch.export.Dim.DYNAMIC),
)
ep = export(
Foo(),
(torch.randn(4, 4), torch.randn(4, 4)),
dynamic_shapes=dynamic_shapes,
strict=False,
)
FileCheck().check_count("torch.ops.aten.slice.Tensor", 2, exactly=True).run(
str(ep.graph)
)
FileCheck().check_count("operator.sub", 1, exactly=True).run(str(ep.graph))
def test_colon_parameter(self):
class M(torch.nn.Module):
def __init__(self) -> None:

View File

@ -1218,7 +1218,7 @@ def forward(self, x_1):
gm = make_fx(f, tracing_mode="symbolic")(src_tokens)
# Guards to rule out batch_size == sys.maxsize (wobbling between 2 and
# 1 ok)
self.assertEqual(len(gm.shape_env.guards), 1)
self.assertEqual(len(gm.shape_env.guards), 0)
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
def test_cpu_scalar_cuda(self):

View File

@ -140,6 +140,11 @@ class PythonSymNodeImpl : public c10::SymNodeImpl {
return getPyObj().attr("guard_or_false")(file, line).cast<bool>();
}
bool statically_known_true(const char* file, int64_t line) override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("statically_known_true")(file, line).cast<bool>();
}
bool guard_or_true(const char* file, int64_t line) override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("guard_or_true")(file, line).cast<bool>();

View File

@ -572,6 +572,12 @@ class SymNode:
_advise_is_size(SymInt(self))
return r
def statically_known_true(self, file, line):
from torch.fx.experimental.symbolic_shapes import statically_known_true
assert self.is_bool()
return statically_known_true(SymBool(self))
def guard_size_oblivious(self, file, line):
"""
Like guard_bool, but if we encounter unbacked symbols, if those symbols