mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Rewrite maybe_reduce more carefully for unbacked SymInt (#119562)
Fixes https://github.com/pytorch/pytorch/issues/119476 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/119562 Approved by: https://github.com/albanD ghstack dependencies: #119559
This commit is contained in:
committed by
PyTorch MergeBot
parent
28f299a870
commit
6665b96ebb
@ -413,4 +413,11 @@ inline SymBool sym_ge(const SymInt& a, const SymInt& b) {
|
||||
return a.sym_ge(b);
|
||||
}
|
||||
|
||||
inline bool definitely_true(
|
||||
const c10::SymBool& b,
|
||||
const char* file,
|
||||
int64_t line) {
|
||||
return b.has_hint() && b.guard_bool(file, line);
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include <c10/core/Contiguity.h>
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <c10/core/SymInt.h>
|
||||
#include <c10/core/SymIntArrayRef.h>
|
||||
#include <c10/core/SymbolicShapeMeta.h>
|
||||
|
||||
@ -130,17 +131,13 @@ DEFINE_SYMBOOL_COMPUTE(compute_non_overlapping_and_dense, is_non_overlapping_and
|
||||
// test_aot_autograd_symbolic_exhaustive_nn_functional_unfold_cpu_float32 to run
|
||||
// very slowly.
|
||||
|
||||
static bool definitely_true(const SymBool& b) {
|
||||
return b.has_hint() && b.guard_bool(__FILE__, __LINE__);
|
||||
}
|
||||
|
||||
SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim4() const {
|
||||
init_is_contiguous();
|
||||
if (definitely_true(is_contiguous())) {
|
||||
if (definitely_true(is_contiguous(), __FILE__, __LINE__)) {
|
||||
return true;
|
||||
}
|
||||
init_is_channels_last_contiguous();
|
||||
if (definitely_true(is_channels_last_contiguous())) {
|
||||
if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) {
|
||||
return true;
|
||||
}
|
||||
return is_contiguous() | is_channels_last_contiguous() |
|
||||
@ -149,7 +146,7 @@ SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim4() const {
|
||||
|
||||
SymBool SymbolicShapeMeta::compute_channels_last_contiguous_3d_dim5() const {
|
||||
init_is_channels_last_contiguous();
|
||||
if (definitely_true(is_channels_last_contiguous())) {
|
||||
if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) {
|
||||
return false;
|
||||
}
|
||||
return ~is_channels_last_contiguous() & compute_channels_last_contiguous_3d();
|
||||
@ -157,7 +154,7 @@ SymBool SymbolicShapeMeta::compute_channels_last_contiguous_3d_dim5() const {
|
||||
|
||||
SymBool SymbolicShapeMeta::compute_channels_last_2d_dim5() const {
|
||||
init_is_channels_last_3d_contiguous();
|
||||
if (definitely_true(is_channels_last_3d_contiguous())) {
|
||||
if (definitely_true(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) {
|
||||
return false;
|
||||
}
|
||||
return ~is_channels_last_3d_contiguous() &
|
||||
@ -165,20 +162,20 @@ SymBool SymbolicShapeMeta::compute_channels_last_2d_dim5() const {
|
||||
}
|
||||
|
||||
SymBool SymbolicShapeMeta::compute_channels_last_3d_dim5() const {
|
||||
if (definitely_true(is_channels_last())) {
|
||||
if (definitely_true(is_channels_last(), __FILE__, __LINE__)) {
|
||||
return false;
|
||||
}
|
||||
return ~is_channels_last() & compute_strides_like_channels_last_3d();
|
||||
}
|
||||
|
||||
SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim5() const {
|
||||
if (definitely_true(is_contiguous())) {
|
||||
if (definitely_true(is_contiguous(), __FILE__, __LINE__)) {
|
||||
return true;
|
||||
}
|
||||
if (definitely_true(is_channels_last_contiguous())) {
|
||||
if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) {
|
||||
return true;
|
||||
}
|
||||
if (definitely_true(is_channels_last_3d_contiguous())) {
|
||||
if (definitely_true(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) {
|
||||
return true;
|
||||
}
|
||||
return is_contiguous() | is_channels_last_contiguous() |
|
||||
@ -186,7 +183,7 @@ SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim5() const {
|
||||
}
|
||||
|
||||
SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_anydim() const {
|
||||
if (definitely_true(is_contiguous())) {
|
||||
if (definitely_true(is_contiguous(), __FILE__, __LINE__)) {
|
||||
return true;
|
||||
}
|
||||
return is_contiguous() | compute_non_overlapping_and_dense();
|
||||
|
@ -8257,6 +8257,29 @@ def ___make_guard_fn():
|
||||
self.assertEqual(msg, "shape torch.Size([8, 8]) batch size 1.00")
|
||||
self.assertEqual(res, img1 + torch.sin(img1))
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_validate_outputs_unbacked(self):
|
||||
class SillyCat(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x0, x1, i):
|
||||
ctx.save_for_backward(i)
|
||||
return torch.cat([x0, x1])
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
(i,) = ctx.saved_tensors
|
||||
i0, i1 = i.tolist()
|
||||
g_x0, g_x1 = grad_out.split([i0, i1])
|
||||
return g_x0, g_x1, None
|
||||
|
||||
@torch.compile(backend="aot_eager", fullgraph=True)
|
||||
def f(x, i):
|
||||
i0, i1 = i.tolist()
|
||||
x0, x1 = x.split([i0, i1])
|
||||
return SillyCat.apply(x0, x1, i)
|
||||
|
||||
f(torch.randn(9, requires_grad=True), torch.tensor([3, 6]))
|
||||
|
||||
def test_str_format_assert1(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(img):
|
||||
|
@ -57,13 +57,59 @@ at::Tensor InputMetadata::maybe_reduce(
|
||||
const size_t i,
|
||||
at::Tensor grad,
|
||||
const std::function<std::string(const std::string&)>& format_error) const {
|
||||
if (!is_same_shape(grad)) {
|
||||
if (is_expandable_to_shape(grad)) {
|
||||
return reduce_grad(grad);
|
||||
auto fail = [&]() {
|
||||
const auto message = incompatible_shape_error_message(i, grad);
|
||||
TORCH_CHECK(false, format_error(message.str()));
|
||||
};
|
||||
|
||||
// Nested tensor makes my brain explode, so I've just hard-coded the logic
|
||||
// for this case, at risk of code duplication. This logic does NOT do the
|
||||
// careful oblivious logic as seen below
|
||||
if (is_nested_ || is_cpp_nested_tensor() || grad.is_nested() ||
|
||||
::torch::autograd::is_cpp_nested_tensor(grad)) {
|
||||
if (!is_same_shape(grad)) {
|
||||
if (is_expandable_to_shape(grad)) {
|
||||
return reduce_grad(grad);
|
||||
} else {
|
||||
fail();
|
||||
}
|
||||
} else {
|
||||
const auto message = incompatible_shape_error_message(i, grad);
|
||||
TORCH_CHECK(false, format_error(message.str()));
|
||||
return grad;
|
||||
}
|
||||
}
|
||||
|
||||
auto shape = shape_as_dim_vector();
|
||||
auto desired = grad.sym_sizes();
|
||||
|
||||
size_t ndim = shape.size();
|
||||
size_t target_dim = desired.size();
|
||||
if (ndim > target_dim) {
|
||||
fail();
|
||||
}
|
||||
bool needs_reduce = false;
|
||||
for (const auto i : c10::irange(ndim)) {
|
||||
const auto& size = shape[ndim - i - 1];
|
||||
const auto& target = desired[target_dim - i - 1];
|
||||
// The conditions here are written carefully so that we are able to
|
||||
// infer deferred runtime asserts
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(size.sym_eq(1))) {
|
||||
// NB: we could short circuit this once needs_reduce is true but there's
|
||||
// no point since the reduction function will guard on this anyway
|
||||
if (!c10::definitely_true(size.sym_eq(target), __FILE__, __LINE__)) {
|
||||
needs_reduce = true;
|
||||
}
|
||||
} else {
|
||||
if (!size.sym_eq(target).expect_true(__FILE__, __LINE__)) {
|
||||
fail();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (ndim != target_dim) {
|
||||
needs_reduce = true;
|
||||
}
|
||||
|
||||
if (needs_reduce) {
|
||||
return reduce_grad(grad);
|
||||
} else {
|
||||
return grad;
|
||||
}
|
||||
|
Reference in New Issue
Block a user