[DI] Allow explicit taskLauncher for torchscript interpreter (#46865)

Summary:
By default, TorchScript execution is single threaded and uses the caller's thread pool. For the use case of distributed inference, we hope there is a way to customize the behavior where the  interpreter in torch script can be executed in other places. This diff allows an explicit taskLauncher for torchscript interpreter.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/46865

Test Plan:
unit test is passed.

fbshipit-source-id: 1d7b003926c0d1f8facc53206efb960cff8897ac

Fixes #{issue number}

Reviewed By: houseroad

Differential Revision: D24616102

Pulled By: garroud

fbshipit-source-id: 79202b62f92d0b0baf72e4bf7aa3f05e0da91d59
This commit is contained in:
Gaoxiang Liu
2020-11-04 17:06:05 -08:00
committed by Facebook GitHub Bot
parent b704cbeffe
commit 735f8cc6c2
12 changed files with 136 additions and 23 deletions

View File

@ -35,7 +35,9 @@ struct BuiltinOpFunction : public Function {
callable_(stack);
}
c10::intrusive_ptr<c10::ivalue::Future> runAsync(Stack& stack) override {
c10::intrusive_ptr<c10::ivalue::Future> runAsync(
Stack& stack,
TaskLauncher /* not used */) override {
run(stack);
auto res = c10::make_intrusive<c10::ivalue::Future>(stack.front().type());
res->markCompleted(std::move(stack.front()));

View File

@ -8,6 +8,10 @@ namespace c10 {
struct FunctionSchema;
};
namespace at {
CAFFE2_API void launch(std::function<void()> func);
}
namespace torch {
namespace jit {
@ -17,6 +21,7 @@ struct GraphExecutor;
using Stack = std::vector<at::IValue>;
using Kwargs = std::unordered_map<std::string, at::IValue>;
struct RecursiveMethodCallError : public std::exception {};
using TaskLauncher = std::function<void(std::function<void()>)>;
TORCH_API void preoptimizeGraph(std::shared_ptr<Graph>& graph);
@ -36,7 +41,9 @@ struct TORCH_API Function {
virtual void run(Stack&& stack) = 0;
virtual c10::intrusive_ptr<c10::ivalue::Future> runAsync(Stack& stack) = 0;
virtual c10::intrusive_ptr<c10::ivalue::Future> runAsync(
Stack& stack,
TaskLauncher taskLauncher = at::launch) = 0;
virtual at::IValue operator()(
std::vector<at::IValue> stack,

View File

@ -2,6 +2,9 @@
#include "test/cpp/jit/test_utils.h"
#include "torch/csrc/jit/runtime/graph_executor.h"
#include "torch/jit.h"
#include "torch/script.h"
#include "torch/torch.h"
namespace torch {
namespace jit {
@ -29,5 +32,40 @@ TEST(GraphExecutorTest, Basic_CUDA) {
ASSERT_TRUE(almostEqual(stack[1].toTensor(), r1));
}
TEST(GraphExecutorTest, runAsync_executor) {
/*
TODO: there are some problem with C++ parsing script program involving
fork. Use the test module below for now.
issue about this: github.com/pytorch/pytorch/issues/46368
The test module file is generated by following:
class DemoModule(torch.nn.Module):
def forward(self):
r1 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
r2 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
return r1.wait() + r2.wait()
demo = DemoModule()
torch.jit.save(torch.jit.script(demo), 'test_interpreter_async.pth')
*/
std::string filePath(__FILE__);
auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1);
testModelFile.append("test_interpreter_async.pt");
auto module = load(testModelFile);
auto graph = module.get_method("forward").graph();
GraphExecutor graphExecutor(graph, "");
auto asyncCounter = 0;
std::mutex mtx;
// a dummy executor which actually use at::launch, but add up a counter
auto launcher = [&](std::function<void()> f) {
mtx.lock();
++asyncCounter;
mtx.unlock();
at::launch(move(f));
};
std::vector<IValue> stack;
stack.push_back(module._ivalue());
graphExecutor.runAsync(stack, launcher)->wait();
ASSERT_TRUE(asyncCounter > 0);
}
} // namespace jit
} // namespace torch

View File

@ -1,6 +1,10 @@
#include <gtest/gtest.h>
#include <ATen/Parallel.h>
#include "test/cpp/jit/test_utils.h"
#include "torch/jit.h"
#include "torch/script.h"
#include "torch/torch.h"
namespace torch {
namespace jit {
@ -138,5 +142,41 @@ TEST(InterpreterTest, Basic_CUDA) {
ASSERT_TRUE(exactlyEqual(outputs[0], hx));
ASSERT_TRUE(exactlyEqual(outputs[1], cx));
}
TEST(InterpreterTest, runAsyncBasicTest) {
/*
TODO: there are some problem with C++ parsing script program involving
fork. Use the test module below for now.
issue about this: github.com/pytorch/pytorch/issues/46368
The test module file is generated by following:
class DemoModule(torch.nn.Module):
def forward(self):
r1 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
r2 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
return r1.wait() + r2.wait()
demo = DemoModule()
torch.jit.save(torch.jit.script(demo), 'test_interpreter_async.pth')
*/
std::string filePath(__FILE__);
auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1);
testModelFile.append("test_interpreter_async.pt");
auto model = load(testModelFile);
auto graph = model.get_method("forward").graph();
Code function(graph, "");
auto asyncCounter = 0;
std::mutex mtx;
// a dummy executor which actually use at::launch, but add up a counter
auto launcher = [&](std::function<void()> f) {
mtx.lock();
++asyncCounter;
mtx.unlock();
at::launch(f);
};
std::vector<IValue> stack;
stack.push_back(model._ivalue());
InterpreterState interp(function, launcher);
interp.runAsync(stack)->wait();
ASSERT_TRUE(asyncCounter > 0);
}
} // namespace jit
} // namespace torch

Binary file not shown.

View File

@ -39,8 +39,10 @@ void GraphFunction::run(Stack&& stack) {
run(stack);
}
c10::intrusive_ptr<c10::ivalue::Future> GraphFunction::runAsync(Stack& stack) {
return get_executor().runAsync(stack);
c10::intrusive_ptr<c10::ivalue::Future> GraphFunction::runAsync(
Stack& stack,
TaskLauncher taskLauncher) {
return get_executor().runAsync(stack, std::move(taskLauncher));
}
IValue GraphFunction::operator()(

View File

@ -25,7 +25,9 @@ struct TORCH_API GraphFunction : public Function {
void run(Stack&& stack) override;
c10::intrusive_ptr<c10::ivalue::Future> runAsync(Stack& stack) override;
c10::intrusive_ptr<c10::ivalue::Future> runAsync(
Stack& stack,
TaskLauncher taskLauncher = at::launch) override;
IValue operator()(std::vector<IValue> stack, const Kwargs& kwargs = Kwargs())
override;

View File

@ -516,7 +516,9 @@ void GraphExecutorImplBase::run(Stack& stack) {
last_executed_optimized_graph = plan.graph;
}
c10::intrusive_ptr<Future> GraphExecutorImplBase::runAsync(Stack& stack) {
c10::intrusive_ptr<Future> GraphExecutorImplBase::runAsync(
Stack& stack,
TaskLauncher taskLauncher) {
TORCH_CHECK(
stack.size() >= num_inputs,
"expected ",
@ -529,13 +531,14 @@ c10::intrusive_ptr<Future> GraphExecutorImplBase::runAsync(Stack& stack) {
logging::runtime_counters::GRAPH_EXECUTOR_INVOCATIONS, 1.0);
struct Frame {
explicit Frame(ExecutionPlan eplan)
: plan(std::move(eplan)), state(plan.code) {}
explicit Frame(ExecutionPlan eplan, TaskLauncher taskLauncher)
: plan(std::move(eplan)), state(plan.code, std::move(taskLauncher)) {}
ExecutionPlan plan;
InterpreterState state;
};
auto frame = std::make_shared<Frame>(
getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts()));
getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts()),
std::move(taskLauncher));
auto res = frame->state.runAsync(stack);
last_executed_optimized_graph = frame->plan.graph;
if (!res->completed()) {
@ -731,8 +734,10 @@ void GraphExecutor::run(Stack& inputs) {
return pImpl->run(inputs);
}
c10::intrusive_ptr<Future> GraphExecutor::runAsync(Stack& stack) {
return pImpl->runAsync(stack);
c10::intrusive_ptr<Future> GraphExecutor::runAsync(
Stack& stack,
TaskLauncher taskLauncher) {
return pImpl->runAsync(stack, std::move(taskLauncher));
}
size_t GraphExecutor::getDefaultNumBailOuts() {

View File

@ -58,7 +58,9 @@ struct TORCH_API GraphExecutor {
GraphExecutor(std::shared_ptr<Graph> graph, std::string function_name);
void run(Stack& inputs);
c10::intrusive_ptr<Future> runAsync(Stack& stack);
c10::intrusive_ptr<Future> runAsync(
Stack& stack,
TaskLauncher taskLauncher = at::launch);
// `remaining_bailout_depth` stands for the maximum number of profiled and
// specialized recompilations allowed for the current `GraphExecutor`. if

View File

@ -69,7 +69,9 @@ struct GraphExecutorImplBase {
// entry point where execution begins
void run(Stack& stack);
c10::intrusive_ptr<Future> runAsync(Stack& stack);
c10::intrusive_ptr<Future> runAsync(
Stack& stack,
TaskLauncher taskLauncher = at::launch);
virtual ExecutionPlan getPlanFor(
Stack& stack,

View File

@ -1031,7 +1031,8 @@ struct CodeImpl {
// InterpreterState state that and used to compute a Code
struct InterpreterStateImpl : c10::intrusive_ptr_target {
InterpreterStateImpl(const Code& code) {
InterpreterStateImpl(const Code& code, TaskLauncher taskLauncher)
: taskLauncher_(std::move(taskLauncher)) {
enterFrame(code, 0);
}
@ -1057,6 +1058,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
// including any inputs to this function
int64_t stack_start_ = -1;
c10::intrusive_ptr<Future> future_;
TaskLauncher taskLauncher_;
// this holds all the tensors for this interpreter run
// we don't bother minimizing the size of this vector, since the extra
@ -1335,11 +1337,14 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
Callback(
c10::intrusive_ptr<InterpreterStateImpl> state,
Stack stack)
: state_(std::move(state)), stack_(std::move(stack)) {
: stateImpl_(std::move(state)),
state_(stateImpl_),
stack_(std::move(stack)) {
dist_autograd_context_id_ = getDistAutogradContextId();
state_ = InterpreterState(stateImpl_);
}
void operator()() {
at::launch(InterpreterContinuation(
stateImpl_->taskLauncher_(InterpreterContinuation(
state_,
std::move(stack_),
dist_autograd_context_id_,
@ -1347,6 +1352,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
}
private:
c10::intrusive_ptr<InterpreterStateImpl> stateImpl_;
InterpreterState state_;
Stack stack_;
int64_t dist_autograd_context_id_;
@ -1511,14 +1517,15 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
InterpreterState forked_interpreter(
forked_fn->get_executor()
.getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts())
.code);
.code,
taskLauncher_);
InterpreterContinuation continuation(
forked_interpreter,
Stack(stack.end() - inst.N, stack.end()),
getDistAutogradContextId());
drop(stack, inst.N);
push(stack, forked_interpreter.getFuture());
at::launch(std::move(continuation));
taskLauncher_(std::move(continuation));
++frame.pc;
} break;
case WARN: {
@ -1740,8 +1747,10 @@ size_t Code::register_size() const {
return pImpl->register_size_;
}
InterpreterState::InterpreterState(const Code& code)
: pImpl(c10::make_intrusive<InterpreterStateImpl>(code)) {}
InterpreterState::InterpreterState(const Code& code, TaskLauncher taskLauncher)
: pImpl(c10::make_intrusive<InterpreterStateImpl>(
code,
std::move(taskLauncher))) {}
InterpreterState::~InterpreterState() = default;
void InterpreterState::run(Stack& stack) {

View File

@ -10,7 +10,8 @@
namespace at {
class Tensor;
}
CAFFE2_API void launch(std::function<void()> func);
} // namespace at
namespace c10 {
struct IValue;
struct OperatorName;
@ -32,6 +33,7 @@ struct Node;
struct Instruction;
using Stack = std::vector<c10::IValue>;
using c10::ivalue::Future;
using TaskLauncher = std::function<void(std::function<void()>)>;
struct TORCH_API Code {
Code() : pImpl(nullptr) {}
@ -66,9 +68,11 @@ struct TORCH_API Code {
};
struct InterpreterState {
TORCH_API InterpreterState(const Code& code);
TORCH_API InterpreterState(
const Code& code,
TaskLauncher taskLauncher = at::launch);
TORCH_API void run(Stack& stack);
c10::intrusive_ptr<Future> runAsync(Stack& stack);
TORCH_API c10::intrusive_ptr<Future> runAsync(Stack& stack);
c10::intrusive_ptr<Future> getFuture();
TORCH_API ~InterpreterState();