mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Support C++ statically_known_true (#151346)
Differential Revision: [D73040543](https://our.internmc.facebook.com/intern/diff/D73040543/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/151346 Approved by: https://github.com/laithsakka
This commit is contained in:
committed by
PyTorch MergeBot
parent
8895c290f4
commit
eb1f85a2a0
@ -222,8 +222,8 @@ inline Tensor applySlice(
|
||||
? (*self_sizes)[dim]
|
||||
: self.sym_size(dim);
|
||||
if (!disable_slice_optimization &&
|
||||
TORCH_GUARD_SIZE_OBLIVIOUS(start.sym_eq(0)) &&
|
||||
TORCH_GUARD_SIZE_OBLIVIOUS(length.sym_eq(stop)) && step == 1) {
|
||||
TORCH_STATICALLY_KNOWN_TRUE(start.sym_eq(0)) &&
|
||||
TORCH_STATICALLY_KNOWN_TRUE(length.sym_eq(stop)) && step == 1) {
|
||||
return self;
|
||||
}
|
||||
}
|
||||
|
@ -80,6 +80,14 @@ bool SymBool::guard_or_false(const char* file, int64_t line) const {
|
||||
return a->guard_or_false(file, line);
|
||||
}
|
||||
|
||||
bool SymBool::statically_known_true(const char* file, int64_t line) const {
|
||||
if (auto ma = maybe_as_bool()) {
|
||||
return *ma;
|
||||
}
|
||||
SymNode a = toSymNodeImpl();
|
||||
return a->statically_known_true(file, line);
|
||||
}
|
||||
|
||||
bool SymBool::guard_or_true(const char* file, int64_t line) const {
|
||||
if (auto ma = maybe_as_bool()) {
|
||||
return *ma;
|
||||
|
@ -1,3 +1,4 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
@ -62,6 +63,7 @@ class C10_API SymBool {
|
||||
bool guard_bool(const char* file, int64_t line) const;
|
||||
bool expect_true(const char* file, int64_t line) const;
|
||||
bool guard_size_oblivious(const char* file, int64_t line) const;
|
||||
bool statically_known_true(const char* file, int64_t line) const;
|
||||
bool guard_or_false(const char* file, int64_t line) const;
|
||||
bool guard_or_true(const char* file, int64_t line) const;
|
||||
|
||||
@ -129,6 +131,20 @@ inline bool guard_or_false(
|
||||
return b.guard_or_false(file, line);
|
||||
}
|
||||
|
||||
inline bool statically_known_true(
|
||||
bool b,
|
||||
const char* file [[maybe_unused]],
|
||||
int64_t line [[maybe_unused]]) {
|
||||
return b;
|
||||
}
|
||||
|
||||
inline bool statically_known_true(
|
||||
const c10::SymBool& b,
|
||||
const char* file,
|
||||
int64_t line) {
|
||||
return b.statically_known_true(file, line);
|
||||
}
|
||||
|
||||
inline bool guard_or_true(
|
||||
bool b,
|
||||
const char* file [[maybe_unused]],
|
||||
@ -146,6 +162,9 @@ inline bool guard_or_true(
|
||||
#define TORCH_GUARD_SIZE_OBLIVIOUS(cond) \
|
||||
c10::guard_size_oblivious((cond), __FILE__, __LINE__)
|
||||
|
||||
#define TORCH_STATICALLY_KNOWN_TRUE(cond) \
|
||||
c10::statically_known_true((cond), __FILE__, __LINE__)
|
||||
|
||||
#define TORCH_GUARD_OR_FALSE(cond) \
|
||||
c10::guard_or_false((cond), __FILE__, __LINE__)
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
@ -191,6 +192,11 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
|
||||
// with a better implementation!
|
||||
return guard_bool(file, line);
|
||||
}
|
||||
virtual bool statically_known_true(const char* file, int64_t line) {
|
||||
// No improvement for unbacked SymBools by default, replace this
|
||||
// with a better implementation!
|
||||
return guard_bool(file, line);
|
||||
}
|
||||
virtual bool guard_or_true(const char* file, int64_t line) {
|
||||
// No improvement for unbacked SymBools by default, replace this
|
||||
// with a better implementation!
|
||||
|
@ -696,6 +696,30 @@ graph():
|
||||
ep = export(f, args, strict=False)
|
||||
self.assertEqual(ep.module()(*args), f(*args))
|
||||
|
||||
@testing.expectedFailureCppSerDes # Cpp serder seems to fail parsing complicated guards
|
||||
def test_export_statically_known_true(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
shape = y.shape[0] ** 2 - 3 * y.shape[0]
|
||||
end = shape
|
||||
return x[:, :end]
|
||||
|
||||
dynamic_shapes = (
|
||||
(torch.export.Dim.DYNAMIC, torch.export.Dim.DYNAMIC),
|
||||
(torch.export.Dim.DYNAMIC, torch.export.Dim.DYNAMIC),
|
||||
)
|
||||
|
||||
ep = export(
|
||||
Foo(),
|
||||
(torch.randn(4, 4), torch.randn(4, 4)),
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=False,
|
||||
)
|
||||
FileCheck().check_count("torch.ops.aten.slice.Tensor", 2, exactly=True).run(
|
||||
str(ep.graph)
|
||||
)
|
||||
FileCheck().check_count("operator.sub", 1, exactly=True).run(str(ep.graph))
|
||||
|
||||
def test_colon_parameter(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
@ -1218,7 +1218,7 @@ def forward(self, x_1):
|
||||
gm = make_fx(f, tracing_mode="symbolic")(src_tokens)
|
||||
# Guards to rule out batch_size == sys.maxsize (wobbling between 2 and
|
||||
# 1 ok)
|
||||
self.assertEqual(len(gm.shape_env.guards), 1)
|
||||
self.assertEqual(len(gm.shape_env.guards), 0)
|
||||
|
||||
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
|
||||
def test_cpu_scalar_cuda(self):
|
||||
|
@ -140,6 +140,11 @@ class PythonSymNodeImpl : public c10::SymNodeImpl {
|
||||
return getPyObj().attr("guard_or_false")(file, line).cast<bool>();
|
||||
}
|
||||
|
||||
bool statically_known_true(const char* file, int64_t line) override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
return getPyObj().attr("statically_known_true")(file, line).cast<bool>();
|
||||
}
|
||||
|
||||
bool guard_or_true(const char* file, int64_t line) override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
return getPyObj().attr("guard_or_true")(file, line).cast<bool>();
|
||||
|
@ -572,6 +572,12 @@ class SymNode:
|
||||
_advise_is_size(SymInt(self))
|
||||
return r
|
||||
|
||||
def statically_known_true(self, file, line):
|
||||
from torch.fx.experimental.symbolic_shapes import statically_known_true
|
||||
|
||||
assert self.is_bool()
|
||||
return statically_known_true(SymBool(self))
|
||||
|
||||
def guard_size_oblivious(self, file, line):
|
||||
"""
|
||||
Like guard_bool, but if we encounter unbacked symbols, if those symbols
|
||||
|
Reference in New Issue
Block a user