mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[DI] Allow explicit taskLauncher for torchscript interpreter (#46865)
Summary: By default, TorchScript execution is single threaded and uses the caller's thread pool. For the use case of distributed inference, we hope there is a way to customize the behavior where the interpreter in torch script can be executed in other places. This diff allows an explicit taskLauncher for torchscript interpreter. Pull Request resolved: https://github.com/pytorch/pytorch/pull/46865 Test Plan: unit test is passed. fbshipit-source-id: 1d7b003926c0d1f8facc53206efb960cff8897ac Fixes #{issue number} Reviewed By: houseroad Differential Revision: D24616102 Pulled By: garroud fbshipit-source-id: 79202b62f92d0b0baf72e4bf7aa3f05e0da91d59
This commit is contained in:
committed by
Facebook GitHub Bot
parent
b704cbeffe
commit
735f8cc6c2
@ -35,7 +35,9 @@ struct BuiltinOpFunction : public Function {
|
||||
callable_(stack);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<c10::ivalue::Future> runAsync(Stack& stack) override {
|
||||
c10::intrusive_ptr<c10::ivalue::Future> runAsync(
|
||||
Stack& stack,
|
||||
TaskLauncher /* not used */) override {
|
||||
run(stack);
|
||||
auto res = c10::make_intrusive<c10::ivalue::Future>(stack.front().type());
|
||||
res->markCompleted(std::move(stack.front()));
|
||||
|
@ -8,6 +8,10 @@ namespace c10 {
|
||||
struct FunctionSchema;
|
||||
};
|
||||
|
||||
namespace at {
|
||||
CAFFE2_API void launch(std::function<void()> func);
|
||||
}
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
@ -17,6 +21,7 @@ struct GraphExecutor;
|
||||
using Stack = std::vector<at::IValue>;
|
||||
using Kwargs = std::unordered_map<std::string, at::IValue>;
|
||||
struct RecursiveMethodCallError : public std::exception {};
|
||||
using TaskLauncher = std::function<void(std::function<void()>)>;
|
||||
|
||||
TORCH_API void preoptimizeGraph(std::shared_ptr<Graph>& graph);
|
||||
|
||||
@ -36,7 +41,9 @@ struct TORCH_API Function {
|
||||
|
||||
virtual void run(Stack&& stack) = 0;
|
||||
|
||||
virtual c10::intrusive_ptr<c10::ivalue::Future> runAsync(Stack& stack) = 0;
|
||||
virtual c10::intrusive_ptr<c10::ivalue::Future> runAsync(
|
||||
Stack& stack,
|
||||
TaskLauncher taskLauncher = at::launch) = 0;
|
||||
|
||||
virtual at::IValue operator()(
|
||||
std::vector<at::IValue> stack,
|
||||
|
@ -2,6 +2,9 @@
|
||||
|
||||
#include "test/cpp/jit/test_utils.h"
|
||||
#include "torch/csrc/jit/runtime/graph_executor.h"
|
||||
#include "torch/jit.h"
|
||||
#include "torch/script.h"
|
||||
#include "torch/torch.h"
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
@ -29,5 +32,40 @@ TEST(GraphExecutorTest, Basic_CUDA) {
|
||||
ASSERT_TRUE(almostEqual(stack[1].toTensor(), r1));
|
||||
}
|
||||
|
||||
TEST(GraphExecutorTest, runAsync_executor) {
|
||||
/*
|
||||
TODO: there are some problem with C++ parsing script program involving
|
||||
fork. Use the test module below for now.
|
||||
issue about this: github.com/pytorch/pytorch/issues/46368
|
||||
The test module file is generated by following:
|
||||
class DemoModule(torch.nn.Module):
|
||||
def forward(self):
|
||||
r1 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
|
||||
r2 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
|
||||
return r1.wait() + r2.wait()
|
||||
demo = DemoModule()
|
||||
torch.jit.save(torch.jit.script(demo), 'test_interpreter_async.pth')
|
||||
*/
|
||||
std::string filePath(__FILE__);
|
||||
auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1);
|
||||
testModelFile.append("test_interpreter_async.pt");
|
||||
auto module = load(testModelFile);
|
||||
auto graph = module.get_method("forward").graph();
|
||||
GraphExecutor graphExecutor(graph, "");
|
||||
auto asyncCounter = 0;
|
||||
std::mutex mtx;
|
||||
// a dummy executor which actually use at::launch, but add up a counter
|
||||
auto launcher = [&](std::function<void()> f) {
|
||||
mtx.lock();
|
||||
++asyncCounter;
|
||||
mtx.unlock();
|
||||
at::launch(move(f));
|
||||
};
|
||||
std::vector<IValue> stack;
|
||||
stack.push_back(module._ivalue());
|
||||
graphExecutor.runAsync(stack, launcher)->wait();
|
||||
ASSERT_TRUE(asyncCounter > 0);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -1,6 +1,10 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ATen/Parallel.h>
|
||||
#include "test/cpp/jit/test_utils.h"
|
||||
#include "torch/jit.h"
|
||||
#include "torch/script.h"
|
||||
#include "torch/torch.h"
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
@ -138,5 +142,41 @@ TEST(InterpreterTest, Basic_CUDA) {
|
||||
ASSERT_TRUE(exactlyEqual(outputs[0], hx));
|
||||
ASSERT_TRUE(exactlyEqual(outputs[1], cx));
|
||||
}
|
||||
|
||||
TEST(InterpreterTest, runAsyncBasicTest) {
|
||||
/*
|
||||
TODO: there are some problem with C++ parsing script program involving
|
||||
fork. Use the test module below for now.
|
||||
issue about this: github.com/pytorch/pytorch/issues/46368
|
||||
The test module file is generated by following:
|
||||
class DemoModule(torch.nn.Module):
|
||||
def forward(self):
|
||||
r1 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
|
||||
r2 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
|
||||
return r1.wait() + r2.wait()
|
||||
demo = DemoModule()
|
||||
torch.jit.save(torch.jit.script(demo), 'test_interpreter_async.pth')
|
||||
*/
|
||||
std::string filePath(__FILE__);
|
||||
auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1);
|
||||
testModelFile.append("test_interpreter_async.pt");
|
||||
auto model = load(testModelFile);
|
||||
auto graph = model.get_method("forward").graph();
|
||||
Code function(graph, "");
|
||||
auto asyncCounter = 0;
|
||||
std::mutex mtx;
|
||||
// a dummy executor which actually use at::launch, but add up a counter
|
||||
auto launcher = [&](std::function<void()> f) {
|
||||
mtx.lock();
|
||||
++asyncCounter;
|
||||
mtx.unlock();
|
||||
at::launch(f);
|
||||
};
|
||||
std::vector<IValue> stack;
|
||||
stack.push_back(model._ivalue());
|
||||
InterpreterState interp(function, launcher);
|
||||
interp.runAsync(stack)->wait();
|
||||
ASSERT_TRUE(asyncCounter > 0);
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
BIN
test/cpp/jit/test_interpreter_async.pt
Normal file
BIN
test/cpp/jit/test_interpreter_async.pt
Normal file
Binary file not shown.
@ -39,8 +39,10 @@ void GraphFunction::run(Stack&& stack) {
|
||||
run(stack);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<c10::ivalue::Future> GraphFunction::runAsync(Stack& stack) {
|
||||
return get_executor().runAsync(stack);
|
||||
c10::intrusive_ptr<c10::ivalue::Future> GraphFunction::runAsync(
|
||||
Stack& stack,
|
||||
TaskLauncher taskLauncher) {
|
||||
return get_executor().runAsync(stack, std::move(taskLauncher));
|
||||
}
|
||||
|
||||
IValue GraphFunction::operator()(
|
||||
|
@ -25,7 +25,9 @@ struct TORCH_API GraphFunction : public Function {
|
||||
|
||||
void run(Stack&& stack) override;
|
||||
|
||||
c10::intrusive_ptr<c10::ivalue::Future> runAsync(Stack& stack) override;
|
||||
c10::intrusive_ptr<c10::ivalue::Future> runAsync(
|
||||
Stack& stack,
|
||||
TaskLauncher taskLauncher = at::launch) override;
|
||||
|
||||
IValue operator()(std::vector<IValue> stack, const Kwargs& kwargs = Kwargs())
|
||||
override;
|
||||
|
@ -516,7 +516,9 @@ void GraphExecutorImplBase::run(Stack& stack) {
|
||||
last_executed_optimized_graph = plan.graph;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Future> GraphExecutorImplBase::runAsync(Stack& stack) {
|
||||
c10::intrusive_ptr<Future> GraphExecutorImplBase::runAsync(
|
||||
Stack& stack,
|
||||
TaskLauncher taskLauncher) {
|
||||
TORCH_CHECK(
|
||||
stack.size() >= num_inputs,
|
||||
"expected ",
|
||||
@ -529,13 +531,14 @@ c10::intrusive_ptr<Future> GraphExecutorImplBase::runAsync(Stack& stack) {
|
||||
logging::runtime_counters::GRAPH_EXECUTOR_INVOCATIONS, 1.0);
|
||||
|
||||
struct Frame {
|
||||
explicit Frame(ExecutionPlan eplan)
|
||||
: plan(std::move(eplan)), state(plan.code) {}
|
||||
explicit Frame(ExecutionPlan eplan, TaskLauncher taskLauncher)
|
||||
: plan(std::move(eplan)), state(plan.code, std::move(taskLauncher)) {}
|
||||
ExecutionPlan plan;
|
||||
InterpreterState state;
|
||||
};
|
||||
auto frame = std::make_shared<Frame>(
|
||||
getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts()));
|
||||
getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts()),
|
||||
std::move(taskLauncher));
|
||||
auto res = frame->state.runAsync(stack);
|
||||
last_executed_optimized_graph = frame->plan.graph;
|
||||
if (!res->completed()) {
|
||||
@ -731,8 +734,10 @@ void GraphExecutor::run(Stack& inputs) {
|
||||
return pImpl->run(inputs);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Future> GraphExecutor::runAsync(Stack& stack) {
|
||||
return pImpl->runAsync(stack);
|
||||
c10::intrusive_ptr<Future> GraphExecutor::runAsync(
|
||||
Stack& stack,
|
||||
TaskLauncher taskLauncher) {
|
||||
return pImpl->runAsync(stack, std::move(taskLauncher));
|
||||
}
|
||||
|
||||
size_t GraphExecutor::getDefaultNumBailOuts() {
|
||||
|
@ -58,7 +58,9 @@ struct TORCH_API GraphExecutor {
|
||||
GraphExecutor(std::shared_ptr<Graph> graph, std::string function_name);
|
||||
|
||||
void run(Stack& inputs);
|
||||
c10::intrusive_ptr<Future> runAsync(Stack& stack);
|
||||
c10::intrusive_ptr<Future> runAsync(
|
||||
Stack& stack,
|
||||
TaskLauncher taskLauncher = at::launch);
|
||||
|
||||
// `remaining_bailout_depth` stands for the maximum number of profiled and
|
||||
// specialized recompilations allowed for the current `GraphExecutor`. if
|
||||
|
@ -69,7 +69,9 @@ struct GraphExecutorImplBase {
|
||||
|
||||
// entry point where execution begins
|
||||
void run(Stack& stack);
|
||||
c10::intrusive_ptr<Future> runAsync(Stack& stack);
|
||||
c10::intrusive_ptr<Future> runAsync(
|
||||
Stack& stack,
|
||||
TaskLauncher taskLauncher = at::launch);
|
||||
|
||||
virtual ExecutionPlan getPlanFor(
|
||||
Stack& stack,
|
||||
|
@ -1031,7 +1031,8 @@ struct CodeImpl {
|
||||
|
||||
// InterpreterState state that and used to compute a Code
|
||||
struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
||||
InterpreterStateImpl(const Code& code) {
|
||||
InterpreterStateImpl(const Code& code, TaskLauncher taskLauncher)
|
||||
: taskLauncher_(std::move(taskLauncher)) {
|
||||
enterFrame(code, 0);
|
||||
}
|
||||
|
||||
@ -1057,6 +1058,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
||||
// including any inputs to this function
|
||||
int64_t stack_start_ = -1;
|
||||
c10::intrusive_ptr<Future> future_;
|
||||
TaskLauncher taskLauncher_;
|
||||
|
||||
// this holds all the tensors for this interpreter run
|
||||
// we don't bother minimizing the size of this vector, since the extra
|
||||
@ -1335,11 +1337,14 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
||||
Callback(
|
||||
c10::intrusive_ptr<InterpreterStateImpl> state,
|
||||
Stack stack)
|
||||
: state_(std::move(state)), stack_(std::move(stack)) {
|
||||
: stateImpl_(std::move(state)),
|
||||
state_(stateImpl_),
|
||||
stack_(std::move(stack)) {
|
||||
dist_autograd_context_id_ = getDistAutogradContextId();
|
||||
state_ = InterpreterState(stateImpl_);
|
||||
}
|
||||
void operator()() {
|
||||
at::launch(InterpreterContinuation(
|
||||
stateImpl_->taskLauncher_(InterpreterContinuation(
|
||||
state_,
|
||||
std::move(stack_),
|
||||
dist_autograd_context_id_,
|
||||
@ -1347,6 +1352,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
||||
}
|
||||
|
||||
private:
|
||||
c10::intrusive_ptr<InterpreterStateImpl> stateImpl_;
|
||||
InterpreterState state_;
|
||||
Stack stack_;
|
||||
int64_t dist_autograd_context_id_;
|
||||
@ -1511,14 +1517,15 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
||||
InterpreterState forked_interpreter(
|
||||
forked_fn->get_executor()
|
||||
.getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts())
|
||||
.code);
|
||||
.code,
|
||||
taskLauncher_);
|
||||
InterpreterContinuation continuation(
|
||||
forked_interpreter,
|
||||
Stack(stack.end() - inst.N, stack.end()),
|
||||
getDistAutogradContextId());
|
||||
drop(stack, inst.N);
|
||||
push(stack, forked_interpreter.getFuture());
|
||||
at::launch(std::move(continuation));
|
||||
taskLauncher_(std::move(continuation));
|
||||
++frame.pc;
|
||||
} break;
|
||||
case WARN: {
|
||||
@ -1740,8 +1747,10 @@ size_t Code::register_size() const {
|
||||
return pImpl->register_size_;
|
||||
}
|
||||
|
||||
InterpreterState::InterpreterState(const Code& code)
|
||||
: pImpl(c10::make_intrusive<InterpreterStateImpl>(code)) {}
|
||||
InterpreterState::InterpreterState(const Code& code, TaskLauncher taskLauncher)
|
||||
: pImpl(c10::make_intrusive<InterpreterStateImpl>(
|
||||
code,
|
||||
std::move(taskLauncher))) {}
|
||||
InterpreterState::~InterpreterState() = default;
|
||||
|
||||
void InterpreterState::run(Stack& stack) {
|
||||
|
@ -10,7 +10,8 @@
|
||||
|
||||
namespace at {
|
||||
class Tensor;
|
||||
}
|
||||
CAFFE2_API void launch(std::function<void()> func);
|
||||
} // namespace at
|
||||
namespace c10 {
|
||||
struct IValue;
|
||||
struct OperatorName;
|
||||
@ -32,6 +33,7 @@ struct Node;
|
||||
struct Instruction;
|
||||
using Stack = std::vector<c10::IValue>;
|
||||
using c10::ivalue::Future;
|
||||
using TaskLauncher = std::function<void(std::function<void()>)>;
|
||||
|
||||
struct TORCH_API Code {
|
||||
Code() : pImpl(nullptr) {}
|
||||
@ -66,9 +68,11 @@ struct TORCH_API Code {
|
||||
};
|
||||
|
||||
struct InterpreterState {
|
||||
TORCH_API InterpreterState(const Code& code);
|
||||
TORCH_API InterpreterState(
|
||||
const Code& code,
|
||||
TaskLauncher taskLauncher = at::launch);
|
||||
TORCH_API void run(Stack& stack);
|
||||
c10::intrusive_ptr<Future> runAsync(Stack& stack);
|
||||
TORCH_API c10::intrusive_ptr<Future> runAsync(Stack& stack);
|
||||
c10::intrusive_ptr<Future> getFuture();
|
||||
TORCH_API ~InterpreterState();
|
||||
|
||||
|
Reference in New Issue
Block a user