Speed up DataTypeToTypeMeta (#66113)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66113

For a benchmark compiled in opt-mode in which the lookup items were shuffled and then the items were looked up round-robin fashion 10M times (for a total of 140M lookups) compiled in opt-mode we see:
```
Function           Container            Time (ms) Multiplier
TypeMetaToDataType if-chain                   233         1x
TypeMetaToDataType std::vector                795      3.41x
TypeMetaToDataType std::map                  1566      6.72x
TypeMetaToDataType std::unordered_map        2136      9.17x

DataTypeToTypeMeta switch                     102         1x
DataTypeToTypeMeta std::vector                666      6.53x
DataTypeToTypeMeta std::map                  1212      11.9x
DataTypeToTypeMeta std::unordered_map        1539      15.1x
DataTypeToTypeMeta folly::F14FastMap         1789      17.5x
```
From this, we draw two conclusions:
1. Using a complex container like `std::map` is worse than using a simple vector lookup here (there aren't enough items for the Big-O to assert itself).
2. Using any container at all is a mistake. (Unless we pull in more exotic reasoning like invalidating the code cache or preventing inlining.)

Test Plan: Sandcastle

Reviewed By: dzhulgakov

Differential Revision: D31375117

fbshipit-source-id: 0b310c6c2e94080d125c82fb7c2b43ab869adbcb
This commit is contained in:
Richard Barnes
2021-10-07 08:04:38 -07:00
committed by Facebook GitHub Bot
parent 1e4bcbdddb
commit 2f1ab477f1
2 changed files with 64 additions and 45 deletions

View File

@ -4,58 +4,77 @@
#include <atomic>
#include <memory>
#include <string>
#include <vector>
namespace caffe2 {
TensorProto::DataType TypeMetaToDataType(const TypeMeta meta) {
TensorProto::DataType TypeMetaToDataType(const TypeMeta& meta) {
static_assert(
sizeof(int) == 4, "int in this compiler does not equal to 4 bytes.");
static std::map<TypeIdentifier, TensorProto::DataType> data_type_map{
{TypeMeta::Id<float>(), TensorProto_DataType_FLOAT},
{TypeMeta::Id<int>(), TensorProto_DataType_INT32},
// BYTE does not have a type meta to proto mapping: we should
// always use uint8_t when serializing. BYTE is kept for backward
// compatibility.
// {TypeMeta::Id<>(), TensorProto_DataType_BYTE},
{TypeMeta::Id<string>(), TensorProto_DataType_STRING},
{TypeMeta::Id<bool>(), TensorProto_DataType_BOOL},
{TypeMeta::Id<uint8_t>(), TensorProto_DataType_UINT8},
{TypeMeta::Id<int8_t>(), TensorProto_DataType_INT8},
{TypeMeta::Id<uint16_t>(), TensorProto_DataType_UINT16},
{TypeMeta::Id<int16_t>(), TensorProto_DataType_INT16},
{TypeMeta::Id<int64_t>(), TensorProto_DataType_INT64},
{TypeMeta::Id<at::Half>(), TensorProto_DataType_FLOAT16},
{TypeMeta::Id<double>(), TensorProto_DataType_DOUBLE},
{TypeMeta::Id<c10::qint8>(), TensorProto_DataType_INT8},
{TypeMeta::Id<c10::quint8>(), TensorProto_DataType_UINT8},
{TypeMeta::Id<c10::qint32>(), TensorProto_DataType_INT32},
};
const auto it = data_type_map.find(meta.id());
return (
it == data_type_map.end() ? TensorProto_DataType_UNDEFINED : it->second);
// Can't use a switch because `meta_id` is not an integer type
const auto meta_id = meta.id();
if (meta_id == TypeMeta::Id<float>()) {
return TensorProto_DataType_FLOAT;
} else if (meta_id == TypeMeta::Id<int>()) {
return TensorProto_DataType_INT32;
} else if (meta_id == TypeMeta::Id<string>()) {
return TensorProto_DataType_STRING;
} else if (meta_id == TypeMeta::Id<bool>()) {
return TensorProto_DataType_BOOL;
} else if (meta_id == TypeMeta::Id<uint8_t>()) {
return TensorProto_DataType_UINT8;
} else if (meta_id == TypeMeta::Id<int8_t>()) {
return TensorProto_DataType_INT8;
} else if (meta_id == TypeMeta::Id<uint16_t>()) {
return TensorProto_DataType_UINT16;
} else if (meta_id == TypeMeta::Id<int16_t>()) {
return TensorProto_DataType_INT16;
} else if (meta_id == TypeMeta::Id<int64_t>()) {
return TensorProto_DataType_INT64;
} else if (meta_id == TypeMeta::Id<at::Half>()) {
return TensorProto_DataType_FLOAT16;
} else if (meta_id == TypeMeta::Id<double>()) {
return TensorProto_DataType_DOUBLE;
} else if (meta_id == TypeMeta::Id<c10::qint8>()) {
return TensorProto_DataType_INT8;
} else if (meta_id == TypeMeta::Id<c10::quint8>()) {
return TensorProto_DataType_UINT8;
} else if (meta_id == TypeMeta::Id<c10::qint32>()) {
return TensorProto_DataType_INT32;
} else {
return TensorProto_DataType_UNDEFINED;
}
}
const TypeMeta DataTypeToTypeMeta(const TensorProto::DataType& dt) {
static std::map<TensorProto::DataType, TypeMeta> type_meta_map{
{TensorProto_DataType_FLOAT, TypeMeta::Make<float>()},
{TensorProto_DataType_INT32, TypeMeta::Make<int>()},
{TensorProto_DataType_BYTE, TypeMeta::Make<uint8_t>()},
{TensorProto_DataType_STRING, TypeMeta::Make<std::string>()},
{TensorProto_DataType_BOOL, TypeMeta::Make<bool>()},
{TensorProto_DataType_UINT8, TypeMeta::Make<uint8_t>()},
{TensorProto_DataType_INT8, TypeMeta::Make<int8_t>()},
{TensorProto_DataType_UINT16, TypeMeta::Make<uint16_t>()},
{TensorProto_DataType_INT16, TypeMeta::Make<int16_t>()},
{TensorProto_DataType_INT64, TypeMeta::Make<int64_t>()},
{TensorProto_DataType_FLOAT16, TypeMeta::Make<at::Half>()},
{TensorProto_DataType_DOUBLE, TypeMeta::Make<double>()},
const TypeMeta DataTypeToTypeMeta(const TensorProto_DataType& dt) {
switch (dt) {
case TensorProto_DataType_FLOAT:
return TypeMeta::Make<float>();
case TensorProto_DataType_INT32:
return TypeMeta::Make<int>();
case TensorProto_DataType_BYTE:
return TypeMeta::Make<uint8_t>();
case TensorProto_DataType_STRING:
return TypeMeta::Make<std::string>();
case TensorProto_DataType_BOOL:
return TypeMeta::Make<bool>();
case TensorProto_DataType_UINT8:
return TypeMeta::Make<uint8_t>();
case TensorProto_DataType_INT8:
return TypeMeta::Make<int8_t>();
case TensorProto_DataType_UINT16:
return TypeMeta::Make<uint16_t>();
case TensorProto_DataType_INT16:
return TypeMeta::Make<int16_t>();
case TensorProto_DataType_INT64:
return TypeMeta::Make<int64_t>();
case TensorProto_DataType_FLOAT16:
return TypeMeta::Make<at::Half>();
case TensorProto_DataType_DOUBLE:
return TypeMeta::Make<double>();
default:
throw std::runtime_error("Unknown data type.");
};
const auto it = type_meta_map.find(dt);
if (it == type_meta_map.end()) {
throw std::runtime_error("Unknown data type.");
}
return it->second;
}
} // namespace caffe2

View File

@ -47,7 +47,7 @@ inline int32_t GetDimFromOrderString(const std::string& str) {
inline constexpr char NameScopeSeparator() { return '/'; }
// From TypeMeta to caffe2::DataType protobuffer enum.
TORCH_API TensorProto::DataType TypeMetaToDataType(const TypeMeta meta);
TORCH_API TensorProto::DataType TypeMetaToDataType(const TypeMeta& meta);
// From caffe2::DataType protobuffer enum to TypeMeta
TORCH_API const TypeMeta DataTypeToTypeMeta(const TensorProto::DataType& dt);