From eb1f85a2a0053cf848a55498525f1b06293a99e7 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Thu, 17 Apr 2025 15:13:47 -0700 Subject: [PATCH] Support C++ statically_known_true (#151346) Differential Revision: [D73040543](https://our.internmc.facebook.com/intern/diff/D73040543/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/151346 Approved by: https://github.com/laithsakka --- aten/src/ATen/TensorIndexing.h | 4 ++-- c10/core/SymBool.cpp | 8 ++++++++ c10/core/SymBool.h | 19 +++++++++++++++++++ c10/core/SymNodeImpl.h | 6 ++++++ test/export/test_export.py | 24 ++++++++++++++++++++++++ test/test_proxy_tensor.py | 2 +- torch/csrc/utils/python_symnode.h | 5 +++++ torch/fx/experimental/sym_node.py | 6 ++++++ 8 files changed, 71 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/TensorIndexing.h b/aten/src/ATen/TensorIndexing.h index 38fe78901ce7..b385ea80b809 100644 --- a/aten/src/ATen/TensorIndexing.h +++ b/aten/src/ATen/TensorIndexing.h @@ -222,8 +222,8 @@ inline Tensor applySlice( ? (*self_sizes)[dim] : self.sym_size(dim); if (!disable_slice_optimization && - TORCH_GUARD_SIZE_OBLIVIOUS(start.sym_eq(0)) && - TORCH_GUARD_SIZE_OBLIVIOUS(length.sym_eq(stop)) && step == 1) { + TORCH_STATICALLY_KNOWN_TRUE(start.sym_eq(0)) && + TORCH_STATICALLY_KNOWN_TRUE(length.sym_eq(stop)) && step == 1) { return self; } } diff --git a/c10/core/SymBool.cpp b/c10/core/SymBool.cpp index 63fcf064e01b..d804eb9d2740 100644 --- a/c10/core/SymBool.cpp +++ b/c10/core/SymBool.cpp @@ -80,6 +80,14 @@ bool SymBool::guard_or_false(const char* file, int64_t line) const { return a->guard_or_false(file, line); } +bool SymBool::statically_known_true(const char* file, int64_t line) const { + if (auto ma = maybe_as_bool()) { + return *ma; + } + SymNode a = toSymNodeImpl(); + return a->statically_known_true(file, line); +} + bool SymBool::guard_or_true(const char* file, int64_t line) const { if (auto ma = maybe_as_bool()) { return *ma; diff --git a/c10/core/SymBool.h b/c10/core/SymBool.h index 875377b2eb37..6982d0380e57 100644 --- a/c10/core/SymBool.h +++ b/c10/core/SymBool.h @@ -1,3 +1,4 @@ + #pragma once #include @@ -62,6 +63,7 @@ class C10_API SymBool { bool guard_bool(const char* file, int64_t line) const; bool expect_true(const char* file, int64_t line) const; bool guard_size_oblivious(const char* file, int64_t line) const; + bool statically_known_true(const char* file, int64_t line) const; bool guard_or_false(const char* file, int64_t line) const; bool guard_or_true(const char* file, int64_t line) const; @@ -129,6 +131,20 @@ inline bool guard_or_false( return b.guard_or_false(file, line); } +inline bool statically_known_true( + bool b, + const char* file [[maybe_unused]], + int64_t line [[maybe_unused]]) { + return b; +} + +inline bool statically_known_true( + const c10::SymBool& b, + const char* file, + int64_t line) { + return b.statically_known_true(file, line); +} + inline bool guard_or_true( bool b, const char* file [[maybe_unused]], @@ -146,6 +162,9 @@ inline bool guard_or_true( #define TORCH_GUARD_SIZE_OBLIVIOUS(cond) \ c10::guard_size_oblivious((cond), __FILE__, __LINE__) +#define TORCH_STATICALLY_KNOWN_TRUE(cond) \ + c10::statically_known_true((cond), __FILE__, __LINE__) + #define TORCH_GUARD_OR_FALSE(cond) \ c10::guard_or_false((cond), __FILE__, __LINE__) diff --git a/c10/core/SymNodeImpl.h b/c10/core/SymNodeImpl.h index 6589a1e0b780..ae5423632bb5 100644 --- a/c10/core/SymNodeImpl.h +++ b/c10/core/SymNodeImpl.h @@ -1,3 +1,4 @@ + #pragma once #include @@ -191,6 +192,11 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target { // with a better implementation! return guard_bool(file, line); } + virtual bool statically_known_true(const char* file, int64_t line) { + // No improvement for unbacked SymBools by default, replace this + // with a better implementation! + return guard_bool(file, line); + } virtual bool guard_or_true(const char* file, int64_t line) { // No improvement for unbacked SymBools by default, replace this // with a better implementation! diff --git a/test/export/test_export.py b/test/export/test_export.py index 968c76592045..b93924dff031 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -696,6 +696,30 @@ graph(): ep = export(f, args, strict=False) self.assertEqual(ep.module()(*args), f(*args)) + @testing.expectedFailureCppSerDes # Cpp serder seems to fail parsing complicated guards + def test_export_statically_known_true(self): + class Foo(torch.nn.Module): + def forward(self, x, y): + shape = y.shape[0] ** 2 - 3 * y.shape[0] + end = shape + return x[:, :end] + + dynamic_shapes = ( + (torch.export.Dim.DYNAMIC, torch.export.Dim.DYNAMIC), + (torch.export.Dim.DYNAMIC, torch.export.Dim.DYNAMIC), + ) + + ep = export( + Foo(), + (torch.randn(4, 4), torch.randn(4, 4)), + dynamic_shapes=dynamic_shapes, + strict=False, + ) + FileCheck().check_count("torch.ops.aten.slice.Tensor", 2, exactly=True).run( + str(ep.graph) + ) + FileCheck().check_count("operator.sub", 1, exactly=True).run(str(ep.graph)) + def test_colon_parameter(self): class M(torch.nn.Module): def __init__(self) -> None: diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 40c4750f04f0..c624c3e03f71 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1218,7 +1218,7 @@ def forward(self, x_1): gm = make_fx(f, tracing_mode="symbolic")(src_tokens) # Guards to rule out batch_size == sys.maxsize (wobbling between 2 and # 1 ok) - self.assertEqual(len(gm.shape_env.guards), 1) + self.assertEqual(len(gm.shape_env.guards), 0) @unittest.skipIf(not HAS_CUDA, 'CUDA-only test') def test_cpu_scalar_cuda(self): diff --git a/torch/csrc/utils/python_symnode.h b/torch/csrc/utils/python_symnode.h index 9c73f9ca2b9e..69d03b9b7a43 100644 --- a/torch/csrc/utils/python_symnode.h +++ b/torch/csrc/utils/python_symnode.h @@ -140,6 +140,11 @@ class PythonSymNodeImpl : public c10::SymNodeImpl { return getPyObj().attr("guard_or_false")(file, line).cast(); } + bool statically_known_true(const char* file, int64_t line) override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("statically_known_true")(file, line).cast(); + } + bool guard_or_true(const char* file, int64_t line) override { py::gil_scoped_acquire acquire; return getPyObj().attr("guard_or_true")(file, line).cast(); diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index b902649d228d..865c17acbfe5 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -572,6 +572,12 @@ class SymNode: _advise_is_size(SymInt(self)) return r + def statically_known_true(self, file, line): + from torch.fx.experimental.symbolic_shapes import statically_known_true + + assert self.is_bool() + return statically_known_true(SymBool(self)) + def guard_size_oblivious(self, file, line): """ Like guard_bool, but if we encounter unbacked symbols, if those symbols