Support rpc_async call with timeout in JIT (#37884)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37884

Adds support to use rpc_timeout param in rpc_async call from jit for
parity with eager mode. Done by:
1) Add timeout as an input in ir_emitter.cpp if it is specified
2) Parse float IValue from inputs in `prim::rpc_async` operator. Give the default if needed.

Added UTs in jit/rpc_test.
ghstack-source-id: 104083031

Test Plan: Added UTs in jit/rpc_test.

Differential Revision: D21268895

fbshipit-source-id: 34bb10a2ac08b67dd6b789121ab43e2c0e696229
This commit is contained in:
Rohan Varma
2020-05-14 12:41:27 -07:00
committed by Facebook GitHub Bot
parent 3300dd5227
commit f178bf10f1
4 changed files with 188 additions and 8 deletions

View File

@ -0,0 +1,21 @@
#!/usr/bin/env python3
import unittest
from torch.testing._internal.common_distributed import MultiProcessTestCase
from torch.testing._internal.common_utils import TEST_WITH_ASAN, run_tests
from torch.testing._internal.distributed.rpc.jit.rpc_test_faulty import (
JitFaultyAgentRpcTest,
)
@unittest.skipIf(
TEST_WITH_ASAN, "Skip ASAN as torch + multiprocessing spawn have known issues"
)
class JitFaultyAgentRpcTestWithSpawn(MultiProcessTestCase, JitFaultyAgentRpcTest):
def setUp(self):
super(JitFaultyAgentRpcTestWithSpawn, self).setUp()
self._spawn_processes()
if __name__ == "__main__":
run_tests()

View File

