mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Eliminate c10::guts::to_string (#108480)
This PR replace c10::guts::to_string with std::to_string. The major part of changes is using void* as optimizer state key since string is used only for serialization and using pointers as hashing keys is more efficient than a string. Some other guts functions in the affected source files are also replaced. Pull Request resolved: https://github.com/pytorch/pytorch/pull/108480 Approved by: https://github.com/Skylion007
This commit is contained in:
@ -1,14 +1,13 @@
|
|||||||
#include <ATen/MapAllocator.h>
|
#include <ATen/MapAllocator.h>
|
||||||
|
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <string>
|
|
||||||
#include <random>
|
#include <random>
|
||||||
|
#include <string>
|
||||||
#if ATOMIC_INT_LOCK_FREE == 2
|
#if ATOMIC_INT_LOCK_FREE == 2
|
||||||
#define AT_ATOMIC_IPC_REFCOUNT 1
|
#define AT_ATOMIC_IPC_REFCOUNT 1
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include <c10/core/CPUAllocator.h>
|
#include <c10/core/CPUAllocator.h>
|
||||||
#include <c10/util/C++17.h>
|
|
||||||
#include <c10/util/Unicode.h>
|
#include <c10/util/Unicode.h>
|
||||||
|
|
||||||
/* stuff for mapped files */
|
/* stuff for mapped files */
|
||||||
@ -17,9 +16,9 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(HAVE_MMAP)
|
#if defined(HAVE_MMAP)
|
||||||
|
#include <fcntl.h>
|
||||||
#include <sys/mman.h>
|
#include <sys/mman.h>
|
||||||
#include <sys/stat.h>
|
#include <sys/stat.h>
|
||||||
#include <fcntl.h>
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if !defined(_MSC_VER) || defined(HAVE_MMAP)
|
#if !defined(_MSC_VER) || defined(HAVE_MMAP)
|
||||||
@ -28,28 +27,29 @@
|
|||||||
#elif defined(_MSC_VER)
|
#elif defined(_MSC_VER)
|
||||||
#include <c10/util/win32-headers.h>
|
#include <c10/util/win32-headers.h>
|
||||||
#endif
|
#endif
|
||||||
|
#include <fmt/format.h>
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
|
|
||||||
static constexpr int64_t map_alloc_alignment = 64;
|
static constexpr int64_t map_alloc_alignment = 64;
|
||||||
|
|
||||||
TORCH_API std::string NewProcessWideShmHandle()
|
std::string NewProcessWideShmHandle() {
|
||||||
{
|
|
||||||
static std::atomic<uint64_t> counter{0};
|
static std::atomic<uint64_t> counter{0};
|
||||||
static std::random_device rd;
|
static std::random_device rd;
|
||||||
std::string handle = "/torch_";
|
|
||||||
#ifdef _MSC_VER
|
#ifdef _MSC_VER
|
||||||
handle += c10::guts::to_string(GetCurrentProcessId());
|
return fmt::format(
|
||||||
|
"/torch_{}_{}_{}",
|
||||||
|
GetCurrentProcessId(),
|
||||||
|
rd(),
|
||||||
|
counter.fetch_add(1, std::memory_order_relaxed));
|
||||||
#else
|
#else
|
||||||
handle += c10::guts::to_string(getpid());
|
return fmt::format(
|
||||||
|
"/torch_{}_{}_{}",
|
||||||
|
getpid(),
|
||||||
|
rd(),
|
||||||
|
counter.fetch_add(1, std::memory_order_relaxed));
|
||||||
#endif
|
#endif
|
||||||
handle += "_";
|
|
||||||
handle += c10::guts::to_string(rd());
|
|
||||||
handle += "_";
|
|
||||||
handle += c10::guts::to_string(counter.fetch_add(1, std::memory_order_relaxed));
|
|
||||||
return handle;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(_WIN32) || defined(HAVE_MMAP)
|
#if defined(_WIN32) || defined(HAVE_MMAP)
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -907,7 +907,7 @@ TEST(OperatorRegistrationTestLegacyFunctionBasedKernel, givenKernelWithOptionalI
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string concatKernel(const Tensor& tensor1, std::string a, const std::string& b, int64_t c) {
|
std::string concatKernel(const Tensor& tensor1, std::string a, const std::string& b, int64_t c) {
|
||||||
return a + b + c10::guts::to_string(c);
|
return a + b + std::to_string(c);
|
||||||
}
|
}
|
||||||
|
|
||||||
void expectCallsConcatUnboxed(DispatchKey dispatch_key) {
|
void expectCallsConcatUnboxed(DispatchKey dispatch_key) {
|
||||||
|
@ -649,7 +649,7 @@ TEST(OperatorRegistrationTestFunctionBasedKernel, givenKernelWithOptionalInputs_
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string concatKernel(const Tensor& tensor1, std::string a, const std::string& b, int64_t c) {
|
std::string concatKernel(const Tensor& tensor1, std::string a, const std::string& b, int64_t c) {
|
||||||
return a + b + c10::guts::to_string(c);
|
return a + b + std::to_string(c);
|
||||||
}
|
}
|
||||||
|
|
||||||
void expectCallsConcatUnboxed(DispatchKey dispatch_key) {
|
void expectCallsConcatUnboxed(DispatchKey dispatch_key) {
|
||||||
|
@ -854,7 +854,7 @@ void expectCallsConcatUnboxed(DispatchKey dispatch_key) {
|
|||||||
TEST(OperatorRegistrationTestLegacyLambdaBasedKernel, givenKernel_whenRegistered_thenCanBeCalledUnboxed) {
|
TEST(OperatorRegistrationTestLegacyLambdaBasedKernel, givenKernel_whenRegistered_thenCanBeCalledUnboxed) {
|
||||||
std::string prefix = "prefix";
|
std::string prefix = "prefix";
|
||||||
auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, str a, str b, int c) -> str", [&] (const Tensor& tensor1, std::string a, const std::string& b, int64_t c) {
|
auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, str a, str b, int c) -> str", [&] (const Tensor& tensor1, std::string a, const std::string& b, int64_t c) {
|
||||||
return prefix + a + b + c10::guts::to_string(c);
|
return prefix + a + b + std::to_string(c);
|
||||||
});
|
});
|
||||||
expectCallsConcatUnboxed(DispatchKey::CPU);
|
expectCallsConcatUnboxed(DispatchKey::CPU);
|
||||||
}
|
}
|
||||||
|
@ -576,7 +576,7 @@ void expectCallsConcatUnboxed(DispatchKey dispatch_key) {
|
|||||||
TEST(OperatorRegistrationTestLambdaBasedKernel, givenKernel_whenRegistered_thenCanBeCalledUnboxed) {
|
TEST(OperatorRegistrationTestLambdaBasedKernel, givenKernel_whenRegistered_thenCanBeCalledUnboxed) {
|
||||||
auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, str a, str b, int c) -> str", torch::RegisterOperators::options()
|
auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, str a, str b, int c) -> str", torch::RegisterOperators::options()
|
||||||
.kernel(DispatchKey::CPU, [] (const Tensor& tensor1, std::string a, const std::string& b, int64_t c) {
|
.kernel(DispatchKey::CPU, [] (const Tensor& tensor1, std::string a, const std::string& b, int64_t c) {
|
||||||
return a + b + c10::guts::to_string(c);
|
return a + b + std::to_string(c);
|
||||||
}));
|
}));
|
||||||
expectCallsConcatUnboxed(DispatchKey::CPU);
|
expectCallsConcatUnboxed(DispatchKey::CPU);
|
||||||
}
|
}
|
||||||
|
@ -787,7 +787,7 @@ struct ConcatKernel final : OperatorKernel {
|
|||||||
explicit ConcatKernel(std::string prefix): prefix_(std::move(prefix)) {}
|
explicit ConcatKernel(std::string prefix): prefix_(std::move(prefix)) {}
|
||||||
|
|
||||||
std::string operator()(const Tensor& tensor1, std::string a, const std::string& b, int64_t c) {
|
std::string operator()(const Tensor& tensor1, std::string a, const std::string& b, int64_t c) {
|
||||||
return prefix_ + a + b + c10::guts::to_string(c);
|
return prefix_ + a + b + std::to_string(c);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string prefix_;
|
std::string prefix_;
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
#include <c10/util/MaybeOwned.h>
|
#include <c10/util/MaybeOwned.h>
|
||||||
#include <c10/util/intrusive_ptr.h>
|
#include <c10/util/intrusive_ptr.h>
|
||||||
#include <typeindex>
|
#include <typeindex>
|
||||||
|
#include <type_traits>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@ -510,17 +511,17 @@ public:
|
|||||||
template <
|
template <
|
||||||
typename... Args,
|
typename... Args,
|
||||||
std::enable_if_t<
|
std::enable_if_t<
|
||||||
!guts::disjunction<
|
!std::disjunction<
|
||||||
std::is_lvalue_reference<Args>...,
|
std::is_lvalue_reference<Args>...,
|
||||||
guts::negation<std::is_constructible<IValue, Args>>...>::value,
|
std::negation<std::is_constructible<IValue, Args>>...>::value,
|
||||||
std::nullptr_t> = nullptr>
|
std::nullptr_t> = nullptr>
|
||||||
IValue(const std::tuple<Args...>& t);
|
IValue(const std::tuple<Args...>& t);
|
||||||
template <
|
template <
|
||||||
typename... Args,
|
typename... Args,
|
||||||
std::enable_if_t<
|
std::enable_if_t<
|
||||||
!guts::disjunction<
|
!std::disjunction<
|
||||||
std::is_lvalue_reference<Args>...,
|
std::is_lvalue_reference<Args>...,
|
||||||
guts::negation<std::is_constructible<IValue, Args>>...>::value,
|
std::negation<std::is_constructible<IValue, Args>>...>::value,
|
||||||
std::nullptr_t> = nullptr>
|
std::nullptr_t> = nullptr>
|
||||||
IValue(std::tuple<Args...>&& t);
|
IValue(std::tuple<Args...>&& t);
|
||||||
bool isTuple() const {
|
bool isTuple() const {
|
||||||
@ -981,7 +982,7 @@ public:
|
|||||||
TORCH_FORALL_TAGS(DEFINE_CASE)
|
TORCH_FORALL_TAGS(DEFINE_CASE)
|
||||||
#undef DEFINE_CASE
|
#undef DEFINE_CASE
|
||||||
}
|
}
|
||||||
return "InvalidTag(" + c10::guts::to_string(static_cast<int>(tag)) + ")";
|
return "InvalidTag(" + std::to_string(static_cast<int>(tag)) + ")";
|
||||||
}
|
}
|
||||||
|
|
||||||
// generic v.to<at::Tensor>() implementations
|
// generic v.to<at::Tensor>() implementations
|
||||||
|
@ -1048,7 +1048,7 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
|
|||||||
using IValueWithStorages = std::tuple<IValue, std::vector<WeakStorage>>;
|
using IValueWithStorages = std::tuple<IValue, std::vector<WeakStorage>>;
|
||||||
#if __cpp_lib_is_invocable >= 201703
|
#if __cpp_lib_is_invocable >= 201703
|
||||||
static_assert(
|
static_assert(
|
||||||
guts::disjunction<
|
std::disjunction<
|
||||||
std::is_invocable_r<IValue, T, Future&>,
|
std::is_invocable_r<IValue, T, Future&>,
|
||||||
std::is_invocable_r<IValueWithStorages, T, Future&>>::value,
|
std::is_invocable_r<IValueWithStorages, T, Future&>>::value,
|
||||||
"The callback must have signature IValue(Future&) or "
|
"The callback must have signature IValue(Future&) or "
|
||||||
@ -1918,9 +1918,9 @@ template <
|
|||||||
typename... Args,
|
typename... Args,
|
||||||
typename Indices = std::make_index_sequence<sizeof...(Args)>,
|
typename Indices = std::make_index_sequence<sizeof...(Args)>,
|
||||||
std::enable_if_t<
|
std::enable_if_t<
|
||||||
!guts::disjunction<
|
!std::disjunction<
|
||||||
std::is_lvalue_reference<Args>...,
|
std::is_lvalue_reference<Args>...,
|
||||||
guts::negation<std::is_constructible<IValue, Args>>...>::value,
|
std::negation<std::is_constructible<IValue, Args>>...>::value,
|
||||||
std::nullptr_t> = nullptr>
|
std::nullptr_t> = nullptr>
|
||||||
std::tuple<Args...> generic_to(const IValue& ivalue, _fake_type<std::tuple<Args...>>) {
|
std::tuple<Args...> generic_to(const IValue& ivalue, _fake_type<std::tuple<Args...>>) {
|
||||||
const auto& vals = ivalue.toTupleRef().elements();
|
const auto& vals = ivalue.toTupleRef().elements();
|
||||||
@ -2098,9 +2098,9 @@ inline IValue::IValue(c10::intrusive_ptr<ivalue::Tuple> v)
|
|||||||
template <
|
template <
|
||||||
typename... Args,
|
typename... Args,
|
||||||
std::enable_if_t<
|
std::enable_if_t<
|
||||||
!guts::disjunction<
|
!std::disjunction<
|
||||||
std::is_lvalue_reference<Args>...,
|
std::is_lvalue_reference<Args>...,
|
||||||
guts::negation<std::is_constructible<IValue, Args>>...>::value,
|
std::negation<std::is_constructible<IValue, Args>>...>::value,
|
||||||
std::nullptr_t>>
|
std::nullptr_t>>
|
||||||
inline IValue::IValue(const std::tuple<Args...>& t)
|
inline IValue::IValue(const std::tuple<Args...>& t)
|
||||||
: IValue(c10::guts::apply(c10::ivalue::Tuple::create<const Args&...>, t)) {
|
: IValue(c10::guts::apply(c10::ivalue::Tuple::create<const Args&...>, t)) {
|
||||||
@ -2109,9 +2109,9 @@ inline IValue::IValue(const std::tuple<Args...>& t)
|
|||||||
template <
|
template <
|
||||||
typename... Args,
|
typename... Args,
|
||||||
std::enable_if_t<
|
std::enable_if_t<
|
||||||
!guts::disjunction<
|
!std::disjunction<
|
||||||
std::is_lvalue_reference<Args>...,
|
std::is_lvalue_reference<Args>...,
|
||||||
guts::negation<std::is_constructible<IValue, Args>>...>::value,
|
std::negation<std::is_constructible<IValue, Args>>...>::value,
|
||||||
std::nullptr_t>>
|
std::nullptr_t>>
|
||||||
inline IValue::IValue(std::tuple<Args...>&& t)
|
inline IValue::IValue(std::tuple<Args...>&& t)
|
||||||
: IValue(c10::guts::apply(c10::ivalue::Tuple::create<Args&&...>, std::move(t))) {
|
: IValue(c10::guts::apply(c10::ivalue::Tuple::create<Args&&...>, std::move(t))) {
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <ATen/core/op_registration/infer_schema.h>
|
#include <ATen/core/op_registration/infer_schema.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
#include <sstream>
|
#include <fmt/format.h>
|
||||||
|
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
|
|
||||||
@ -8,46 +8,55 @@ namespace detail {
|
|||||||
namespace infer_schema {
|
namespace infer_schema {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
std::string fastToString(size_t x) {
|
|
||||||
if (C10_LIKELY(x < 10)) {
|
|
||||||
std::string result;
|
|
||||||
result.push_back('_');
|
|
||||||
result.push_back('0' + x);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
return "_" + c10::guts::to_string(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<Argument> createArgumentVector(c10::ArrayRef<ArgumentDef> args) {
|
std::vector<Argument> createArgumentVector(c10::ArrayRef<ArgumentDef> args) {
|
||||||
std::vector<Argument> result;
|
std::vector<Argument> result;
|
||||||
result.reserve(args.size());
|
result.reserve(args.size());
|
||||||
for (const auto i : c10::irange(args.size())) {
|
for (const auto i : c10::irange(args.size())) {
|
||||||
// Arguments are named "_<index>"
|
// Arguments are named "_<index>"
|
||||||
result.emplace_back(fastToString(i), (*args[i].getFakeTypeFn)(), (*args[i].getTypeFn)());
|
result.emplace_back(
|
||||||
|
fmt::format("_{}", i),
|
||||||
|
(*args[i].getFakeTypeFn)(),
|
||||||
|
(*args[i].getTypeFn)());
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
}
|
} // namespace
|
||||||
// This is intentionally a separate function and in a .cpp file
|
// This is intentionally a separate function and in a .cpp file
|
||||||
// because then the template is smaller and that benefits binary size
|
// because then the template is smaller and that benefits binary size
|
||||||
FunctionSchema make_function_schema(std::string&& name, std::string&& overload_name, c10::ArrayRef<ArgumentDef> arguments, c10::ArrayRef<ArgumentDef> returns) {
|
FunctionSchema make_function_schema(
|
||||||
return FunctionSchema(std::move(name), std::move(overload_name), createArgumentVector(arguments), createArgumentVector(returns));
|
std::string&& name,
|
||||||
|
std::string&& overload_name,
|
||||||
|
c10::ArrayRef<ArgumentDef> arguments,
|
||||||
|
c10::ArrayRef<ArgumentDef> returns) {
|
||||||
|
return FunctionSchema(
|
||||||
|
std::move(name),
|
||||||
|
std::move(overload_name),
|
||||||
|
createArgumentVector(arguments),
|
||||||
|
createArgumentVector(returns));
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionSchema make_function_schema(c10::ArrayRef<ArgumentDef> arguments, c10::ArrayRef<ArgumentDef> returns) {
|
FunctionSchema make_function_schema(
|
||||||
|
c10::ArrayRef<ArgumentDef> arguments,
|
||||||
|
c10::ArrayRef<ArgumentDef> returns) {
|
||||||
return make_function_schema("", "", arguments, returns);
|
return make_function_schema("", "", arguments, returns);
|
||||||
}
|
}
|
||||||
}
|
} // namespace infer_schema
|
||||||
}
|
} // namespace detail
|
||||||
|
|
||||||
c10::optional<std::string> findSchemaDifferences(const FunctionSchema& lhs, const FunctionSchema& rhs) {
|
c10::optional<std::string> findSchemaDifferences(
|
||||||
|
const FunctionSchema& lhs,
|
||||||
|
const FunctionSchema& rhs) {
|
||||||
if (lhs.arguments().size() != rhs.arguments().size()) {
|
if (lhs.arguments().size() != rhs.arguments().size()) {
|
||||||
return "The number of arguments is different. " + guts::to_string(lhs.arguments().size()) +
|
return fmt::format(
|
||||||
" vs " + guts::to_string(rhs.arguments().size()) + ".";
|
"The number of arguments is different. {} vs {}.",
|
||||||
|
lhs.arguments().size(),
|
||||||
|
rhs.arguments().size());
|
||||||
}
|
}
|
||||||
if (lhs.returns().size() != rhs.returns().size()) {
|
if (lhs.returns().size() != rhs.returns().size()) {
|
||||||
return "The number of returns is different. " + guts::to_string(lhs.returns().size()) +
|
return fmt::format(
|
||||||
" vs " + guts::to_string(rhs.returns().size());
|
"The number of returns is different. {} vs {}.",
|
||||||
|
lhs.returns().size(),
|
||||||
|
rhs.returns().size());
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const auto i : c10::irange(lhs.arguments().size())) {
|
for (const auto i : c10::irange(lhs.arguments().size())) {
|
||||||
@ -57,8 +66,11 @@ c10::optional<std::string> findSchemaDifferences(const FunctionSchema& lhs, cons
|
|||||||
// cheaper, particularly when one of the types is a singleton like
|
// cheaper, particularly when one of the types is a singleton like
|
||||||
// NumberType or AnyType.
|
// NumberType or AnyType.
|
||||||
if (leftType.get() != rightType.get() && *leftType != *rightType) {
|
if (leftType.get() != rightType.get() && *leftType != *rightType) {
|
||||||
return "Type mismatch in argument " + guts::to_string(i+1) + ": " + lhs.arguments()[i].type()->str() +
|
return fmt::format(
|
||||||
" vs " + rhs.arguments()[i].type()->str();
|
"Type mismatch in argument {}: {} vs {}.",
|
||||||
|
i + 1,
|
||||||
|
lhs.arguments()[i].type()->str(),
|
||||||
|
rhs.arguments()[i].type()->str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -67,8 +79,11 @@ c10::optional<std::string> findSchemaDifferences(const FunctionSchema& lhs, cons
|
|||||||
const TypePtr& rightType = rhs.returns()[i].type();
|
const TypePtr& rightType = rhs.returns()[i].type();
|
||||||
// See above about comparing pointers first.
|
// See above about comparing pointers first.
|
||||||
if (leftType.get() != rightType.get() && *leftType != *rightType) {
|
if (leftType.get() != rightType.get() && *leftType != *rightType) {
|
||||||
return "Type mismatch in return " + guts::to_string(i+1) + ": " + lhs.returns()[i].type()->str() +
|
return fmt::format(
|
||||||
" vs " + rhs.returns()[i].type()->str();
|
"Type mismatch in return {}: {} vs {}.",
|
||||||
|
i + 1,
|
||||||
|
lhs.returns()[i].type()->str(),
|
||||||
|
rhs.returns()[i].type()->str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -76,4 +91,4 @@ c10::optional<std::string> findSchemaDifferences(const FunctionSchema& lhs, cons
|
|||||||
return c10::nullopt;
|
return c10::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
} // namespace c10
|
||||||
|
@ -214,7 +214,7 @@ TEST(OrderedPreservingDictTest, test_range_erase) {
|
|||||||
const int64_t nb_values = 1000;
|
const int64_t nb_values = 1000;
|
||||||
HMap map;
|
HMap map;
|
||||||
for (const auto i : c10::irange(nb_values)) {
|
for (const auto i : c10::irange(nb_values)) {
|
||||||
map[c10::guts::to_string(i)] = i;
|
map[std::to_string(i)] = i;
|
||||||
auto begin = map.begin();
|
auto begin = map.begin();
|
||||||
for (int64_t j = 0; j <= i; ++j, begin++) {
|
for (int64_t j = 0; j <= i; ++j, begin++) {
|
||||||
TORCH_INTERNAL_ASSERT(begin->second == j);
|
TORCH_INTERNAL_ASSERT(begin->second == j);
|
||||||
@ -239,8 +239,7 @@ TEST(OrderedPreservingDictTest, test_range_erase) {
|
|||||||
if (i >= 10 && i < 220) {
|
if (i >= 10 && i < 220) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto exp_it =
|
auto exp_it = std::pair<std::string, std::int64_t>(std::to_string(i), i);
|
||||||
std::pair<std::string, std::int64_t>(c10::guts::to_string(i), i);
|
|
||||||
TORCH_INTERNAL_ASSERT(*it == exp_it);
|
TORCH_INTERNAL_ASSERT(*it == exp_it);
|
||||||
++it;
|
++it;
|
||||||
}
|
}
|
||||||
@ -313,13 +312,13 @@ TEST(OrderedPreservingDictTest, test_copy_constructor_and_operator) {
|
|||||||
const std::size_t nb_values = 100;
|
const std::size_t nb_values = 100;
|
||||||
HMap map;
|
HMap map;
|
||||||
for (const auto i : c10::irange(nb_values)) {
|
for (const auto i : c10::irange(nb_values)) {
|
||||||
map[c10::guts::to_string(i)] = c10::guts::to_string(i);
|
map[std::to_string(i)] = std::to_string(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
HMap map_copy = map;
|
HMap map_copy = map;
|
||||||
HMap map_copy2(map);
|
HMap map_copy2(map);
|
||||||
HMap map_copy3;
|
HMap map_copy3;
|
||||||
map_copy3[c10::guts::to_string(0)] = c10::guts::to_string(0);
|
map_copy3[std::to_string(0)] = std::to_string(0);
|
||||||
|
|
||||||
map_copy3 = map;
|
map_copy3 = map;
|
||||||
|
|
||||||
|
@ -15,7 +15,6 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <c10/util/C++17.h>
|
|
||||||
#include <c10/util/Deprecated.h>
|
#include <c10/util/Deprecated.h>
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
#include <c10/util/SmallVector.h>
|
#include <c10/util/SmallVector.h>
|
||||||
|
@ -233,50 +233,6 @@ struct function_takes_identity_argument<
|
|||||||
#endif
|
#endif
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
// GCC 4.8 doesn't define std::to_string, even though that's in C++11. Let's
|
|
||||||
// define it.
|
|
||||||
namespace detail {
|
|
||||||
class DummyClassForToString final {};
|
|
||||||
} // namespace detail
|
|
||||||
} // namespace guts
|
|
||||||
} // namespace c10
|
|
||||||
namespace std {
|
|
||||||
// We use SFINAE to detect if std::to_string exists for a type, but that only
|
|
||||||
// works if the function name is defined. So let's define a std::to_string for a
|
|
||||||
// dummy type. If you're getting an error here saying that this overload doesn't
|
|
||||||
// match your std::to_string() call, then you're calling std::to_string() but
|
|
||||||
// should be calling c10::guts::to_string().
|
|
||||||
inline std::string to_string(c10::guts::detail::DummyClassForToString) {
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace std
|
|
||||||
namespace c10 {
|
|
||||||
namespace guts {
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
template <class T, class Enable = void>
|
|
||||||
struct to_string_ final {
|
|
||||||
static std::string call(T value) {
|
|
||||||
std::ostringstream str;
|
|
||||||
str << value;
|
|
||||||
return str.str();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
// If a std::to_string exists, use that instead
|
|
||||||
template <class T>
|
|
||||||
struct to_string_<T, void_t<decltype(std::to_string(std::declval<T>()))>>
|
|
||||||
final {
|
|
||||||
static std::string call(T value) {
|
|
||||||
return std::to_string(value);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace detail
|
|
||||||
template <class T>
|
|
||||||
inline std::string to_string(T value) {
|
|
||||||
return detail::to_string_<T>::call(value);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace guts
|
} // namespace guts
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
|
||||||
|
@ -182,8 +182,7 @@ TEST(OptimTest, OptimizerAccessors) {
|
|||||||
|
|
||||||
// test for state() with non-const reference return
|
// test for state() with non-const reference return
|
||||||
auto& state_ = static_cast<AdagradParamState&>(
|
auto& state_ = static_cast<AdagradParamState&>(
|
||||||
*(optimizer
|
*(optimizer.state()[params_1[0].unsafeGetTensorImpl()]));
|
||||||
.state()[c10::guts::to_string(params_1[0].unsafeGetTensorImpl())]));
|
|
||||||
state_.step(state_.step() + 1);
|
state_.step(state_.step() + 1);
|
||||||
|
|
||||||
const auto& optimizer_ = Adagrad(params, options);
|
const auto& optimizer_ = Adagrad(params, options);
|
||||||
|
@ -54,9 +54,9 @@ void is_optimizer_param_group_equal(
|
|||||||
|
|
||||||
template <typename DerivedOptimizerParamState>
|
template <typename DerivedOptimizerParamState>
|
||||||
void is_optimizer_state_equal(
|
void is_optimizer_state_equal(
|
||||||
const ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>&
|
const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
|
||||||
lhs_state,
|
lhs_state,
|
||||||
const ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>&
|
const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
|
||||||
rhs_state) {
|
rhs_state) {
|
||||||
ASSERT_TRUE(lhs_state.size() == rhs_state.size());
|
ASSERT_TRUE(lhs_state.size() == rhs_state.size());
|
||||||
for (const auto& value : lhs_state) {
|
for (const auto& value : lhs_state) {
|
||||||
@ -188,7 +188,7 @@ void write_tensors_to_archive(
|
|||||||
key + "/size", torch::tensor(static_cast<int64_t>(buffers.size())));
|
key + "/size", torch::tensor(static_cast<int64_t>(buffers.size())));
|
||||||
for (const auto index : c10::irange(buffers.size())) {
|
for (const auto index : c10::irange(buffers.size())) {
|
||||||
archive.write(
|
archive.write(
|
||||||
key + "/" + c10::to_string(index), buffers[index], /*is_buffer=*/true);
|
key + "/" + std::to_string(index), buffers[index], /*is_buffer=*/true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -218,7 +218,7 @@ TEST(SerializeTest, KeysFunc) {
|
|||||||
torch::serialize::OutputArchive output_archive;
|
torch::serialize::OutputArchive output_archive;
|
||||||
for (const auto i : c10::irange(3)) {
|
for (const auto i : c10::irange(3)) {
|
||||||
output_archive.write(
|
output_archive.write(
|
||||||
"element/" + c10::to_string(i), c10::IValue(static_cast<int64_t>(i)));
|
"element/" + std::to_string(i), c10::IValue(static_cast<int64_t>(i)));
|
||||||
}
|
}
|
||||||
output_archive.save_to(tempfile.name);
|
output_archive.save_to(tempfile.name);
|
||||||
torch::serialize::InputArchive input_archive;
|
torch::serialize::InputArchive input_archive;
|
||||||
@ -226,7 +226,7 @@ TEST(SerializeTest, KeysFunc) {
|
|||||||
std::vector<std::string> keys = input_archive.keys();
|
std::vector<std::string> keys = input_archive.keys();
|
||||||
ASSERT_EQ(keys.size(), 3);
|
ASSERT_EQ(keys.size(), 3);
|
||||||
for (const auto i : c10::irange(keys.size())) {
|
for (const auto i : c10::irange(keys.size())) {
|
||||||
ASSERT_EQ(keys[i], "element/" + c10::to_string(i));
|
ASSERT_EQ(keys[i], "element/" + std::to_string(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -235,7 +235,7 @@ TEST(SerializeTest, TryReadFunc) {
|
|||||||
torch::serialize::OutputArchive output_archive;
|
torch::serialize::OutputArchive output_archive;
|
||||||
for (const auto i : c10::irange(3)) {
|
for (const auto i : c10::irange(3)) {
|
||||||
output_archive.write(
|
output_archive.write(
|
||||||
"element/" + c10::to_string(i), c10::IValue(static_cast<int64_t>(i)));
|
"element/" + std::to_string(i), c10::IValue(static_cast<int64_t>(i)));
|
||||||
}
|
}
|
||||||
output_archive.save_to(tempfile.name);
|
output_archive.save_to(tempfile.name);
|
||||||
torch::serialize::InputArchive input_archive;
|
torch::serialize::InputArchive input_archive;
|
||||||
@ -557,7 +557,7 @@ TEST(SerializeTest, Optim_Adagrad) {
|
|||||||
const auto& params_ = optim1.param_groups()[0].params();
|
const auto& params_ = optim1.param_groups()[0].params();
|
||||||
const auto& optim1_state = optim1.state();
|
const auto& optim1_state = optim1.state();
|
||||||
for (const auto& param : params_) {
|
for (const auto& param : params_) {
|
||||||
auto key_ = c10::guts::to_string(param.unsafeGetTensorImpl());
|
auto key_ = param.unsafeGetTensorImpl();
|
||||||
const AdagradParamState& curr_state_ =
|
const AdagradParamState& curr_state_ =
|
||||||
static_cast<const AdagradParamState&>(*(optim1_state.at(key_).get()));
|
static_cast<const AdagradParamState&>(*(optim1_state.at(key_).get()));
|
||||||
sum_buffers.emplace_back(curr_state_.sum());
|
sum_buffers.emplace_back(curr_state_.sum());
|
||||||
@ -602,7 +602,7 @@ TEST(SerializeTest, Optim_SGD) {
|
|||||||
const auto& optim1_state = optim1.state();
|
const auto& optim1_state = optim1.state();
|
||||||
for (const auto i : c10::irange(params_.size())) {
|
for (const auto i : c10::irange(params_.size())) {
|
||||||
if (i != (params_.size() - 1)) {
|
if (i != (params_.size() - 1)) {
|
||||||
auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl());
|
auto key_ = params_[i].unsafeGetTensorImpl();
|
||||||
const SGDParamState& curr_state_ =
|
const SGDParamState& curr_state_ =
|
||||||
static_cast<const SGDParamState&>(*(optim1_state.at(key_).get()));
|
static_cast<const SGDParamState&>(*(optim1_state.at(key_).get()));
|
||||||
momentum_buffers.emplace_back(curr_state_.momentum_buffer());
|
momentum_buffers.emplace_back(curr_state_.momentum_buffer());
|
||||||
@ -653,7 +653,7 @@ TEST(SerializeTest, Optim_Adam) {
|
|||||||
const auto& optim1_state = optim1.state();
|
const auto& optim1_state = optim1.state();
|
||||||
for (const auto i : c10::irange(params_.size())) {
|
for (const auto i : c10::irange(params_.size())) {
|
||||||
if (i != (params_.size() - 1)) {
|
if (i != (params_.size() - 1)) {
|
||||||
auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl());
|
auto key_ = params_[i].unsafeGetTensorImpl();
|
||||||
const AdamParamState& curr_state_ =
|
const AdamParamState& curr_state_ =
|
||||||
static_cast<const AdamParamState&>(*(optim1_state.at(key_).get()));
|
static_cast<const AdamParamState&>(*(optim1_state.at(key_).get()));
|
||||||
step_buffers.emplace_back(curr_state_.step());
|
step_buffers.emplace_back(curr_state_.step());
|
||||||
@ -712,7 +712,7 @@ TEST(SerializeTest, Optim_AdamW) {
|
|||||||
const auto& optim1_state = optim1.state();
|
const auto& optim1_state = optim1.state();
|
||||||
for (const auto i : c10::irange(params_.size())) {
|
for (const auto i : c10::irange(params_.size())) {
|
||||||
if (i != (params_.size() - 1)) {
|
if (i != (params_.size() - 1)) {
|
||||||
auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl());
|
auto key_ = params_[i].unsafeGetTensorImpl();
|
||||||
const AdamWParamState& curr_state_ =
|
const AdamWParamState& curr_state_ =
|
||||||
static_cast<const AdamWParamState&>(*(optim1_state.at(key_).get()));
|
static_cast<const AdamWParamState&>(*(optim1_state.at(key_).get()));
|
||||||
step_buffers.emplace_back(curr_state_.step());
|
step_buffers.emplace_back(curr_state_.step());
|
||||||
@ -769,7 +769,7 @@ TEST(SerializeTest, Optim_RMSprop) {
|
|||||||
const auto& optim1_state = optim1.state();
|
const auto& optim1_state = optim1.state();
|
||||||
for (const auto i : c10::irange(params_.size())) {
|
for (const auto i : c10::irange(params_.size())) {
|
||||||
if (i != (params_.size() - 1)) {
|
if (i != (params_.size() - 1)) {
|
||||||
auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl());
|
auto key_ = params_[i].unsafeGetTensorImpl();
|
||||||
const RMSpropParamState& curr_state_ =
|
const RMSpropParamState& curr_state_ =
|
||||||
static_cast<const RMSpropParamState&>(*(optim1_state.at(key_).get()));
|
static_cast<const RMSpropParamState&>(*(optim1_state.at(key_).get()));
|
||||||
square_average_buffers.emplace_back(curr_state_.square_avg());
|
square_average_buffers.emplace_back(curr_state_.square_avg());
|
||||||
@ -799,8 +799,8 @@ TEST(SerializeTest, Optim_RMSprop) {
|
|||||||
// old RMSprop didn't track step value
|
// old RMSprop didn't track step value
|
||||||
for (const auto i : c10::irange(params1_2_.size())) {
|
for (const auto i : c10::irange(params1_2_.size())) {
|
||||||
if (i != (params1_2_.size() - 1)) {
|
if (i != (params1_2_.size() - 1)) {
|
||||||
auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl());
|
auto key_ = params_[i].unsafeGetTensorImpl();
|
||||||
auto key1_2_ = c10::guts::to_string(params1_2_[i].unsafeGetTensorImpl());
|
auto key1_2_ = params1_2_[i].unsafeGetTensorImpl();
|
||||||
const RMSpropParamState& curr_state_ =
|
const RMSpropParamState& curr_state_ =
|
||||||
static_cast<const RMSpropParamState&>(*(optim1_state.at(key_).get()));
|
static_cast<const RMSpropParamState&>(*(optim1_state.at(key_).get()));
|
||||||
RMSpropParamState& curr_state1_2_ =
|
RMSpropParamState& curr_state1_2_ =
|
||||||
@ -838,7 +838,7 @@ TEST(SerializeTest, Optim_LBFGS) {
|
|||||||
std::deque<at::Tensor> old_dirs, old_stps;
|
std::deque<at::Tensor> old_dirs, old_stps;
|
||||||
|
|
||||||
const auto& params_ = optim1.param_groups()[0].params();
|
const auto& params_ = optim1.param_groups()[0].params();
|
||||||
auto key_ = c10::guts::to_string(params_[0].unsafeGetTensorImpl());
|
auto key_ = params_[0].unsafeGetTensorImpl();
|
||||||
const auto& optim1_state =
|
const auto& optim1_state =
|
||||||
static_cast<const LBFGSParamState&>(*(optim1.state().at(key_).get()));
|
static_cast<const LBFGSParamState&>(*(optim1.state().at(key_).get()));
|
||||||
d = optim1_state.d();
|
d = optim1_state.d();
|
||||||
@ -865,7 +865,7 @@ TEST(SerializeTest, Optim_LBFGS) {
|
|||||||
torch::load, optim1_2, optim_tempfile_old_format.name);
|
torch::load, optim1_2, optim_tempfile_old_format.name);
|
||||||
|
|
||||||
const auto& params1_2_ = optim1_2.param_groups()[0].params();
|
const auto& params1_2_ = optim1_2.param_groups()[0].params();
|
||||||
auto param_key = c10::guts::to_string(params1_2_[0].unsafeGetTensorImpl());
|
auto param_key = params1_2_[0].unsafeGetTensorImpl();
|
||||||
auto& optim1_2_state =
|
auto& optim1_2_state =
|
||||||
static_cast<LBFGSParamState&>(*(optim1_2.state().at(param_key).get()));
|
static_cast<LBFGSParamState&>(*(optim1_2.state().at(param_key).get()));
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
#include <torch/serialize/archive.h>
|
#include <torch/serialize/archive.h>
|
||||||
#include <torch/types.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -84,8 +85,7 @@ class TORCH_API Adagrad : public Optimizer {
|
|||||||
p.data(),
|
p.data(),
|
||||||
defaults.initial_accumulator_value(),
|
defaults.initial_accumulator_value(),
|
||||||
at::MemoryFormat::Preserve));
|
at::MemoryFormat::Preserve));
|
||||||
state_[c10::guts::to_string(p.unsafeGetTensorImpl())] =
|
state_[p.unsafeGetTensorImpl()] = std::move(state);
|
||||||
std::move(state);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -156,12 +156,12 @@ class TORCH_API Optimizer {
|
|||||||
const std::vector<OptimizerParamGroup>& param_groups() const noexcept;
|
const std::vector<OptimizerParamGroup>& param_groups() const noexcept;
|
||||||
|
|
||||||
/// Provides a reference to the state this optimizer holds
|
/// Provides a reference to the state this optimizer holds
|
||||||
ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>&
|
ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
|
||||||
state() noexcept;
|
state() noexcept;
|
||||||
|
|
||||||
/// Provides a const reference to the state this optimizer holds
|
/// Provides a const reference to the state this optimizer holds
|
||||||
const ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>&
|
const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>& state()
|
||||||
state() const noexcept;
|
const noexcept;
|
||||||
|
|
||||||
/// Serializes the optimizer state into the given `archive`.
|
/// Serializes the optimizer state into the given `archive`.
|
||||||
virtual void save(serialize::OutputArchive& archive) const;
|
virtual void save(serialize::OutputArchive& archive) const;
|
||||||
@ -171,7 +171,7 @@ class TORCH_API Optimizer {
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::vector<OptimizerParamGroup> param_groups_;
|
std::vector<OptimizerParamGroup> param_groups_;
|
||||||
ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>> state_;
|
ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> state_;
|
||||||
std::unique_ptr<OptimizerOptions> defaults_;
|
std::unique_ptr<OptimizerOptions> defaults_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -17,11 +17,12 @@ namespace detail {
|
|||||||
template <typename DerivedOptimizerParamState>
|
template <typename DerivedOptimizerParamState>
|
||||||
void serialize(
|
void serialize(
|
||||||
serialize::OutputArchive& archive,
|
serialize::OutputArchive& archive,
|
||||||
const ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>&
|
const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
|
||||||
state) {
|
state) {
|
||||||
for (const auto& item : state) {
|
for (const auto& item : state) {
|
||||||
serialize::OutputArchive param_state_archive(archive.compilation_unit());
|
serialize::OutputArchive param_state_archive(archive.compilation_unit());
|
||||||
std::string tensorimpl_key = item.first;
|
std::string tensorimpl_key =
|
||||||
|
std::to_string(reinterpret_cast<size_t>(item.first));
|
||||||
const DerivedOptimizerParamState& curr_state =
|
const DerivedOptimizerParamState& curr_state =
|
||||||
static_cast<const DerivedOptimizerParamState&>(*(item.second.get()));
|
static_cast<const DerivedOptimizerParamState&>(*(item.second.get()));
|
||||||
curr_state.serialize(param_state_archive);
|
curr_state.serialize(param_state_archive);
|
||||||
@ -33,15 +34,14 @@ void serialize(
|
|||||||
template <typename DerivedOptimizerParamState>
|
template <typename DerivedOptimizerParamState>
|
||||||
void serialize(
|
void serialize(
|
||||||
serialize::InputArchive& archive,
|
serialize::InputArchive& archive,
|
||||||
ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>&
|
ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>& state) {
|
||||||
state) {
|
|
||||||
std::vector<std::string> tensorimpl_keys = archive.keys();
|
std::vector<std::string> tensorimpl_keys = archive.keys();
|
||||||
for (const std::string& tensorimpl_key : tensorimpl_keys) {
|
for (const std::string& tensorimpl_key : tensorimpl_keys) {
|
||||||
serialize::InputArchive param_state_archive;
|
serialize::InputArchive param_state_archive;
|
||||||
archive.read(tensorimpl_key, param_state_archive);
|
archive.read(tensorimpl_key, param_state_archive);
|
||||||
DerivedOptimizerParamState param_state;
|
DerivedOptimizerParamState param_state;
|
||||||
param_state.serialize(param_state_archive);
|
param_state.serialize(param_state_archive);
|
||||||
state[tensorimpl_key] =
|
state[reinterpret_cast<void*>(std::stoull(tensorimpl_key))] =
|
||||||
std::make_unique<DerivedOptimizerParamState>(param_state);
|
std::make_unique<DerivedOptimizerParamState>(param_state);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -61,8 +61,9 @@ void serialize(
|
|||||||
"params/size", torch::tensor(static_cast<int64_t>(params.size())));
|
"params/size", torch::tensor(static_cast<int64_t>(params.size())));
|
||||||
for (const auto index : c10::irange(params.size())) {
|
for (const auto index : c10::irange(params.size())) {
|
||||||
param_group_archive.write(
|
param_group_archive.write(
|
||||||
"params/" + c10::guts::to_string(index),
|
"params/" + std::to_string(index),
|
||||||
IValue(c10::guts::to_string(params[index].unsafeGetTensorImpl())));
|
IValue(std::to_string(
|
||||||
|
reinterpret_cast<size_t>(params[index].unsafeGetTensorImpl()))));
|
||||||
}
|
}
|
||||||
const DerivedOptimizerParamOptions& param_group_options =
|
const DerivedOptimizerParamOptions& param_group_options =
|
||||||
static_cast<const DerivedOptimizerParamOptions&>(
|
static_cast<const DerivedOptimizerParamOptions&>(
|
||||||
@ -71,8 +72,7 @@ void serialize(
|
|||||||
param_group_archive.compilation_unit());
|
param_group_archive.compilation_unit());
|
||||||
param_group_options.serialize(param_group_options_archive);
|
param_group_options.serialize(param_group_options_archive);
|
||||||
param_group_archive.write("options", param_group_options_archive);
|
param_group_archive.write("options", param_group_options_archive);
|
||||||
archive.write(
|
archive.write("param_groups/" + std::to_string(i), param_group_archive);
|
||||||
"param_groups/" + c10::guts::to_string(i), param_group_archive);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -92,15 +92,14 @@ void serialize(
|
|||||||
const int64_t param_groups_size = param_groups_size_tensor.item<int64_t>();
|
const int64_t param_groups_size = param_groups_size_tensor.item<int64_t>();
|
||||||
for (const auto i : c10::irange(param_groups_size)) {
|
for (const auto i : c10::irange(param_groups_size)) {
|
||||||
serialize::InputArchive param_group_archive;
|
serialize::InputArchive param_group_archive;
|
||||||
archive.read(
|
archive.read("param_groups/" + std::to_string(i), param_group_archive);
|
||||||
"param_groups/" + c10::guts::to_string(i), param_group_archive);
|
|
||||||
torch::Tensor size_tensor;
|
torch::Tensor size_tensor;
|
||||||
param_group_archive.read("params/size", size_tensor);
|
param_group_archive.read("params/size", size_tensor);
|
||||||
const int64_t size = size_tensor.item<int64_t>();
|
const int64_t size = size_tensor.item<int64_t>();
|
||||||
std::vector<std::string> params;
|
std::vector<std::string> params;
|
||||||
for (const auto index : c10::irange(size)) {
|
for (const auto index : c10::irange(size)) {
|
||||||
IValue ivalue;
|
IValue ivalue;
|
||||||
param_group_archive.read("params/" + c10::to_string(index), ivalue);
|
param_group_archive.read("params/" + std::to_string(index), ivalue);
|
||||||
std::string element = ivalue.toStringRef();
|
std::string element = ivalue.toStringRef();
|
||||||
params.emplace_back(element);
|
params.emplace_back(element);
|
||||||
}
|
}
|
||||||
@ -170,8 +169,7 @@ void serialize(serialize::InputArchive& archive, Optimizer& optimizer) {
|
|||||||
TORCH_INTERNAL_ASSERT(pytorch_version.toStringRef() == "1.5.0");
|
TORCH_INTERNAL_ASSERT(pytorch_version.toStringRef() == "1.5.0");
|
||||||
serialize::InputArchive state_archive;
|
serialize::InputArchive state_archive;
|
||||||
archive.read("state", state_archive);
|
archive.read("state", state_archive);
|
||||||
ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>
|
ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> saved_state;
|
||||||
saved_state;
|
|
||||||
detail::serialize<DerivedOptimizerParamState>(state_archive, saved_state);
|
detail::serialize<DerivedOptimizerParamState>(state_archive, saved_state);
|
||||||
|
|
||||||
serialize::InputArchive param_groups_archive;
|
serialize::InputArchive param_groups_archive;
|
||||||
@ -194,10 +192,11 @@ void serialize(serialize::InputArchive& archive, Optimizer& optimizer) {
|
|||||||
"loaded state dict contains a parameter group that has a different size than the optimizer's parameter group");
|
"loaded state dict contains a parameter group that has a different size than the optimizer's parameter group");
|
||||||
|
|
||||||
for (const auto idx : c10::irange(params.size())) {
|
for (const auto idx : c10::irange(params.size())) {
|
||||||
if (saved_state.find(param_group_old_keys[idx]) != saved_state.end()) {
|
auto param_group_old_key =
|
||||||
optimizer
|
reinterpret_cast<void*>(std::stoull(param_group_old_keys[idx]));
|
||||||
.state()[c10::guts::to_string(params[idx].unsafeGetTensorImpl())] =
|
if (saved_state.find(param_group_old_key) != saved_state.end()) {
|
||||||
std::move(saved_state[param_group_old_keys[idx]]);
|
optimizer.state()[params[idx].unsafeGetTensorImpl()] =
|
||||||
|
std::move(saved_state[param_group_old_key]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -213,7 +212,7 @@ void serialize(
|
|||||||
key + "/size", torch::tensor(static_cast<int64_t>(buffers.size())));
|
key + "/size", torch::tensor(static_cast<int64_t>(buffers.size())));
|
||||||
for (const auto index : c10::irange(buffers.size())) {
|
for (const auto index : c10::irange(buffers.size())) {
|
||||||
archive.write(
|
archive.write(
|
||||||
key + "/" + c10::to_string(index), buffers[index], /*is_buffer=*/true);
|
key + "/" + std::to_string(index), buffers[index], /*is_buffer=*/true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -230,7 +229,7 @@ void serialize(
|
|||||||
for (const auto index : c10::irange(size)) {
|
for (const auto index : c10::irange(size)) {
|
||||||
buffers.emplace_back();
|
buffers.emplace_back();
|
||||||
archive.read(
|
archive.read(
|
||||||
key + "/" + c10::to_string(index), buffers.back(), /*is_buffer=*/true);
|
key + "/" + std::to_string(index), buffers.back(), /*is_buffer=*/true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -67,7 +67,7 @@ void save(const std::vector<torch::Tensor>& tensor_vec, SaveToArgs&&... args) {
|
|||||||
serialize::OutputArchive archive(std::make_shared<jit::CompilationUnit>());
|
serialize::OutputArchive archive(std::make_shared<jit::CompilationUnit>());
|
||||||
for (const auto i : c10::irange(tensor_vec.size())) {
|
for (const auto i : c10::irange(tensor_vec.size())) {
|
||||||
auto& value = tensor_vec[i];
|
auto& value = tensor_vec[i];
|
||||||
archive.write(c10::to_string(i), value);
|
archive.write(std::to_string(i), value);
|
||||||
}
|
}
|
||||||
archive.save_to(std::forward<SaveToArgs>(args)...);
|
archive.save_to(std::forward<SaveToArgs>(args)...);
|
||||||
}
|
}
|
||||||
@ -135,7 +135,7 @@ void load(std::vector<torch::Tensor>& tensor_vec, LoadFromArgs&&... args) {
|
|||||||
// the serialized `std::vector<torch::Tensor>`.
|
// the serialized `std::vector<torch::Tensor>`.
|
||||||
size_t index = 0;
|
size_t index = 0;
|
||||||
torch::Tensor value;
|
torch::Tensor value;
|
||||||
while (archive.try_read(c10::to_string(index), value)) {
|
while (archive.try_read(std::to_string(index), value)) {
|
||||||
tensor_vec.push_back(std::move(value));
|
tensor_vec.push_back(std::move(value));
|
||||||
value = torch::Tensor();
|
value = torch::Tensor();
|
||||||
index++;
|
index++;
|
||||||
|
@ -77,11 +77,11 @@ Tensor Adagrad::step(LossClosure closure) {
|
|||||||
}
|
}
|
||||||
auto grad = p.grad();
|
auto grad = p.grad();
|
||||||
TORCH_INTERNAL_ASSERT(
|
TORCH_INTERNAL_ASSERT(
|
||||||
state_[c10::guts::to_string(p.unsafeGetTensorImpl())] != nullptr,
|
state_[p.unsafeGetTensorImpl()] != nullptr,
|
||||||
"state found NULL for the Tensor ",
|
"state found NULL for the Tensor ",
|
||||||
p);
|
p);
|
||||||
auto& state = static_cast<AdagradParamState&>(
|
auto& state =
|
||||||
*state_[c10::guts::to_string(p.unsafeGetTensorImpl())]);
|
static_cast<AdagradParamState&>(*state_[p.unsafeGetTensorImpl()]);
|
||||||
auto& options = static_cast<AdagradOptions&>(group.options());
|
auto& options = static_cast<AdagradOptions&>(group.options());
|
||||||
|
|
||||||
state.step(state.step() + 1);
|
state.step(state.step() + 1);
|
||||||
@ -147,8 +147,7 @@ void Adagrad::load(serialize::InputArchive& archive) {
|
|||||||
auto state = std::make_unique<AdagradParamState>();
|
auto state = std::make_unique<AdagradParamState>();
|
||||||
state->step(step_buffers[idx]);
|
state->step(step_buffers[idx]);
|
||||||
state->sum(sum_buffers[idx]);
|
state->sum(sum_buffers[idx]);
|
||||||
state_[c10::guts::to_string(params[idx].unsafeGetTensorImpl())] =
|
state_[params[idx].unsafeGetTensorImpl()] = std::move(state);
|
||||||
std::move(state);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -84,8 +84,7 @@ Tensor Adam::step(LossClosure closure) {
|
|||||||
}
|
}
|
||||||
auto grad = p.grad();
|
auto grad = p.grad();
|
||||||
TORCH_CHECK(!grad.is_sparse(), "Adam does not support sparse gradients" /*, please consider SparseAdam instead*/);
|
TORCH_CHECK(!grad.is_sparse(), "Adam does not support sparse gradients" /*, please consider SparseAdam instead*/);
|
||||||
auto param_state =
|
auto param_state = state_.find(p.unsafeGetTensorImpl());
|
||||||
state_.find(c10::guts::to_string(p.unsafeGetTensorImpl()));
|
|
||||||
auto& options = static_cast<AdamOptions&>(group.options());
|
auto& options = static_cast<AdamOptions&>(group.options());
|
||||||
|
|
||||||
// State initialization
|
// State initialization
|
||||||
@ -100,12 +99,11 @@ Tensor Adam::step(LossClosure closure) {
|
|||||||
// Maintains max of all exp. moving avg. of sq. grad. values
|
// Maintains max of all exp. moving avg. of sq. grad. values
|
||||||
state->max_exp_avg_sq(torch::zeros_like(p, MemoryFormat::Preserve));
|
state->max_exp_avg_sq(torch::zeros_like(p, MemoryFormat::Preserve));
|
||||||
}
|
}
|
||||||
state_[c10::guts::to_string(p.unsafeGetTensorImpl())] =
|
state_[p.unsafeGetTensorImpl()] = std::move(state);
|
||||||
std::move(state);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto& state = static_cast<AdamParamState&>(
|
auto& state =
|
||||||
*state_[c10::guts::to_string(p.unsafeGetTensorImpl())]);
|
static_cast<AdamParamState&>(*state_[p.unsafeGetTensorImpl()]);
|
||||||
auto& exp_avg = state.exp_avg();
|
auto& exp_avg = state.exp_avg();
|
||||||
auto& exp_avg_sq = state.exp_avg_sq();
|
auto& exp_avg_sq = state.exp_avg_sq();
|
||||||
auto& max_exp_avg_sq = state.max_exp_avg_sq();
|
auto& max_exp_avg_sq = state.max_exp_avg_sq();
|
||||||
@ -179,8 +177,7 @@ void Adam::load(serialize::InputArchive& archive) {
|
|||||||
if (idx < max_exp_average_sq_buffers.size()) {
|
if (idx < max_exp_average_sq_buffers.size()) {
|
||||||
state->max_exp_avg_sq(max_exp_average_sq_buffers.at(idx));
|
state->max_exp_avg_sq(max_exp_average_sq_buffers.at(idx));
|
||||||
}
|
}
|
||||||
state_[c10::guts::to_string(params.at(idx).unsafeGetTensorImpl())] =
|
state_[params.at(idx).unsafeGetTensorImpl()] = std::move(state);
|
||||||
std::move(state);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -84,8 +84,7 @@ Tensor AdamW::step(LossClosure closure) {
|
|||||||
}
|
}
|
||||||
const auto& grad = p.grad();
|
const auto& grad = p.grad();
|
||||||
TORCH_CHECK(!grad.is_sparse(), "AdamW does not support sparse gradients" /*, please consider SparseAdamW instead*/);
|
TORCH_CHECK(!grad.is_sparse(), "AdamW does not support sparse gradients" /*, please consider SparseAdamW instead*/);
|
||||||
auto param_state =
|
auto param_state = state_.find(p.unsafeGetTensorImpl());
|
||||||
state_.find(c10::guts::to_string(p.unsafeGetTensorImpl()));
|
|
||||||
auto& options = static_cast<AdamWOptions&>(group.options());
|
auto& options = static_cast<AdamWOptions&>(group.options());
|
||||||
|
|
||||||
// Perform stepweight decay
|
// Perform stepweight decay
|
||||||
@ -105,12 +104,11 @@ Tensor AdamW::step(LossClosure closure) {
|
|||||||
// Maintains max of all exp. moving avg. of sq. grad. values
|
// Maintains max of all exp. moving avg. of sq. grad. values
|
||||||
state->max_exp_avg_sq(torch::zeros_like(p, MemoryFormat::Preserve));
|
state->max_exp_avg_sq(torch::zeros_like(p, MemoryFormat::Preserve));
|
||||||
}
|
}
|
||||||
state_[c10::guts::to_string(p.unsafeGetTensorImpl())] =
|
state_[p.unsafeGetTensorImpl()] = std::move(state);
|
||||||
std::move(state);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto& state = static_cast<AdamWParamState&>(
|
auto& state =
|
||||||
*state_[c10::guts::to_string(p.unsafeGetTensorImpl())]);
|
static_cast<AdamWParamState&>(*state_[p.unsafeGetTensorImpl()]);
|
||||||
auto& exp_avg = state.exp_avg();
|
auto& exp_avg = state.exp_avg();
|
||||||
auto& exp_avg_sq = state.exp_avg_sq();
|
auto& exp_avg_sq = state.exp_avg_sq();
|
||||||
auto& max_exp_avg_sq = state.max_exp_avg_sq();
|
auto& max_exp_avg_sq = state.max_exp_avg_sq();
|
||||||
@ -180,8 +178,7 @@ void AdamW::load(serialize::InputArchive& archive) {
|
|||||||
if (idx < max_exp_average_sq_buffers.size()) {
|
if (idx < max_exp_average_sq_buffers.size()) {
|
||||||
state->max_exp_avg_sq(max_exp_average_sq_buffers.at(idx));
|
state->max_exp_avg_sq(max_exp_average_sq_buffers.at(idx));
|
||||||
}
|
}
|
||||||
state_[c10::guts::to_string(params.at(idx).unsafeGetTensorImpl())] =
|
state_[params.at(idx).unsafeGetTensorImpl()] = std::move(state);
|
||||||
std::move(state);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -438,14 +438,13 @@ Tensor LBFGS::step(LossClosure closure) {
|
|||||||
|
|
||||||
// NOTE: LBFGS has only global state, but we register it as state for
|
// NOTE: LBFGS has only global state, but we register it as state for
|
||||||
// the first param, because this helps with casting in load_state_dict
|
// the first param, because this helps with casting in load_state_dict
|
||||||
auto param_state =
|
auto param_state = state_.find(_params.at(0).unsafeGetTensorImpl());
|
||||||
state_.find(c10::guts::to_string(_params.at(0).unsafeGetTensorImpl()));
|
|
||||||
if (param_state == state_.end()) {
|
if (param_state == state_.end()) {
|
||||||
state_[c10::guts::to_string(_params.at(0).unsafeGetTensorImpl())] =
|
state_[_params.at(0).unsafeGetTensorImpl()] =
|
||||||
std::make_unique<LBFGSParamState>();
|
std::make_unique<LBFGSParamState>();
|
||||||
}
|
}
|
||||||
auto& state = static_cast<LBFGSParamState&>(
|
auto& state = static_cast<LBFGSParamState&>(
|
||||||
*state_[c10::guts::to_string(_params.at(0).unsafeGetTensorImpl())]);
|
*state_[_params.at(0).unsafeGetTensorImpl()]);
|
||||||
// evaluate initial f(x) and df/dx
|
// evaluate initial f(x) and df/dx
|
||||||
Tensor orig_loss;
|
Tensor orig_loss;
|
||||||
{
|
{
|
||||||
@ -655,8 +654,7 @@ void LBFGS::load(serialize::InputArchive& archive) {
|
|||||||
state->prev_loss(prev_loss.item<double>());
|
state->prev_loss(prev_loss.item<double>());
|
||||||
state->old_dirs(old_dirs);
|
state->old_dirs(old_dirs);
|
||||||
state->old_stps(old_stps);
|
state->old_stps(old_stps);
|
||||||
state_[c10::guts::to_string(
|
state_[param_groups_.at(0).params().at(0).unsafeGetTensorImpl()] =
|
||||||
param_groups_.at(0).params().at(0).unsafeGetTensorImpl())] =
|
|
||||||
std::move(state);
|
std::move(state);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -109,7 +109,7 @@ void Optimizer::add_param_group(const OptimizerParamGroup& param_group) {
|
|||||||
}
|
}
|
||||||
for (const auto& p : param_group_.params()) {
|
for (const auto& p : param_group_.params()) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
state_.count(c10::guts::to_string(p.unsafeGetTensorImpl())) == 0,
|
state_.count(p.unsafeGetTensorImpl()) == 0,
|
||||||
"some parameters appear in more than one parameter group");
|
"some parameters appear in more than one parameter group");
|
||||||
}
|
}
|
||||||
param_groups_.emplace_back(std::move(param_group_));
|
param_groups_.emplace_back(std::move(param_group_));
|
||||||
@ -171,12 +171,12 @@ const std::vector<OptimizerParamGroup>& Optimizer::param_groups()
|
|||||||
return param_groups_;
|
return param_groups_;
|
||||||
}
|
}
|
||||||
|
|
||||||
ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>&
|
ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>& Optimizer::
|
||||||
Optimizer::state() noexcept {
|
state() noexcept {
|
||||||
return state_;
|
return state_;
|
||||||
}
|
}
|
||||||
|
|
||||||
const ska::flat_hash_map<std::string, std::unique_ptr<OptimizerParamState>>&
|
const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
|
||||||
Optimizer::state() const noexcept {
|
Optimizer::state() const noexcept {
|
||||||
return state_;
|
return state_;
|
||||||
}
|
}
|
||||||
|
@ -85,8 +85,7 @@ Tensor RMSprop::step(LossClosure closure) {
|
|||||||
auto grad = p.grad();
|
auto grad = p.grad();
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
!grad.is_sparse(), "RMSprop does not support sparse gradients");
|
!grad.is_sparse(), "RMSprop does not support sparse gradients");
|
||||||
auto param_state =
|
auto param_state = state_.find(p.unsafeGetTensorImpl());
|
||||||
state_.find(c10::guts::to_string(p.unsafeGetTensorImpl()));
|
|
||||||
auto& options = static_cast<RMSpropOptions&>(group.options());
|
auto& options = static_cast<RMSpropOptions&>(group.options());
|
||||||
|
|
||||||
// State initialization
|
// State initialization
|
||||||
@ -100,12 +99,11 @@ Tensor RMSprop::step(LossClosure closure) {
|
|||||||
if (options.centered()) {
|
if (options.centered()) {
|
||||||
state->grad_avg(torch::zeros_like(p, MemoryFormat::Preserve));
|
state->grad_avg(torch::zeros_like(p, MemoryFormat::Preserve));
|
||||||
}
|
}
|
||||||
state_[c10::guts::to_string(p.unsafeGetTensorImpl())] =
|
state_[p.unsafeGetTensorImpl()] = std::move(state);
|
||||||
std::move(state);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto& state = static_cast<RMSpropParamState&>(
|
auto& state =
|
||||||
*state_[c10::guts::to_string(p.unsafeGetTensorImpl())]);
|
static_cast<RMSpropParamState&>(*state_[p.unsafeGetTensorImpl()]);
|
||||||
auto& square_avg = state.square_avg();
|
auto& square_avg = state.square_avg();
|
||||||
auto alpha = options.alpha();
|
auto alpha = options.alpha();
|
||||||
|
|
||||||
@ -176,8 +174,7 @@ void RMSprop::load(serialize::InputArchive& archive) {
|
|||||||
if (idx < grad_average_buffers.size()) {
|
if (idx < grad_average_buffers.size()) {
|
||||||
state->grad_avg(grad_average_buffers.at(idx));
|
state->grad_avg(grad_average_buffers.at(idx));
|
||||||
}
|
}
|
||||||
state_[c10::guts::to_string(params[idx].unsafeGetTensorImpl())] =
|
state_[params[idx].unsafeGetTensorImpl()] = std::move(state);
|
||||||
std::move(state);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -84,14 +84,12 @@ Tensor SGD::step(LossClosure closure) {
|
|||||||
}
|
}
|
||||||
if (momentum != 0) {
|
if (momentum != 0) {
|
||||||
Tensor buf;
|
Tensor buf;
|
||||||
auto param_state =
|
auto param_state = state_.find(p.unsafeGetTensorImpl());
|
||||||
state_.find(c10::guts::to_string(p.unsafeGetTensorImpl()));
|
|
||||||
if (param_state == state_.end()) {
|
if (param_state == state_.end()) {
|
||||||
buf = torch::clone(d_p).detach();
|
buf = torch::clone(d_p).detach();
|
||||||
auto state = std::make_unique<SGDParamState>();
|
auto state = std::make_unique<SGDParamState>();
|
||||||
state->momentum_buffer(buf);
|
state->momentum_buffer(buf);
|
||||||
state_[c10::guts::to_string(p.unsafeGetTensorImpl())] =
|
state_[p.unsafeGetTensorImpl()] = std::move(state);
|
||||||
std::move(state);
|
|
||||||
} else {
|
} else {
|
||||||
buf = static_cast<SGDParamState&>(*param_state->second)
|
buf = static_cast<SGDParamState&>(*param_state->second)
|
||||||
.momentum_buffer();
|
.momentum_buffer();
|
||||||
@ -130,8 +128,7 @@ void SGD::load(serialize::InputArchive& archive) {
|
|||||||
for (const auto idx : c10::irange(momentum_buffers.size())) {
|
for (const auto idx : c10::irange(momentum_buffers.size())) {
|
||||||
auto state = std::make_unique<SGDParamState>();
|
auto state = std::make_unique<SGDParamState>();
|
||||||
state->momentum_buffer(momentum_buffers[idx]);
|
state->momentum_buffer(momentum_buffers[idx]);
|
||||||
state_[c10::guts::to_string(params[idx].unsafeGetTensorImpl())] =
|
state_[params[idx].unsafeGetTensorImpl()] = std::move(state);
|
||||||
std::move(state);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
#include <torch/csrc/autograd/profiler_kineto.h>
|
#include <torch/csrc/autograd/profiler_kineto.h>
|
||||||
|
|
||||||
#include <c10/macros/Export.h>
|
#include <c10/macros/Export.h>
|
||||||
#include <c10/util/C++17.h>
|
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
#include <c10/util/flat_hash_map.h>
|
#include <c10/util/flat_hash_map.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
|
@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
#include <ATen/core/TensorBase.h>
|
#include <ATen/core/TensorBase.h>
|
||||||
#include <c10/macros/Macros.h>
|
#include <c10/macros/Macros.h>
|
||||||
#include <c10/util/C++17.h>
|
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
#include <c10/util/Logging.h>
|
#include <c10/util/Logging.h>
|
||||||
#include <c10/util/Optional.h>
|
#include <c10/util/Optional.h>
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
#include <c10/util/C++17.h>
|
|
||||||
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
|
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
|
||||||
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
||||||
#include <torch/csrc/distributed/rpc/utils.h>
|
#include <torch/csrc/distributed/rpc/utils.h>
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
#include <torch/csrc/distributed/rpc/python_call.h>
|
#include <torch/csrc/distributed/rpc/python_call.h>
|
||||||
|
|
||||||
#include <c10/util/C++17.h>
|
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace distributed {
|
namespace distributed {
|
||||||
namespace rpc {
|
namespace rpc {
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
#include <ATen/ThreadLocalState.h>
|
#include <ATen/ThreadLocalState.h>
|
||||||
#include <c10/util/C++17.h>
|
|
||||||
#include <torch/csrc/distributed/autograd/context/container.h>
|
#include <torch/csrc/distributed/autograd/context/container.h>
|
||||||
#include <torch/csrc/distributed/autograd/utils.h>
|
#include <torch/csrc/distributed/autograd/utils.h>
|
||||||
#include <torch/csrc/distributed/rpc/message.h>
|
#include <torch/csrc/distributed/rpc/message.h>
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
#include <torch/csrc/distributed/rpc/python_resp.h>
|
#include <torch/csrc/distributed/rpc/python_resp.h>
|
||||||
|
|
||||||
#include <c10/util/C++17.h>
|
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace distributed {
|
namespace distributed {
|
||||||
namespace rpc {
|
namespace rpc {
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
#include <torch/csrc/distributed/rpc/script_resp.h>
|
#include <torch/csrc/distributed/rpc/script_resp.h>
|
||||||
|
|
||||||
#include <c10/util/C++17.h>
|
|
||||||
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
||||||
#include <torch/csrc/jit/serialization/pickle.h>
|
#include <torch/csrc/jit/serialization/pickle.h>
|
||||||
#include <torch/csrc/jit/serialization/unpickler.h>
|
#include <torch/csrc/jit/serialization/unpickler.h>
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
#include <torch/csrc/distributed/rpc/unpickled_python_call.h>
|
#include <torch/csrc/distributed/rpc/unpickled_python_call.h>
|
||||||
|
|
||||||
#include <c10/util/C++17.h>
|
|
||||||
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
|
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
#include <torch/csrc/distributed/rpc/unpickled_python_remote_call.h>
|
#include <torch/csrc/distributed/rpc/unpickled_python_remote_call.h>
|
||||||
|
|
||||||
#include <c10/util/C++17.h>
|
|
||||||
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
|
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
|
@ -93,7 +93,7 @@ C10_EXPORT std::string kindToString(int kind) {
|
|||||||
TC_FORALL_TOKEN_KINDS(DEFINE_CASE)
|
TC_FORALL_TOKEN_KINDS(DEFINE_CASE)
|
||||||
#undef DEFINE_CASE
|
#undef DEFINE_CASE
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Unknown kind: " + c10::guts::to_string(kind));
|
throw std::runtime_error("Unknown kind: " + std::to_string(kind));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -518,8 +518,7 @@ struct Lexer {
|
|||||||
indent_stack.pop_back();
|
indent_stack.pop_back();
|
||||||
next_tokens.emplace_back(TK_DEDENT, r.range);
|
next_tokens.emplace_back(TK_DEDENT, r.range);
|
||||||
if (indent_stack.empty()) {
|
if (indent_stack.empty()) {
|
||||||
reportError(
|
reportError("invalid indent level " + std::to_string(depth), r);
|
||||||
"invalid indent level " + c10::guts::to_string(depth), r);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return; // We've already queued the tokens
|
return; // We've already queued the tokens
|
||||||
|
@ -138,7 +138,7 @@ c10::optional<AliasInfo> SchemaTypeParser::parseAliasAnnotation() {
|
|||||||
L.expect(')');
|
L.expect(')');
|
||||||
} else if (L.nextIf('!')) {
|
} else if (L.nextIf('!')) {
|
||||||
alias_info.addBeforeSet(
|
alias_info.addBeforeSet(
|
||||||
Symbol::fromQualString("alias::$" + c10::guts::to_string(next_id++)));
|
Symbol::fromQualString("alias::$" + std::to_string(next_id++)));
|
||||||
alias_info.setIsWrite(true);
|
alias_info.setIsWrite(true);
|
||||||
} else {
|
} else {
|
||||||
return c10::nullopt;
|
return c10::nullopt;
|
||||||
|
@ -508,7 +508,7 @@ RangeValue::RangeValue(
|
|||||||
if (!typ->cast<IntType>()) {
|
if (!typ->cast<IntType>()) {
|
||||||
throw ErrorReport(loc)
|
throw ErrorReport(loc)
|
||||||
<< "all inputs of range must be ints, found " << typ->repr_str()
|
<< "all inputs of range must be ints, found " << typ->repr_str()
|
||||||
<< " in argument " << c10::guts::to_string(i);
|
<< " in argument " << std::to_string(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -564,7 +564,7 @@ mobile::Module _load_for_mobile_impl(
|
|||||||
// Add model_name and model_size to metadata_map
|
// Add model_name and model_size to metadata_map
|
||||||
extra_files.insert(std::make_pair("model_name", result.name()));
|
extra_files.insert(std::make_pair("model_name", result.name()));
|
||||||
extra_files.insert(
|
extra_files.insert(
|
||||||
std::make_pair("model_size", c10::guts::to_string(model_size)));
|
std::make_pair("model_size", std::to_string(model_size)));
|
||||||
metadata_map = observer->processMetadataFromExtra(extra_files);
|
metadata_map = observer->processMetadataFromExtra(extra_files);
|
||||||
observer->onExitLoadModel(instance_key, metadata_map);
|
observer->onExitLoadModel(instance_key, metadata_map);
|
||||||
}
|
}
|
||||||
|
@ -63,7 +63,7 @@ void SGD::add_param_group(const SGDParamGroup& param_group) {
|
|||||||
}
|
}
|
||||||
for (const auto& p : param_group_.params()) {
|
for (const auto& p : param_group_.params()) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
state_.count(c10::guts::to_string(p.unsafeGetTensorImpl())) == 0,
|
state_.count(p.unsafeGetTensorImpl()) == 0,
|
||||||
"some parameters appear in more than one parameter group");
|
"some parameters appear in more than one parameter group");
|
||||||
}
|
}
|
||||||
param_groups_.emplace_back(std::move(param_group_));
|
param_groups_.emplace_back(std::move(param_group_));
|
||||||
@ -104,14 +104,12 @@ Tensor SGD::step(const LossClosure& closure) {
|
|||||||
}
|
}
|
||||||
if (momentum != 0) {
|
if (momentum != 0) {
|
||||||
Tensor buf;
|
Tensor buf;
|
||||||
auto param_state =
|
auto param_state = state_.find(p.unsafeGetTensorImpl());
|
||||||
state_.find(c10::guts::to_string(p.unsafeGetTensorImpl()));
|
|
||||||
if (param_state == state_.end()) {
|
if (param_state == state_.end()) {
|
||||||
buf = torch::clone(d_p).detach();
|
buf = torch::clone(d_p).detach();
|
||||||
auto state = std::make_unique<SGDParamState>();
|
auto state = std::make_unique<SGDParamState>();
|
||||||
state->momentum_buffer(buf);
|
state->momentum_buffer(buf);
|
||||||
state_[c10::guts::to_string(p.unsafeGetTensorImpl())] =
|
state_[p.unsafeGetTensorImpl()] = std::move(state);
|
||||||
std::move(state);
|
|
||||||
} else {
|
} else {
|
||||||
buf = static_cast<SGDParamState&>(*param_state->second)
|
buf = static_cast<SGDParamState&>(*param_state->second)
|
||||||
.momentum_buffer();
|
.momentum_buffer();
|
||||||
|
@ -122,7 +122,7 @@ class TORCH_API SGD {
|
|||||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||||
std::vector<SGDParamGroup> param_groups_;
|
std::vector<SGDParamGroup> param_groups_;
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||||
ska::flat_hash_map<std::string, std::unique_ptr<SGDParamState>> state_;
|
ska::flat_hash_map<void*, std::unique_ptr<SGDParamState>> state_;
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||||
std::unique_ptr<SGDOptions> defaults_;
|
std::unique_ptr<SGDOptions> defaults_;
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||||
|
@ -126,7 +126,7 @@ Node* insertDeQuant(
|
|||||||
Node* dequant = graph->create(Symbol::aten("dequantize"), {quantized_val});
|
Node* dequant = graph->create(Symbol::aten("dequantize"), {quantized_val});
|
||||||
dequant->output()
|
dequant->output()
|
||||||
->setDebugName(
|
->setDebugName(
|
||||||
original_val->debugName() + ".dequant." + c10::guts::to_string(id))
|
original_val->debugName() + ".dequant." + std::to_string(id))
|
||||||
->setType(original_val->type());
|
->setType(original_val->type());
|
||||||
graph->insertNode(dequant);
|
graph->insertNode(dequant);
|
||||||
return dequant;
|
return dequant;
|
||||||
|
@ -305,7 +305,7 @@ void Pickler::pushStorageOfTensor(const at::Tensor& tensor) {
|
|||||||
// root_key
|
// root_key
|
||||||
std::string root_key = get_tensor_id_ != nullptr
|
std::string root_key = get_tensor_id_ != nullptr
|
||||||
? get_tensor_id_(tensor)
|
? get_tensor_id_(tensor)
|
||||||
: c10::to_string(tensor_data_.size());
|
: std::to_string(tensor_data_.size());
|
||||||
pushString(root_key);
|
pushString(root_key);
|
||||||
// location
|
// location
|
||||||
pushString(tensor.device().str());
|
pushString(tensor.device().str());
|
||||||
|
@ -362,7 +362,7 @@ struct PythonPrintImpl {
|
|||||||
std::string name = candidate;
|
std::string name = candidate;
|
||||||
while (used.count(name) || reserved_names.count(name)) {
|
while (used.count(name) || reserved_names.count(name)) {
|
||||||
// NOLINTNEXTLINE(performance-inefficient-string-concatenation)
|
// NOLINTNEXTLINE(performance-inefficient-string-concatenation)
|
||||||
name = candidate + c10::to_string(next_id[name]++);
|
name = candidate + std::to_string(next_id[name]++);
|
||||||
}
|
}
|
||||||
used.insert(name);
|
used.insert(name);
|
||||||
return name;
|
return name;
|
||||||
|
Reference in New Issue
Block a user