mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support rpc_async call with timeout in JIT (#37884)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/37884 Adds support to use rpc_timeout param in rpc_async call from jit for parity with eager mode. Done by: 1) Add timeout as an input in ir_emitter.cpp if it is specified 2) Parse float IValue from inputs in `prim::rpc_async` operator. Give the default if needed. Added UTs in jit/rpc_test. ghstack-source-id: 104083031 Test Plan: Added UTs in jit/rpc_test. Differential Revision: D21268895 fbshipit-source-id: 34bb10a2ac08b67dd6b789121ab43e2c0e696229
This commit is contained in:
committed by
Facebook GitHub Bot
parent
3300dd5227
commit
f178bf10f1
21
test/distributed/rpc/jit/faulty_agent/test_rpc_spawn.py
Normal file
21
test/distributed/rpc/jit/faulty_agent/test_rpc_spawn.py
Normal file
@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env python3
|
||||
import unittest
|
||||
|
||||
from torch.testing._internal.common_distributed import MultiProcessTestCase
|
||||
from torch.testing._internal.common_utils import TEST_WITH_ASAN, run_tests
|
||||
from torch.testing._internal.distributed.rpc.jit.rpc_test_faulty import (
|
||||
JitFaultyAgentRpcTest,
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
TEST_WITH_ASAN, "Skip ASAN as torch + multiprocessing spawn have known issues"
|
||||
)
|
||||
class JitFaultyAgentRpcTestWithSpawn(MultiProcessTestCase, JitFaultyAgentRpcTest):
|
||||
def setUp(self):
|
||||
super(JitFaultyAgentRpcTestWithSpawn, self).setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -2806,9 +2806,13 @@ struct to_ir {
|
||||
// through RPC in TorchScript,
|
||||
// Ideally, function value in JIT IR is first-class citizen and
|
||||
// The RPC C++ entry API can take c10::Function directly.
|
||||
if (apply.inputs().size() < 2 || apply.inputs().size() > 4) {
|
||||
auto rpcMinInputs = 2;
|
||||
auto rpcMaxInputs = 5; // NOLINT
|
||||
if (apply.inputs().size() < rpcMinInputs ||
|
||||
apply.inputs().size() > rpcMaxInputs) {
|
||||
throw ErrorReport(apply)
|
||||
<< "Possible forms of call to rpc_async(..) are\n"
|
||||
<< "rpc_async(dst_worker_name, user_callable, args, kwargs, timeout)\n"
|
||||
<< "rpc_async(dst_worker_name, user_callable, args, kwargs)\n"
|
||||
<< "rpc_async(dst_worker_name, user_callable, args)\n"
|
||||
<< "rpc_async(dst_worker_name, user_callable)\n"
|
||||
@ -2840,7 +2844,9 @@ struct to_ir {
|
||||
// If `kwargs` is an empty dict, users are allowed to not pass `kwargs`.
|
||||
// If `args` and `kwargs` are an empty tuple and an empty dict,
|
||||
// respectively, users are allowed to not pass `args` and `kwargs`.
|
||||
TreeList args_kwargs_trees(input_trees.begin() + 2, input_trees.end());
|
||||
|
||||
TreeList args_kwargs_timeout_trees(
|
||||
input_trees.begin() + 2, input_trees.end());
|
||||
|
||||
// Get user callable.
|
||||
const auto& callablePtrs = user_callable_function_value->callees();
|
||||
@ -2859,9 +2865,9 @@ struct to_ir {
|
||||
std::vector<NamedValue> kwargs;
|
||||
// Get args and kwargs as `NamedValue`s.
|
||||
// Similar to getNamedValues(..) and emitAttributes(..).
|
||||
if (args_kwargs_trees.size() >= 1) {
|
||||
if (args_kwargs_timeout_trees.size() >= 1) {
|
||||
// Unroll args from a Var that is known to be a Tuple.
|
||||
auto& args_tree = args_kwargs_trees[0];
|
||||
auto& args_tree = args_kwargs_timeout_trees[0];
|
||||
auto entry_sugared_values = emitSugaredExpr(Expr(args_tree), 1)
|
||||
->asTuple(args_tree->range(), method);
|
||||
args.reserve(entry_sugared_values.size());
|
||||
@ -2895,7 +2901,7 @@ struct to_ir {
|
||||
rpc_async_node->addInput(dst_worker_name_value);
|
||||
rpc_async_node->addInput(userCallableQualNameValue);
|
||||
|
||||
for (const auto& tree : args_kwargs_trees) {
|
||||
for (const auto& tree : args_kwargs_timeout_trees) {
|
||||
rpc_async_node->addInput(emitExpr(Expr(tree)));
|
||||
}
|
||||
}
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <torch/csrc/distributed/autograd/context/container.h>
|
||||
#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
|
||||
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
||||
#include <torch/csrc/distributed/rpc/rref_impl.h>
|
||||
#include <torch/csrc/distributed/rpc/torchscript_functions.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
@ -122,13 +123,20 @@ RegisterOperators reg_rpc_ops(
|
||||
IValue emptyTuple(c10::ivalue::Tuple::create({}));
|
||||
IValue emptyDict{
|
||||
c10::impl::GenericDict(AnyType::get(), AnyType::get())};
|
||||
// Equavalent to Python statment
|
||||
// Equivalent to Python statement
|
||||
// `args = args if args is not None else ()`.
|
||||
auto& argsTupleIValue =
|
||||
num_inputs >= 3 ? *stackIter++ : emptyTuple;
|
||||
// `kwargs = kwargs if kwargs is not None else {}`.
|
||||
auto& kwargsDictIValue =
|
||||
num_inputs >= 4 ? *stackIter++ : emptyDict;
|
||||
|
||||
// IValue corresponding to placeholder for RPC timeout. Used if no
|
||||
// rpc timeout is specified by user.
|
||||
IValue noTimeout(torch::distributed::rpc::kUnsetRpcTimeout);
|
||||
const auto rpcMaxInputs = 5;
|
||||
auto& timeoutIValue =
|
||||
num_inputs >= rpcMaxInputs ? *stackIter++ : noTimeout;
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
dstWorkerIValue.isString() ||
|
||||
c10::getCustomClassType<
|
||||
@ -137,6 +145,7 @@ RegisterOperators reg_rpc_ops(
|
||||
TORCH_INTERNAL_ASSERT(qualifiedNameIValue.isString());
|
||||
TORCH_INTERNAL_ASSERT(argsTupleIValue.isTuple());
|
||||
TORCH_INTERNAL_ASSERT(kwargsDictIValue.isGenericDict());
|
||||
TORCH_INTERNAL_ASSERT(timeoutIValue.isDouble());
|
||||
|
||||
// Get FunctionSchema for qualifiedName.
|
||||
auto qualifiedName =
|
||||
@ -207,13 +216,15 @@ RegisterOperators reg_rpc_ops(
|
||||
dstWorkerNameStr =
|
||||
dstWorkerIValue.toCustomClass<dist_rpc::WorkerInfo>()->name_;
|
||||
}
|
||||
|
||||
// Get RPC timeout, if specified by user.
|
||||
const auto rpcTimeout = timeoutIValue.toDouble();
|
||||
// Send RPC request.
|
||||
auto futureIValuePtr = dist_rpc::rpcTorchscript(
|
||||
dstWorkerNameStr,
|
||||
qualifiedName,
|
||||
functionSchema,
|
||||
userCallableStack);
|
||||
userCallableStack,
|
||||
rpcTimeout);
|
||||
|
||||
// Push output to the stack.
|
||||
drop(stack, num_inputs);
|
||||
|
142
torch/testing/_internal/distributed/rpc/jit/rpc_test_faulty.py
Normal file
142
torch/testing/_internal/distributed/rpc/jit/rpc_test_faulty.py
Normal file
@ -0,0 +1,142 @@
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed.rpc as rpc
|
||||
import torch.testing._internal.dist_utils as dist_utils
|
||||
from torch import Tensor
|
||||
from torch.testing._internal.dist_utils import (
|
||||
dist_init,
|
||||
get_timeout_error_regex,
|
||||
worker_name,
|
||||
)
|
||||
from torch.testing._internal.distributed.rpc.faulty_rpc_agent_test_fixture import (
|
||||
FaultyRpcAgentTestFixture,
|
||||
)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def two_args_two_kwargs(
|
||||
first_arg,
|
||||
second_arg,
|
||||
first_kwarg=torch.tensor([3, 3]),
|
||||
second_kwarg=torch.tensor([4, 4]),
|
||||
):
|
||||
return first_arg + second_arg + first_kwarg + second_kwarg
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def rpc_async_call_remote_torchscript_in_torchscript(
|
||||
dst_worker_name: str, args: Tuple[Tensor, Tensor], kwargs: Dict[str, Tensor]
|
||||
):
|
||||
fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
|
||||
ret = fut.wait()
|
||||
return ret
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def rpc_async_call_with_timeout(
|
||||
dst_worker_name: str,
|
||||
args: Tuple[Tensor, Tensor],
|
||||
kwargs: Dict[str, Tensor],
|
||||
timeout: float,
|
||||
):
|
||||
fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs, timeout)
|
||||
ret = fut.wait()
|
||||
return ret
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def rpc_async_call_with_timeout_future_ret(
|
||||
dst_worker_name: str,
|
||||
args: Tuple[Tensor, Tensor],
|
||||
kwargs: Dict[str, Tensor],
|
||||
timeout: float,
|
||||
):
|
||||
fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs, timeout)
|
||||
return fut
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def rpc_async_call_future_ret(
|
||||
dst_worker_name: str, args: Tuple[Tensor, Tensor], kwargs: Dict[str, Tensor]
|
||||
):
|
||||
fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
|
||||
return fut
|
||||
|
||||
|
||||
class JitFaultyAgentRpcTest(FaultyRpcAgentTestFixture):
|
||||
"""
|
||||
Run tests for rpc_async in JIT under the faulty agent test fixture to test
|
||||
arbitrary timeouts.
|
||||
"""
|
||||
@dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5})
|
||||
def test_timeout_in_torchscript_function(self):
|
||||
# Call rpc_async + fut.wait() in torchscript function and ensure that
|
||||
# timeout is raised.
|
||||
if self.rank != 0:
|
||||
return
|
||||
|
||||
dst_worker_name = worker_name((self.rank + 1) % self.world_size)
|
||||
|
||||
args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
|
||||
kwargs = {
|
||||
"first_kwarg": torch.tensor([2, 2]),
|
||||
"second_kwarg": torch.tensor([3, 3]),
|
||||
}
|
||||
expected_error = get_timeout_error_regex(
|
||||
dist_utils.TEST_CONFIG.rpc_backend_name
|
||||
)
|
||||
print("Test config is {}".format(dist_utils.TEST_CONFIG.rpc_backend_name))
|
||||
# Ensure that we get a timeout if we override the default timeout and
|
||||
# the RPC takes longer to execute.
|
||||
with self.assertRaisesRegex(RuntimeError, expected_error):
|
||||
rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0.5)
|
||||
|
||||
# Ensure that we timeout if we don't specify a timeout but the default
|
||||
# is less than the RPC takes to execute.
|
||||
rpc._set_rpc_timeout(0.001)
|
||||
with self.assertRaisesRegex(RuntimeError, expected_error):
|
||||
rpc_async_call_remote_torchscript_in_torchscript(
|
||||
dst_worker_name, args, kwargs
|
||||
)
|
||||
|
||||
# Ensure that we run to completion if zero timeout is specified.
|
||||
ret = rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0)
|
||||
self.assertEqual(ret, torch.tensor([8, 8]))
|
||||
# reset for clean shutdown
|
||||
rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
|
||||
|
||||
@dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5})
|
||||
def test_timeout_in_python(self):
|
||||
# Ensures timeouts are raised if we call rpc_async from within a
|
||||
# torchscript function, but wait on the future in python.
|
||||
if self.rank != 0:
|
||||
return
|
||||
|
||||
dst_worker_name = worker_name((self.rank + 1) % self.world_size)
|
||||
args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
|
||||
kwargs = {
|
||||
"first_kwarg": torch.tensor([2, 2]),
|
||||
"second_kwarg": torch.tensor([3, 3]),
|
||||
}
|
||||
expected_error = get_timeout_error_regex(
|
||||
dist_utils.TEST_CONFIG.rpc_backend_name
|
||||
)
|
||||
|
||||
fut = rpc_async_call_with_timeout_future_ret(dst_worker_name, args, kwargs, 0.5)
|
||||
with self.assertRaisesRegex(RuntimeError, expected_error):
|
||||
fut.wait()
|
||||
|
||||
# Ensure timeout if we don't specify but the default is less than the
|
||||
# RPC takes to execute.
|
||||
rpc._set_rpc_timeout(0.001)
|
||||
fut = rpc_async_call_future_ret(dst_worker_name, args, kwargs)
|
||||
with self.assertRaisesRegex(RuntimeError, expected_error):
|
||||
fut.wait()
|
||||
|
||||
# Ensure run to completion if zero timeout is specified
|
||||
fut = rpc_async_call_with_timeout_future_ret(dst_worker_name, args, kwargs, 0)
|
||||
result = fut.wait()
|
||||
self.assertEqual(result, torch.tensor([8, 8]))
|
||||
# reset for clean shutdown
|
||||
rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
|
Reference in New Issue
Block a user