@ -2806,9 +2806,13 @@ struct to_ir {
// through RPC in TorchScript,
// Ideally, function value in JIT IR is first-class citizen and
// The RPC C++ entry API can take c10::Function directly.
if (apply.inputs().size() < 2 || apply.inputs().size() > 4) {
auto rpcMinInputs = 2;
auto rpcMaxInputs = 5; // NOLINT
if (apply.inputs().size() < rpcMinInputs ||
apply.inputs().size() > rpcMaxInputs) {
throw ErrorReport(apply)
<< "Possible forms of call to rpc_async(..) are\n"
<< "rpc_async(dst_worker_name, user_callable, args, kwargs, timeout)\n"
<< "rpc_async(dst_worker_name, user_callable, args, kwargs)\n"
<< "rpc_async(dst_worker_name, user_callable, args)\n"
<< "rpc_async(dst_worker_name, user_callable)\n"
@ -2840,7 +2844,9 @@ struct to_ir {
// If `kwargs` is an empty dict, users are allowed to not pass `kwargs`.
// If `args` and `kwargs` are an empty tuple and an empty dict,
// respectively, users are allowed to not pass `args` and `kwargs`.
TreeList args_kwargs_trees(input_trees.begin() + 2, input_trees.end());
TreeList args_kwargs_timeout_trees(
input_trees.begin() + 2, input_trees.end());
// Get user callable.
const auto& callablePtrs = user_callable_function_value->callees();
@ -2859,9 +2865,9 @@ struct to_ir {
std::vector<NamedValue> kwargs;
// Get args and kwargs as `NamedValue`s.
// Similar to getNamedValues(..) and emitAttributes(..).
if (args_kwargs_trees.size() >= 1) {
if (args_kwargs_timeout_trees.size() >= 1) {
// Unroll args from a Var that is known to be a Tuple.
auto& args_tree = args_kwargs_trees[0];
auto& args_tree = args_kwargs_timeout_trees[0];
auto entry_sugared_values = emitSugaredExpr(Expr(args_tree), 1)
->asTuple(args_tree->range(), method);
args.reserve(entry_sugared_values.size());
@ -2895,7 +2901,7 @@ struct to_ir {
rpc_async_node->addInput(dst_worker_name_value);
rpc_async_node->addInput(userCallableQualNameValue);
for (const auto& tree : args_kwargs_trees) {
for (const auto& tree : args_kwargs_timeout_trees) {
rpc_async_node->addInput(emitExpr(Expr(tree)));
}
}

View File

@ -2,6 +2,7 @@
#include <ATen/core/op_registration/op_registration.h>
#include <torch/csrc/distributed/autograd/context/container.h>
#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/rref_impl.h>
#include <torch/csrc/distributed/rpc/torchscript_functions.h>
#include <torch/csrc/jit/python/pybind_utils.h>
@ -122,13 +123,20 @@ RegisterOperators reg_rpc_ops(
IValue emptyTuple(c10::ivalue::Tuple::create({}));
IValue emptyDict{
c10::impl::GenericDict(AnyType::get(), AnyType::get())};
// Equavalent to Python statment
// Equivalent to Python statement
// `args = args if args is not None else ()`.
auto& argsTupleIValue =
num_inputs >= 3 ? *stackIter++ : emptyTuple;
// `kwargs = kwargs if kwargs is not None else {}`.
auto& kwargsDictIValue =
num_inputs >= 4 ? *stackIter++ : emptyDict;
// IValue corresponding to placeholder for RPC timeout. Used if no
// rpc timeout is specified by user.
IValue noTimeout(torch::distributed::rpc::kUnsetRpcTimeout);
const auto rpcMaxInputs = 5;
auto& timeoutIValue =
num_inputs >= rpcMaxInputs ? *stackIter++ : noTimeout;
TORCH_INTERNAL_ASSERT(
dstWorkerIValue.isString() ||
c10::getCustomClassType<
@ -137,6 +145,7 @@ RegisterOperators reg_rpc_ops(
TORCH_INTERNAL_ASSERT(qualifiedNameIValue.isString());
TORCH_INTERNAL_ASSERT(argsTupleIValue.isTuple());
TORCH_INTERNAL_ASSERT(kwargsDictIValue.isGenericDict());
TORCH_INTERNAL_ASSERT(timeoutIValue.isDouble());
// Get FunctionSchema for qualifiedName.
auto qualifiedName =
@ -207,13 +216,15 @@ RegisterOperators reg_rpc_ops(
dstWorkerNameStr =
dstWorkerIValue.toCustomClass<dist_rpc::WorkerInfo>()->name_;
}
// Get RPC timeout, if specified by user.
const auto rpcTimeout = timeoutIValue.toDouble();
// Send RPC request.
auto futureIValuePtr = dist_rpc::rpcTorchscript(
dstWorkerNameStr,
qualifiedName,
functionSchema,
userCallableStack);
userCallableStack,
rpcTimeout);
// Push output to the stack.
drop(stack, num_inputs);

View File

@ -0,0 +1,142 @@
from typing import Dict, Tuple
import torch
import torch.distributed.rpc as rpc
import torch.testing._internal.dist_utils as dist_utils
from torch import Tensor
from torch.testing._internal.dist_utils import (
dist_init,
get_timeout_error_regex,
worker_name,
)
from torch.testing._internal.distributed.rpc.faulty_rpc_agent_test_fixture import (
FaultyRpcAgentTestFixture,
)
@torch.jit.script
def two_args_two_kwargs(
first_arg,
second_arg,
first_kwarg=torch.tensor([3, 3]),
second_kwarg=torch.tensor([4, 4]),
):
return first_arg + second_arg + first_kwarg + second_kwarg
@torch.jit.script
def rpc_async_call_remote_torchscript_in_torchscript(
dst_worker_name: str, args: Tuple[Tensor, Tensor], kwargs: Dict[str, Tensor]
):
fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
ret = fut.wait()
return ret
@torch.jit.script
def rpc_async_call_with_timeout(
dst_worker_name: str,
args: Tuple[Tensor, Tensor],
kwargs: Dict[str, Tensor],
timeout: float,
):
fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs, timeout)
ret = fut.wait()
return ret
@torch.jit.script
def rpc_async_call_with_timeout_future_ret(
dst_worker_name: str,
args: Tuple[Tensor, Tensor],
kwargs: Dict[str, Tensor],
timeout: float,
):
fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs, timeout)
return fut
@torch.jit.script
def rpc_async_call_future_ret(
dst_worker_name: str, args: Tuple[Tensor, Tensor], kwargs: Dict[str, Tensor]
):
fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
return fut
class JitFaultyAgentRpcTest(FaultyRpcAgentTestFixture):
"""
Run tests for rpc_async in JIT under the faulty agent test fixture to test
arbitrary timeouts.
"""
@dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5})
def test_timeout_in_torchscript_function(self):
# Call rpc_async + fut.wait() in torchscript function and ensure that
# timeout is raised.
if self.rank != 0:
return
dst_worker_name = worker_name((self.rank + 1) % self.world_size)
args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
kwargs = {
"first_kwarg": torch.tensor([2, 2]),
"second_kwarg": torch.tensor([3, 3]),
}
expected_error = get_timeout_error_regex(
dist_utils.TEST_CONFIG.rpc_backend_name
)
print("Test config is {}".format(dist_utils.TEST_CONFIG.rpc_backend_name))
# Ensure that we get a timeout if we override the default timeout and
# the RPC takes longer to execute.
with self.assertRaisesRegex(RuntimeError, expected_error):
rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0.5)
# Ensure that we timeout if we don't specify a timeout but the default
# is less than the RPC takes to execute.
rpc._set_rpc_timeout(0.001)
with self.assertRaisesRegex(RuntimeError, expected_error):
rpc_async_call_remote_torchscript_in_torchscript(
dst_worker_name, args, kwargs
)
# Ensure that we run to completion if zero timeout is specified.
ret = rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0)
self.assertEqual(ret, torch.tensor([8, 8]))
# reset for clean shutdown
rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
@dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5})
def test_timeout_in_python(self):
# Ensures timeouts are raised if we call rpc_async from within a
# torchscript function, but wait on the future in python.
if self.rank != 0:
return
dst_worker_name = worker_name((self.rank + 1) % self.world_size)
args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
kwargs = {
"first_kwarg": torch.tensor([2, 2]),
"second_kwarg": torch.tensor([3, 3]),
}
expected_error = get_timeout_error_regex(
dist_utils.TEST_CONFIG.rpc_backend_name
)
fut = rpc_async_call_with_timeout_future_ret(dst_worker_name, args, kwargs, 0.5)
with self.assertRaisesRegex(RuntimeError, expected_error):
fut.wait()
# Ensure timeout if we don't specify but the default is less than the
# RPC takes to execute.
rpc._set_rpc_timeout(0.001)
fut = rpc_async_call_future_ret(dst_worker_name, args, kwargs)
with self.assertRaisesRegex(RuntimeError, expected_error):
fut.wait()
# Ensure run to completion if zero timeout is specified
fut = rpc_async_call_with_timeout_future_ret(dst_worker_name, args, kwargs, 0)
result = fut.wait()
self.assertEqual(result, torch.tensor([8, 8]))
# reset for clean shutdown
rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)