mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Speed up DataTypeToTypeMeta (#66113)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66113 For a benchmark compiled in opt-mode in which the lookup items were shuffled and then the items were looked up round-robin fashion 10M times (for a total of 140M lookups) compiled in opt-mode we see: ``` Function Container Time (ms) Multiplier TypeMetaToDataType if-chain 233 1x TypeMetaToDataType std::vector 795 3.41x TypeMetaToDataType std::map 1566 6.72x TypeMetaToDataType std::unordered_map 2136 9.17x DataTypeToTypeMeta switch 102 1x DataTypeToTypeMeta std::vector 666 6.53x DataTypeToTypeMeta std::map 1212 11.9x DataTypeToTypeMeta std::unordered_map 1539 15.1x DataTypeToTypeMeta folly::F14FastMap 1789 17.5x ``` From this, we draw two conclusions: 1. Using a complex container like `std::map` is worse than using a simple vector lookup here (there aren't enough items for the Big-O to assert itself). 2. Using any container at all is a mistake. (Unless we pull in more exotic reasoning like invalidating the code cache or preventing inlining.) Test Plan: Sandcastle Reviewed By: dzhulgakov Differential Revision: D31375117 fbshipit-source-id: 0b310c6c2e94080d125c82fb7c2b43ab869adbcb
This commit is contained in:
committed by
Facebook GitHub Bot
parent
1e4bcbdddb
commit
2f1ab477f1
@ -4,58 +4,77 @@
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
TensorProto::DataType TypeMetaToDataType(const TypeMeta meta) {
|
||||
TensorProto::DataType TypeMetaToDataType(const TypeMeta& meta) {
|
||||
static_assert(
|
||||
sizeof(int) == 4, "int in this compiler does not equal to 4 bytes.");
|
||||
static std::map<TypeIdentifier, TensorProto::DataType> data_type_map{
|
||||
{TypeMeta::Id<float>(), TensorProto_DataType_FLOAT},
|
||||
{TypeMeta::Id<int>(), TensorProto_DataType_INT32},
|
||||
// BYTE does not have a type meta to proto mapping: we should
|
||||
// always use uint8_t when serializing. BYTE is kept for backward
|
||||
// compatibility.
|
||||
// {TypeMeta::Id<>(), TensorProto_DataType_BYTE},
|
||||
{TypeMeta::Id<string>(), TensorProto_DataType_STRING},
|
||||
{TypeMeta::Id<bool>(), TensorProto_DataType_BOOL},
|
||||
{TypeMeta::Id<uint8_t>(), TensorProto_DataType_UINT8},
|
||||
{TypeMeta::Id<int8_t>(), TensorProto_DataType_INT8},
|
||||
{TypeMeta::Id<uint16_t>(), TensorProto_DataType_UINT16},
|
||||
{TypeMeta::Id<int16_t>(), TensorProto_DataType_INT16},
|
||||
{TypeMeta::Id<int64_t>(), TensorProto_DataType_INT64},
|
||||
{TypeMeta::Id<at::Half>(), TensorProto_DataType_FLOAT16},
|
||||
{TypeMeta::Id<double>(), TensorProto_DataType_DOUBLE},
|
||||
{TypeMeta::Id<c10::qint8>(), TensorProto_DataType_INT8},
|
||||
{TypeMeta::Id<c10::quint8>(), TensorProto_DataType_UINT8},
|
||||
{TypeMeta::Id<c10::qint32>(), TensorProto_DataType_INT32},
|
||||
};
|
||||
const auto it = data_type_map.find(meta.id());
|
||||
return (
|
||||
it == data_type_map.end() ? TensorProto_DataType_UNDEFINED : it->second);
|
||||
|
||||
// Can't use a switch because `meta_id` is not an integer type
|
||||
const auto meta_id = meta.id();
|
||||
if (meta_id == TypeMeta::Id<float>()) {
|
||||
return TensorProto_DataType_FLOAT;
|
||||
} else if (meta_id == TypeMeta::Id<int>()) {
|
||||
return TensorProto_DataType_INT32;
|
||||
} else if (meta_id == TypeMeta::Id<string>()) {
|
||||
return TensorProto_DataType_STRING;
|
||||
} else if (meta_id == TypeMeta::Id<bool>()) {
|
||||
return TensorProto_DataType_BOOL;
|
||||
} else if (meta_id == TypeMeta::Id<uint8_t>()) {
|
||||
return TensorProto_DataType_UINT8;
|
||||
} else if (meta_id == TypeMeta::Id<int8_t>()) {
|
||||
return TensorProto_DataType_INT8;
|
||||
} else if (meta_id == TypeMeta::Id<uint16_t>()) {
|
||||
return TensorProto_DataType_UINT16;
|
||||
} else if (meta_id == TypeMeta::Id<int16_t>()) {
|
||||
return TensorProto_DataType_INT16;
|
||||
} else if (meta_id == TypeMeta::Id<int64_t>()) {
|
||||
return TensorProto_DataType_INT64;
|
||||
} else if (meta_id == TypeMeta::Id<at::Half>()) {
|
||||
return TensorProto_DataType_FLOAT16;
|
||||
} else if (meta_id == TypeMeta::Id<double>()) {
|
||||
return TensorProto_DataType_DOUBLE;
|
||||
} else if (meta_id == TypeMeta::Id<c10::qint8>()) {
|
||||
return TensorProto_DataType_INT8;
|
||||
} else if (meta_id == TypeMeta::Id<c10::quint8>()) {
|
||||
return TensorProto_DataType_UINT8;
|
||||
} else if (meta_id == TypeMeta::Id<c10::qint32>()) {
|
||||
return TensorProto_DataType_INT32;
|
||||
} else {
|
||||
return TensorProto_DataType_UNDEFINED;
|
||||
}
|
||||
}
|
||||
|
||||
const TypeMeta DataTypeToTypeMeta(const TensorProto::DataType& dt) {
|
||||
static std::map<TensorProto::DataType, TypeMeta> type_meta_map{
|
||||
{TensorProto_DataType_FLOAT, TypeMeta::Make<float>()},
|
||||
{TensorProto_DataType_INT32, TypeMeta::Make<int>()},
|
||||
{TensorProto_DataType_BYTE, TypeMeta::Make<uint8_t>()},
|
||||
{TensorProto_DataType_STRING, TypeMeta::Make<std::string>()},
|
||||
{TensorProto_DataType_BOOL, TypeMeta::Make<bool>()},
|
||||
{TensorProto_DataType_UINT8, TypeMeta::Make<uint8_t>()},
|
||||
{TensorProto_DataType_INT8, TypeMeta::Make<int8_t>()},
|
||||
{TensorProto_DataType_UINT16, TypeMeta::Make<uint16_t>()},
|
||||
{TensorProto_DataType_INT16, TypeMeta::Make<int16_t>()},
|
||||
{TensorProto_DataType_INT64, TypeMeta::Make<int64_t>()},
|
||||
{TensorProto_DataType_FLOAT16, TypeMeta::Make<at::Half>()},
|
||||
{TensorProto_DataType_DOUBLE, TypeMeta::Make<double>()},
|
||||
const TypeMeta DataTypeToTypeMeta(const TensorProto_DataType& dt) {
|
||||
switch (dt) {
|
||||
case TensorProto_DataType_FLOAT:
|
||||
return TypeMeta::Make<float>();
|
||||
case TensorProto_DataType_INT32:
|
||||
return TypeMeta::Make<int>();
|
||||
case TensorProto_DataType_BYTE:
|
||||
return TypeMeta::Make<uint8_t>();
|
||||
case TensorProto_DataType_STRING:
|
||||
return TypeMeta::Make<std::string>();
|
||||
case TensorProto_DataType_BOOL:
|
||||
return TypeMeta::Make<bool>();
|
||||
case TensorProto_DataType_UINT8:
|
||||
return TypeMeta::Make<uint8_t>();
|
||||
case TensorProto_DataType_INT8:
|
||||
return TypeMeta::Make<int8_t>();
|
||||
case TensorProto_DataType_UINT16:
|
||||
return TypeMeta::Make<uint16_t>();
|
||||
case TensorProto_DataType_INT16:
|
||||
return TypeMeta::Make<int16_t>();
|
||||
case TensorProto_DataType_INT64:
|
||||
return TypeMeta::Make<int64_t>();
|
||||
case TensorProto_DataType_FLOAT16:
|
||||
return TypeMeta::Make<at::Half>();
|
||||
case TensorProto_DataType_DOUBLE:
|
||||
return TypeMeta::Make<double>();
|
||||
default:
|
||||
throw std::runtime_error("Unknown data type.");
|
||||
};
|
||||
const auto it = type_meta_map.find(dt);
|
||||
if (it == type_meta_map.end()) {
|
||||
throw std::runtime_error("Unknown data type.");
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
|
@ -47,7 +47,7 @@ inline int32_t GetDimFromOrderString(const std::string& str) {
|
||||
inline constexpr char NameScopeSeparator() { return '/'; }
|
||||
|
||||
// From TypeMeta to caffe2::DataType protobuffer enum.
|
||||
TORCH_API TensorProto::DataType TypeMetaToDataType(const TypeMeta meta);
|
||||
TORCH_API TensorProto::DataType TypeMetaToDataType(const TypeMeta& meta);
|
||||
|
||||
// From caffe2::DataType protobuffer enum to TypeMeta
|
||||
TORCH_API const TypeMeta DataTypeToTypeMeta(const TensorProto::DataType& dt);
|
||||
|
Reference in New Issue
Block a user