mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix qembeddingbag_byte_prepack_meta to use sym_sizes (#159985)
Summary: In qembeddingbag_byte_prepack_meta, weight.sizes() would return a concrete int. we should use .sym_size() to return a SymInt instead. Test Plan: CI Rollback Plan: Reviewed By: kqfu, henryoier Differential Revision: D79744512 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159985 Approved by: https://github.com/jerryzh168, https://github.com/henryoier
This commit is contained in:
committed by
PyTorch MergeBot
parent
e619c6bb90
commit
8147370733
@ -333,14 +333,14 @@ Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) {
|
||||
weight.scalar_type() == at::ScalarType::Float ||
|
||||
weight.scalar_type() == at::ScalarType::Half,
|
||||
"'embedding_bag_byte_prepack' only support float32 or float16.");
|
||||
const auto weight_sizes = weight.sizes();
|
||||
const auto cols_dim = weight_sizes.size() - 1;
|
||||
const int32_t embedding_cols = static_cast<int32_t>(weight_sizes[cols_dim]);
|
||||
const auto weight_sizes = weight.sym_sizes();
|
||||
const auto cols_dim = weight.ndimension() - 1;
|
||||
const auto embedding_cols = weight_sizes[cols_dim];
|
||||
// Add 8 bytes per column to store FP32 scale and zero_point per row.
|
||||
const int32_t output_columns = static_cast<int32_t>(embedding_cols + 2 * sizeof(float));
|
||||
const auto output_columns = embedding_cols + 2 * sizeof(float);
|
||||
|
||||
// Adjust output dimensions to account for FP32 scale and zero_points.
|
||||
std::vector<int64_t> output_shape = weight_sizes.vec();
|
||||
auto output_shape = weight_sizes.vec();
|
||||
output_shape.at(cols_dim) = output_columns;
|
||||
at::SymDimVector output_shape_vec(output_shape);
|
||||
|
||||
|
Reference in New Issue
Block a user