Files
pytorch/torch/csrc/jit/runtime/register_distributed_ops.cpp
cyy 41bd0c900a [1/N] Deprecate c10::string_view and at::string (#151972)
The calls of `c10::string_view` in the code base are replaced by `std::string_view`. The calls of `at::string` are replaced by `std::string`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151972
Approved by: https://github.com/malfet
2025-04-29 07:23:52 +00:00

286 lines
10 KiB
C++

#include <ATen/ATen.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/csrc/distributed/autograd/autograd.h>
#include <torch/csrc/distributed/autograd/context/container.h>
#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/rref_impl.h>
#include <torch/csrc/distributed/rpc/torchscript_functions.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/jit/runtime/register_ops_utils.h>
#include <torch/library.h>
#include <fmt/format.h>
#include <stdexcept>
namespace dist_autograd = torch::distributed::autograd;
namespace dist_rpc = torch::distributed::rpc;
namespace torch::jit {
namespace {
distributed::rpc::RegisterWorkerInfoOnce workerInfo{};
// prepare the rpc input arguments and call the C++ impls
void prepare_and_call_rpc_op(
Stack& stack,
int num_inputs,
const std::string& rpc_op) {
// Get inputs from the stack.
auto stackIter = stack.end() - num_inputs;
auto& dstWorkerIValue = *stackIter++;
auto& qualifiedNameIValue = *stackIter++;
IValue emptyTuple(c10::ivalue::Tuple::create({}));
IValue emptyDict{c10::impl::GenericDict(AnyType::get(), AnyType::get())};
// Equivalent to Python statement
// `args = args if args is not None else ()`.
auto& argsTupleIValue = num_inputs >= 3 ? *stackIter++ : emptyTuple;
// `kwargs = kwargs if kwargs is not None else {}`.
auto& kwargsDictIValue = num_inputs >= 4 ? *stackIter++ : emptyDict;
// IValue corresponding to placeholder for RPC timeout. Used if no
// rpc timeout is specified by user.
IValue noTimeout(torch::distributed::rpc::kUnsetRpcTimeout);
const auto rpcMaxInputs = 5;
auto& timeoutIValue = num_inputs >= rpcMaxInputs ? *stackIter++ : noTimeout;
TORCH_INTERNAL_ASSERT(
dstWorkerIValue.isString() ||
c10::getCustomClassType<c10::intrusive_ptr<dist_rpc::WorkerInfo>>() ==
dstWorkerIValue.type());
TORCH_INTERNAL_ASSERT(qualifiedNameIValue.isString());
TORCH_INTERNAL_ASSERT(argsTupleIValue.isTuple());
TORCH_INTERNAL_ASSERT(kwargsDictIValue.isGenericDict());
TORCH_INTERNAL_ASSERT(timeoutIValue.isDouble());
// Get FunctionSchema for qualifiedName.
auto qualifiedName = c10::QualifiedName(qualifiedNameIValue.toStringRef());
std::shared_ptr<CompilationUnit> cuPtr;
{
py::gil_scoped_acquire acquire;
cuPtr = get_python_cu();
}
auto& functionSchema = cuPtr->get_function(qualifiedName).getSchema();
// Build the stack for the user callable.
// It's similar to
// Stack createStackForSchema(FunctionSchema, py::args,
// py::kwargs). Instead, it's Stack
// createStackForSchema(FunctionSchema, IValue<Tuple>,
// IValue<Dict>).
Stack userCallableStack;
userCallableStack.reserve(functionSchema.arguments().size());
// Move args from Tuple IValue to Stack.
for (auto& elem : argsTupleIValue.toTupleRef().elements()) {
push(userCallableStack, std::move(elem));
}
// Move kwargs from Dict IValue to Stack.
size_t consumed_kwargs = 0;
auto kwargsDict = kwargsDictIValue.toGenericDict();
for (size_t i = userCallableStack.size();
i < functionSchema.arguments().size();
++i) {
const auto& arg = functionSchema.arguments()[i];
const auto& argName = arg.name();
if (kwargsDict.contains(argName)) {
push(userCallableStack, kwargsDict.at(argName));
consumed_kwargs += 1;
} else if (arg.default_value()) {
push(userCallableStack, *arg.default_value());
} else {
throw std::runtime_error(c10::str(
functionSchema.name(),
"() is missing value for argument '",
argName,
"'. Declaration: ",
functionSchema));
}
}
// Raise exception showing the unexpected kwargs.
if (consumed_kwargs != kwargsDict.size()) {
std::vector<std::string> names;
for (const auto& entry : kwargsDict) {
const IValue& keyIValue = entry.key();
const std::string& keyStr = keyIValue.toStringRef();
names.emplace_back(keyStr);
}
throw std::runtime_error(functionSchema.findErrorInKwargs(names));
}
// Get destination WorkerName.
std::string dstWorkerNameStr;
if (dstWorkerIValue.isString()) {
// ivalue::ConstantString::str_ is a const member, which can't be
// moved, copy it here.
dstWorkerNameStr = dstWorkerIValue.toStringRef();
} else {
dstWorkerNameStr =
dstWorkerIValue.toCustomClass<dist_rpc::WorkerInfo>()->name_;
}
// Get RPC timeout, if specified by user.
const auto rpcTimeout = timeoutIValue.toDouble();
if (rpc_op == "rpc_async") {
// Send RPC request.
auto futureIValuePtr = dist_rpc::rpcTorchscript(
dstWorkerNameStr,
qualifiedName,
functionSchema,
userCallableStack,
rpcTimeout);
// Push output to the stack.
drop(stack, num_inputs);
stack.emplace_back(std::move(futureIValuePtr));
} else if (rpc_op == "rpc_sync") {
// Send RPC request.
auto futureIValuePtr = dist_rpc::rpcTorchscript(
dstWorkerNameStr,
qualifiedName,
functionSchema,
userCallableStack,
rpcTimeout);
futureIValuePtr->wait();
if (futureIValuePtr->hasError()) {
// throw error if future hasError
throw std::runtime_error(futureIValuePtr->tryRetrieveErrorMessage());
} else {
auto res = futureIValuePtr->value();
// Push output to the stack.
drop(stack, num_inputs);
stack.emplace_back(std::move(res));
}
} else if (rpc_op == "rpc_remote") {
auto rrefPtr = dist_rpc::remoteTorchscript(
dstWorkerNameStr,
qualifiedName,
functionSchema,
userCallableStack,
rpcTimeout);
// Push output to the stack.
drop(stack, num_inputs);
stack.emplace_back(
c10::static_intrusive_pointer_cast<c10::RRefInterface>(rrefPtr));
} else {
throw std::runtime_error(
c10::str(rpc_op, "() is not supported in TorchScript!'"));
}
}
RegisterOperators reg_rpc_ops(
{Operator(
fmt::format(
"aten::to_here(RRef(t) self, float timeout = {}) -> t(*)",
torch::distributed::rpc::kDefaultRpcTimeoutSeconds),
[](Stack& stack) {
auto timeout = pop(stack).toDouble();
auto rref = pop(stack).toRRef();
IValue res;
if (rref->isOwner()) {
res =
c10::dynamic_intrusive_pointer_cast<dist_rpc::OwnerRRef>(rref)
->getValue();
} else {
res = c10::dynamic_intrusive_pointer_cast<dist_rpc::UserRRef>(rref)
->toHere(timeout);
}
push(stack, std::move(res));
},
aliasAnalysisFromSchema()),
Operator(
"aten::local_value(RRef(t) self) -> t(*)",
[](Stack& stack) {
auto rref = pop(stack).toRRef();
TORCH_CHECK(
rref->isOwner(),
"Can't call RRef.local_value() on a non-owner RRef.");
IValue res =
c10::static_intrusive_pointer_cast<dist_rpc::OwnerRRef>(rref)
->getValue();
push(stack, std::move(res));
},
aliasAnalysisFromSchema()),
Operator(
"aten::is_owner(RRef(t) self) -> bool",
[](Stack& stack) {
auto rref = pop(stack).toRRef();
push(stack, rref->isOwner());
},
aliasAnalysisFromSchema()),
Operator(
"aten::owner(RRef(t) self) -> __torch__.torch.classes.dist_rpc.WorkerInfo",
[](Stack& stack) {
auto rref = pop(stack).toRRef();
push(
stack,
torch::make_custom_class<distributed::rpc::WorkerInfo>(
rref->ownerName(), rref->owner()));
},
aliasAnalysisFromSchema()),
Operator(
"aten::owner_name(RRef(t) self) -> str",
[](Stack& stack) {
auto rref = pop(stack).toRRef();
push(stack, rref->ownerName());
},
aliasAnalysisFromSchema()),
Operator(
"aten::confirmed_by_owner(RRef(t) self) -> bool",
[](Stack& stack) {
auto rref = pop(stack).toRRef();
push(stack, rref->confirmedByOwner());
},
aliasAnalysisFromSchema()),
Operator(
"aten::dist_backward(int context_id, Tensor[] roots, bool retain_graph=False) -> ()",
[](Stack& stack) {
bool retain_graph = pop(stack).toBool();
auto roots_list = pop(stack).toTensorList();
int64_t context_id = pop(stack).toInt();
torch::autograd::variable_list roots(
roots_list.begin(), roots_list.end());
dist_autograd::backward(context_id, roots, retain_graph);
},
aliasAnalysisConservative()),
Operator(
prim::rpc_sync,
[](const Node* node) -> Operation {
int num_inputs = node->inputs().size();
return [num_inputs](Stack& stack) {
prepare_and_call_rpc_op(stack, num_inputs, "rpc_sync");
};
},
aliasAnalysisSpecialCase()),
Operator(
prim::rpc_remote,
[](const Node* node) -> Operation {
int num_inputs = node->inputs().size();
return [num_inputs](Stack& stack) {
prepare_and_call_rpc_op(stack, num_inputs, "rpc_remote");
};
},
aliasAnalysisSpecialCase()),
Operator(
prim::rpc_async,
[](const Node* node) -> Operation {
int num_inputs = node->inputs().size();
return [num_inputs](Stack& stack) {
prepare_and_call_rpc_op(stack, num_inputs, "rpc_async");
};
},
aliasAnalysisSpecialCase())});
// Implementations located in
// torch/csrc/jit/runtime/register_distributed_ops.cpp
TORCH_LIBRARY_IMPL(aten, CatchAll, m) {
m.impl("get_gradients", [](int64_t context_id) {
const auto& autogradContext =
dist_autograd::DistAutogradContainer::getInstance().retrieveContext(
context_id);
return autogradContext->getGradients();
});
}
} // namespace
} // namespace torch::jit