mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Caffe2] Create fewer strings during argument fetching (#64285)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64285 With C++14 heterogeneous ordered container lookup, it is no longer necessary to create a `std::string` in order to look up elements of a `CaffeMap` keyed by std::string. Accordingly, this diff reworks the argument-getting operator functions to avoid that in favor of `c10::string_view`. ghstack-source-id: 137139818 ghstack-source-id: 137139818 Test Plan: buildsizebot iOS apps -- code size win. less strings is probably marginally good for perf but this only happens at setup time anyway. Reviewed By: dzhulgakov Differential Revision: D26826676 fbshipit-source-id: ee653b14dc2c528bae8c90f0fc6a7a419cbca1d6
This commit is contained in:
committed by
Facebook GitHub Bot
parent
468001600c
commit
03a58a2ba0
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/StringUtil.h>
|
||||
#include <c10/util/string_view.h>
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <ATen/core/interned_strings.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
@ -272,7 +273,7 @@ struct FunctionSchema {
|
||||
});
|
||||
}
|
||||
|
||||
c10::optional<int> argumentIndexWithName(const std::string& name) const {
|
||||
c10::optional<int> argumentIndexWithName(c10::string_view name) const {
|
||||
for(size_t i = 0; i < arguments().size(); ++i) {
|
||||
if(name == arguments()[i].name())
|
||||
return i;
|
||||
|
@ -831,7 +831,7 @@ std::function<void(const OperatorDef&)> GetOperatorLogger() {
|
||||
}
|
||||
|
||||
c10::optional<int> OperatorBase::argumentIndexWithName(
|
||||
const std::string& name) const {
|
||||
c10::string_view name) const {
|
||||
#if defined(EXPOSE_C2_OPS) || \
|
||||
!defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
|
||||
return getFunctionSchema().argumentIndexWithName(name);
|
||||
|
@ -15,6 +15,7 @@
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Registry.h>
|
||||
#include <c10/util/string_view.h>
|
||||
#include <c10/util/typeid.h>
|
||||
#include <c10/core/Stream.h>
|
||||
#include "caffe2/core/blob.h"
|
||||
@ -97,7 +98,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
|
||||
|
||||
/** @brief Checks if the operator has an argument of the given name.
|
||||
*/
|
||||
inline bool HasArgument(const string& name) const {
|
||||
inline bool HasArgument(c10::string_view name) const {
|
||||
if (isLegacyOperator()) {
|
||||
CAFFE_ENFORCE(operator_def_, "operator_def was null!");
|
||||
return ArgumentHelper::HasArgument(*operator_def_, name);
|
||||
@ -108,7 +109,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
|
||||
// Functions that deal with arguments. Basically, this allows us to map an
|
||||
// argument name to a specific type of argument that we are trying to access.
|
||||
template <typename T>
|
||||
inline T GetSingleArgument(const string& name, const T& default_value) const {
|
||||
inline T GetSingleArgument(c10::string_view name, const T& default_value) const {
|
||||
if (isLegacyOperator()) {
|
||||
CAFFE_ENFORCE(operator_def_, "operator_def was null!");
|
||||
return ArgumentHelper::GetSingleArgument<OperatorDef, T>(
|
||||
@ -126,7 +127,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool HasSingleArgumentOfType(const string& name) const {
|
||||
inline bool HasSingleArgumentOfType(c10::string_view name) const {
|
||||
CAFFE_ENFORCE(operator_def_, "operator_def was null!");
|
||||
return ArgumentHelper::HasSingleArgumentOfType<OperatorDef, T>(
|
||||
*operator_def_, name);
|
||||
@ -141,7 +142,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
|
||||
|
||||
template <typename T>
|
||||
inline vector<T> GetRepeatedArgument(
|
||||
const string& name,
|
||||
c10::string_view name,
|
||||
const vector<T>& default_value = {}) const;
|
||||
|
||||
// Get the inputs and outputs as specific types.
|
||||
@ -654,7 +655,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
|
||||
}
|
||||
}
|
||||
|
||||
c10::optional<int> argumentIndexWithName(const std::string& name) const;
|
||||
c10::optional<int> argumentIndexWithName(c10::string_view name) const;
|
||||
|
||||
// An event used by asynchronous execution.
|
||||
std::unique_ptr<Event> event_;
|
||||
@ -664,7 +665,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
|
||||
|
||||
template <>
|
||||
inline NetDef OperatorBase::GetSingleArgument<NetDef>(
|
||||
const std::string& name,
|
||||
c10::string_view name,
|
||||
const NetDef& default_value) const {
|
||||
if (isLegacyOperator()) {
|
||||
CAFFE_ENFORCE(operator_def_, "operator_def was null!");
|
||||
@ -756,7 +757,7 @@ inline vector<int16_t> OperatorBase::GetVectorFromIValueList<int16_t>(
|
||||
|
||||
template <typename T>
|
||||
inline vector<T> OperatorBase::GetRepeatedArgument(
|
||||
const string& name,
|
||||
c10::string_view name,
|
||||
const vector<T>& default_value) const {
|
||||
if (isLegacyOperator()) {
|
||||
CAFFE_ENFORCE(operator_def_, "operator_def was null!");
|
||||
@ -778,7 +779,7 @@ inline vector<T> OperatorBase::GetRepeatedArgument(
|
||||
// int16_t. We need to load it as List<int64_t> and transform to int16_t.
|
||||
template <>
|
||||
inline vector<int16_t> OperatorBase::GetRepeatedArgument<int16_t>(
|
||||
const string& name,
|
||||
c10::string_view name,
|
||||
const vector<int16_t>& default_value) const {
|
||||
if (isLegacyOperator()) {
|
||||
CAFFE_ENFORCE(operator_def_, "operator_def was null!");
|
||||
|
@ -323,8 +323,12 @@ C10_EXPORT ArgumentHelper::ArgumentHelper(const NetDef& netdef) {
|
||||
}
|
||||
}
|
||||
|
||||
C10_EXPORT bool ArgumentHelper::HasArgument(const string& name) const {
|
||||
C10_EXPORT bool ArgumentHelper::HasArgument(c10::string_view name) const {
|
||||
#ifdef CAFFE2_ENABLE_REDUCED_STRINGS_IN_ARGUMENT_LOOKUP
|
||||
return arg_map_.count(name);
|
||||
#else
|
||||
return arg_map_.count(std::string(name));
|
||||
#endif
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -364,18 +368,19 @@ std::ostream& operator<<(std::ostream& output, const NetDef& n) {
|
||||
T, fieldname, enforce_lossless_conversion) \
|
||||
template <> \
|
||||
C10_EXPORT T ArgumentHelper::GetSingleArgument<T>( \
|
||||
const string& name, const T& default_value) const { \
|
||||
if (arg_map_.count(name) == 0) { \
|
||||
c10::string_view name, const T& default_value) const { \
|
||||
auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); \
|
||||
if (it == arg_map_.end()) { \
|
||||
VLOG(1) << "Using default parameter value " << default_value \
|
||||
<< " for parameter " << name; \
|
||||
return default_value; \
|
||||
} \
|
||||
CAFFE_ENFORCE( \
|
||||
arg_map_.at(name).has_##fieldname(), \
|
||||
it->second.has_##fieldname(), \
|
||||
"Argument ", \
|
||||
name, \
|
||||
" does not have the right field: expected field " #fieldname); \
|
||||
auto value = arg_map_.at(name).fieldname(); \
|
||||
auto value = it->second.fieldname(); \
|
||||
if (enforce_lossless_conversion) { \
|
||||
auto supportsConversion = \
|
||||
SupportsLosslessConversion<decltype(value), T>(value); \
|
||||
@ -391,11 +396,12 @@ std::ostream& operator<<(std::ostream& output, const NetDef& n) {
|
||||
} \
|
||||
template <> \
|
||||
C10_EXPORT bool ArgumentHelper::HasSingleArgumentOfType<T>( \
|
||||
const string& name) const { \
|
||||
if (arg_map_.count(name) == 0) { \
|
||||
c10::string_view name) const { \
|
||||
auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); \
|
||||
if (it == arg_map_.end()) { \
|
||||
return false; \
|
||||
} \
|
||||
return arg_map_.at(name).has_##fieldname(); \
|
||||
return it->second.has_##fieldname(); \
|
||||
}
|
||||
|
||||
INSTANTIATE_GET_SINGLE_ARGUMENT(float, f, false)
|
||||
@ -415,13 +421,14 @@ INSTANTIATE_GET_SINGLE_ARGUMENT(NetDef, n, false)
|
||||
#define INSTANTIATE_GET_REPEATED_ARGUMENT( \
|
||||
T, fieldname, enforce_lossless_conversion) \
|
||||
template <> \
|
||||
C10_EXPORT std::vector<T> ArgumentHelper::GetRepeatedArgument<T>( \
|
||||
const string& name, const std::vector<T>& default_value) const { \
|
||||
if (arg_map_.count(name) == 0) { \
|
||||
C10_EXPORT std::vector<T> ArgumentHelper::GetRepeatedArgument<T>( \
|
||||
c10::string_view name, const std::vector<T>& default_value) const { \
|
||||
auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); \
|
||||
if (it == arg_map_.end()) { \
|
||||
return default_value; \
|
||||
} \
|
||||
std::vector<T> values; \
|
||||
for (const auto& v : arg_map_.at(name).fieldname()) { \
|
||||
std::vector<T> values; \
|
||||
for (const auto& v : it->second.fieldname()) { \
|
||||
if (enforce_lossless_conversion) { \
|
||||
auto supportsConversion = \
|
||||
SupportsLosslessConversion<decltype(v), T>(v); \
|
||||
@ -531,7 +538,7 @@ C10_EXPORT bool HasInput(const OperatorDef& op, const std::string& input) {
|
||||
// Return the argument index or -1 if it does not exist.
|
||||
C10_EXPORT int GetArgumentIndex(
|
||||
const google::protobuf::RepeatedPtrField<Argument>& args,
|
||||
const string& name) {
|
||||
c10::string_view name) {
|
||||
int index = 0;
|
||||
for (const Argument& arg : args) {
|
||||
if (arg.name() == name) {
|
||||
@ -544,7 +551,7 @@ C10_EXPORT int GetArgumentIndex(
|
||||
|
||||
C10_EXPORT const Argument& GetArgument(
|
||||
const OperatorDef& def,
|
||||
const string& name) {
|
||||
c10::string_view name) {
|
||||
int index = GetArgumentIndex(def.arg(), name);
|
||||
if (index != -1) {
|
||||
return def.arg(index);
|
||||
@ -557,7 +564,7 @@ C10_EXPORT const Argument& GetArgument(
|
||||
}
|
||||
}
|
||||
|
||||
C10_EXPORT const Argument& GetArgument(const NetDef& def, const string& name) {
|
||||
C10_EXPORT const Argument& GetArgument(const NetDef& def, c10::string_view name) {
|
||||
int index = GetArgumentIndex(def.arg(), name);
|
||||
if (index != -1) {
|
||||
return def.arg(index);
|
||||
@ -572,7 +579,7 @@ C10_EXPORT const Argument& GetArgument(const NetDef& def, const string& name) {
|
||||
|
||||
C10_EXPORT const Argument* GetArgumentPtr(
|
||||
const OperatorDef& def,
|
||||
const string& name) {
|
||||
c10::string_view name) {
|
||||
int index = GetArgumentIndex(def.arg(), name);
|
||||
if (index != -1) {
|
||||
return &def.arg(index);
|
||||
@ -583,7 +590,7 @@ C10_EXPORT const Argument* GetArgumentPtr(
|
||||
|
||||
C10_EXPORT const Argument* GetArgumentPtr(
|
||||
const NetDef& def,
|
||||
const string& name) {
|
||||
c10::string_view name) {
|
||||
int index = GetArgumentIndex(def.arg(), name);
|
||||
if (index != -1) {
|
||||
return &def.arg(index);
|
||||
@ -594,7 +601,7 @@ C10_EXPORT const Argument* GetArgumentPtr(
|
||||
|
||||
C10_EXPORT bool GetFlagArgument(
|
||||
const google::protobuf::RepeatedPtrField<Argument>& args,
|
||||
const string& name,
|
||||
c10::string_view name,
|
||||
bool default_value) {
|
||||
int index = GetArgumentIndex(args, name);
|
||||
if (index != -1) {
|
||||
@ -609,13 +616,13 @@ C10_EXPORT bool GetFlagArgument(
|
||||
|
||||
C10_EXPORT bool GetFlagArgument(
|
||||
const OperatorDef& def,
|
||||
const string& name,
|
||||
c10::string_view name,
|
||||
bool default_value) {
|
||||
return GetFlagArgument(def.arg(), name, default_value);
|
||||
}
|
||||
|
||||
C10_EXPORT bool
|
||||
GetFlagArgument(const NetDef& def, const string& name, bool default_value) {
|
||||
GetFlagArgument(const NetDef& def, c10::string_view name, bool default_value) {
|
||||
return GetFlagArgument(def.arg(), name, default_value);
|
||||
}
|
||||
|
||||
|
@ -8,10 +8,18 @@
|
||||
#endif // !CAFFE2_USE_LITE_PROTO
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
#include <c10/util/string_view.h>
|
||||
|
||||
#include "caffe2/utils/proto_wrap.h"
|
||||
#include "caffe2/proto/caffe2_pb.h"
|
||||
|
||||
#ifndef C10_ANDROID
|
||||
#define CAFFE2_ENABLE_REDUCED_STRINGS_IN_ARGUMENT_LOOKUP
|
||||
#define CAFFE2_ARG_MAP_FIND(map, key) map.find(key)
|
||||
#else
|
||||
#define CAFFE2_ARG_MAP_FIND(map, key) map.find(std::string(key))
|
||||
#endif
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
using std::string;
|
||||
@ -204,40 +212,40 @@ TORCH_API bool HasInput(const OperatorDef& op, const std::string& input);
|
||||
class C10_EXPORT ArgumentHelper {
|
||||
public:
|
||||
template <typename Def>
|
||||
static bool HasArgument(const Def& def, const string& name) {
|
||||
static bool HasArgument(const Def& def, c10::string_view name) {
|
||||
return ArgumentHelper(def).HasArgument(name);
|
||||
}
|
||||
|
||||
template <typename Def, typename T>
|
||||
static T GetSingleArgument(
|
||||
const Def& def,
|
||||
const string& name,
|
||||
c10::string_view name,
|
||||
const T& default_value) {
|
||||
return ArgumentHelper(def).GetSingleArgument<T>(name, default_value);
|
||||
}
|
||||
|
||||
template <typename Def, typename T>
|
||||
static bool HasSingleArgumentOfType(const Def& def, const string& name) {
|
||||
static bool HasSingleArgumentOfType(const Def& def, c10::string_view name) {
|
||||
return ArgumentHelper(def).HasSingleArgumentOfType<T>(name);
|
||||
}
|
||||
|
||||
template <typename Def, typename T>
|
||||
static std::vector<T> GetRepeatedArgument(
|
||||
const Def& def,
|
||||
const string& name,
|
||||
c10::string_view name,
|
||||
const std::vector<T>& default_value = std::vector<T>()) {
|
||||
return ArgumentHelper(def).GetRepeatedArgument<T>(name, default_value);
|
||||
}
|
||||
|
||||
template <typename Def, typename MessageType>
|
||||
static MessageType GetMessageArgument(const Def& def, const string& name) {
|
||||
static MessageType GetMessageArgument(const Def& def, c10::string_view name) {
|
||||
return ArgumentHelper(def).GetMessageArgument<MessageType>(name);
|
||||
}
|
||||
|
||||
template <typename Def, typename MessageType>
|
||||
static std::vector<MessageType> GetRepeatedMessageArgument(
|
||||
const Def& def,
|
||||
const string& name) {
|
||||
c10::string_view name) {
|
||||
return ArgumentHelper(def).GetRepeatedMessageArgument<MessageType>(name);
|
||||
}
|
||||
|
||||
@ -255,24 +263,25 @@ class C10_EXPORT ArgumentHelper {
|
||||
|
||||
explicit ArgumentHelper(const OperatorDef& def);
|
||||
explicit ArgumentHelper(const NetDef& netdef);
|
||||
bool HasArgument(const string& name) const;
|
||||
bool HasArgument(c10::string_view name) const;
|
||||
|
||||
template <typename T>
|
||||
T GetSingleArgument(const string& name, const T& default_value) const;
|
||||
T GetSingleArgument(c10::string_view name, const T& default_value) const;
|
||||
template <typename T>
|
||||
bool HasSingleArgumentOfType(const string& name) const;
|
||||
bool HasSingleArgumentOfType(c10::string_view name) const;
|
||||
template <typename T>
|
||||
std::vector<T> GetRepeatedArgument(
|
||||
const string& name,
|
||||
c10::string_view name,
|
||||
const std::vector<T>& default_value = std::vector<T>()) const;
|
||||
|
||||
template <typename MessageType>
|
||||
MessageType GetMessageArgument(const string& name) const {
|
||||
CAFFE_ENFORCE(arg_map_.count(name), "Cannot find parameter named ", name);
|
||||
MessageType GetMessageArgument(c10::string_view name) const {
|
||||
auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name);
|
||||
CAFFE_ENFORCE(it != arg_map_.end(), "Cannot find parameter named ", name);
|
||||
MessageType message;
|
||||
if (arg_map_.at(name).has_s()) {
|
||||
if (it->second.has_s()) {
|
||||
CAFFE_ENFORCE(
|
||||
message.ParseFromString(arg_map_.at(name).s()),
|
||||
message.ParseFromString(it->second.s()),
|
||||
"Failed to parse content from the string");
|
||||
} else {
|
||||
VLOG(1) << "Return empty message for parameter " << name;
|
||||
@ -281,42 +290,47 @@ class C10_EXPORT ArgumentHelper {
|
||||
}
|
||||
|
||||
template <typename MessageType>
|
||||
std::vector<MessageType> GetRepeatedMessageArgument(const string& name) const {
|
||||
CAFFE_ENFORCE(arg_map_.count(name), "Cannot find parameter named ", name);
|
||||
std::vector<MessageType> messages(arg_map_.at(name).strings_size());
|
||||
std::vector<MessageType> GetRepeatedMessageArgument(c10::string_view name) const {
|
||||
auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name);
|
||||
CAFFE_ENFORCE(it != arg_map_.end(), "Cannot find parameter named ", name);
|
||||
std::vector<MessageType> messages(it->second.strings_size());
|
||||
for (int i = 0; i < messages.size(); ++i) {
|
||||
CAFFE_ENFORCE(
|
||||
messages[i].ParseFromString(arg_map_.at(name).strings(i)),
|
||||
messages[i].ParseFromString(it->second.strings(i)),
|
||||
"Failed to parse content from the string");
|
||||
}
|
||||
return messages;
|
||||
}
|
||||
|
||||
private:
|
||||
std::map<string, Argument> arg_map_;
|
||||
std::map<string, Argument
|
||||
#ifdef CAFFE2_ENABLE_REDUCED_STRINGS_IN_ARGUMENT_LOOKUP
|
||||
, std::less<>
|
||||
#endif
|
||||
> arg_map_;
|
||||
};
|
||||
|
||||
// **** Arguments Utils *****
|
||||
|
||||
// Helper methods to get an argument from OperatorDef or NetDef given argument
|
||||
// name. Throws if argument does not exist.
|
||||
TORCH_API const Argument& GetArgument(const OperatorDef& def, const string& name);
|
||||
TORCH_API const Argument& GetArgument(const NetDef& def, const string& name);
|
||||
TORCH_API const Argument& GetArgument(const OperatorDef& def, c10::string_view name);
|
||||
TORCH_API const Argument& GetArgument(const NetDef& def, c10::string_view name);
|
||||
// Helper methods to get an argument from OperatorDef or NetDef given argument
|
||||
// name. Returns nullptr if argument does not exist.
|
||||
TORCH_API const Argument* GetArgumentPtr(const OperatorDef& def, const string& name);
|
||||
TORCH_API const Argument* GetArgumentPtr(const NetDef& def, const string& name);
|
||||
TORCH_API const Argument* GetArgumentPtr(const OperatorDef& def, c10::string_view name);
|
||||
TORCH_API const Argument* GetArgumentPtr(const NetDef& def, c10::string_view name);
|
||||
|
||||
// Helper methods to query a boolean argument flag from OperatorDef or NetDef
|
||||
// given argument name. If argument does not exist, return default value.
|
||||
// Throws if argument exists but the type is not boolean.
|
||||
TORCH_API bool GetFlagArgument(
|
||||
const OperatorDef& def,
|
||||
const string& name,
|
||||
c10::string_view name,
|
||||
bool default_value = false);
|
||||
TORCH_API bool GetFlagArgument(
|
||||
const NetDef& def,
|
||||
const string& name,
|
||||
c10::string_view name,
|
||||
bool default_value = false);
|
||||
|
||||
TORCH_API Argument* GetMutableArgument(
|
||||
|
Reference in New Issue
Block a user