[deploy][1/n] Make deploy code conform to PyTorch style. (#65861)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65861

First in a series. This PR changes the code in deploy.h/cpp and
interpreter_impl.h/cpp to be camel case instead of snake case. Starting
with this as it has the most impact on downstream users.

Test Plan: Imported from OSS

Reviewed By: shannonzhu

Differential Revision: D31291183

Pulled By: suo

fbshipit-source-id: ba6f74042947c9a08fb9cb3ad7276d8dbb5b2934
This commit is contained in:
Michael Suo
2021-09-30 22:56:04 -07:00
committed by Facebook GitHub Bot
parent 765b6a90f3
commit 33c03cb61a
11 changed files with 235 additions and 239 deletions

View File

@ -24,8 +24,8 @@ TEST(IMethodTest, CallMethod) {
auto scriptMethod = scriptModel.get_method("forward");
torch::deploy::InterpreterManager manager(3);
torch::deploy::Package package = manager.load_package(path("SIMPLE", simple));
auto pyModel = package.load_pickle("model", "model.pkl");
torch::deploy::Package package = manager.loadPackage(path("SIMPLE", simple));
auto pyModel = package.loadPickle("model", "model.pkl");
torch::deploy::PythonMethodWrapper pyMethod(pyModel, "forward");
EXPECT_EQ(scriptMethod.name(), "forward");
@ -52,8 +52,8 @@ TEST(IMethodTest, GetArgumentNames) {
EXPECT_STREQ(scriptNames[0].c_str(), "input");
torch::deploy::InterpreterManager manager(3);
torch::deploy::Package package = manager.load_package(path("SIMPLE", simple));
auto pyModel = package.load_pickle("model", "model.pkl");
torch::deploy::Package package = manager.loadPackage(path("SIMPLE", simple));
auto pyModel = package.loadPickle("model", "model.pkl");
torch::deploy::PythonMethodWrapper pyMethod(pyModel, "forward");
auto& pyNames = pyMethod.getArgumentNames();

View File

@ -7,9 +7,9 @@
#include <unistd.h>
struct InterpreterSymbol {
const char* start_sym;
const char* end_sym;
bool custom_loader;
const char* startSym;
const char* endSym;
bool customLoader;
};
// these symbols are generated by cmake, using ld -r -b binary
@ -21,7 +21,7 @@ struct InterpreterSymbol {
namespace torch {
namespace deploy {
const std::initializer_list<InterpreterSymbol> interpreter_search_path = {
const std::initializer_list<InterpreterSymbol> kInterpreterSearchPath = {
{"_binary_libtorch_deployinterpreter_all_so_start",
"_binary_libtorch_deployinterpreter_all_so_end",
true},
@ -35,41 +35,41 @@ const std::initializer_list<InterpreterSymbol> interpreter_search_path = {
static bool writeDeployInterpreter(FILE* dst) {
TORCH_INTERNAL_ASSERT(dst);
const char* lib_start = nullptr;
const char* lib_end = nullptr;
bool custom_loader = false;
for (const auto& s : interpreter_search_path) {
lib_start = (const char*)dlsym(nullptr, s.start_sym);
if (lib_start) {
lib_end = (const char*)dlsym(nullptr, s.end_sym);
custom_loader = s.custom_loader;
const char* libStart = nullptr;
const char* libEnd = nullptr;
bool customLoader = false;
for (const auto& s : kInterpreterSearchPath) {
libStart = (const char*)dlsym(nullptr, s.startSym);
if (libStart) {
libEnd = (const char*)dlsym(nullptr, s.endSym);
customLoader = s.customLoader;
break;
}
}
TORCH_CHECK(
lib_start != nullptr && lib_end != nullptr,
libStart != nullptr && libEnd != nullptr,
"torch::deploy requires a build-time dependency on embedded_interpreter or embedded_interpreter_cuda, neither of which were found. torch::cuda::is_available()=",
torch::cuda::is_available());
size_t size = lib_end - lib_start;
size_t written = fwrite(lib_start, 1, size, dst);
size_t size = libEnd - libStart;
size_t written = fwrite(libStart, 1, size, dst);
TORCH_INTERNAL_ASSERT(size == written, "expected written == size");
return custom_loader;
return customLoader;
}
InterpreterManager::InterpreterManager(size_t n_interp) : resources_(n_interp) {
InterpreterManager::InterpreterManager(size_t nInterp) : resources_(nInterp) {
TORCH_DEPLOY_TRY
for (const auto i : c10::irange(n_interp)) {
for (const auto i : c10::irange(nInterp)) {
instances_.emplace_back(this);
auto I = instances_.back().acquire_session();
auto I = instances_.back().acquireSession();
// make torch.version.interp be the interpreter id
// can be used for balancing work across GPUs
I.global("torch", "version").attr("__setattr__")({"interp", int(i)});
// std::cerr << "Interpreter " << i << " initialized\n";
instances_.back().pImpl_->set_find_module(
instances_.back().pImpl_->setFindModule(
[this](const std::string& name) -> at::optional<std::string> {
auto it = registered_module_sources_.find(name);
if (it != registered_module_sources_.end()) {
auto it = registeredModuleSource_.find(name);
if (it != registeredModuleSource_.end()) {
return it->second;
} else {
return at::nullopt;
@ -81,7 +81,7 @@ InterpreterManager::InterpreterManager(size_t n_interp) : resources_(n_interp) {
// Since torch::deploy::Obj.toIValue cannot infer empty list, we hack it to
// return None for empty list.
// TODO(jwtan): Make the discovery of these modules easier.
register_module_source(
reigsterModuleSource(
"GetArgumentNamesModule",
"from inspect import signature\n"
"from typing import Callable, Optional\n"
@ -93,55 +93,54 @@ InterpreterManager::InterpreterManager(size_t n_interp) : resources_(n_interp) {
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
Package InterpreterManager::load_package(const std::string& uri) {
Package InterpreterManager::loadPackage(const std::string& uri) {
TORCH_DEPLOY_TRY
return Package(uri, this);
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
Package InterpreterManager::load_package(
Package InterpreterManager::loadPackage(
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> reader) {
TORCH_DEPLOY_TRY
return Package(reader, this);
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
Obj InterpreterSession::from_movable(const ReplicatedObj& obj) {
Obj InterpreterSession::fromMovable(const ReplicatedObj& obj) {
TORCH_DEPLOY_TRY
return impl_->unpickle_or_get(obj.pImpl_->object_id_, obj.pImpl_->data_);
return impl_->unpickleOrGet(obj.pImpl_->objectId_, obj.pImpl_->data_);
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
InterpreterSession ReplicatedObj::acquire_session(
const Interpreter* on_this_interpreter) const {
InterpreterSession ReplicatedObj::acquireSession(
const Interpreter* onThisInterpreter) const {
TORCH_DEPLOY_TRY
InterpreterSession I = on_this_interpreter
? on_this_interpreter->acquire_session()
: pImpl_->manager_->acquire_one();
I.self = I.from_movable(*this);
InterpreterSession I = onThisInterpreter ? onThisInterpreter->acquireSession()
: pImpl_->manager_->acquireOne();
I.self = I.fromMovable(*this);
return I;
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
// NOLINTNEXTLINE(bugprone-exception-escape)
InterpreterSession::~InterpreterSession() {
if (manager_ && notify_idx_ >= 0) {
manager_->resources_.free(notify_idx_);
if (manager_ && notifyIdx_ >= 0) {
manager_->resources_.free(notifyIdx_);
}
}
void ReplicatedObjImpl::unload(const Interpreter* on_this_interpreter) {
void ReplicatedObjImpl::unload(const Interpreter* onThisInterpreter) {
TORCH_DEPLOY_TRY
if (!on_this_interpreter) {
if (!onThisInterpreter) {
// NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
for (auto& interp : manager_->all_instances()) {
for (auto& interp : manager_->allInstances()) {
unload(&interp);
}
return;
}
InterpreterSession I = on_this_interpreter->acquire_session();
I.impl_->unload(object_id_);
InterpreterSession I = onThisInterpreter->acquireSession();
I.impl_->unload(objectId_);
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
@ -150,20 +149,20 @@ ReplicatedObjImpl::~ReplicatedObjImpl() {
unload(nullptr);
}
void ReplicatedObj::unload(const Interpreter* on_this_interpreter) {
void ReplicatedObj::unload(const Interpreter* onThisInterpreter) {
TORCH_DEPLOY_TRY
pImpl_->unload(on_this_interpreter);
pImpl_->unload(onThisInterpreter);
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
ReplicatedObj InterpreterSession::create_movable(Obj obj) {
ReplicatedObj InterpreterSession::createMovable(Obj obj) {
TORCH_DEPLOY_TRY
TORCH_CHECK(
manager_,
"Can only create a movable object when the session was created from an interpreter that is part of a InterpreterManager");
auto pickled = impl_->pickle(self, obj);
return ReplicatedObj(std::make_shared<ReplicatedObjImpl>(
manager_->next_object_id_++, std::move(pickled), manager_));
manager_->nextObjectId_++, std::move(pickled), manager_));
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
@ -187,24 +186,24 @@ static dlopen_t find_real_dlopen() {
Interpreter::Interpreter(InterpreterManager* manager)
: handle_(nullptr), manager_(manager) {
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
char library_name[] = "/tmp/torch_deployXXXXXX";
int fd = mkstemp(library_name);
char libraryName[] = "/tmp/torch_deployXXXXXX";
int fd = mkstemp(libraryName);
TORCH_INTERNAL_ASSERT(fd != -1, "failed to create temporary file");
library_name_ = library_name;
libraryName_ = libraryName;
FILE* dst = fdopen(fd, "wb");
custom_loader_ = writeDeployInterpreter(dst);
customLoader_ = writeDeployInterpreter(dst);
fclose(dst);
int flags = RTLD_LOCAL | RTLD_LAZY;
if (custom_loader_) {
if (customLoader_) {
flags |= RTLD_DEEPBIND;
}
#ifdef FBCODE_CAFFE2
static dlopen_t dlopen_ = find_real_dlopen();
handle_ = dlopen_(library_name, flags);
handle_ = dlopen_(libraryName, flags);
#else
handle_ = dlopen(library_name, flags);
handle_ = dlopen(libraryName, flags);
#endif
if (!handle_) {
@ -214,30 +213,29 @@ Interpreter::Interpreter(InterpreterManager* manager)
// note: if you want better debugging symbols for things inside
// new_intepreter_impl, comment out this line so that the so lasts long enough
// for the debugger to see it.
unlink(library_name_.c_str());
unlink(libraryName_.c_str());
if (custom_loader_) {
if (customLoader_) {
// when using the custom loader we need to link python symbols against
// the right version of the symbols for the interpreter which an be looked
// up from the handle_ to this shared library. here we register the handle
// with the code that does custom loading of python extensions.
auto deploy_set_self_ptr =
(void (*)(void*))dlsym(handle_, "deploy_set_self");
AT_ASSERT(deploy_set_self_ptr);
deploy_set_self_ptr(handle_);
auto deploySetSelfPtr = (void (*)(void*))dlsym(handle_, "deploy_set_self");
AT_ASSERT(deploySetSelfPtr);
deploySetSelfPtr(handle_);
}
void* new_interpreter_impl = dlsym(handle_, "new_interpreter_impl");
AT_ASSERT(new_interpreter_impl);
void* newInterpreterImpl = dlsym(handle_, "newInterpreterImpl");
AT_ASSERT(newInterpreterImpl);
pImpl_ = std::unique_ptr<InterpreterImpl>(
((InterpreterImpl * (*)()) new_interpreter_impl)());
((InterpreterImpl * (*)()) newInterpreterImpl)());
}
Interpreter::~Interpreter() {
if (handle_) {
// ensure python uninitialization runs before we dlclose the library
pImpl_.reset();
if (custom_loader_) {
if (customLoader_) {
auto deploy_flush_python_libs =
(void (*)())dlsym(handle_, "deploy_flush_python_libs");
deploy_flush_python_libs();
@ -250,7 +248,7 @@ int LoadBalancer::acquire() {
TORCH_DEPLOY_TRY
thread_local int last = 0;
size_t minusers = SIZE_MAX;
int min_idx = 0;
int minIdx = 0;
for (size_t i = 0; i < n_; ++i, ++last) {
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
if (last >= n_) {
@ -273,14 +271,14 @@ int LoadBalancer::acquire() {
if (prev < minusers) {
minusers = prev;
min_idx = last;
minIdx = last;
}
}
// we failed to find a completely free interpreter. heuristically use the
// one with the least number of user (note that this may have changed since
// then, so this is only a heuristic).
__atomic_fetch_add(&uses_[8 * min_idx], 1ULL, __ATOMIC_SEQ_CST);
return min_idx;
__atomic_fetch_add(&uses_[8 * minIdx], 1ULL, __ATOMIC_SEQ_CST);
return minIdx;
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
@ -293,8 +291,8 @@ void LoadBalancer::free(int where) {
void PythonMethodWrapper::setArgumentNames(
std::vector<std::string>& argumentNamesOut) const {
auto session = model_.acquire_session();
auto method = session.self.attr(method_name_.c_str());
auto session = model_.acquireSession();
auto method = session.self.attr(methodName_.c_str());
auto iArgumentNames =
session.global("GetArgumentNamesModule", "getArgumentNames")({method})
.toIValue();

View File

@ -32,13 +32,13 @@ struct TORCH_API InterpreterSession {
return impl_->global(module, name);
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
Obj from_ivalue(at::IValue ivalue) {
Obj fromIValue(at::IValue ivalue) {
TORCH_DEPLOY_TRY
return impl_->from_ivalue(std::move(ivalue));
return impl_->fromIValue(std::move(ivalue));
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
ReplicatedObj create_movable(Obj obj);
Obj from_movable(const ReplicatedObj& obj);
ReplicatedObj createMovable(Obj obj);
Obj fromMovable(const ReplicatedObj& obj);
private:
friend struct ReplicatedObj;
@ -47,27 +47,27 @@ struct TORCH_API InterpreterSession {
friend struct ReplicatedObjImpl;
std::unique_ptr<InterpreterSessionImpl> impl_;
InterpreterManager* manager_; // if created from one
int64_t notify_idx_ = -1;
int64_t notifyIdx_ = -1;
};
class TORCH_API Interpreter {
private:
std::string library_name_;
std::string libraryName_;
void* handle_;
std::unique_ptr<InterpreterImpl> pImpl_;
bool custom_loader_ = false;
bool customLoader_ = false;
InterpreterManager* manager_; // optional if managed by one
public:
Interpreter(InterpreterManager* manager);
InterpreterSession acquire_session() const {
InterpreterSession acquireSession() const {
TORCH_DEPLOY_TRY
return InterpreterSession(pImpl_->acquire_session(), manager_);
return InterpreterSession(pImpl_->acquireSession(), manager_);
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
~Interpreter();
Interpreter(Interpreter&& rhs) noexcept
: library_name_(std::move(rhs.library_name_)),
: libraryName_(std::move(rhs.libraryName_)),
handle_(rhs.handle_),
pImpl_(std::move(rhs.pImpl_)),
manager_(rhs.manager_) {
@ -108,22 +108,22 @@ struct TORCH_API LoadBalancer {
};
struct TORCH_API InterpreterManager {
explicit InterpreterManager(size_t n_interp = 2);
explicit InterpreterManager(size_t nInterp = 2);
// get a free model, guarenteed that no other user of acquire_one has the same
// get a free model, guarenteed that no other user of acquireOne has the same
// model. It _is_ possible that other users will be using the interpreter.
InterpreterSession acquire_one() {
InterpreterSession acquireOne() {
TORCH_DEPLOY_TRY
int where = resources_.acquire();
InterpreterSession I = instances_[where].acquire_session();
I.notify_idx_ = where;
InterpreterSession I = instances_[where].acquireSession();
I.notifyIdx_ = where;
return I;
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
// use to make sure something gets run on all interpreters, such as loading or
// unloading a model eagerly
at::ArrayRef<Interpreter> all_instances() {
at::ArrayRef<Interpreter> allInstances() {
TORCH_DEPLOY_TRY
return instances_;
TORCH_DEPLOY_SAFE_CATCH_RETHROW
@ -134,8 +134,8 @@ struct TORCH_API InterpreterManager {
resources_.setResourceLimit(N);
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
Package load_package(const std::string& uri);
Package load_package(
Package loadPackage(const std::string& uri);
Package loadPackage(
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> reader);
// convience function for loading some python source code as a module across
@ -143,8 +143,8 @@ struct TORCH_API InterpreterManager {
// execute python code, or for small amounts of application logic that are
// best written in Python. For larger amounts of code, prefer creating and
// loading them as packages.
void register_module_source(std::string name, std::string src) {
registered_module_sources_[std::move(name)] = std::move(src);
void reigsterModuleSource(std::string name, std::string src) {
registeredModuleSource_[std::move(name)] = std::move(src);
}
InterpreterManager(const InterpreterManager&) = delete;
@ -154,10 +154,10 @@ struct TORCH_API InterpreterManager {
private:
friend struct Package;
friend struct InterpreterSession;
size_t next_object_id_ = 0;
size_t nextObjectId_ = 0;
std::vector<Interpreter> instances_;
LoadBalancer resources_;
std::unordered_map<std::string, std::string> registered_module_sources_;
std::unordered_map<std::string, std::string> registeredModuleSource_;
};
struct TORCH_API ReplicatedObjImpl {
@ -166,51 +166,51 @@ struct TORCH_API ReplicatedObjImpl {
// NOLINTNEXTLINE(modernize-pass-by-value)
PickledObject data,
InterpreterManager* manager)
: object_id_(object_id), data_(data), manager_(manager) {}
: objectId_(object_id), data_(data), manager_(manager) {}
// NOLINTNEXTLINE(bugprone-exception-escape)
~ReplicatedObjImpl();
void unload(const Interpreter* on_this_interpreter);
int64_t object_id_;
void unload(const Interpreter* onThisInterpreter);
int64_t objectId_;
PickledObject data_;
InterpreterManager* manager_;
};
struct TORCH_API ReplicatedObj {
ReplicatedObj() : pImpl_(nullptr) {}
InterpreterSession acquire_session(
const Interpreter* on_this_interpreter = nullptr) const;
InterpreterSession acquireSession(
const Interpreter* onThisInterpreter = nullptr) const;
at::IValue operator()(at::ArrayRef<at::IValue> args) const {
TORCH_DEPLOY_TRY
auto I = acquire_session();
auto I = acquireSession();
return I.self(args).toIValue();
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
[[nodiscard]] at::IValue call_kwargs(
[[nodiscard]] at::IValue callKwargs(
std::vector<at::IValue> args,
std::unordered_map<std::string, c10::IValue> kwargs) const {
TORCH_DEPLOY_TRY
auto I = acquire_session();
return I.self.call_kwargs(std::move(args), std::move(kwargs)).toIValue();
auto I = acquireSession();
return I.self.callKwargs(std::move(args), std::move(kwargs)).toIValue();
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
[[nodiscard]] at::IValue call_kwargs(
[[nodiscard]] at::IValue callKwargs(
std::unordered_map<std::string, c10::IValue> kwargs) const {
TORCH_DEPLOY_TRY
auto I = acquire_session();
return I.self.call_kwargs(std::move(kwargs)).toIValue();
auto I = acquireSession();
return I.self.callKwargs(std::move(kwargs)).toIValue();
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
[[nodiscard]] bool hasattr(const char* name) const {
TORCH_DEPLOY_TRY
auto I = acquire_session();
auto I = acquireSession();
return I.self.hasattr(name);
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
void unload(const Interpreter* on_this_interpreter = nullptr);
void unload(const Interpreter* onThisInterpreter = nullptr);
private:
ReplicatedObj(std::shared_ptr<ReplicatedObjImpl> pImpl)
@ -229,11 +229,11 @@ class PythonMethodWrapper : public torch::IMethod {
// TODO(whc) make bound method pickleable, then directly construct from that
PythonMethodWrapper(
torch::deploy::ReplicatedObj model,
std::string method_name)
: model_(std::move(model)), method_name_(std::move(method_name)) {}
std::string methodName)
: model_(std::move(model)), methodName_(std::move(methodName)) {}
const std::string& name() const override {
return method_name_;
return methodName_;
}
c10::IValue operator()(
@ -241,35 +241,33 @@ class PythonMethodWrapper : public torch::IMethod {
const IValueMap& kwargs = IValueMap()) const override {
// TODO(whc) ideally, pickle the method itself as replicatedobj, to skip
// this lookup each time
auto model_session = model_.acquire_session();
auto method = model_session.self.attr(method_name_.c_str());
return method.call_kwargs(args, kwargs).toIValue();
auto modelSession = model_.acquireSession();
auto method = modelSession.self.attr(methodName_.c_str());
return method.callKwargs(args, kwargs).toIValue();
}
private:
void setArgumentNames(std::vector<std::string>&) const override;
torch::deploy::ReplicatedObj model_;
std::string method_name_;
std::string methodName_;
};
struct TORCH_API Package {
// shorthand for getting the object as a pickle resource in the package
ReplicatedObj load_pickle(
const std::string& module,
const std::string& file) {
ReplicatedObj loadPickle(const std::string& module, const std::string& file) {
TORCH_DEPLOY_TRY
auto I = acquire_session();
auto I = acquireSession();
auto loaded = I.self.attr("load_pickle")({module, file});
return I.create_movable(loaded);
return I.createMovable(loaded);
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
InterpreterSession acquire_session() {
InterpreterSession acquireSession() {
TORCH_DEPLOY_TRY
auto I = manager_->acquire_one();
I.self = I.impl_->create_or_get_package_importer_from_container_file(
container_file_);
auto I = manager_->acquireOne();
I.self =
I.impl_->createOrGetPackageImporterFromContainerFile(containerFile_);
return I;
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
@ -280,19 +278,19 @@ struct TORCH_API Package {
InterpreterManager*
pm) // or really any of the constructors to our zip file format
: manager_(pm),
container_file_(
containerFile_(
std::make_shared<caffe2::serialize::PyTorchStreamReader>(uri)) {}
Package(
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> reader,
InterpreterManager*
pm) // or really any of the constructors to our zip file format
: manager_(pm),
container_file_(
containerFile_(
std::make_shared<caffe2::serialize::PyTorchStreamReader>(reader)) {}
friend struct ReplicatedObj;
friend struct InterpreterManager;
InterpreterManager* manager_;
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> container_file_;
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> containerFile_;
};
} // namespace deploy

View File

@ -55,12 +55,12 @@ const int min_items_to_complete = 1;
struct RunPython {
static torch::deploy::ReplicatedObj load_and_wrap(
torch::deploy::Package& package) {
auto I = package.acquire_session();
auto I = package.acquireSession();
auto obj = I.self.attr("load_pickle")({"model", "model.pkl"});
if (cuda) {
obj = I.global("gpu_wrapper", "GPUWrapper")({obj});
}
return I.create_movable(obj);
return I.createMovable(obj);
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
RunPython(
@ -69,7 +69,7 @@ struct RunPython {
const torch::deploy::Interpreter* interps)
: obj_(load_and_wrap(package)), eg_(std::move(eg)), interps_(interps) {}
void operator()(int i) {
auto I = obj_.acquire_session();
auto I = obj_.acquireSession();
if (cuda) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<at::IValue> eg2 = {i};
@ -189,12 +189,12 @@ struct Benchmark {
pthread_barrier_init(&first_run_, nullptr, n_threads_ + 1);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
torch::deploy::Package package = manager_.load_package(file_to_run_);
torch::deploy::Package package = manager_.loadPackage(file_to_run_);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<at::IValue> eg;
{
auto I = package.acquire_session();
auto I = package.acquireSession();
eg = I.global("builtins", "tuple")(
I.self.attr("load_pickle")({"model", "example.pkl"}))
@ -208,7 +208,7 @@ struct Benchmark {
run_one_work_item = RunJIT(file_to_run_, std::move(eg));
} else {
run_one_work_item =
RunPython(package, std::move(eg), manager_.all_instances().data());
RunPython(package, std::move(eg), manager_.allInstances().data());
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
@ -305,8 +305,8 @@ int main(int argc, char* argv[]) {
torch::deploy::InterpreterManager manager(max_thread);
// make sure gpu_wrapper.py is in the import path
for (auto& interp : manager.all_instances()) {
auto I = interp.acquire_session();
for (auto& interp : manager.allInstances()) {
auto I = interp.acquireSession();
I.global("sys", "path").attr("append")({"torch/csrc/deploy/example"});
}

View File

@ -57,8 +57,8 @@ if __name__ == '__main__':
from torch.package import PackageImporter
i = PackageImporter(sys.argv[1])
torch.version.interp = 0
model = i.load_pickle('model', 'model.pkl')
eg = i.load_pickle('model', 'example.pkl')
model = i.loadPickle('model', 'model.pkl')
eg = i.loadPickle('model', 'example.pkl')
r = model(*eg)
gpu_model = GPUWrapper(model)

View File

@ -1,4 +1,4 @@
INTERPRETER_0.1 {
global: new_interpreter_impl;
global: newInterpreterImpl;
local: *;
};

View File

@ -371,9 +371,9 @@ struct __attribute__((visibility("hidden"))) ConcreteInterpreterImpl
// we cache these so we don't have to repeat the conversion of strings into
// Python and hash table lookups to get to these object
save_storage = global_impl("torch._deploy", "_save_storages");
load_storage = global_impl("torch._deploy", "_load_storages");
get_package = global_impl("torch._deploy", "_get_package");
saveStorage = global_impl("torch._deploy", "_save_storages");
loadStorage = global_impl("torch._deploy", "_load_storages");
getPackage = global_impl("torch._deploy", "_get_package");
objects = global_impl("torch._deploy", "_deploy_objects");
// Release the GIL that PyInitialize acquires
PyEval_SaveThread();
@ -385,16 +385,16 @@ struct __attribute__((visibility("hidden"))) ConcreteInterpreterImpl
// note: this leads the referneces to these objects, but we are about to
// deinit python anyway so it doesn't matter
objects.release();
save_storage.release();
load_storage.release();
get_package.release();
saveStorage.release();
loadStorage.release();
getPackage.release();
if (Py_FinalizeEx() != 0) {
exit(1); // can't use TORCH_INTERNAL_ASSERT because we are in a
// non-throwing destructor.
}
}
void set_find_module(
void setFindModule(
std::function<at::optional<std::string>(const std::string&)> find_module)
override {
std::function<py::object(const std::string&)> wrapped_find_module =
@ -410,10 +410,10 @@ struct __attribute__((visibility("hidden"))) ConcreteInterpreterImpl
.attr("append")(register_module_importer);
}
torch::deploy::InterpreterSessionImpl* acquire_session() override;
py::object save_storage;
py::object load_storage;
py::object get_package;
torch::deploy::InterpreterSessionImpl* acquireSession() override;
py::object saveStorage;
py::object loadStorage;
py::object getPackage;
py::dict objects;
std::mutex init_lock_;
};
@ -426,18 +426,18 @@ struct __attribute__((visibility("hidden"))) ConcreteInterpreterSessionImpl
return wrap(global_impl(module, name));
}
Obj from_ivalue(IValue value) override {
Obj fromIValue(IValue value) override {
return wrap(torch::jit::toPyObject(value));
}
Obj create_or_get_package_importer_from_container_file(
Obj createOrGetPackageImporterFromContainerFile(
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
container_file_) override {
containerFile_) override {
InitLockAcquire guard(interp_->init_lock_);
return wrap(interp_->get_package(container_file_));
return wrap(interp_->getPackage(containerFile_));
}
PickledObject pickle(Obj container, Obj obj) override {
py::tuple result = interp_->save_storage(unwrap(container), unwrap(obj));
py::tuple result = interp_->saveStorage(unwrap(container), unwrap(obj));
py::bytes bytes = py::cast<py::bytes>(result[0]);
py::list storages = py::cast<py::list>(result[1]);
py::list dtypes = py::cast<py::list>(result[2]);
@ -458,7 +458,7 @@ struct __attribute__((visibility("hidden"))) ConcreteInterpreterSessionImpl
std::move(dtypes_c),
std::move(container_file)};
}
Obj unpickle_or_get(int64_t id, const PickledObject& obj) override {
Obj unpickleOrGet(int64_t id, const PickledObject& obj) override {
py::dict objects = interp_->objects;
py::object id_p = py::cast(id);
if (objects.contains(id_p)) {
@ -479,8 +479,8 @@ struct __attribute__((visibility("hidden"))) ConcreteInterpreterSessionImpl
obj.storages_[i], scalarTypeToTypeMeta(obj.types_[i])));
storages[i] = std::move(new_storage);
}
py::object result = interp_->load_storage(
id, obj.container_file_, py::bytes(obj.data_), storages);
py::object result = interp_->loadStorage(
id, obj.containerFile_, py::bytes(obj.data_), storages);
return wrap(result);
}
void unload(int64_t id) override {
@ -511,7 +511,7 @@ struct __attribute__((visibility("hidden"))) ConcreteInterpreterSessionImpl
return wrap(call(unwrap(obj), m_args));
}
Obj call_kwargs(
Obj callKwargs(
Obj obj,
std::vector<at::IValue> args,
std::unordered_map<std::string, c10::IValue> kwargs) override {
@ -528,10 +528,10 @@ struct __attribute__((visibility("hidden"))) ConcreteInterpreterSessionImpl
return wrap(call(unwrap(obj), py_args, py_kwargs));
}
Obj call_kwargs(Obj obj, std::unordered_map<std::string, c10::IValue> kwargs)
Obj callKwargs(Obj obj, std::unordered_map<std::string, c10::IValue> kwargs)
override {
std::vector<at::IValue> args;
return call_kwargs(obj, args, kwargs);
return callKwargs(obj, args, kwargs);
}
bool hasattr(Obj obj, const char* attr) override {
@ -571,12 +571,12 @@ struct __attribute__((visibility("hidden"))) ConcreteInterpreterSessionImpl
};
torch::deploy::InterpreterSessionImpl* ConcreteInterpreterImpl::
acquire_session() {
acquireSession() {
return new ConcreteInterpreterSessionImpl(this);
}
extern "C" __attribute__((visibility("default")))
torch::deploy::InterpreterImpl*
new_interpreter_impl(void) {
newInterpreterImpl(void) {
return new ConcreteInterpreterImpl();
}

View File

@ -56,7 +56,7 @@ struct PickledObject {
// types for the storages, required to
// reconstruct correct Python storages
std::vector<at::ScalarType> types_;
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> container_file_;
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> containerFile_;
};
// this is a wrapper class that refers to a PyObject* instance in a particular
@ -74,10 +74,10 @@ struct Obj {
at::IValue toIValue() const;
Obj operator()(at::ArrayRef<Obj> args);
Obj operator()(at::ArrayRef<at::IValue> args);
Obj call_kwargs(
Obj callKwargs(
std::vector<at::IValue> args,
std::unordered_map<std::string, c10::IValue> kwargs);
Obj call_kwargs(std::unordered_map<std::string, c10::IValue> kwargs);
Obj callKwargs(std::unordered_map<std::string, c10::IValue> kwargs);
bool hasattr(const char* attr);
Obj attr(const char* attr);
@ -97,23 +97,23 @@ struct InterpreterSessionImpl {
private:
virtual Obj global(const char* module, const char* name) = 0;
virtual Obj from_ivalue(at::IValue value) = 0;
virtual Obj create_or_get_package_importer_from_container_file(
virtual Obj fromIValue(at::IValue value) = 0;
virtual Obj createOrGetPackageImporterFromContainerFile(
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
container_file_) = 0;
containerFile_) = 0;
virtual PickledObject pickle(Obj container, Obj obj) = 0;
virtual Obj unpickle_or_get(int64_t id, const PickledObject& obj) = 0;
virtual Obj unpickleOrGet(int64_t id, const PickledObject& obj) = 0;
virtual void unload(int64_t id) = 0;
virtual at::IValue toIValue(Obj obj) const = 0;
virtual Obj call(Obj obj, at::ArrayRef<Obj> args) = 0;
virtual Obj call(Obj obj, at::ArrayRef<at::IValue> args) = 0;
virtual Obj call_kwargs(
virtual Obj callKwargs(
Obj obj,
std::vector<at::IValue> args,
std::unordered_map<std::string, c10::IValue> kwargs) = 0;
virtual Obj call_kwargs(
virtual Obj callKwargs(
Obj obj,
std::unordered_map<std::string, c10::IValue> kwargs) = 0;
virtual Obj attr(Obj obj, const char* attr) = 0;
@ -126,8 +126,8 @@ struct InterpreterSessionImpl {
};
struct InterpreterImpl {
virtual InterpreterSessionImpl* acquire_session() = 0;
virtual void set_find_module(
virtual InterpreterSessionImpl* acquireSession() = 0;
virtual void setFindModule(
std::function<at::optional<std::string>(const std::string&)>
find_module) = 0;
virtual ~InterpreterImpl() = default; // this will uninitialize python
@ -154,17 +154,17 @@ inline Obj Obj::operator()(at::ArrayRef<at::IValue> args) {
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
inline Obj Obj::call_kwargs(
inline Obj Obj::callKwargs(
std::vector<at::IValue> args,
std::unordered_map<std::string, c10::IValue> kwargs) {
TORCH_DEPLOY_TRY
return interaction_->call_kwargs(*this, std::move(args), std::move(kwargs));
return interaction_->callKwargs(*this, std::move(args), std::move(kwargs));
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
inline Obj Obj::call_kwargs(
inline Obj Obj::callKwargs(
std::unordered_map<std::string, c10::IValue> kwargs) {
TORCH_DEPLOY_TRY
return interaction_->call_kwargs(*this, std::move(kwargs));
return interaction_->callKwargs(*this, std::move(kwargs));
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
inline bool Obj::hasattr(const char* attr) {

View File

@ -21,11 +21,11 @@ int main(int argc, char* argv[]) {
void compare_torchpy_jit(const char* model_filename, const char* jit_filename) {
// Test
torch::deploy::InterpreterManager m(1);
torch::deploy::Package p = m.load_package(model_filename);
auto model = p.load_pickle("model", "model.pkl");
torch::deploy::Package p = m.loadPackage(model_filename);
auto model = p.loadPickle("model", "model.pkl");
at::IValue eg;
{
auto I = p.acquire_session();
auto I = p.acquireSession();
eg = I.self.attr("load_pickle")({"model", "example.pkl"}).toIValue();
}
@ -49,9 +49,9 @@ const char* path(const char* envname, const char* path) {
TEST(TorchpyTest, LoadLibrary) {
torch::deploy::InterpreterManager m(1);
torch::deploy::Package p = m.load_package(
torch::deploy::Package p = m.loadPackage(
path("LOAD_LIBRARY", "torch/csrc/deploy/example/generated/load_library"));
auto model = p.load_pickle("fn", "fn.pkl");
auto model = p.loadPickle("fn", "fn.pkl");
model({});
}
@ -62,14 +62,14 @@ TEST(TorchpyTest, InitTwice) {
TEST(TorchpyTest, DifferentInterps) {
torch::deploy::InterpreterManager m(2);
m.register_module_source("check_none", "check = id(None)\n");
m.reigsterModuleSource("check_none", "check = id(None)\n");
int64_t id0 = 0, id1 = 0;
{
auto I = m.all_instances()[0].acquire_session();
auto I = m.allInstances()[0].acquireSession();
id0 = I.global("check_none", "check").toIValue().toInt();
}
{
auto I = m.all_instances()[1].acquire_session();
auto I = m.allInstances()[1].acquireSession();
id1 = I.global("check_none", "check").toIValue().toInt();
}
ASSERT_NE(id0, id1);
@ -89,18 +89,18 @@ TEST(TorchpyTest, Movable) {
torch::deploy::InterpreterManager m(1);
torch::deploy::ReplicatedObj obj;
{
auto I = m.acquire_one();
auto I = m.acquireOne();
auto model =
I.global("torch.nn", "Module")(std::vector<torch::deploy::Obj>());
obj = I.create_movable(model);
obj = I.createMovable(model);
}
obj.acquire_session();
obj.acquireSession();
}
TEST(TorchpyTest, MultiSerialSimpleModel) {
torch::deploy::InterpreterManager manager(3);
torch::deploy::Package p = manager.load_package(path("SIMPLE", simple));
auto model = p.load_pickle("model", "model.pkl");
torch::deploy::Package p = manager.loadPackage(path("SIMPLE", simple));
auto model = p.loadPickle("model", "model.pkl");
auto ref_model = torch::jit::load(path("SIMPLE_JIT", simple_jit));
auto input = torch::ones({10, 20});
@ -124,13 +124,13 @@ TEST(TorchpyTest, MultiSerialSimpleModel) {
std::vector<c10::IValue> args;
args.emplace_back(input);
std::unordered_map<std::string, c10::IValue> kwargs_empty;
auto jit_output_args = model.call_kwargs(args, kwargs_empty).toTensor();
auto jit_output_args = model.callKwargs(args, kwargs_empty).toTensor();
ASSERT_TRUE(ref_output.equal(jit_output_args));
// and with kwargs only
std::unordered_map<std::string, c10::IValue> kwargs;
kwargs["input"] = input;
auto jit_output_kwargs = model.call_kwargs(kwargs).toTensor();
auto jit_output_kwargs = model.callKwargs(kwargs).toTensor();
ASSERT_TRUE(ref_output.equal(jit_output_kwargs));
// test hasattr
@ -142,8 +142,8 @@ TEST(TorchpyTest, ThreadedSimpleModel) {
size_t nthreads = 3;
torch::deploy::InterpreterManager manager(nthreads);
torch::deploy::Package p = manager.load_package(path("SIMPLE", simple));
auto model = p.load_pickle("model", "model.pkl");
torch::deploy::Package p = manager.loadPackage(path("SIMPLE", simple));
auto model = p.loadPickle("model", "model.pkl");
auto ref_model = torch::jit::load(path("SIMPLE_JIT", simple_jit));
auto input = torch::ones({10, 20});
@ -179,13 +179,13 @@ TEST(TorchpyTest, ThrowsSafely) {
// See explanation in deploy.h
torch::deploy::InterpreterManager manager(3);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW(manager.load_package("some garbage path"), c10::Error);
EXPECT_THROW(manager.loadPackage("some garbage path"), c10::Error);
torch::deploy::Package p = manager.load_package(path("SIMPLE", simple));
torch::deploy::Package p = manager.loadPackage(path("SIMPLE", simple));
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW(p.load_pickle("some other", "garbage path"), c10::Error);
EXPECT_THROW(p.loadPickle("some other", "garbage path"), c10::Error);
auto model = p.load_pickle("model", "model.pkl");
auto model = p.loadPickle("model", "model.pkl");
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW(model(at::IValue("unexpected input")), c10::Error);
}
@ -193,30 +193,30 @@ TEST(TorchpyTest, ThrowsSafely) {
TEST(TorchpyTest, AcquireMultipleSessionsInTheSamePackage) {
torch::deploy::InterpreterManager m(1);
torch::deploy::Package p = m.load_package(path("SIMPLE", simple));
auto I = p.acquire_session();
torch::deploy::Package p = m.loadPackage(path("SIMPLE", simple));
auto I = p.acquireSession();
auto I1 = p.acquire_session();
auto I1 = p.acquireSession();
}
TEST(TorchpyTest, AcquireMultipleSessionsInDifferentPackages) {
torch::deploy::InterpreterManager m(1);
torch::deploy::Package p = m.load_package(path("SIMPLE", simple));
auto I = p.acquire_session();
torch::deploy::Package p = m.loadPackage(path("SIMPLE", simple));
auto I = p.acquireSession();
torch::deploy::Package p1 = m.load_package(
torch::deploy::Package p1 = m.loadPackage(
path("RESNET", "torch/csrc/deploy/example/generated/resnet"));
auto I1 = p1.acquire_session();
auto I1 = p1.acquireSession();
}
TEST(TorchpyTest, TensorSharingNotAllowed) {
size_t nthreads = 2;
torch::deploy::InterpreterManager m(nthreads);
// generate a tensor from one interpreter
auto I0 = m.all_instances()[0].acquire_session();
auto I1 = m.all_instances()[1].acquire_session();
auto obj = I0.global("torch", "empty")({I0.from_ivalue(2)});
auto I0 = m.allInstances()[0].acquireSession();
auto I1 = m.allInstances()[1].acquireSession();
auto obj = I0.global("torch", "empty")({I0.fromIValue(2)});
auto t = obj.toIValue().toTensor();
// try to feed it to the other interpreter, should error
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
@ -237,9 +237,9 @@ TEST(TorchpyTest, TaggingRace) {
std::atomic<int64_t> failed(0);
at::parallel_for(0, nthreads, 1, [&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
auto I = m.all_instances()[i].acquire_session();
auto I = m.allInstances()[i].acquireSession();
try {
I.from_ivalue(t);
I.fromIValue(t);
success++;
} catch (const c10::Error& e) {
failed++;
@ -255,20 +255,20 @@ TEST(TorchpyTest, DisarmHook) {
at::Tensor t = torch::empty(2);
{
torch::deploy::InterpreterManager m(1);
auto I = m.acquire_one();
I.from_ivalue(t);
auto I = m.acquireOne();
I.fromIValue(t);
} // unload the old interpreter
torch::deploy::InterpreterManager m(1);
auto I = m.acquire_one();
auto I = m.acquireOne();
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
ASSERT_THROW(I.from_ivalue(t), c10::Error); // NOT a segfault
ASSERT_THROW(I.fromIValue(t), c10::Error); // NOT a segfault
}
TEST(TorchpyTest, RegisterModule) {
torch::deploy::InterpreterManager m(2);
m.register_module_source("foomodule", "def add1(x): return x + 1\n");
for (const auto& interp : m.all_instances()) {
auto I = interp.acquire_session();
m.reigsterModuleSource("foomodule", "def add1(x): return x + 1\n");
for (const auto& interp : m.allInstances()) {
auto I = interp.acquireSession();
AT_ASSERT(3 == I.global("foomodule", "add1")({2}).toIValue().toInt());
}
}
@ -276,9 +276,9 @@ TEST(TorchpyTest, RegisterModule) {
TEST(TorchpyTest, FxModule) {
size_t nthreads = 3;
torch::deploy::InterpreterManager manager(nthreads);
torch::deploy::Package p = manager.load_package(path(
torch::deploy::Package p = manager.loadPackage(path(
"SIMPLE_LEAF_FX", "torch/csrc/deploy/example/generated/simple_leaf_fx"));
auto model = p.load_pickle("model", "model.pkl");
auto model = p.loadPickle("model", "model.pkl");
std::vector<at::Tensor> outputs;
auto input = torch::ones({5, 10});
@ -304,8 +304,8 @@ thread_local int in_another_module = 5;
TEST(TorchpyTest, SharedLibraryLoad) {
torch::deploy::InterpreterManager manager(2);
auto no_args = at::ArrayRef<torch::deploy::Obj>();
for (auto& interp : manager.all_instances()) {
auto I = interp.acquire_session();
for (auto& interp : manager.allInstances()) {
auto I = interp.acquireSession();
const char* test_lib_path = getenv("LIBTEST_DEPLOY_LIB");
if (!test_lib_path) {
@ -329,8 +329,8 @@ TEST(TorchpyTest, SharedLibraryLoad) {
// I.global("numpy", "array"); // force numpy to load here so it is loaded
// // twice before we run the tests
}
for (auto& interp : manager.all_instances()) {
auto I = interp.acquire_session();
for (auto& interp : manager.allInstances()) {
auto I = interp.acquireSession();
// auto i =
// I.global("test_deploy_python", "numpy_test")({1}).toIValue().toInt();
I.global("libtest_deploy_lib", "raise_and_catch_exception")({true});
@ -372,16 +372,16 @@ TEST(TorchpyTest, UsesDistributed) {
"USES_DISTRIBUTED",
"torch/csrc/deploy/example/generated/uses_distributed");
torch::deploy::InterpreterManager m(1);
torch::deploy::Package p = m.load_package(model_filename);
torch::deploy::Package p = m.loadPackage(model_filename);
{
auto I = p.acquire_session();
auto I = p.acquireSession();
I.self.attr("import_module")({"uses_distributed"});
}
}
TEST(TorchpyTest, Autograd) {
torch::deploy::InterpreterManager m(2);
m.register_module_source("autograd_test", R"PYTHON(
m.reigsterModuleSource("autograd_test", R"PYTHON(
import torch
x = torch.ones(5) # input tensor
@ -396,11 +396,11 @@ result = torch.Tensor([1,2,3])
)PYTHON");
at::Tensor w_grad0, w_grad1;
{
auto I = m.all_instances()[0].acquire_session();
auto I = m.allInstances()[0].acquireSession();
w_grad0 = I.global("autograd_test", "result").toIValue().toTensor();
}
{
auto I = m.all_instances()[1].acquire_session();
auto I = m.allInstances()[1].acquireSession();
w_grad1 = I.global("autograd_test", "result").toIValue().toTensor();
}
EXPECT_TRUE(w_grad0.equal(w_grad1));

View File

@ -30,15 +30,15 @@ TEST(TorchDeployGPUTest, SimpleModel) {
// Test
torch::deploy::InterpreterManager m(1);
torch::deploy::Package p = m.load_package(model_filename);
auto model = p.load_pickle("model", "model.pkl");
torch::deploy::Package p = m.loadPackage(model_filename);
auto model = p.loadPickle("model", "model.pkl");
{
auto M = model.acquire_session();
auto M = model.acquireSession();
M.self.attr("to")({"cuda"});
}
std::vector<at::IValue> inputs;
{
auto I = p.acquire_session();
auto I = p.acquireSession();
auto eg = I.self.attr("load_pickle")({"model", "example.pkl"}).toIValue();
inputs = eg.toTuple()->elements();
inputs[0] = inputs[0].toTensor().to("cuda");
@ -59,9 +59,9 @@ TEST(TorchDeployGPUTest, UsesDistributed) {
"USES_DISTRIBUTED",
"torch/csrc/deploy/example/generated/uses_distributed");
torch::deploy::InterpreterManager m(1);
torch::deploy::Package p = m.load_package(model_filename);
torch::deploy::Package p = m.loadPackage(model_filename);
{
auto I = p.acquire_session();
auto I = p.acquireSession();
I.self.attr("import_module")({"uses_distributed"});
}
}
@ -73,10 +73,10 @@ TEST(TorchDeployGPUTest, TensorRT) {
auto packagePath = path(
"MAKE_TRT_MODULE", "torch/csrc/deploy/example/generated/make_trt_module");
torch::deploy::InterpreterManager m(1);
torch::deploy::Package p = m.load_package(packagePath);
auto makeModel = p.load_pickle("make_trt_module", "model.pkl");
torch::deploy::Package p = m.loadPackage(packagePath);
auto makeModel = p.loadPickle("make_trt_module", "model.pkl");
{
auto I = makeModel.acquire_session();
auto I = makeModel.acquireSession();
auto model = I.self(at::ArrayRef<at::IValue>{});
auto input = at::ones({1, 2, 3}).cuda();
auto output = input * 2;

View File

@ -6,14 +6,14 @@
bool run() {
torch::deploy::InterpreterManager m(2);
m.register_module_source("check_none", "check = id(None)\n");
m.reigsterModuleSource("check_none", "check = id(None)\n");
int64_t id0 = 0, id1 = 0;
{
auto I = m.all_instances()[0].acquire_session();
auto I = m.allInstances()[0].acquireSession();
id0 = I.global("check_none", "check").toIValue().toInt();
}
{
auto I = m.all_instances()[1].acquire_session();
auto I = m.allInstances()[1].acquireSession();
id1 = I.global("check_none", "check").toIValue().toInt();
}
return id0 != id1;