Fix qembeddingbag_byte_prepack_meta to use sym_sizes (#159985)

Summary: In qembeddingbag_byte_prepack_meta, weight.sizes() would return a concrete int. we should use .sym_size() to return a SymInt instead.

Test Plan:
CI

Rollback Plan:

Reviewed By: kqfu, henryoier

Differential Revision: D79744512

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159985
Approved by: https://github.com/jerryzh168, https://github.com/henryoier
This commit is contained in:
Sherlock Huang
2025-08-07 21:22:29 +00:00
committed by PyTorch MergeBot
parent e619c6bb90
commit 8147370733

View File

@ -333,14 +333,14 @@ Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) {
weight.scalar_type() == at::ScalarType::Float ||
weight.scalar_type() == at::ScalarType::Half,
"'embedding_bag_byte_prepack' only support float32 or float16.");
const auto weight_sizes = weight.sizes();
const auto cols_dim = weight_sizes.size() - 1;
const int32_t embedding_cols = static_cast<int32_t>(weight_sizes[cols_dim]);
const auto weight_sizes = weight.sym_sizes();
const auto cols_dim = weight.ndimension() - 1;
const auto embedding_cols = weight_sizes[cols_dim];
// Add 8 bytes per column to store FP32 scale and zero_point per row.
const int32_t output_columns = static_cast<int32_t>(embedding_cols + 2 * sizeof(float));
const auto output_columns = embedding_cols + 2 * sizeof(float);
// Adjust output dimensions to account for FP32 scale and zero_points.
std::vector<int64_t> output_shape = weight_sizes.vec();
auto output_shape = weight_sizes.vec();
output_shape.at(cols_dim) = output_columns;
at::SymDimVector output_shape_vec(output_shape);