Files
pytorch/benchmarks/operator_benchmark/pt/qembedding_pack_test.py
Xuehai Pan 7763c83af6 [5/N][Easy] fix typo for usort config in pyproject.toml (kown -> known): sort torch (#127126)
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127126
Approved by: https://github.com/kit1980
ghstack dependencies: #127122, #127123, #127124, #127125
2024-05-27 04:22:18 +00:00

110 lines
3.5 KiB
Python

import operator_benchmark as op_bench
import torch
embeddingbag_conversion_short_configs = op_bench.cross_product_configs(
num_embeddings=(80,), embedding_dim=(128, 256, 512), tags=("short",)
)
embeddingbag_conversion_long_configs = op_bench.cross_product_configs(
num_embeddings=(100, 120, 1000),
embedding_dim=(16, 64, 128, 256, 512, 1024, 2048),
tags=("long",),
)
embeddingbag_conversion_three_dim_configs = op_bench.cross_product_configs(
num_embeddings=(80,),
embedding_dim=(128, 256, 512),
batch_size=(10,),
tags=("short",),
)
conversion_ops = op_bench.op_list(
attrs=(
("qembeddingbag_byte_prepack", torch.ops.quantized.embedding_bag_byte_prepack),
("qembeddingbag_4bit_prepack", torch.ops.quantized.embedding_bag_4bit_prepack),
("qembeddingbag_2bit_prepack", torch.ops.quantized.embedding_bag_2bit_prepack),
),
attr_names=("op_name", "op_func"),
)
unpack_ops = op_bench.op_list(
attrs=(
("qembeddingbag_byte_unpack", torch.ops.quantized.embedding_bag_byte_unpack),
("qembeddingbag_4bit_unpack", torch.ops.quantized.embedding_bag_4bit_unpack),
("qembeddingbag_2bit_unpack", torch.ops.quantized.embedding_bag_2bit_unpack),
),
attr_names=("op_name", "op_func"),
)
class EmbeddingBagFloatToFusedBase(op_bench.TorchBenchmarkBase):
def init(self, num_embeddings, embedding_dim, op_func):
self.inputs = {
"weight": torch.rand(num_embeddings, embedding_dim, dtype=torch.float) + 1
}
self.op_func = op_func
def forward(self, weight):
return self.op_func(weight)
class EmbeddingBagThreeDimFloatToFusedBase(op_bench.TorchBenchmarkBase):
def init(self, num_embeddings, embedding_dim, batch_size, op_func):
self.inputs = {
"weight": torch.rand(
batch_size, num_embeddings, embedding_dim, dtype=torch.float
)
+ 1
}
self.op_func = op_func
def forward(self, weight):
return self.op_func(weight)
class EmbeddingBagFusedToFloatBase(op_bench.TorchBenchmarkBase):
def init(self, num_embeddings, embedding_dim, op_func):
weight = torch.randn(num_embeddings, embedding_dim + 8, dtype=torch.float)
self.inputs = {"packed_weight": weight.to(torch.uint8)}
self.op_func = op_func
def forward(self, packed_weight):
return self.op_func(packed_weight)
class EmbeddingBagThreeDimFusedToFloatBase(op_bench.TorchBenchmarkBase):
def init(self, num_embeddings, embedding_dim, batch_size, op_func):
weight = torch.randn(
batch_size, num_embeddings, embedding_dim + 8, dtype=torch.float
)
self.inputs = {"packed_weight": weight.to(torch.uint8)}
self.op_func = op_func
def forward(self, packed_weight):
return self.op_func(packed_weight)
op_bench.generate_pt_tests_from_op_list(
conversion_ops,
embeddingbag_conversion_short_configs + embeddingbag_conversion_long_configs,
EmbeddingBagFloatToFusedBase,
)
op_bench.generate_pt_tests_from_op_list(
unpack_ops,
embeddingbag_conversion_short_configs + embeddingbag_conversion_long_configs,
EmbeddingBagFusedToFloatBase,
)
op_bench.generate_pt_tests_from_op_list(
conversion_ops,
embeddingbag_conversion_three_dim_configs,
EmbeddingBagThreeDimFloatToFusedBase,
)
op_bench.generate_pt_tests_from_op_list(
unpack_ops,
embeddingbag_conversion_three_dim_configs,
EmbeddingBagThreeDimFusedToFloatBase,
)
if __name__ == "__main__":
op_bench.benchmark_runner.main()