diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp index 1e91fecd4500..807a9b25d377 100644 --- a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp @@ -333,14 +333,14 @@ Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) { weight.scalar_type() == at::ScalarType::Float || weight.scalar_type() == at::ScalarType::Half, "'embedding_bag_byte_prepack' only support float32 or float16."); - const auto weight_sizes = weight.sizes(); - const auto cols_dim = weight_sizes.size() - 1; - const int32_t embedding_cols = static_cast(weight_sizes[cols_dim]); + const auto weight_sizes = weight.sym_sizes(); + const auto cols_dim = weight.ndimension() - 1; + const auto embedding_cols = weight_sizes[cols_dim]; // Add 8 bytes per column to store FP32 scale and zero_point per row. - const int32_t output_columns = static_cast(embedding_cols + 2 * sizeof(float)); + const auto output_columns = embedding_cols + 2 * sizeof(float); // Adjust output dimensions to account for FP32 scale and zero_points. - std::vector output_shape = weight_sizes.vec(); + auto output_shape = weight_sizes.vec(); output_shape.at(cols_dim) = output_columns; at::SymDimVector output_shape_vec(output_shape);