mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Eliminate c10::guts::to_string (#108480)
This PR replace c10::guts::to_string with std::to_string. The major part of changes is using void* as optimizer state key since string is used only for serialization and using pointers as hashing keys is more efficient than a string. Some other guts functions in the affected source files are also replaced. Pull Request resolved: https://github.com/pytorch/pytorch/pull/108480 Approved by: https://github.com/Skylion007
This commit is contained in:
@ -1,14 +1,13 @@
|
||||
#include <ATen/MapAllocator.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <string>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#if ATOMIC_INT_LOCK_FREE == 2
|
||||
#define AT_ATOMIC_IPC_REFCOUNT 1
|
||||
#endif
|
||||
|
||||
#include <c10/core/CPUAllocator.h>
|
||||
#include <c10/util/C++17.h>
|
||||
#include <c10/util/Unicode.h>
|
||||
|
||||
/* stuff for mapped files */
|
||||
@ -17,9 +16,9 @@
|
||||
#endif
|
||||
|
||||
#if defined(HAVE_MMAP)
|
||||
#include <fcntl.h>
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <fcntl.h>
|
||||
#endif
|
||||
|
||||
#if !defined(_MSC_VER) || defined(HAVE_MMAP)
|
||||
@ -28,28 +27,29 @@
|
||||
#elif defined(_MSC_VER)
|
||||
#include <c10/util/win32-headers.h>
|
||||
#endif
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
static constexpr int64_t map_alloc_alignment = 64;
|
||||
|
||||
TORCH_API std::string NewProcessWideShmHandle()
|
||||
{
|
||||
std::string NewProcessWideShmHandle() {
|
||||
static std::atomic<uint64_t> counter{0};
|
||||
static std::random_device rd;
|
||||
std::string handle = "/torch_";
|
||||
#ifdef _MSC_VER
|
||||
handle += c10::guts::to_string(GetCurrentProcessId());
|
||||
return fmt::format(
|
||||
"/torch_{}_{}_{}",
|
||||
GetCurrentProcessId(),
|
||||
rd(),
|
||||
counter.fetch_add(1, std::memory_order_relaxed));
|
||||
#else
|
||||
handle += c10::guts::to_string(getpid());
|
||||
return fmt::format(
|
||||
"/torch_{}_{}_{}",
|
||||
getpid(),
|
||||
rd(),
|
||||
counter.fetch_add(1, std::memory_order_relaxed));
|
||||
#endif
|
||||
handle += "_";
|
||||
handle += c10::guts::to_string(rd());
|
||||
handle += "_";
|
||||
handle += c10::guts::to_string(counter.fetch_add(1, std::memory_order_relaxed));
|
||||
return handle;
|
||||
}
|
||||
|
||||
#if defined(_WIN32) || defined(HAVE_MMAP)
|
||||
|
||||
namespace {
|
||||
|
@ -907,7 +907,7 @@ TEST(OperatorRegistrationTestLegacyFunctionBasedKernel, givenKernelWithOptionalI
|
||||
}
|
||||
|
||||
std::string concatKernel(const Tensor& tensor1, std::string a, const std::string& b, int64_t c) {
|
||||
return a + b + c10::guts::to_string(c);
|
||||
return a + b + std::to_string(c);
|
||||
}
|
||||
|
||||
void expectCallsConcatUnboxed(DispatchKey dispatch_key) {
|
||||
|
@ -649,7 +649,7 @@ TEST(OperatorRegistrationTestFunctionBasedKernel, givenKernelWithOptionalInputs_
|
||||
}
|
||||
|
||||
std::string concatKernel(const Tensor& tensor1, std::string a, const std::string& b, int64_t c) {
|
||||
return a + b + c10::guts::to_string(c);
|
||||
return a + b + std::to_string(c);
|
||||
}
|
||||
|
||||
void expectCallsConcatUnboxed(DispatchKey dispatch_key) {
|
||||
|
@ -854,7 +854,7 @@ void expectCallsConcatUnboxed(DispatchKey dispatch_key) {
|
||||
TEST(OperatorRegistrationTestLegacyLambdaBasedKernel, givenKernel_whenRegistered_thenCanBeCalledUnboxed) {
|
||||
std::string prefix = "prefix";
|
||||
auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, str a, str b, int c) -> str", [&] (const Tensor& tensor1, std::string a, const std::string& b, int64_t c) {
|
||||
return prefix + a + b + c10::guts::to_string(c);
|
||||
return prefix + a + b + std::to_string(c);
|
||||
});
|
||||
expectCallsConcatUnboxed(DispatchKey::CPU);
|
||||
}
|
||||
|
@ -576,7 +576,7 @@ void expectCallsConcatUnboxed(DispatchKey dispatch_key) {
|
||||
TEST(OperatorRegistrationTestLambdaBasedKernel, givenKernel_whenRegistered_thenCanBeCalledUnboxed) {
|
||||
auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, str a, str b, int c) -> str", torch::RegisterOperators::options()
|
||||
.kernel(DispatchKey::CPU, [] (const Tensor& tensor1, std::string a, const std::string& b, int64_t c) {
|
||||
return a + b + c10::guts::to_string(c);
|
||||
return a + b + std::to_string(c);
|
||||
}));
|
||||
expectCallsConcatUnboxed(DispatchKey::CPU);
|
||||
}
|
||||
|
@ -787,7 +787,7 @@ struct ConcatKernel final : OperatorKernel {
|
||||
explicit ConcatKernel(std::string prefix): prefix_(std::move(prefix)) {}
|
||||
|
||||
std::string operator()(const Tensor& tensor1, std::string a, const std::string& b, int64_t c) {
|
||||
return prefix_ + a + b + c10::guts::to_string(c);
|
||||
return prefix_ + a + b + std::to_string(c);
|
||||
}
|
||||
|
||||
std::string prefix_;
|
||||
|
@ -14,6 +14,7 @@
|
||||
#include <c10/util/MaybeOwned.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <typeindex>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
@ -510,17 +511,17 @@ public:
|
||||
template <
|
||||
typename... Args,
|
||||
std::enable_if_t<
|
||||
!guts::disjunction<
|
||||
!std::disjunction<
|
||||
std::is_lvalue_reference<Args>...,
|
||||
guts::negation<std::is_constructible<IValue, Args>>...>::value,
|
||||
std::negation<std::is_constructible<IValue, Args>>...>::value,
|
||||
std::nullptr_t> = nullptr>
|
||||
IValue(const std::tuple<Args...>& t);
|
||||
template <
|
||||
typename... Args,
|
||||
std::enable_if_t<
|
||||
!guts::disjunction<
|
||||
!std::disjunction<
|
||||
std::is_lvalue_reference<Args>...,
|
||||
guts::negation<std::is_constructible<IValue, Args>>...>::value,
|
||||
std::negation<std::is_constructible<IValue, Args>>...>::value,
|
||||
std::nullptr_t> = nullptr>
|
||||
IValue(std::tuple<Args...>&& t);
|
||||
bool isTuple() const {
|
||||
@ -981,7 +982,7 @@ public:
|
||||
TORCH_FORALL_TAGS(DEFINE_CASE)
|
||||
#undef DEFINE_CASE
|
||||
}
|
||||
return "InvalidTag(" + c10::guts::to_string(static_cast<int>(tag)) + ")";
|
||||
return "InvalidTag(" + std::to_string(static_cast<int>(tag)) + ")";
|
||||
}
|
||||
|
||||
// generic v.to<at::Tensor>() implementations
|
||||
|
@ -1048,7 +1048,7 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
|
||||
using IValueWithStorages = std::tuple<IValue, std::vector<WeakStorage>>;
|
||||
#if __cpp_lib_is_invocable >= 201703
|
||||
static_assert(
|
||||
guts::disjunction<
|
||||
std::disjunction<
|
||||
std::is_invocable_r<IValue, T, Future&>,
|
||||
std::is_invocable_r<IValueWithStorages, T, Future&>>::value,
|
||||
"The callback must have signature IValue(Future&) or "
|
||||
@ -1918,9 +1918,9 @@ template <
|
||||
typename... Args,
|
||||
typename Indices = std::make_index_sequence<sizeof...(Args)>,
|
||||
std::enable_if_t<
|
||||
!guts::disjunction<
|
||||
!std::disjunction<
|
||||
std::is_lvalue_reference<Args>...,
|
||||
guts::negation<std::is_constructible<IValue, Args>>...>::value,
|
||||
std::negation<std::is_constructible<IValue, Args>>...>::value,
|
||||
std::nullptr_t> = nullptr>
|
||||
std::tuple<Args...> generic_to(const IValue& ivalue, _fake_type<std::tuple<Args...>>) {
|
||||
const auto& vals = ivalue.toTupleRef().elements();
|
||||
@ -2098,9 +2098,9 @@ inline IValue::IValue(c10::intrusive_ptr<ivalue::Tuple> v)
|
||||
template <
|
||||
typename... Args,
|
||||
std::enable_if_t<
|
||||
!guts::disjunction<
|
||||
!std::disjunction<
|
||||
std::is_lvalue_reference<Args>...,
|
||||
guts::negation<std::is_constructible<IValue, Args>>...>::value,
|
||||
std::negation<std::is_constructible<IValue, Args>>...>::value,
|
||||
std::nullptr_t>>
|
||||
inline IValue::IValue(const std::tuple<Args...>& t)
|
||||
: IValue(c10::guts::apply(c10::ivalue::Tuple::create<const Args&...>, t)) {
|
||||
@ -2109,9 +2109,9 @@ inline IValue::IValue(const std::tuple<Args...>& t)
|
||||
template <
|
||||
typename... Args,
|
||||
std::enable_if_t<
|
||||
!guts::disjunction<
|
||||
!std::disjunction<
|
||||
std::is_lvalue_reference<Args>...,
|
||||
guts::negation<std::is_constructible<IValue, Args>>...>::value,
|
||||
std::negation<std::is_constructible<IValue, Args>>...>::value,
|
||||
std::nullptr_t>>
|
||||
inline IValue::IValue(std::tuple<Args...>&& t)
|
||||
: IValue(c10::guts::apply(c10::ivalue::Tuple::create<Args&&...>, std::move(t))) {
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <ATen/core/op_registration/infer_schema.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <sstream>
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
@ -8,46 +8,55 @@ namespace detail {
|
||||
namespace infer_schema {
|
||||
namespace {
|
||||
|
||||
std::string fastToString(size_t x) {
|
||||
if (C10_LIKELY(x < 10)) {
|
||||
std::string result;
|
||||
result.push_back('_');
|
||||
result.push_back('0' + x);
|
||||
return result;
|
||||
}
|
||||
return "_" + c10::guts::to_string(x);
|
||||
}
|
||||
|
||||
std::vector<Argument> createArgumentVector(c10::ArrayRef<ArgumentDef> args) {
|
||||
std::vector<Argument> result;
|
||||
result.reserve(args.size());
|
||||
for (const auto i : c10::irange(args.size())) {
|
||||
// Arguments are named "_<index>"
|
||||
result.emplace_back(fastToString(i), (*args[i].getFakeTypeFn)(), (*args[i].getTypeFn)());
|
||||
result.emplace_back(
|
||||
fmt::format("_{}", i),
|
||||
(*args[i].getFakeTypeFn)(),
|
||||
(*args[i].getTypeFn)());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
// This is intentionally a separate function and in a .cpp file
|
||||
// because then the template is smaller and that benefits binary size
|
||||
FunctionSchema make_function_schema(std::string&& name, std::string&& overload_name, c10::ArrayRef<ArgumentDef> arguments, c10::ArrayRef<ArgumentDef> returns) {
|
||||
return FunctionSchema(std::move(name), std::move(overload_name), createArgumentVector(arguments), createArgumentVector(returns));
|
||||
FunctionSchema make_function_schema(
|
||||
std::string&& name,
|
||||
std::string&& overload_name,
|
||||
c10::ArrayRef<ArgumentDef> arguments,
|
||||
c10::ArrayRef<ArgumentDef> returns) {
|
||||
return FunctionSchema(
|
||||
std::move(name),
|
||||
std::move(overload_name),
|
||||
createArgumentVector(arguments),
|
||||
createArgumentVector(returns));
|
||||
}
|
||||
|
||||
FunctionSchema make_function_schema(c10::ArrayRef<ArgumentDef> arguments, c10::ArrayRef<ArgumentDef> returns) {
|
||||
FunctionSchema make_function_schema(
|
||||
c10::ArrayRef<ArgumentDef> arguments,
|
||||
c10::ArrayRef<ArgumentDef> returns) {
|
||||
return make_function_schema("", "", arguments, returns);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace infer_schema
|
||||
} // namespace detail
|
||||
|
||||
c10::optional<std::string> findSchemaDifferences(const FunctionSchema& lhs, const FunctionSchema& rhs) {
|
||||
c10::optional<std::string> findSchemaDifferences(
|
||||
const FunctionSchema& lhs,
|
||||
const FunctionSchema& rhs) {
|
||||
if (lhs.arguments().size() != rhs.arguments().size()) {
|
||||
return "The number of arguments is different. " + guts::to_string(lhs.arguments().size()) +
|
||||
" vs " + guts::to_string(rhs.arguments().size()) + ".";
|
||||
return fmt::format(
|
||||
"The number of arguments is different. {} vs {}.",
|
||||
lhs.arguments().size(),
|
||||
rhs.arguments().size());
|
||||
}
|
||||
if (lhs.returns().size() != rhs.returns().size()) {
|
||||
return "The number of returns is different. " + guts::to_string(lhs.returns().size()) +
|
||||
" vs " + guts::to_string(rhs.returns().size());
|
||||
return fmt::format(
|
||||
"The number of returns is different. {} vs {}.",
|
||||
lhs.returns().size(),
|
||||
rhs.returns().size());
|
||||
}
|
||||
|
||||
for (const auto i : c10::irange(lhs.arguments().size())) {
|
||||
@ -57,8 +66,11 @@ c10::optional<std::string> findSchemaDifferences(const FunctionSchema& lhs, cons
|
||||
// cheaper, particularly when one of the types is a singleton like
|
||||
// NumberType or AnyType.
|
||||
if (leftType.get() != rightType.get() && *leftType != *rightType) {
|
||||
return "Type mismatch in argument " + guts::to_string(i+1) + ": " + lhs.arguments()[i].type()->str() +
|
||||
" vs " + rhs.arguments()[i].type()->str();
|
||||
return fmt::format(
|
||||
"Type mismatch in argument {}: {} vs {}.",
|
||||
i + 1,
|
||||
lhs.arguments()[i].type()->str(),
|
||||
rhs.arguments()[i].type()->str());
|
||||
}
|
||||
}
|
||||
|
||||
@ -67,8 +79,11 @@ c10::optional<std::string> findSchemaDifferences(const FunctionSchema& lhs, cons
|
||||
const TypePtr& rightType = rhs.returns()[i].type();
|
||||
// See above about comparing pointers first.
|
||||
if (leftType.get() != rightType.get() && *leftType != *rightType) {
|
||||
return "Type mismatch in return " + guts::to_string(i+1) + ": " + lhs.returns()[i].type()->str() +
|
||||
" vs " + rhs.returns()[i].type()->str();
|
||||
return fmt::format(
|
||||
"Type mismatch in return {}: {} vs {}.",
|
||||
i + 1,
|
||||
lhs.returns()[i].type()->str(),
|
||||
rhs.returns()[i].type()->str());
|
||||
}
|
||||
}
|
||||
|
||||
@ -76,4 +91,4 @@ c10::optional<std::string> findSchemaDifferences(const FunctionSchema& lhs, cons
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace c10
|
||||
|
@ -214,7 +214,7 @@ TEST(OrderedPreservingDictTest, test_range_erase) {
|
||||
const int64_t nb_values = 1000;
|
||||
HMap map;
|
||||
for (const auto i : c10::irange(nb_values)) {
|
||||
map[c10::guts::to_string(i)] = i;
|
||||
map[std::to_string(i)] = i;
|
||||
auto begin = map.begin();
|
||||
for (int64_t j = 0; j <= i; ++j, begin++) {
|
||||
TORCH_INTERNAL_ASSERT(begin->second == j);
|
||||
@ -239,8 +239,7 @@ TEST(OrderedPreservingDictTest, test_range_erase) {
|
||||
if (i >= 10 && i < 220) {
|
||||
continue;
|
||||
}
|
||||
auto exp_it =
|
||||
std::pair<std::string, std::int64_t>(c10::guts::to_string(i), i);
|
||||
auto exp_it = std::pair<std::string, std::int64_t>(std::to_string(i), i);
|
||||
TORCH_INTERNAL_ASSERT(*it == exp_it);
|
||||
++it;
|
||||
}
|
||||
@ -313,13 +312,13 @@ TEST(OrderedPreservingDictTest, test_copy_constructor_and_operator) {
|
||||
const std::size_t nb_values = 100;
|
||||
HMap map;
|
||||
for (const auto i : c10::irange(nb_values)) {
|
||||
map[c10::guts::to_string(i)] = c10::guts::to_string(i);
|
||||
map[std::to_string(i)] = std::to_string(i);
|
||||
}
|
||||
|
||||
HMap map_copy = map;
|
||||
HMap map_copy2(map);
|
||||
HMap map_copy3;
|
||||
map_copy3[c10::guts::to_string(0)] = c10::guts::to_string(0);
|
||||
map_copy3[std::to_string(0)] = std::to_string(0);
|
||||
|
||||
map_copy3 = map;
|
||||
|
||||
|
@ -15,7 +15,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/C++17.h>
|
||||
#include <c10/util/Deprecated.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
|
@ -233,50 +233,6 @@ struct function_takes_identity_argument<
|
||||
#endif
|
||||
} // namespace detail
|
||||
|
||||
// GCC 4.8 doesn't define std::to_string, even though that's in C++11. Let's
|
||||
// define it.
|
||||
namespace detail {
|
||||
class DummyClassForToString final {};
|
||||
} // namespace detail
|
||||
} // namespace guts
|
||||
} // namespace c10
|
||||
namespace std {
|
||||
// We use SFINAE to detect if std::to_string exists for a type, but that only
|
||||
// works if the function name is defined. So let's define a std::to_string for a
|
||||
// dummy type. If you're getting an error here saying that this overload doesn't
|
||||
// match your std::to_string() call, then you're calling std::to_string() but
|
||||
// should be calling c10::guts::to_string().
|
||||
inline std::string to_string(c10::guts::detail::DummyClassForToString) {
|
||||
return "";
|
||||
}
|
||||
|
||||
} // namespace std
|
||||
namespace c10 {
|
||||
namespace guts {
|
||||
namespace detail {
|
||||
|
||||
template <class T, class Enable = void>
|
||||
struct to_string_ final {
|
||||
static std::string call(T value) {
|
||||
std::ostringstream str;
|
||||
str << value;
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
// If a std::to_string exists, use that instead
|
||||
template <class T>
|
||||
struct to_string_<T, void_t<decltype(std::to_string(std::declval<T>()))>>
|
||||
final {
|
||||
static std::string call(T value) {
|
||||
return std::to_string(value);
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
template <class T>
|
||||
inline std::string to_string(T value) {
|
||||
return detail::to_string_<T>::call(value);
|
||||
}
|
||||
|
||||
} // namespace guts
|
||||
} // namespace c10
|
||||
|
||||
|
@ -182,8 +182,7 @@ TEST(OptimTest, OptimizerAccessors) {
|
||||
|
||||
// test for state() with non-const reference return
|
||||
auto& state_ = static_cast<AdagradParamState&>(
|
||||
*(optimizer
|
||||
.state()[c10::guts::to_string(params_1[0].unsafeGetTensorImpl())]));
|
||||
*(optimizer.state()[params_1[0].unsafeGetTensorImpl()]));
|
||||
state_.step(state_.step() + 1);
|
||||
|
||||
const auto& optimizer_ = Adagrad(params, options);
|
||||
|
@ -54,9 +54,9 @@ void is_optimizer_param_group_equal(
|
||||
|
||||
template <typename DerivedOptimizerParamState>
|
||||
void is_optimizer_state_equal(
|
||||
const ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>&
|
||||
const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
|
||||
lhs_state,
|
||||
const ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>&
|
||||
const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
|
||||
rhs_state) {
|
||||
ASSERT_TRUE(lhs_state.size() == rhs_state.size());
|
||||
for (const auto& value : lhs_state) {
|
||||
@ -188,7 +188,7 @@ void write_tensors_to_archive(
|
||||
key + "/size", torch::tensor(static_cast<int64_t>(buffers.size())));
|
||||
for (const auto index : c10::irange(buffers.size())) {
|
||||
archive.write(
|
||||
key + "/" + c10::to_string(index), buffers[index], /*is_buffer=*/true);
|
||||
key + "/" + std::to_string(index), buffers[index], /*is_buffer=*/true);
|
||||
}
|
||||
}
|
||||
|
||||
@ -218,7 +218,7 @@ TEST(SerializeTest, KeysFunc) {
|
||||
torch::serialize::OutputArchive output_archive;
|
||||
for (const auto i : c10::irange(3)) {
|
||||
output_archive.write(
|
||||
"element/" + c10::to_string(i), c10::IValue(static_cast<int64_t>(i)));
|
||||
"element/" + std::to_string(i), c10::IValue(static_cast<int64_t>(i)));
|
||||
}
|
||||
output_archive.save_to(tempfile.name);
|
||||
torch::serialize::InputArchive input_archive;
|
||||
@ -226,7 +226,7 @@ TEST(SerializeTest, KeysFunc) {
|
||||
std::vector<std::string> keys = input_archive.keys();
|
||||
ASSERT_EQ(keys.size(), 3);
|
||||
for (const auto i : c10::irange(keys.size())) {
|
||||
ASSERT_EQ(keys[i], "element/" + c10::to_string(i));
|
||||
ASSERT_EQ(keys[i], "element/" + std::to_string(i));
|
||||
}
|
||||
}
|
||||
|
||||
@ -235,7 +235,7 @@ TEST(SerializeTest, TryReadFunc) {
|
||||
torch::serialize::OutputArchive output_archive;
|
||||
for (const auto i : c10::irange(3)) {
|
||||
output_archive.write(
|
||||
"element/" + c10::to_string(i), c10::IValue(static_cast<int64_t>(i)));
|
||||
"element/" + std::to_string(i), c10::IValue(static_cast<int64_t>(i)));
|
||||
}
|
||||
output_archive.save_to(tempfile.name);
|
||||
torch::serialize::InputArchive input_archive;
|
||||
@ -557,7 +557,7 @@ TEST(SerializeTest, Optim_Adagrad) {
|
||||
const auto& params_ = optim1.param_groups()[0].params();
|
||||
const auto& optim1_state = optim1.state();
|
||||
for (const auto& param : params_) {
|
||||
auto key_ = c10::guts::to_string(param.unsafeGetTensorImpl());
|
||||
auto key_ = param.unsafeGetTensorImpl();
|
||||
const AdagradParamState& curr_state_ =
|
||||
static_cast<const AdagradParamState&>(*(optim1_state.at(key_).get()));
|
||||
sum_buffers.emplace_back(curr_state_.sum());
|
||||
@ -602,7 +602,7 @@ TEST(SerializeTest, Optim_SGD) {
|
||||
const auto& optim1_state = optim1.state();
|
||||
for (const auto i : c10::irange(params_.size())) {
|
||||
if (i != (params_.size() - 1)) {
|
||||
auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl());
|
||||
auto key_ = params_[i].unsafeGetTensorImpl();
|
||||
const SGDParamState& curr_state_ =
|
||||
static_cast<const SGDParamState&>(*(optim1_state.at(key_).get()));
|
||||
momentum_buffers.emplace_back(curr_state_.momentum_buffer());
|
||||
@ -653,7 +653,7 @@ TEST(SerializeTest, Optim_Adam) {
|
||||
const auto& optim1_state = optim1.state();
|
||||
for (const auto i : c10::irange(params_.size())) {
|
||||
if (i != (params_.size() - 1)) {
|
||||
auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl());
|
||||
auto key_ = params_[i].unsafeGetTensorImpl();
|
||||
const AdamParamState& curr_state_ =
|
||||
static_cast<const AdamParamState&>(*(optim1_state.at(key_).get()));
|
||||
step_buffers.emplace_back(curr_state_.step());
|
||||
@ -712,7 +712,7 @@ TEST(SerializeTest, Optim_AdamW) {
|
||||
const auto& optim1_state = optim1.state();
|
||||
for (const auto i : c10::irange(params_.size())) {
|
||||
if (i != (params_.size() - 1)) {
|
||||
auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl());
|
||||
auto key_ = params_[i].unsafeGetTensorImpl();
|
||||
const AdamWParamState& curr_state_ =
|
||||
static_cast<const AdamWParamState&>(*(optim1_state.at(key_).get()));
|
||||
step_buffers.emplace_back(curr_state_.step());
|
||||
@ -769,7 +769,7 @@ TEST(SerializeTest, Optim_RMSprop) {
|
||||
const auto& optim1_state = optim1.state();
|
||||
for (const auto i : c10::irange(params_.size())) {
|
||||
if (i != (params_.size() - 1)) {
|
||||
auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl());
|
||||
auto key_ = params_[i].unsafeGetTensorImpl();
|
||||
const RMSpropParamState& curr_state_ =
|
||||
static_cast<const RMSpropParamState&>(*(optim1_state.at(key_).get()));
|
||||
square_average_buffers.emplace_back(curr_state_.square_avg());
|
||||
@ -799,8 +799,8 @@ TEST(SerializeTest, Optim_RMSprop) {
|
||||
// old RMSprop didn't track step value
|
||||
for (const auto i : c10::irange(params1_2_.size())) {
|
||||
if (i != (params1_2_.size() - 1)) {
|
||||
auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl());
|
||||
auto key1_2_ = c10::guts::to_string(params1_2_[i].unsafeGetTensorImpl());
|
||||
auto key_ = params_[i].unsafeGetTensorImpl();
|
||||
auto key1_2_ = params1_2_[i].unsafeGetTensorImpl();
|
||||
const RMSpropParamState& curr_state_ =
|
||||
static_cast<const RMSpropParamState&>(*(optim1_state.at(key_).get()));
|
||||
RMSpropParamState& curr_state1_2_ =
|
||||
@ -838,7 +838,7 @@ TEST(SerializeTest, Optim_LBFGS) {
|
||||
std::deque<at::Tensor> old_dirs, old_stps;
|
||||
|
||||
const auto& params_ = optim1.param_groups()[0].params();
|
||||
auto key_ = c10::guts::to_string(params_[0].unsafeGetTensorImpl());
|
||||
auto key_ = params_[0].unsafeGetTensorImpl();
|
||||
const auto& optim1_state =
|
||||
static_cast<const LBFGSParamState&>(*(optim1.state().at(key_).get()));
|
||||
d = optim1_state.d();
|
||||
@ -865,7 +865,7 @@ TEST(SerializeTest, Optim_LBFGS) {
|
||||
torch::load, optim1_2, optim_tempfile_old_format.name);
|
||||
|
||||
const auto& params1_2_ = optim1_2.param_groups()[0].params();
|
||||
auto param_key = c10::guts::to_string(params1_2_[0].unsafeGetTensorImpl());
|
||||
auto param_key = params1_2_[0].unsafeGetTensorImpl();
|
||||
auto& optim1_2_state =
|
||||
static_cast<LBFGSParamState&>(*(optim1_2.state().at(param_key).get()));
|
||||
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include <torch/serialize/archive.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
@ -84,8 +85,7 @@ class TORCH_API Adagrad : public Optimizer {
|
||||
p.data(),
|
||||
defaults.initial_accumulator_value(),
|
||||
at::MemoryFormat::Preserve));
|
||||
state_[c10::guts::to_string(p.unsafeGetTensorImpl())] =
|
||||
std::move(state);
|
||||
state_[p.unsafeGetTensorImpl()] = std::move(state);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -156,12 +156,12 @@ class TORCH_API Optimizer {
|
||||
const std::vector<OptimizerParamGroup>& param_groups() const noexcept;
|
||||
|
||||
/// Provides a reference to the state this optimizer holds
|
||||
ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>&
|
||||
ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
|
||||
state() noexcept;
|
||||
|
||||
/// Provides a const reference to the state this optimizer holds
|
||||
const ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>&
|
||||
state() const noexcept;
|
||||
const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>& state()
|
||||
const noexcept;
|
||||
|
||||
/// Serializes the optimizer state into the given `archive`.
|
||||
virtual void save(serialize::OutputArchive& archive) const;
|
||||
@ -171,7 +171,7 @@ class TORCH_API Optimizer {
|
||||
|
||||
protected:
|
||||
std::vector<OptimizerParamGroup> param_groups_;
|
||||
ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>> state_;
|
||||
ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> state_;
|
||||
std::unique_ptr<OptimizerOptions> defaults_;
|
||||
};
|
||||
|
||||
|
@ -17,11 +17,12 @@ namespace detail {
|
||||
template <typename DerivedOptimizerParamState>
|
||||
void serialize(
|
||||
serialize::OutputArchive& archive,
|
||||
const ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>&
|
||||
const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
|
||||
state) {
|
||||
for (const auto& item : state) {
|
||||
serialize::OutputArchive param_state_archive(archive.compilation_unit());
|
||||
std::string tensorimpl_key = item.first;
|
||||
std::string tensorimpl_key =
|
||||
std::to_string(reinterpret_cast<size_t>(item.first));
|
||||
const DerivedOptimizerParamState& curr_state =
|
||||
static_cast<const DerivedOptimizerParamState&>(*(item.second.get()));
|
||||
curr_state.serialize(param_state_archive);
|
||||
@ -33,15 +34,14 @@ void serialize(
|
||||
template <typename DerivedOptimizerParamState>
|
||||
void serialize(
|
||||
serialize::InputArchive& archive,
|
||||
ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>&
|
||||
state) {
|
||||
ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>& state) {
|
||||
std::vector<std::string> tensorimpl_keys = archive.keys();
|
||||
for (const std::string& tensorimpl_key : tensorimpl_keys) {
|
||||
serialize::InputArchive param_state_archive;
|
||||
archive.read(tensorimpl_key, param_state_archive);
|
||||
DerivedOptimizerParamState param_state;
|
||||
param_state.serialize(param_state_archive);
|
||||
state[tensorimpl_key] =
|
||||
state[reinterpret_cast<void*>(std::stoull(tensorimpl_key))] =
|
||||
std::make_unique<DerivedOptimizerParamState>(param_state);
|
||||
}
|
||||
}
|
||||
@ -61,8 +61,9 @@ void serialize(
|
||||
"params/size", torch::tensor(static_cast<int64_t>(params.size())));
|
||||
for (const auto index : c10::irange(params.size())) {
|
||||
param_group_archive.write(
|
||||
"params/" + c10::guts::to_string(index),
|
||||
IValue(c10::guts::to_string(params[index].unsafeGetTensorImpl())));
|
||||
"params/" + std::to_string(index),
|
||||
IValue(std::to_string(
|
||||
reinterpret_cast<size_t>(params[index].unsafeGetTensorImpl()))));
|
||||
}
|
||||
const DerivedOptimizerParamOptions& param_group_options =
|
||||
static_cast<const DerivedOptimizerParamOptions&>(
|
||||
@ -71,8 +72,7 @@ void serialize(
|
||||
param_group_archive.compilation_unit());
|
||||
param_group_options.serialize(param_group_options_archive);
|
||||
param_group_archive.write("options", param_group_options_archive);
|
||||
archive.write(
|
||||
"param_groups/" + c10::guts::to_string(i), param_group_archive);
|
||||
archive.write("param_groups/" + std::to_string(i), param_group_archive);
|
||||
}
|
||||
}
|
||||
|
||||
@ -92,15 +92,14 @@ void serialize(
|
||||
const int64_t param_groups_size = param_groups_size_tensor.item<int64_t>();
|
||||
for (const auto i : c10::irange(param_groups_size)) {
|
||||
serialize::InputArchive param_group_archive;
|
||||
archive.read(
|
||||
"param_groups/" + c10::guts::to_string(i), param_group_archive);
|
||||
archive.read("param_groups/" + std::to_string(i), param_group_archive);
|
||||
torch::Tensor size_tensor;
|
||||
param_group_archive.read("params/size", size_tensor);
|
||||
const int64_t size = size_tensor.item<int64_t>();
|
||||
std::vector<std::string> params;
|
||||
for (const auto index : c10::irange(size)) {
|
||||
IValue ivalue;
|
||||
param_group_archive.read("params/" + c10::to_string(index), ivalue);
|
||||
param_group_archive.read("params/" + std::to_string(index), ivalue);
|
||||
std::string element = ivalue.toStringRef();
|
||||
params.emplace_back(element);
|
||||
}
|
||||
@ -170,8 +169,7 @@ void serialize(serialize::InputArchive& archive, Optimizer& optimizer) {
|
||||
TORCH_INTERNAL_ASSERT(pytorch_version.toStringRef() == "1.5.0");
|
||||
serialize::InputArchive state_archive;
|
||||
archive.read("state", state_archive);
|
||||
ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>
|
||||
saved_state;
|
||||
ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> saved_state;
|
||||
detail::serialize<DerivedOptimizerParamState>(state_archive, saved_state);
|
||||
|
||||
serialize::InputArchive param_groups_archive;
|
||||
@ -194,10 +192,11 @@ void serialize(serialize::InputArchive& archive, Optimizer& optimizer) {
|
||||
"loaded state dict contains a parameter group that has a different size than the optimizer's parameter group");
|
||||
|
||||
for (const auto idx : c10::irange(params.size())) {
|
||||
if (saved_state.find(param_group_old_keys[idx]) != saved_state.end()) {
|
||||
optimizer
|
||||
.state()[c10::guts::to_string(params[idx].unsafeGetTensorImpl())] =
|
||||
std::move(saved_state[param_group_old_keys[idx]]);
|
||||
auto param_group_old_key =
|
||||
reinterpret_cast<void*>(std::stoull(param_group_old_keys[idx]));
|
||||
if (saved_state.find(param_group_old_key) != saved_state.end()) {
|
||||
optimizer.state()[params[idx].unsafeGetTensorImpl()] =
|
||||
std::move(saved_state[param_group_old_key]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -213,7 +212,7 @@ void serialize(
|
||||
key + "/size", torch::tensor(static_cast<int64_t>(buffers.size())));
|
||||
for (const auto index : c10::irange(buffers.size())) {
|
||||
archive.write(
|
||||
key + "/" + c10::to_string(index), buffers[index], /*is_buffer=*/true);
|
||||
key + "/" + std::to_string(index), buffers[index], /*is_buffer=*/true);
|
||||
}
|
||||
}
|
||||
|
||||
@ -230,7 +229,7 @@ void serialize(
|
||||
for (const auto index : c10::irange(size)) {
|
||||
buffers.emplace_back();
|
||||
archive.read(
|
||||
key + "/" + c10::to_string(index), buffers.back(), /*is_buffer=*/true);
|
||||
key + "/" + std::to_string(index), buffers.back(), /*is_buffer=*/true);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -67,7 +67,7 @@ void save(const std::vector<torch::Tensor>& tensor_vec, SaveToArgs&&... args) {
|
||||
serialize::OutputArchive archive(std::make_shared<jit::CompilationUnit>());
|
||||
for (const auto i : c10::irange(tensor_vec.size())) {
|
||||
auto& value = tensor_vec[i];
|
||||
archive.write(c10::to_string(i), value);
|
||||
archive.write(std::to_string(i), value);
|
||||
}
|
||||
archive.save_to(std::forward<SaveToArgs>(args)...);
|
||||
}
|
||||
@ -135,7 +135,7 @@ void load(std::vector<torch::Tensor>& tensor_vec, LoadFromArgs&&... args) {
|
||||
// the serialized `std::vector<torch::Tensor>`.
|
||||
size_t index = 0;
|
||||
torch::Tensor value;
|
||||
while (archive.try_read(c10::to_string(index), value)) {
|
||||
while (archive.try_read(std::to_string(index), value)) {
|
||||
tensor_vec.push_back(std::move(value));
|
||||
value = torch::Tensor();
|
||||
index++;
|
||||
|
@ -77,11 +77,11 @@ Tensor Adagrad::step(LossClosure closure) {
|
||||
}
|
||||
auto grad = p.grad();
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
state_[c10::guts::to_string(p.unsafeGetTensorImpl())] != nullptr,
|
||||
state_[p.unsafeGetTensorImpl()] != nullptr,
|
||||
"state found NULL for the Tensor ",
|
||||
p);
|
||||
auto& state = static_cast<AdagradParamState&>(
|
||||
*state_[c10::guts::to_string(p.unsafeGetTensorImpl())]);
|
||||
auto& state =
|
||||
static_cast<AdagradParamState&>(*state_[p.unsafeGetTensorImpl()]);
|
||||
auto& options = static_cast<AdagradOptions&>(group.options());
|
||||
|
||||
state.step(state.step() + 1);
|
||||
@ -147,8 +147,7 @@ void Adagrad::load(serialize::InputArchive& archive) {
|
||||
auto state = std::make_unique<AdagradParamState>();
|
||||
state->step(step_buffers[idx]);
|
||||
state->sum(sum_buffers[idx]);
|
||||
state_[c10::guts::to_string(params[idx].unsafeGetTensorImpl())] =
|
||||
std::move(state);
|
||||
state_[params[idx].unsafeGetTensorImpl()] = std::move(state);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -84,8 +84,7 @@ Tensor Adam::step(LossClosure closure) {
|
||||
}
|
||||
auto grad = p.grad();
|
||||
TORCH_CHECK(!grad.is_sparse(), "Adam does not support sparse gradients" /*, please consider SparseAdam instead*/);
|
||||
auto param_state =
|
||||
state_.find(c10::guts::to_string(p.unsafeGetTensorImpl()));
|
||||
auto param_state = state_.find(p.unsafeGetTensorImpl());
|
||||
auto& options = static_cast<AdamOptions&>(group.options());
|
||||
|
||||
// State initialization
|
||||
@ -100,12 +99,11 @@ Tensor Adam::step(LossClosure closure) {
|
||||
// Maintains max of all exp. moving avg. of sq. grad. values
|
||||
state->max_exp_avg_sq(torch::zeros_like(p, MemoryFormat::Preserve));
|
||||
}
|
||||
state_[c10::guts::to_string(p.unsafeGetTensorImpl())] =
|
||||
std::move(state);
|
||||
state_[p.unsafeGetTensorImpl()] = std::move(state);
|
||||
}
|
||||
|
||||
auto& state = static_cast<AdamParamState&>(
|
||||
*state_[c10::guts::to_string(p.unsafeGetTensorImpl())]);
|
||||
auto& state =
|
||||
static_cast<AdamParamState&>(*state_[p.unsafeGetTensorImpl()]);
|
||||
auto& exp_avg = state.exp_avg();
|
||||
auto& exp_avg_sq = state.exp_avg_sq();
|
||||
auto& max_exp_avg_sq = state.max_exp_avg_sq();
|
||||
@ -179,8 +177,7 @@ void Adam::load(serialize::InputArchive& archive) {
|
||||
if (idx < max_exp_average_sq_buffers.size()) {
|
||||
state->max_exp_avg_sq(max_exp_average_sq_buffers.at(idx));
|
||||
}
|
||||
state_[c10::guts::to_string(params.at(idx).unsafeGetTensorImpl())] =
|
||||
std::move(state);
|
||||
state_[params.at(idx).unsafeGetTensorImpl()] = std::move(state);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -84,8 +84,7 @@ Tensor AdamW::step(LossClosure closure) {
|
||||
}
|
||||
const auto& grad = p.grad();
|
||||
TORCH_CHECK(!grad.is_sparse(), "AdamW does not support sparse gradients" /*, please consider SparseAdamW instead*/);
|
||||
auto param_state =
|
||||
state_.find(c10::guts::to_string(p.unsafeGetTensorImpl()));
|
||||
auto param_state = state_.find(p.unsafeGetTensorImpl());
|
||||
auto& options = static_cast<AdamWOptions&>(group.options());
|
||||
|
||||
// Perform stepweight decay
|
||||
@ -105,12 +104,11 @@ Tensor AdamW::step(LossClosure closure) {
|
||||
// Maintains max of all exp. moving avg. of sq. grad. values
|
||||
state->max_exp_avg_sq(torch::zeros_like(p, MemoryFormat::Preserve));
|
||||
}
|
||||
state_[c10::guts::to_string(p.unsafeGetTensorImpl())] =
|
||||
std::move(state);
|
||||
state_[p.unsafeGetTensorImpl()] = std::move(state);
|
||||
}
|
||||
|
||||
auto& state = static_cast<AdamWParamState&>(
|
||||
*state_[c10::guts::to_string(p.unsafeGetTensorImpl())]);
|
||||
auto& state =
|
||||
static_cast<AdamWParamState&>(*state_[p.unsafeGetTensorImpl()]);
|
||||
auto& exp_avg = state.exp_avg();
|
||||
auto& exp_avg_sq = state.exp_avg_sq();
|
||||
auto& max_exp_avg_sq = state.max_exp_avg_sq();
|
||||
@ -180,8 +178,7 @@ void AdamW::load(serialize::InputArchive& archive) {
|
||||
if (idx < max_exp_average_sq_buffers.size()) {
|
||||
state->max_exp_avg_sq(max_exp_average_sq_buffers.at(idx));
|
||||
}
|
||||
state_[c10::guts::to_string(params.at(idx).unsafeGetTensorImpl())] =
|
||||
std::move(state);
|
||||
state_[params.at(idx).unsafeGetTensorImpl()] = std::move(state);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -438,14 +438,13 @@ Tensor LBFGS::step(LossClosure closure) {
|
||||
|
||||
// NOTE: LBFGS has only global state, but we register it as state for
|
||||
// the first param, because this helps with casting in load_state_dict
|
||||
auto param_state =
|
||||
state_.find(c10::guts::to_string(_params.at(0).unsafeGetTensorImpl()));
|
||||
auto param_state = state_.find(_params.at(0).unsafeGetTensorImpl());
|
||||
if (param_state == state_.end()) {
|
||||
state_[c10::guts::to_string(_params.at(0).unsafeGetTensorImpl())] =
|
||||
state_[_params.at(0).unsafeGetTensorImpl()] =
|
||||
std::make_unique<LBFGSParamState>();
|
||||
}
|
||||
auto& state = static_cast<LBFGSParamState&>(
|
||||
*state_[c10::guts::to_string(_params.at(0).unsafeGetTensorImpl())]);
|
||||
*state_[_params.at(0).unsafeGetTensorImpl()]);
|
||||
// evaluate initial f(x) and df/dx
|
||||
Tensor orig_loss;
|
||||
{
|
||||
@ -655,8 +654,7 @@ void LBFGS::load(serialize::InputArchive& archive) {
|
||||
state->prev_loss(prev_loss.item<double>());
|
||||
state->old_dirs(old_dirs);
|
||||
state->old_stps(old_stps);
|
||||
state_[c10::guts::to_string(
|
||||
param_groups_.at(0).params().at(0).unsafeGetTensorImpl())] =
|
||||
state_[param_groups_.at(0).params().at(0).unsafeGetTensorImpl()] =
|
||||
std::move(state);
|
||||
}
|
||||
}
|
||||
|
@ -109,7 +109,7 @@ void Optimizer::add_param_group(const OptimizerParamGroup& param_group) {
|
||||
}
|
||||
for (const auto& p : param_group_.params()) {
|
||||
TORCH_CHECK(
|
||||
state_.count(c10::guts::to_string(p.unsafeGetTensorImpl())) == 0,
|
||||
state_.count(p.unsafeGetTensorImpl()) == 0,
|
||||
"some parameters appear in more than one parameter group");
|
||||
}
|
||||
param_groups_.emplace_back(std::move(param_group_));
|
||||
@ -171,12 +171,12 @@ const std::vector<OptimizerParamGroup>& Optimizer::param_groups()
|
||||
return param_groups_;
|
||||
}
|
||||
|
||||
ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>&
|
||||
Optimizer::state() noexcept {
|
||||
ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>& Optimizer::
|
||||
state() noexcept {
|
||||
return state_;
|
||||
}
|
||||
|
||||
const ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>&
|
||||
const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
|
||||
Optimizer::state() const noexcept {
|
||||
return state_;
|
||||
}
|
||||
|
@ -85,8 +85,7 @@ Tensor RMSprop::step(LossClosure closure) {
|
||||
auto grad = p.grad();
|
||||
TORCH_CHECK(
|
||||
!grad.is_sparse(), "RMSprop does not support sparse gradients");
|
||||
auto param_state =
|
||||
state_.find(c10::guts::to_string(p.unsafeGetTensorImpl()));
|
||||
auto param_state = state_.find(p.unsafeGetTensorImpl());
|
||||
auto& options = static_cast<RMSpropOptions&>(group.options());
|
||||
|
||||
// State initialization
|
||||
@ -100,12 +99,11 @@ Tensor RMSprop::step(LossClosure closure) {
|
||||
if (options.centered()) {
|
||||
state->grad_avg(torch::zeros_like(p, MemoryFormat::Preserve));
|
||||
}
|
||||
state_[c10::guts::to_string(p.unsafeGetTensorImpl())] =
|
||||
std::move(state);
|
||||
state_[p.unsafeGetTensorImpl()] = std::move(state);
|
||||
}
|
||||
|
||||
auto& state = static_cast<RMSpropParamState&>(
|
||||
*state_[c10::guts::to_string(p.unsafeGetTensorImpl())]);
|
||||
auto& state =
|
||||
static_cast<RMSpropParamState&>(*state_[p.unsafeGetTensorImpl()]);
|
||||
auto& square_avg = state.square_avg();
|
||||
auto alpha = options.alpha();
|
||||
|
||||
@ -176,8 +174,7 @@ void RMSprop::load(serialize::InputArchive& archive) {
|
||||
if (idx < grad_average_buffers.size()) {
|
||||
state->grad_avg(grad_average_buffers.at(idx));
|
||||
}
|
||||
state_[c10::guts::to_string(params[idx].unsafeGetTensorImpl())] =
|
||||
std::move(state);
|
||||
state_[params[idx].unsafeGetTensorImpl()] = std::move(state);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -84,14 +84,12 @@ Tensor SGD::step(LossClosure closure) {
|
||||
}
|
||||
if (momentum != 0) {
|
||||
Tensor buf;
|
||||
auto param_state =
|
||||
state_.find(c10::guts::to_string(p.unsafeGetTensorImpl()));
|
||||
auto param_state = state_.find(p.unsafeGetTensorImpl());
|
||||
if (param_state == state_.end()) {
|
||||
buf = torch::clone(d_p).detach();
|
||||
auto state = std::make_unique<SGDParamState>();
|
||||
state->momentum_buffer(buf);
|
||||
state_[c10::guts::to_string(p.unsafeGetTensorImpl())] =
|
||||
std::move(state);
|
||||
state_[p.unsafeGetTensorImpl()] = std::move(state);
|
||||
} else {
|
||||
buf = static_cast<SGDParamState&>(*param_state->second)
|
||||
.momentum_buffer();
|
||||
@ -130,8 +128,7 @@ void SGD::load(serialize::InputArchive& archive) {
|
||||
for (const auto idx : c10::irange(momentum_buffers.size())) {
|
||||
auto state = std::make_unique<SGDParamState>();
|
||||
state->momentum_buffer(momentum_buffers[idx]);
|
||||
state_[c10::guts::to_string(params[idx].unsafeGetTensorImpl())] =
|
||||
std::move(state);
|
||||
state_[params[idx].unsafeGetTensorImpl()] = std::move(state);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3,7 +3,6 @@
|
||||
#include <torch/csrc/autograd/profiler_kineto.h>
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <c10/util/C++17.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
@ -16,7 +16,6 @@
|
||||
|
||||
#include <ATen/core/TensorBase.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/C++17.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Logging.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
@ -1,4 +1,3 @@
|
||||
#include <c10/util/C++17.h>
|
||||
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
|
||||
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
||||
#include <torch/csrc/distributed/rpc/utils.h>
|
||||
|
@ -1,7 +1,5 @@
|
||||
#include <torch/csrc/distributed/rpc/python_call.h>
|
||||
|
||||
#include <c10/util/C++17.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
@ -1,5 +1,4 @@
|
||||
#include <ATen/ThreadLocalState.h>
|
||||
#include <c10/util/C++17.h>
|
||||
#include <torch/csrc/distributed/autograd/context/container.h>
|
||||
#include <torch/csrc/distributed/autograd/utils.h>
|
||||
#include <torch/csrc/distributed/rpc/message.h>
|
||||
|
@ -1,7 +1,5 @@
|
||||
#include <torch/csrc/distributed/rpc/python_resp.h>
|
||||
|
||||
#include <c10/util/C++17.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
@ -1,6 +1,5 @@
|
||||
#include <torch/csrc/distributed/rpc/script_resp.h>
|
||||
|
||||
#include <c10/util/C++17.h>
|
||||
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
||||
#include <torch/csrc/jit/serialization/pickle.h>
|
||||
#include <torch/csrc/jit/serialization/unpickler.h>
|
||||
|
@ -1,6 +1,5 @@
|
||||
#include <torch/csrc/distributed/rpc/unpickled_python_call.h>
|
||||
|
||||
#include <c10/util/C++17.h>
|
||||
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
|
||||
|
||||
namespace torch {
|
||||
|
@ -1,6 +1,5 @@
|
||||
#include <torch/csrc/distributed/rpc/unpickled_python_remote_call.h>
|
||||
|
||||
#include <c10/util/C++17.h>
|
||||
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
|
||||
|
||||
namespace torch {
|
||||
|
@ -93,7 +93,7 @@ C10_EXPORT std::string kindToString(int kind) {
|
||||
TC_FORALL_TOKEN_KINDS(DEFINE_CASE)
|
||||
#undef DEFINE_CASE
|
||||
default:
|
||||
throw std::runtime_error("Unknown kind: " + c10::guts::to_string(kind));
|
||||
throw std::runtime_error("Unknown kind: " + std::to_string(kind));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -518,8 +518,7 @@ struct Lexer {
|
||||
indent_stack.pop_back();
|
||||
next_tokens.emplace_back(TK_DEDENT, r.range);
|
||||
if (indent_stack.empty()) {
|
||||
reportError(
|
||||
"invalid indent level " + c10::guts::to_string(depth), r);
|
||||
reportError("invalid indent level " + std::to_string(depth), r);
|
||||
}
|
||||
}
|
||||
return; // We've already queued the tokens
|
||||
|
@ -138,7 +138,7 @@ c10::optional<AliasInfo> SchemaTypeParser::parseAliasAnnotation() {
|
||||
L.expect(')');
|
||||
} else if (L.nextIf('!')) {
|
||||
alias_info.addBeforeSet(
|
||||
Symbol::fromQualString("alias::$" + c10::guts::to_string(next_id++)));
|
||||
Symbol::fromQualString("alias::$" + std::to_string(next_id++)));
|
||||
alias_info.setIsWrite(true);
|
||||
} else {
|
||||
return c10::nullopt;
|
||||
|
@ -508,7 +508,7 @@ RangeValue::RangeValue(
|
||||
if (!typ->cast<IntType>()) {
|
||||
throw ErrorReport(loc)
|
||||
<< "all inputs of range must be ints, found " << typ->repr_str()
|
||||
<< " in argument " << c10::guts::to_string(i);
|
||||
<< " in argument " << std::to_string(i);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -564,7 +564,7 @@ mobile::Module _load_for_mobile_impl(
|
||||
// Add model_name and model_size to metadata_map
|
||||
extra_files.insert(std::make_pair("model_name", result.name()));
|
||||
extra_files.insert(
|
||||
std::make_pair("model_size", c10::guts::to_string(model_size)));
|
||||
std::make_pair("model_size", std::to_string(model_size)));
|
||||
metadata_map = observer->processMetadataFromExtra(extra_files);
|
||||
observer->onExitLoadModel(instance_key, metadata_map);
|
||||
}
|
||||
|
@ -63,7 +63,7 @@ void SGD::add_param_group(const SGDParamGroup& param_group) {
|
||||
}
|
||||
for (const auto& p : param_group_.params()) {
|
||||
TORCH_CHECK(
|
||||
state_.count(c10::guts::to_string(p.unsafeGetTensorImpl())) == 0,
|
||||
state_.count(p.unsafeGetTensorImpl()) == 0,
|
||||
"some parameters appear in more than one parameter group");
|
||||
}
|
||||
param_groups_.emplace_back(std::move(param_group_));
|
||||
@ -104,14 +104,12 @@ Tensor SGD::step(const LossClosure& closure) {
|
||||
}
|
||||
if (momentum != 0) {
|
||||
Tensor buf;
|
||||
auto param_state =
|
||||
state_.find(c10::guts::to_string(p.unsafeGetTensorImpl()));
|
||||
auto param_state = state_.find(p.unsafeGetTensorImpl());
|
||||
if (param_state == state_.end()) {
|
||||
buf = torch::clone(d_p).detach();
|
||||
auto state = std::make_unique<SGDParamState>();
|
||||
state->momentum_buffer(buf);
|
||||
state_[c10::guts::to_string(p.unsafeGetTensorImpl())] =
|
||||
std::move(state);
|
||||
state_[p.unsafeGetTensorImpl()] = std::move(state);
|
||||
} else {
|
||||
buf = static_cast<SGDParamState&>(*param_state->second)
|
||||
.momentum_buffer();
|
||||
|
@ -122,7 +122,7 @@ class TORCH_API SGD {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
std::vector<SGDParamGroup> param_groups_;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
ska::flat_hash_map<std::string, std::unique_ptr<SGDParamState>> state_;
|
||||
ska::flat_hash_map<void*, std::unique_ptr<SGDParamState>> state_;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
std::unique_ptr<SGDOptions> defaults_;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
|
@ -126,7 +126,7 @@ Node* insertDeQuant(
|
||||
Node* dequant = graph->create(Symbol::aten("dequantize"), {quantized_val});
|
||||
dequant->output()
|
||||
->setDebugName(
|
||||
original_val->debugName() + ".dequant." + c10::guts::to_string(id))
|
||||
original_val->debugName() + ".dequant." + std::to_string(id))
|
||||
->setType(original_val->type());
|
||||
graph->insertNode(dequant);
|
||||
return dequant;
|
||||
|
@ -305,7 +305,7 @@ void Pickler::pushStorageOfTensor(const at::Tensor& tensor) {
|
||||
// root_key
|
||||
std::string root_key = get_tensor_id_ != nullptr
|
||||
? get_tensor_id_(tensor)
|
||||
: c10::to_string(tensor_data_.size());
|
||||
: std::to_string(tensor_data_.size());
|
||||
pushString(root_key);
|
||||
// location
|
||||
pushString(tensor.device().str());
|
||||
|
@ -362,7 +362,7 @@ struct PythonPrintImpl {
|
||||
std::string name = candidate;
|
||||
while (used.count(name) || reserved_names.count(name)) {
|
||||
// NOLINTNEXTLINE(performance-inefficient-string-concatenation)
|
||||
name = candidate + c10::to_string(next_id[name]++);
|
||||
name = candidate + std::to_string(next_id[name]++);
|
||||
}
|
||||
used.insert(name);
|
||||
return name;
|
||||
|
Reference in New Issue
Block a user