[Caffe2] Create fewer strings during argument fetching (#64285)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64285

With C++14 heterogeneous ordered container lookup, it is no longer necessary to create a `std::string` in order to look up elements of a `CaffeMap` keyed by std::string. Accordingly, this diff reworks the argument-getting operator functions to avoid that in favor of `c10::string_view`.
ghstack-source-id: 137139818
ghstack-source-id: 137139818

Test Plan: buildsizebot iOS apps -- code size win. less strings is probably marginally good for perf but this only happens at setup time anyway.

Reviewed By: dzhulgakov

Differential Revision: D26826676

fbshipit-source-id: ee653b14dc2c528bae8c90f0fc6a7a419cbca1d6
This commit is contained in:
Scott Wolchok
2021-09-01 13:24:11 -07:00
committed by Facebook GitHub Bot
parent 468001600c
commit 03a58a2ba0
5 changed files with 79 additions and 56 deletions

View File

@ -1,6 +1,7 @@
#pragma once
#include <c10/util/StringUtil.h>
#include <c10/util/string_view.h>
#include <ATen/core/jit_type.h>
#include <ATen/core/interned_strings.h>
#include <ATen/core/ivalue.h>
@ -272,7 +273,7 @@ struct FunctionSchema {
});
}
c10::optional<int> argumentIndexWithName(const std::string& name) const {
c10::optional<int> argumentIndexWithName(c10::string_view name) const {
for(size_t i = 0; i < arguments().size(); ++i) {
if(name == arguments()[i].name())
return i;

View File

@ -831,7 +831,7 @@ std::function<void(const OperatorDef&)> GetOperatorLogger() {
}
c10::optional<int> OperatorBase::argumentIndexWithName(
const std::string& name) const {
c10::string_view name) const {
#if defined(EXPOSE_C2_OPS) || \
!defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
return getFunctionSchema().argumentIndexWithName(name);

View File

@ -15,6 +15,7 @@
#include <c10/macros/Macros.h>
#include <c10/util/Registry.h>
#include <c10/util/string_view.h>
#include <c10/util/typeid.h>
#include <c10/core/Stream.h>
#include "caffe2/core/blob.h"
@ -97,7 +98,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
/** @brief Checks if the operator has an argument of the given name.
*/
inline bool HasArgument(const string& name) const {
inline bool HasArgument(c10::string_view name) const {
if (isLegacyOperator()) {
CAFFE_ENFORCE(operator_def_, "operator_def was null!");
return ArgumentHelper::HasArgument(*operator_def_, name);
@ -108,7 +109,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
// Functions that deal with arguments. Basically, this allows us to map an
// argument name to a specific type of argument that we are trying to access.
template <typename T>
inline T GetSingleArgument(const string& name, const T& default_value) const {
inline T GetSingleArgument(c10::string_view name, const T& default_value) const {
if (isLegacyOperator()) {
CAFFE_ENFORCE(operator_def_, "operator_def was null!");
return ArgumentHelper::GetSingleArgument<OperatorDef, T>(
@ -126,7 +127,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
}
template <typename T>
inline bool HasSingleArgumentOfType(const string& name) const {
inline bool HasSingleArgumentOfType(c10::string_view name) const {
CAFFE_ENFORCE(operator_def_, "operator_def was null!");
return ArgumentHelper::HasSingleArgumentOfType<OperatorDef, T>(
*operator_def_, name);
@ -141,7 +142,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
template <typename T>
inline vector<T> GetRepeatedArgument(
const string& name,
c10::string_view name,
const vector<T>& default_value = {}) const;
// Get the inputs and outputs as specific types.
@ -654,7 +655,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
}
}
c10::optional<int> argumentIndexWithName(const std::string& name) const;
c10::optional<int> argumentIndexWithName(c10::string_view name) const;
// An event used by asynchronous execution.
std::unique_ptr<Event> event_;
@ -664,7 +665,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
template <>
inline NetDef OperatorBase::GetSingleArgument<NetDef>(
const std::string& name,
c10::string_view name,
const NetDef& default_value) const {
if (isLegacyOperator()) {
CAFFE_ENFORCE(operator_def_, "operator_def was null!");
@ -756,7 +757,7 @@ inline vector<int16_t> OperatorBase::GetVectorFromIValueList<int16_t>(
template <typename T>
inline vector<T> OperatorBase::GetRepeatedArgument(
const string& name,
c10::string_view name,
const vector<T>& default_value) const {
if (isLegacyOperator()) {
CAFFE_ENFORCE(operator_def_, "operator_def was null!");
@ -778,7 +779,7 @@ inline vector<T> OperatorBase::GetRepeatedArgument(
// int16_t. We need to load it as List<int64_t> and transform to int16_t.
template <>
inline vector<int16_t> OperatorBase::GetRepeatedArgument<int16_t>(
const string& name,
c10::string_view name,
const vector<int16_t>& default_value) const {
if (isLegacyOperator()) {
CAFFE_ENFORCE(operator_def_, "operator_def was null!");

View File

@ -323,8 +323,12 @@ C10_EXPORT ArgumentHelper::ArgumentHelper(const NetDef& netdef) {
}
}
C10_EXPORT bool ArgumentHelper::HasArgument(const string& name) const {
C10_EXPORT bool ArgumentHelper::HasArgument(c10::string_view name) const {
#ifdef CAFFE2_ENABLE_REDUCED_STRINGS_IN_ARGUMENT_LOOKUP
return arg_map_.count(name);
#else
return arg_map_.count(std::string(name));
#endif
}
namespace {
@ -364,18 +368,19 @@ std::ostream& operator<<(std::ostream& output, const NetDef& n) {
T, fieldname, enforce_lossless_conversion) \
template <> \
C10_EXPORT T ArgumentHelper::GetSingleArgument<T>( \
const string& name, const T& default_value) const { \
if (arg_map_.count(name) == 0) { \
c10::string_view name, const T& default_value) const { \
auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); \
if (it == arg_map_.end()) { \
VLOG(1) << "Using default parameter value " << default_value \
<< " for parameter " << name; \
return default_value; \
} \
CAFFE_ENFORCE( \
arg_map_.at(name).has_##fieldname(), \
it->second.has_##fieldname(), \
"Argument ", \
name, \
" does not have the right field: expected field " #fieldname); \
auto value = arg_map_.at(name).fieldname(); \
auto value = it->second.fieldname(); \
if (enforce_lossless_conversion) { \
auto supportsConversion = \
SupportsLosslessConversion<decltype(value), T>(value); \
@ -391,11 +396,12 @@ std::ostream& operator<<(std::ostream& output, const NetDef& n) {
} \
template <> \
C10_EXPORT bool ArgumentHelper::HasSingleArgumentOfType<T>( \
const string& name) const { \
if (arg_map_.count(name) == 0) { \
c10::string_view name) const { \
auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); \
if (it == arg_map_.end()) { \
return false; \
} \
return arg_map_.at(name).has_##fieldname(); \
return it->second.has_##fieldname(); \
}
INSTANTIATE_GET_SINGLE_ARGUMENT(float, f, false)
@ -415,13 +421,14 @@ INSTANTIATE_GET_SINGLE_ARGUMENT(NetDef, n, false)
#define INSTANTIATE_GET_REPEATED_ARGUMENT( \
T, fieldname, enforce_lossless_conversion) \
template <> \
C10_EXPORT std::vector<T> ArgumentHelper::GetRepeatedArgument<T>( \
const string& name, const std::vector<T>& default_value) const { \
if (arg_map_.count(name) == 0) { \
C10_EXPORT std::vector<T> ArgumentHelper::GetRepeatedArgument<T>( \
c10::string_view name, const std::vector<T>& default_value) const { \
auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); \
if (it == arg_map_.end()) { \
return default_value; \
} \
std::vector<T> values; \
for (const auto& v : arg_map_.at(name).fieldname()) { \
std::vector<T> values; \
for (const auto& v : it->second.fieldname()) { \
if (enforce_lossless_conversion) { \
auto supportsConversion = \
SupportsLosslessConversion<decltype(v), T>(v); \
@ -531,7 +538,7 @@ C10_EXPORT bool HasInput(const OperatorDef& op, const std::string& input) {
// Return the argument index or -1 if it does not exist.
C10_EXPORT int GetArgumentIndex(
const google::protobuf::RepeatedPtrField<Argument>& args,
const string& name) {
c10::string_view name) {
int index = 0;
for (const Argument& arg : args) {
if (arg.name() == name) {
@ -544,7 +551,7 @@ C10_EXPORT int GetArgumentIndex(
C10_EXPORT const Argument& GetArgument(
const OperatorDef& def,
const string& name) {
c10::string_view name) {
int index = GetArgumentIndex(def.arg(), name);
if (index != -1) {
return def.arg(index);
@ -557,7 +564,7 @@ C10_EXPORT const Argument& GetArgument(
}
}
C10_EXPORT const Argument& GetArgument(const NetDef& def, const string& name) {
C10_EXPORT const Argument& GetArgument(const NetDef& def, c10::string_view name) {
int index = GetArgumentIndex(def.arg(), name);
if (index != -1) {
return def.arg(index);
@ -572,7 +579,7 @@ C10_EXPORT const Argument& GetArgument(const NetDef& def, const string& name) {
C10_EXPORT const Argument* GetArgumentPtr(
const OperatorDef& def,
const string& name) {
c10::string_view name) {
int index = GetArgumentIndex(def.arg(), name);
if (index != -1) {
return &def.arg(index);
@ -583,7 +590,7 @@ C10_EXPORT const Argument* GetArgumentPtr(
C10_EXPORT const Argument* GetArgumentPtr(
const NetDef& def,
const string& name) {
c10::string_view name) {
int index = GetArgumentIndex(def.arg(), name);
if (index != -1) {
return &def.arg(index);
@ -594,7 +601,7 @@ C10_EXPORT const Argument* GetArgumentPtr(
C10_EXPORT bool GetFlagArgument(
const google::protobuf::RepeatedPtrField<Argument>& args,
const string& name,
c10::string_view name,
bool default_value) {
int index = GetArgumentIndex(args, name);
if (index != -1) {
@ -609,13 +616,13 @@ C10_EXPORT bool GetFlagArgument(
C10_EXPORT bool GetFlagArgument(
const OperatorDef& def,
const string& name,
c10::string_view name,
bool default_value) {
return GetFlagArgument(def.arg(), name, default_value);
}
C10_EXPORT bool
GetFlagArgument(const NetDef& def, const string& name, bool default_value) {
GetFlagArgument(const NetDef& def, c10::string_view name, bool default_value) {
return GetFlagArgument(def.arg(), name, default_value);
}

View File

@ -8,10 +8,18 @@
#endif // !CAFFE2_USE_LITE_PROTO
#include <c10/util/Logging.h>
#include <c10/util/string_view.h>
#include "caffe2/utils/proto_wrap.h"
#include "caffe2/proto/caffe2_pb.h"
#ifndef C10_ANDROID
#define CAFFE2_ENABLE_REDUCED_STRINGS_IN_ARGUMENT_LOOKUP
#define CAFFE2_ARG_MAP_FIND(map, key) map.find(key)
#else
#define CAFFE2_ARG_MAP_FIND(map, key) map.find(std::string(key))
#endif
namespace caffe2 {
using std::string;
@ -204,40 +212,40 @@ TORCH_API bool HasInput(const OperatorDef& op, const std::string& input);
class C10_EXPORT ArgumentHelper {
public:
template <typename Def>
static bool HasArgument(const Def& def, const string& name) {
static bool HasArgument(const Def& def, c10::string_view name) {
return ArgumentHelper(def).HasArgument(name);
}
template <typename Def, typename T>
static T GetSingleArgument(
const Def& def,
const string& name,
c10::string_view name,
const T& default_value) {
return ArgumentHelper(def).GetSingleArgument<T>(name, default_value);
}
template <typename Def, typename T>
static bool HasSingleArgumentOfType(const Def& def, const string& name) {
static bool HasSingleArgumentOfType(const Def& def, c10::string_view name) {
return ArgumentHelper(def).HasSingleArgumentOfType<T>(name);
}
template <typename Def, typename T>
static std::vector<T> GetRepeatedArgument(
const Def& def,
const string& name,
c10::string_view name,
const std::vector<T>& default_value = std::vector<T>()) {
return ArgumentHelper(def).GetRepeatedArgument<T>(name, default_value);
}
template <typename Def, typename MessageType>
static MessageType GetMessageArgument(const Def& def, const string& name) {
static MessageType GetMessageArgument(const Def& def, c10::string_view name) {
return ArgumentHelper(def).GetMessageArgument<MessageType>(name);
}
template <typename Def, typename MessageType>
static std::vector<MessageType> GetRepeatedMessageArgument(
const Def& def,
const string& name) {
c10::string_view name) {
return ArgumentHelper(def).GetRepeatedMessageArgument<MessageType>(name);
}
@ -255,24 +263,25 @@ class C10_EXPORT ArgumentHelper {
explicit ArgumentHelper(const OperatorDef& def);
explicit ArgumentHelper(const NetDef& netdef);
bool HasArgument(const string& name) const;
bool HasArgument(c10::string_view name) const;
template <typename T>
T GetSingleArgument(const string& name, const T& default_value) const;
T GetSingleArgument(c10::string_view name, const T& default_value) const;
template <typename T>
bool HasSingleArgumentOfType(const string& name) const;
bool HasSingleArgumentOfType(c10::string_view name) const;
template <typename T>
std::vector<T> GetRepeatedArgument(
const string& name,
c10::string_view name,
const std::vector<T>& default_value = std::vector<T>()) const;
template <typename MessageType>
MessageType GetMessageArgument(const string& name) const {
CAFFE_ENFORCE(arg_map_.count(name), "Cannot find parameter named ", name);
MessageType GetMessageArgument(c10::string_view name) const {
auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name);
CAFFE_ENFORCE(it != arg_map_.end(), "Cannot find parameter named ", name);
MessageType message;
if (arg_map_.at(name).has_s()) {
if (it->second.has_s()) {
CAFFE_ENFORCE(
message.ParseFromString(arg_map_.at(name).s()),
message.ParseFromString(it->second.s()),
"Failed to parse content from the string");
} else {
VLOG(1) << "Return empty message for parameter " << name;
@ -281,42 +290,47 @@ class C10_EXPORT ArgumentHelper {
}
template <typename MessageType>
std::vector<MessageType> GetRepeatedMessageArgument(const string& name) const {
CAFFE_ENFORCE(arg_map_.count(name), "Cannot find parameter named ", name);
std::vector<MessageType> messages(arg_map_.at(name).strings_size());
std::vector<MessageType> GetRepeatedMessageArgument(c10::string_view name) const {
auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name);
CAFFE_ENFORCE(it != arg_map_.end(), "Cannot find parameter named ", name);
std::vector<MessageType> messages(it->second.strings_size());
for (int i = 0; i < messages.size(); ++i) {
CAFFE_ENFORCE(
messages[i].ParseFromString(arg_map_.at(name).strings(i)),
messages[i].ParseFromString(it->second.strings(i)),
"Failed to parse content from the string");
}
return messages;
}
private:
std::map<string, Argument> arg_map_;
std::map<string, Argument
#ifdef CAFFE2_ENABLE_REDUCED_STRINGS_IN_ARGUMENT_LOOKUP
, std::less<>
#endif
> arg_map_;
};
// **** Arguments Utils *****
// Helper methods to get an argument from OperatorDef or NetDef given argument
// name. Throws if argument does not exist.
TORCH_API const Argument& GetArgument(const OperatorDef& def, const string& name);
TORCH_API const Argument& GetArgument(const NetDef& def, const string& name);
TORCH_API const Argument& GetArgument(const OperatorDef& def, c10::string_view name);
TORCH_API const Argument& GetArgument(const NetDef& def, c10::string_view name);
// Helper methods to get an argument from OperatorDef or NetDef given argument
// name. Returns nullptr if argument does not exist.
TORCH_API const Argument* GetArgumentPtr(const OperatorDef& def, const string& name);
TORCH_API const Argument* GetArgumentPtr(const NetDef& def, const string& name);
TORCH_API const Argument* GetArgumentPtr(const OperatorDef& def, c10::string_view name);
TORCH_API const Argument* GetArgumentPtr(const NetDef& def, c10::string_view name);
// Helper methods to query a boolean argument flag from OperatorDef or NetDef
// given argument name. If argument does not exist, return default value.
// Throws if argument exists but the type is not boolean.
TORCH_API bool GetFlagArgument(
const OperatorDef& def,
const string& name,
c10::string_view name,
bool default_value = false);
TORCH_API bool GetFlagArgument(
const NetDef& def,
const string& name,
c10::string_view name,
bool default_value = false);
TORCH_API Argument* GetMutableArgument(