Use sym_eq in _check_rms_norm_inputs_symint (#165112)

Summary:
### Problem
ArrayRef's `equals()`does elementwise quality using `==` operator. This can cause a DDE for unbacked symints since `==`  operator calls `guard_bool`.
```
// SymInt.h
bool operator==(const SymInt& o) const {
  return sym_eq(o).guard_bool(__FILE__, __LINE__);
}
```

### Solution
Adds `sym_equals()` to do elementwise equality for `SymIntArrayRef`. Use this instead of `equals()` for `SymIntArrayRef`.

Reviewed By: guangy10, pianpwk, muchulee8

Differential Revision: D84168401

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165112
Approved by: https://github.com/Skylion007
This commit is contained in:
Colin Peppler
2025-10-14 00:06:24 +00:00
committed by PyTorch MergeBot
parent 9166f6120f
commit 37d57ac9cb
2 changed files with 43 additions and 19 deletions

View File

@ -3,6 +3,9 @@
#include <ATen/core/Tensor.h>
#include <ATen/native/DispatchStub.h>
#include <c10/util/accumulate.h>
#include <c10/core/SymBool.h>
#include <c10/util/StringUtil.h>
namespace at::native {
@ -19,28 +22,30 @@ C10_ALWAYS_INLINE void _check_rms_norm_inputs_symint(
"Expected normalized_shape to be at least 1-dimensional, i.e., ",
"containing at least one element, but got normalized_shape = ",
normalized_shape);
TORCH_CHECK(
!weight.defined() || weight.sym_sizes().equals(normalized_shape),
"Expected weight to be of same shape as normalized_shape, but got ",
"weight of shape ",
weight.sym_sizes(),
" and normalized_shape = ",
normalized_shape);
if (weight.defined()) {
TORCH_SYM_CHECK(
sym_equals(weight.sym_sizes(), normalized_shape),
"Expected weight to be of same shape as normalized_shape, but got ",
"weight of shape ",
weight.sym_sizes(),
" and normalized_shape = ",
normalized_shape);
}
const auto input_ndim = input.dim();
const auto input_shape = input.sym_sizes();
if (input_ndim < normalized_ndim ||
!input_shape.slice(input_ndim - normalized_ndim)
.equals(normalized_shape)) {
std::stringstream ss;
ss << "Given normalized_shape=" << normalized_shape
<< ", expected input with shape [*";
for (auto size : normalized_shape) {
ss << ", " << size;
}
ss << "], but got input of size" << input_shape;
TORCH_CHECK(false, ss.str());
}
TORCH_CHECK_VALUE(
input_ndim >= normalized_ndim,
"Input tensor must have at least ", normalized_ndim, " dimensions, but got ", input_ndim);
auto expect_input_shape_msg = c10::str(
"Given normalized_shape=", normalized_shape,
", expected input with shape [*", c10::Join(", ", normalized_shape),
"], but got input of size", input_shape);
TORCH_SYM_CHECK(
sym_equals(input_shape.slice(input_ndim - normalized_ndim), normalized_shape),
expect_input_shape_msg);
}
C10_ALWAYS_INLINE std::pair<int64_t, int64_t> _check_layer_norm_inputs(

View File

@ -86,4 +86,23 @@ inline SymIntArrayRef fromIntArrayRefSlow(IntArrayRef array_ref) {
reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size());
}
inline c10::SymBool sym_equals(SymIntArrayRef LHS, SymIntArrayRef RHS) {
if (LHS.size() != RHS.size()) {
return c10::SymBool(false);
}
c10::SymBool result = sym_eq(LHS.size(), RHS.size());
for (size_t i = 0; i < RHS.size(); ++i) {
c10::SymBool equals = sym_eq(LHS[i], RHS[i]);
std::optional<bool> equals_bool = equals.maybe_as_bool();
if (equals_bool.has_value() && !*equals_bool) {
// Early return if element comparison is known to be false
return equals;
}
result = result.sym_and(equals);
}
return result;
}
} // namespace c10