mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow building the C++ API without cereal (#11498)
Summary: I am working on unifying the C++ extensions and C++ API, and one constraint for this is that we will want to be able to build the C++ API without cereal, since we won't want to ship it with the Python `torch` package. For this I introduce a `TORCH_WITH_CEREAL` option to CMake. If on, the C++ API will be built with cereal and thus serialization support. If off, serialization functions will throw exceptions, but the library will otherwise still compile the same. __This option is on by default, so for regular C++ API users nothing will change__. However, from C++ extensions, we'll be able to turn it off. This effectively means we won't be searching for any cereal headers from C++ API headers, which wouldn't be installed in the Python package. ebetica ezyang soumith Pull Request resolved: https://github.com/pytorch/pytorch/pull/11498 Differential Revision: D9784803 Pulled By: goldsborough fbshipit-source-id: 5d0a1f2501993012d28cf3d730f45932b483abc4
This commit is contained in:
committed by
Facebook Github Bot
parent
12efef166a
commit
130d55a5f4
@ -124,6 +124,7 @@ cmake_dependent_option(
|
||||
cmake_dependent_option(
|
||||
USE_GLOO_IBVERBS "Use Gloo IB verbs for distributed. Only available if USE_GLOO is on." OFF
|
||||
"USE_GLOO" OFF)
|
||||
option(TORCH_USE_CEREAL "Build the C++ API with Cereal for serialization support" OFF)
|
||||
|
||||
# Used when building Caffe2 through setup.py
|
||||
option(BUILDING_WITH_TORCH_LIBS "Tell cmake if Caffe2 is being built alongside torch libs" OFF)
|
||||
|
@ -125,6 +125,7 @@ function (caffe2_print_configuration_summary)
|
||||
message(STATUS " USE_GLOO : ${USE_GLOO}")
|
||||
message(STATUS " USE_GLOO_IBVERBS : ${USE_GLOO_IBVERBS}")
|
||||
endif()
|
||||
message(STATUS " TORCH_USE_CEREAL : ${TORCH_USE_CEREAL}")
|
||||
|
||||
message(STATUS " Public Dependencies : ${Caffe2_PUBLIC_DEPENDENCY_LIBS}")
|
||||
message(STATUS " Private Dependencies : ${Caffe2_DEPENDENCY_LIBS}")
|
||||
|
@ -9,7 +9,8 @@ from setup_helpers.cuda import USE_CUDA
|
||||
if __name__ == '__main__':
|
||||
# Placeholder for future interface. For now just gives a nice -h.
|
||||
parser = argparse.ArgumentParser(description='Build libtorch')
|
||||
args = parser.parse_args()
|
||||
parser.add_argument('--use-cereal', action='store_true')
|
||||
options = parser.parse_args()
|
||||
|
||||
os.environ['BUILD_TORCH'] = 'ON'
|
||||
os.environ['BUILD_TEST'] = 'ON'
|
||||
@ -19,11 +20,13 @@ if __name__ == '__main__':
|
||||
tools_path = os.path.dirname(os.path.abspath(__file__))
|
||||
build_pytorch_libs = os.path.join(tools_path, 'build_pytorch_libs.sh')
|
||||
|
||||
command = '{} --use-nnpack '.format(build_pytorch_libs)
|
||||
command = [build_pytorch_libs, '--use-nnpack']
|
||||
if USE_CUDA:
|
||||
command += '--use-cuda '
|
||||
command += 'caffe2'
|
||||
command.append('--use-cuda')
|
||||
if options.use_cereal:
|
||||
command.append('--use-cereal')
|
||||
command.append('caffe2')
|
||||
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
subprocess.check_call(shlex.split(command), universal_newlines=True)
|
||||
subprocess.check_call(command, universal_newlines=True)
|
||||
|
@ -22,6 +22,7 @@ USE_NNPACK=0
|
||||
USE_MKLDNN=0
|
||||
USE_GLOO_IBVERBS=0
|
||||
CAFFE2_STATIC_LINK_CUDA=0
|
||||
TORCH_USE_CEREAL=0
|
||||
RERUN_CMAKE=1
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
@ -46,6 +47,9 @@ while [[ $# -gt 0 ]]; do
|
||||
--cuda-static-link)
|
||||
CAFFE2_STATIC_LINK_CUDA=1
|
||||
;;
|
||||
--use-cereal)
|
||||
TORCH_USE_CEREAL=1
|
||||
;;
|
||||
*)
|
||||
break
|
||||
;;
|
||||
@ -190,6 +194,7 @@ function build() {
|
||||
-DTHCUNN_SO_VERSION=1 \
|
||||
-DTHD_SO_VERSION=1 \
|
||||
-DUSE_CUDA=$USE_CUDA \
|
||||
-DTORCH_USE_CEREAL=$TORCH_USE_CEREAL \
|
||||
-DBUILD_EXAMPLES=OFF \
|
||||
-DBUILD_TEST=$BUILD_TEST \
|
||||
-DNO_NNPACK=$((1-$USE_NNPACK)) \
|
||||
|
@ -211,7 +211,6 @@ if (NOT NO_API AND NOT USE_ROCM)
|
||||
${TORCH_SRC_DIR}/csrc/api/src/optim/sgd.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
|
||||
@ -302,6 +301,13 @@ if (NOT NO_API AND NOT USE_ROCM)
|
||||
target_include_directories(torch PUBLIC
|
||||
${TORCH_SRC_DIR}/csrc/api
|
||||
${TORCH_SRC_DIR}/csrc/api/include)
|
||||
|
||||
if (TORCH_USE_CEREAL)
|
||||
target_compile_definitions(torch PUBLIC TORCH_USE_CEREAL)
|
||||
# SYSTEM headers are included with -isystem and thus do not trigger warnings.
|
||||
target_include_directories(torch SYSTEM PUBLIC
|
||||
"${TORCH_ROOT}/third_party/cereal/include") # For cereal/
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(USE_CUDA)
|
||||
@ -369,10 +375,6 @@ target_include_directories(torch PRIVATE ${ATen_CPU_INCLUDE})
|
||||
target_include_directories(torch PUBLIC
|
||||
${TORCH_SRC_DIR}/csrc)
|
||||
|
||||
# SYSTEM headers are included with -isystem and thus do not trigger warnings.
|
||||
target_include_directories(torch SYSTEM PUBLIC
|
||||
"${TORCH_ROOT}/third_party/cereal/include") # For cereal/
|
||||
|
||||
set_target_properties(torch PROPERTIES VERSION 1 SOVERSION 1)
|
||||
|
||||
if(NOT ${CMAKE_VERSION} VERSION_LESS "3.1")
|
||||
@ -407,7 +409,7 @@ endif()
|
||||
if (BUILD_TEST AND NOT NO_API AND NOT USE_ROCM)
|
||||
set(TORCH_API_TEST_DIR "${TORCH_ROOT}/test/cpp/api")
|
||||
|
||||
add_executable(test_api
|
||||
set(TORCH_API_TEST_SOURCES
|
||||
${TORCH_API_TEST_DIR}/any.cpp
|
||||
${TORCH_API_TEST_DIR}/cursor.cpp
|
||||
${TORCH_API_TEST_DIR}/integration.cpp
|
||||
@ -419,15 +421,19 @@ if (BUILD_TEST AND NOT NO_API AND NOT USE_ROCM)
|
||||
${TORCH_API_TEST_DIR}/parallel.cpp
|
||||
${TORCH_API_TEST_DIR}/rnn.cpp
|
||||
${TORCH_API_TEST_DIR}/sequential.cpp
|
||||
${TORCH_API_TEST_DIR}/serialization.cpp
|
||||
${TORCH_API_TEST_DIR}/static.cpp
|
||||
${TORCH_API_TEST_DIR}/tensor_cuda.cpp
|
||||
${TORCH_API_TEST_DIR}/tensor.cpp
|
||||
${TORCH_API_TEST_DIR}/jit.cpp
|
||||
# Temporary until ATen tests are built with Caffe2
|
||||
${TORCH_API_TEST_DIR}/tensor_options.cpp
|
||||
${TORCH_API_TEST_DIR}/tensor_options_cuda.cpp
|
||||
)
|
||||
)
|
||||
|
||||
if (TORCH_USE_CEREAL)
|
||||
list(APPEND TORCH_API_TEST_SOURCES ${TORCH_API_TEST_DIR}/serialization.cpp)
|
||||
endif()
|
||||
|
||||
add_executable(test_api ${TORCH_API_TEST_SOURCES})
|
||||
|
||||
target_include_directories(test_api
|
||||
PUBLIC
|
||||
|
@ -2,13 +2,11 @@
|
||||
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/optim/optimizer.h>
|
||||
#include <torch/serialization.h>
|
||||
#include <torch/tensor.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <cereal/access.hpp>
|
||||
#include <cereal/cereal.hpp>
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
@ -37,12 +35,16 @@ class Adagrad : public Optimizer {
|
||||
|
||||
template <class Archive>
|
||||
void serialize(Archive& ar) {
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
ar(CEREAL_NVP(sum_));
|
||||
ar(CEREAL_NVP(step_));
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
}
|
||||
|
||||
private:
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
friend class cereal::access;
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
Adagrad() : options(0) {}
|
||||
|
||||
std::vector<Tensor> sum_;
|
||||
@ -50,3 +52,10 @@ class Adagrad : public Optimizer {
|
||||
};
|
||||
} // namespace optim
|
||||
} // namespace torch
|
||||
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
CEREAL_REGISTER_TYPE(torch::optim::Adagrad);
|
||||
CEREAL_REGISTER_POLYMORPHIC_RELATION(
|
||||
torch::optim::Optimizer,
|
||||
torch::optim::Adagrad);
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
|
@ -3,12 +3,10 @@
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/nn/pimpl.h>
|
||||
#include <torch/optim/optimizer.h>
|
||||
#include <torch/serialization.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <cereal/access.hpp>
|
||||
#include <cereal/cereal.hpp>
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
@ -36,16 +34,20 @@ class Adam : public Optimizer {
|
||||
|
||||
template <class Archive>
|
||||
void serialize(Archive& ar) {
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
ar(CEREAL_NVP(step_buffers_),
|
||||
CEREAL_NVP(exp_average_buffers_),
|
||||
CEREAL_NVP(exp_average_sq_buffers_),
|
||||
CEREAL_NVP(max_exp_average_sq_buffers_));
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
}
|
||||
|
||||
AdamOptions options;
|
||||
|
||||
private:
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
friend class cereal::access;
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
Adam() : options(0) {}
|
||||
|
||||
std::vector<int64_t> step_buffers_;
|
||||
@ -53,6 +55,12 @@ class Adam : public Optimizer {
|
||||
std::vector<Tensor> exp_average_sq_buffers_;
|
||||
std::vector<Tensor> max_exp_average_sq_buffers_;
|
||||
};
|
||||
|
||||
} // namespace optim
|
||||
} // namespace torch
|
||||
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
CEREAL_REGISTER_TYPE(torch::optim::Adam);
|
||||
CEREAL_REGISTER_POLYMORPHIC_RELATION(
|
||||
torch::optim::Optimizer,
|
||||
torch::optim::Adam);
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
|
@ -2,12 +2,10 @@
|
||||
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/optim/optimizer.h>
|
||||
#include <torch/serialization.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <cereal/access.hpp>
|
||||
#include <cereal/cereal.hpp>
|
||||
|
||||
#include <deque>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
@ -41,6 +39,7 @@ class LBFGS : public LossClosureOptimizer {
|
||||
|
||||
template <class Archive>
|
||||
void serialize(Archive& ar) {
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
ar(CEREAL_NVP(d));
|
||||
ar(CEREAL_NVP(t));
|
||||
ar(CEREAL_NVP(H_diag));
|
||||
@ -48,10 +47,13 @@ class LBFGS : public LossClosureOptimizer {
|
||||
ar(CEREAL_NVP(prev_loss));
|
||||
ar(CEREAL_NVP(old_dirs));
|
||||
ar(CEREAL_NVP(old_stps));
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
}
|
||||
|
||||
private:
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
friend class cereal::access;
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
LBFGS() : options(0) {}
|
||||
|
||||
Tensor gather_flat_grad();
|
||||
@ -69,6 +71,5 @@ class LBFGS : public LossClosureOptimizer {
|
||||
int64_t func_evals{0};
|
||||
int64_t state_n_iter{0};
|
||||
};
|
||||
|
||||
} // namespace optim
|
||||
} // namespace torch
|
||||
|
@ -2,12 +2,10 @@
|
||||
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/optim/optimizer.h>
|
||||
#include <torch/serialization.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <cereal/access.hpp>
|
||||
#include <cereal/cereal.hpp>
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
@ -41,19 +39,29 @@ class RMSprop : public Optimizer {
|
||||
|
||||
template <class Archive>
|
||||
void serialize(Archive& ar) {
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
ar(CEREAL_NVP(square_average_buffers_));
|
||||
ar(CEREAL_NVP(momentum_buffers_));
|
||||
ar(CEREAL_NVP(grad_average_buffers_));
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
}
|
||||
|
||||
private:
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
friend class cereal::access;
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
RMSprop() : options(0) {}
|
||||
|
||||
std::vector<Tensor> square_average_buffers_;
|
||||
std::vector<Tensor> momentum_buffers_;
|
||||
std::vector<Tensor> grad_average_buffers_;
|
||||
};
|
||||
|
||||
} // namespace optim
|
||||
} // namespace torch
|
||||
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
CEREAL_REGISTER_TYPE(torch::optim::RMSprop);
|
||||
CEREAL_REGISTER_POLYMORPHIC_RELATION(
|
||||
torch::optim::Optimizer,
|
||||
torch::optim::RMSprop);
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
|
@ -3,13 +3,11 @@
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/nn/pimpl.h>
|
||||
#include <torch/optim/optimizer.h>
|
||||
#include <torch/serialization.h>
|
||||
#include <torch/tensor.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <cereal/access.hpp>
|
||||
#include <cereal/cereal.hpp>
|
||||
|
||||
#include <cstddef>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@ -37,13 +35,17 @@ class SGD : public Optimizer {
|
||||
|
||||
template <class Archive>
|
||||
void serialize(Archive& ar) {
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
ar(CEREAL_NVP(momentum_buffers_));
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
}
|
||||
|
||||
SGDOptions options;
|
||||
|
||||
private:
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
friend class cereal::access;
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
SGD() : options(0) {}
|
||||
|
||||
std::vector<Tensor> momentum_buffers_;
|
||||
@ -52,3 +54,10 @@ class SGD : public Optimizer {
|
||||
};
|
||||
} // namespace optim
|
||||
} // namespace torch
|
||||
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
CEREAL_REGISTER_TYPE(torch::optim::SGD);
|
||||
CEREAL_REGISTER_POLYMORPHIC_RELATION(
|
||||
torch::optim::Optimizer,
|
||||
torch::optim::SGD);
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
|
@ -2,45 +2,69 @@
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include <torch/optim.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
#include <cereal/access.hpp>
|
||||
#include <cereal/cereal.hpp>
|
||||
#include <cereal/types/polymorphic.hpp>
|
||||
|
||||
#include "cereal/archives/binary.hpp"
|
||||
#include "cereal/types/polymorphic.hpp"
|
||||
|
||||
#include "cereal/types/string.hpp"
|
||||
#include "cereal/types/unordered_map.hpp"
|
||||
#include "cereal/types/vector.hpp"
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
|
||||
namespace torch {
|
||||
|
||||
// Some convenience functions for saving and loading
|
||||
template <typename T>
|
||||
void save(std::ostream& stream, T const& obj) {
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
cereal::BinaryOutputArchive archive(stream);
|
||||
archive(*obj);
|
||||
#else
|
||||
AT_ERROR("PyTorch compiled without serialization support");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void load(std::istream& stream, T& obj) {
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
cereal::BinaryInputArchive archive(stream);
|
||||
archive(*obj);
|
||||
#else
|
||||
AT_ERROR("PyTorch compiled without serialization support");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void save(std::ostream& stream, T const* obj) {
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
cereal::BinaryOutputArchive archive(stream);
|
||||
archive(*obj);
|
||||
#else
|
||||
AT_ERROR("PyTorch compiled without serialization support");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void load(std::istream& stream, T* obj) {
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
cereal::BinaryInputArchive archive(stream);
|
||||
archive(*obj);
|
||||
#else
|
||||
AT_ERROR("PyTorch compiled without serialization support");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void save(std::string const& path, T const& obj) {
|
||||
std::ofstream os(path, std::ios::binary);
|
||||
torch::save(os, obj);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void load(std::string const& path, T& obj) {
|
||||
std::ifstream is(path, std::ios::binary);
|
||||
@ -74,8 +98,7 @@ inline int32_t scalarTypeId(torch::Dtype type) {
|
||||
case torch::Dtype::Undefined:
|
||||
return 8;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"Unknown scalar type: " + std::to_string(static_cast<int>(type)));
|
||||
AT_ERROR("Unknown scalar type: ", static_cast<int>(type));
|
||||
}
|
||||
}
|
||||
|
||||
@ -100,7 +123,7 @@ inline torch::Dtype scalarTypeFromId(int32_t id) {
|
||||
case 8:
|
||||
return torch::Dtype::Undefined;
|
||||
default:
|
||||
throw std::runtime_error("Unknown scalar type id: " + std::to_string(id));
|
||||
AT_ERROR("Unknown scalar type id: ", id);
|
||||
}
|
||||
}
|
||||
|
||||
@ -117,8 +140,7 @@ inline int32_t backendId(at::Backend backend) {
|
||||
case at::Backend::Undefined:
|
||||
return 4;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"Unknown backend: " + std::to_string(static_cast<int>(backend)));
|
||||
AT_ERROR("Unknown backend: ", static_cast<int>(backend));
|
||||
}
|
||||
}
|
||||
|
||||
@ -135,33 +157,15 @@ inline at::Backend backendFromId(int32_t id) {
|
||||
case 4:
|
||||
return at::Backend::Undefined;
|
||||
default:
|
||||
throw std::runtime_error("Unknown backend id: " + std::to_string(id));
|
||||
AT_ERROR("Unknown backend id: ", id);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace torch
|
||||
|
||||
// This is super ugly and I don't know how to simplify it
|
||||
CEREAL_REGISTER_TYPE(torch::optim::SGD);
|
||||
CEREAL_REGISTER_POLYMORPHIC_RELATION(
|
||||
torch::optim::Optimizer,
|
||||
torch::optim::SGD);
|
||||
CEREAL_REGISTER_TYPE(torch::optim::Adagrad);
|
||||
CEREAL_REGISTER_POLYMORPHIC_RELATION(
|
||||
torch::optim::Optimizer,
|
||||
torch::optim::Adagrad);
|
||||
CEREAL_REGISTER_TYPE(torch::optim::RMSprop);
|
||||
CEREAL_REGISTER_POLYMORPHIC_RELATION(
|
||||
torch::optim::Optimizer,
|
||||
torch::optim::RMSprop);
|
||||
CEREAL_REGISTER_TYPE(torch::optim::Adam);
|
||||
CEREAL_REGISTER_POLYMORPHIC_RELATION(
|
||||
torch::optim::Optimizer,
|
||||
torch::optim::Adam);
|
||||
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
namespace cereal {
|
||||
|
||||
namespace agimpl {
|
||||
|
||||
template <class Archive>
|
||||
@ -269,3 +273,4 @@ void load(Archive& archive, torch::Tensor& tensor) {
|
||||
}
|
||||
}
|
||||
} // namespace cereal
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
|
Reference in New Issue
Block a user