mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use sym_eq in _check_rms_norm_inputs_symint (#165112)
Summary: ### Problem ArrayRef's `equals()`does elementwise quality using `==` operator. This can cause a DDE for unbacked symints since `==` operator calls `guard_bool`. ``` // SymInt.h bool operator==(const SymInt& o) const { return sym_eq(o).guard_bool(__FILE__, __LINE__); } ``` ### Solution Adds `sym_equals()` to do elementwise equality for `SymIntArrayRef`. Use this instead of `equals()` for `SymIntArrayRef`. Reviewed By: guangy10, pianpwk, muchulee8 Differential Revision: D84168401 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165112 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
9166f6120f
commit
37d57ac9cb
@ -3,6 +3,9 @@
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <c10/core/SymBool.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
|
||||
|
||||
namespace at::native {
|
||||
|
||||
@ -19,28 +22,30 @@ C10_ALWAYS_INLINE void _check_rms_norm_inputs_symint(
|
||||
"Expected normalized_shape to be at least 1-dimensional, i.e., ",
|
||||
"containing at least one element, but got normalized_shape = ",
|
||||
normalized_shape);
|
||||
TORCH_CHECK(
|
||||
!weight.defined() || weight.sym_sizes().equals(normalized_shape),
|
||||
"Expected weight to be of same shape as normalized_shape, but got ",
|
||||
"weight of shape ",
|
||||
weight.sym_sizes(),
|
||||
" and normalized_shape = ",
|
||||
normalized_shape);
|
||||
if (weight.defined()) {
|
||||
TORCH_SYM_CHECK(
|
||||
sym_equals(weight.sym_sizes(), normalized_shape),
|
||||
"Expected weight to be of same shape as normalized_shape, but got ",
|
||||
"weight of shape ",
|
||||
weight.sym_sizes(),
|
||||
" and normalized_shape = ",
|
||||
normalized_shape);
|
||||
}
|
||||
|
||||
const auto input_ndim = input.dim();
|
||||
const auto input_shape = input.sym_sizes();
|
||||
if (input_ndim < normalized_ndim ||
|
||||
!input_shape.slice(input_ndim - normalized_ndim)
|
||||
.equals(normalized_shape)) {
|
||||
std::stringstream ss;
|
||||
ss << "Given normalized_shape=" << normalized_shape
|
||||
<< ", expected input with shape [*";
|
||||
for (auto size : normalized_shape) {
|
||||
ss << ", " << size;
|
||||
}
|
||||
ss << "], but got input of size" << input_shape;
|
||||
TORCH_CHECK(false, ss.str());
|
||||
}
|
||||
TORCH_CHECK_VALUE(
|
||||
input_ndim >= normalized_ndim,
|
||||
"Input tensor must have at least ", normalized_ndim, " dimensions, but got ", input_ndim);
|
||||
|
||||
auto expect_input_shape_msg = c10::str(
|
||||
"Given normalized_shape=", normalized_shape,
|
||||
", expected input with shape [*", c10::Join(", ", normalized_shape),
|
||||
"], but got input of size", input_shape);
|
||||
|
||||
TORCH_SYM_CHECK(
|
||||
sym_equals(input_shape.slice(input_ndim - normalized_ndim), normalized_shape),
|
||||
expect_input_shape_msg);
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE std::pair<int64_t, int64_t> _check_layer_norm_inputs(
|
||||
|
@ -86,4 +86,23 @@ inline SymIntArrayRef fromIntArrayRefSlow(IntArrayRef array_ref) {
|
||||
reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size());
|
||||
}
|
||||
|
||||
inline c10::SymBool sym_equals(SymIntArrayRef LHS, SymIntArrayRef RHS) {
|
||||
if (LHS.size() != RHS.size()) {
|
||||
return c10::SymBool(false);
|
||||
}
|
||||
|
||||
c10::SymBool result = sym_eq(LHS.size(), RHS.size());
|
||||
for (size_t i = 0; i < RHS.size(); ++i) {
|
||||
c10::SymBool equals = sym_eq(LHS[i], RHS[i]);
|
||||
std::optional<bool> equals_bool = equals.maybe_as_bool();
|
||||
|
||||
if (equals_bool.has_value() && !*equals_bool) {
|
||||
// Early return if element comparison is known to be false
|
||||
return equals;
|
||||
}
|
||||
result = result.sym_and(equals);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
Reference in New Issue
Block a user