[profiler] Allow record_function ctx manager to profile futures (#35055)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35055

This is the first step to improving the way RPCs are profiled as suggested by Ilia. For now, since RPC can return two different types of futures, we have to implement two different code paths, one for the python eager mode future and one for the jit future.

This diff implements the python eager part. We have defined a method `_call_end_callbacks_on_future` that takes in a future and schedules a `RecordFunction` to be completed as a callback on the future.

Once https://github.com/pytorch/pytorch/pull/35039 lands, we can implement the JIT codepath by registering an operator that takes a `Future(t)` as well.

These code paths will be merged once the futures are merged.
ghstack-source-id: 102478180

Test Plan: Added unit tests

Differential Revision: D20452003

fbshipit-source-id: 1acdcb073bd1f63d6fb2e78277ac0be00fd6671d
This commit is contained in:
Rohan Varma
2020-04-20 12:35:39 -07:00
committed by Facebook GitHub Bot
parent 1e054bfbdc
commit 752d3c281a
19 changed files with 476 additions and 130 deletions

View File

@ -366,14 +366,47 @@ class record_function(ContextDecorator):
"""
def __init__(self, name):
self.name = name
# Whether or not we should run record function's end callbacks when exiting.
self.run_callbacks_on_exit = True
def __enter__(self):
self.handle = torch.ops.profiler._record_function_enter(self.name)
return self
def __exit__(self, *args):
torch.ops.profiler._record_function_exit(self.handle)
if self.run_callbacks_on_exit:
torch.ops.profiler._record_function_exit(self.handle)
return False
def _call_end_callbacks_on_future(self, fut):
"""
_call_end_callbacks_on_future is meant to be used for profiling async
calls that return a future. Calling this function will extend recording
beyond this scope, until the future is satisfied. It is useful for profiling
the end to end time of asynchronous calls. This function should only be called
once to attach the callback onto the future, and will throw if called multiple
times.
Arguments:
fut: (torch.distributed.rpc.Future or torch._C.Future): future for which to schedule
callback for.
"""
# Throw if we have already attached a callback onto the future.
if not self.run_callbacks_on_exit:
raise RuntimeError("_call_end_callbacks_on_future can only be called once.")
# We are scheduling to run this RecordFunction's end callbacks when the
# passed in future completes, so don't run end callbacks on exit.
self.run_callbacks_on_exit = False
# TODO: Currently, we have two different futures that can be returned,
# thus, two different code paths. We should clean this up when the
# futures are merged and rpc_async returns a consistent type (https://github.com/pytorch/pytorch/issues/34999).
if isinstance(fut, torch.distributed.rpc.Future):
torch.autograd._call_end_callbacks_on_fut(self.handle, fut)
else:
# jit Future, call jit operator
torch.ops.profiler._call_end_callbacks_on_jit_fut(self.handle, fut)
class emit_nvtx(object):
"""Context manager that makes every autograd operation emit an NVTX range.

View File

@ -5,8 +5,12 @@
#include <torch/csrc/autograd/grad_mode.h>
#include <ATen/autocast_mode.h>
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/autograd/record_function_ops.h>
#include <torch/csrc/autograd/python_function.h>
#include <torch/csrc/autograd/function.h>
#ifdef USE_DISTRIBUTED
#include <torch/csrc/distributed/rpc/message.h>
#endif
PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
using namespace torch::autograd::profiler;
@ -52,10 +56,16 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
m.def("_enable_profiler", enableProfiler);
m.def("_disable_profiler", disableProfiler);
m.def("_profiler_enabled", profilerEnabled);
m.def("_run_before_callbacks", _runBeforeCallbacks);
py::class_<RecordFunction, std::shared_ptr<RecordFunction>>(m, "_RecordFunction")
.def(py::init<>());
// TODO: remove when jit future can hold PyObject (https://github.com/pytorch/pytorch/issues/34999)
#ifdef USE_DISTRIBUTED
m.def(
"_call_end_callbacks_on_fut",
[](const at::Tensor& handle,
const std::shared_ptr<torch::distributed::rpc::FutureMessage>& fut) {
torch::autograd::profiler::_call_end_callbacks_on_fut(handle, fut);
});
#endif
Py_RETURN_TRUE;
}

View File

@ -1,3 +1,4 @@
#include <torch/csrc/autograd/record_function_ops.h>
#include <ATen/cpp_custom_type_hack.h>
#include <torch/csrc/autograd/record_function.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
@ -16,7 +17,7 @@ at::Tensor record_function_enter(const std::string& name) {
auto rec = std::make_unique<RecordFunction>(RecordScope::USER_SCOPE);
// Only add new scope if profiling is enabled.
if (auto* current = rec->current()) {
AT_ASSERT(
TORCH_INTERNAL_ASSERT(
current->name() == StringView("profiler::_record_function_enter"));
// RecordFunction requires parent_ to be alive for it's entire lifetime.
// Since the currently active RecordFunction will only live for the lifetime
@ -28,23 +29,68 @@ at::Tensor record_function_enter(const std::string& name) {
return at::cpp_custom_type_hack::create(std::move(rec), at::TensorOptions());
}
RecordFunction& getRecordFunctionFromTensor(const at::Tensor& handle) {
auto& rec = at::cpp_custom_type_hack::cast<RecordFunction>(handle);
return rec;
}
void record_function_exit(const at::Tensor& handle) {
// We don't actually need to do anything with handle just need to persist the
// lifetime until now.
auto& rec = at::cpp_custom_type_hack::cast<RecordFunction>(handle);
auto& rec = getRecordFunctionFromTensor(handle);
if (auto* current = rec.current()) {
AT_ASSERT(
current->name() == StringView("profiler::_record_function_exit"));
TORCH_INTERNAL_ASSERT(current->name() == StringView("profiler::_record_function_exit"));
current->_end();
}
rec._end();
}
// Same as _call_end_callbacks_on_fut but takes an ivalue future.
// TODO: once python and JIT futures are merged, consolidate this with
// call_end_callbacks_on_fut (https://github.com/pytorch/pytorch/issues/34999).
void _call_end_callbacks_on_jit_fut(
const at::Tensor& handle,
const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
// Add a callback onto the future to mark run RecordFunction's end callbacks
// when the future is completed.
fut->addCallback(
// Copy handle by value to persist after the python context manager is
// exited.
[handle]() {
TORCH_INTERNAL_ASSERT(
handle.defined(),
"Undefined RecordFunction handle. This can happen if the handle is "
"not correctly persisted and is destroyed before the future is "
"realized.");
auto& rec = getRecordFunctionFromTensor(handle);
rec._end();
});
}
// Internal only, do not use directly, use Python's record_function()
static auto registry =
RegisterOperators()
.op("profiler::_record_function_enter", &record_function_enter)
.op("profiler::_record_function_exit", &record_function_exit);
// Needed to register JIT operator in operator registry below
c10::AliasAnalysisKind aliasAnalysisFromSchema() {
return c10::AliasAnalysisKind::FROM_SCHEMA;
}
jit::RegisterOperators reg_fut_ops({
jit::Operator(
"profiler::_call_end_callbacks_on_jit_fut(Tensor x, Future(t) y) -> ()",
[](jit::Stack& stack) {
// Pop inputs, which should be a future and a tensor
auto fut = jit::pop(stack).toFuture();
auto tensor = jit::pop(stack).toTensor();
_call_end_callbacks_on_jit_fut(tensor, fut);
return 0;
},
aliasAnalysisFromSchema()),
});
} // namespace profiler
} // namespace autograd
} // namespace torch

View File

@ -0,0 +1,44 @@
#pragma once
#include <torch/csrc/autograd/record_function.h>
#include <torch/csrc/utils/future.h>
namespace torch {
namespace autograd {
namespace profiler {
// Creates a new profiling scope using RecordFunction and invokes its starting
// callbacks.
at::Tensor record_function_enter(const std::string& name);
// Cast Tensor that was created with at::cpp_custom_type_hack back to
// RecordFunction. This is a temporary workaround until RecordFunction is
// registered as a custom C++ class
// (https://github.com/pytorch/pytorch/issues/35026).
TORCH_API RecordFunction& getRecordFunctionFromTensor(const at::Tensor& handle);
// Schedules RecordFunction's end callbacks to be run on completion of a future.
template <typename T>
void _call_end_callbacks_on_fut(
const at::Tensor& handle,
const std::shared_ptr<torch::utils::Future<T>> fut) {
// Add a callback onto the future to mark run RecordFunction's end callbacks
// when the future is completed.
fut->addCallback(
// Copy handle by value to persist after the python context manager is
// exited.
[handle]() {
TORCH_INTERNAL_ASSERT(
handle.defined(),
"Undefined RecordFunction handle. This can happen if the handle is "
"not correctly persisted and is destroyed before the future is "
"realized.");
auto& rec = getRecordFunctionFromTensor(handle);
rec._end();
});
}
// Ends the profiling scope created with record_function_enter.
void record_function_exit(const at::Tensor& handle);
} // namespace profiler
} // namespace autograd
} // namespace torch

View File

@ -115,24 +115,14 @@ std::shared_ptr<FutureMessage> sendMessageWithAutograd(
RpcAgent& agent,
const WorkerInfo& dst,
torch::distributed::rpc::Message&& wrappedRpcMsg,
bool forceGradRecording,
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf) {
bool forceGradRecording) {
auto msg = getMessageWithAutograd(
dst.id_,
std::move(wrappedRpcMsg),
MessageType::FORWARD_AUTOGRAD_REQ,
forceGradRecording);
auto fut = agent.send(dst, std::move(msg));
if (rf != nullptr) {
// Add a callback to
// the future that captures the RecordFunction to persist it for the
// lifetime of the future. When the future is completed, this will run the
// end() callbacks associated with the RecordFunction, so that async RPCs
// can be profiled correctly.
fut->addCallback([rf]() { rf->_end(); });
}
return fut;
return agent.send(dst, std::move(msg));
}
} // namespace autograd

View File

@ -48,9 +48,7 @@ sendMessageWithAutograd(
rpc::RpcAgent& agent,
const rpc::WorkerInfo& dst,
rpc::Message&& wrappedRpcMsg,
bool forceGradRecording = false,
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf =
nullptr);
bool forceGradRecording = false);
} // namespace autograd
} // namespace distributed

View File

@ -239,6 +239,15 @@ PyObject* rpc_init(PyObject* /* unused */) {
"_deserialize",
&PyRRef::unpickle,
py::call_guard<py::gil_scoped_release>())
.def(
"_get_future",
&PyRRef::getFuture,
py::call_guard<py::gil_scoped_release>(),
R"(
Returns the future that corresponds to the creation of this RRef
on the remote node. This is for internal use cases such as profiling
only.
)")
// not releasing GIL to avoid context switch
.def("__str__", &PyRRef::str);
@ -404,11 +413,10 @@ If the future completes with an error, an exception is thrown.
"_invoke_rpc_builtin",
[](const WorkerInfo& dst,
const std::string& opName,
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf,
const py::args& args,
const py::kwargs& kwargs) {
DCHECK(PyGILState_Check());
return pyRpcBuiltin(dst, opName, rf, args, kwargs);
return pyRpcBuiltin(dst, opName, args, kwargs);
},
py::call_guard<py::gil_scoped_acquire>());
@ -416,22 +424,19 @@ If the future completes with an error, an exception is thrown.
"_invoke_rpc_python_udf",
[](const WorkerInfo& dst,
std::string& pickledPythonUDF,
std::vector<torch::Tensor>& tensors,
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf) {
std::vector<torch::Tensor>& tensors) {
DCHECK(!PyGILState_Check());
return pyRpcPythonUdf(dst, pickledPythonUDF, tensors, rf);
return pyRpcPythonUdf(dst, pickledPythonUDF, tensors);
},
py::call_guard<py::gil_scoped_release>(),
py::arg("dst"),
py::arg("pickledPythonUDF"),
py::arg("tensors"),
py::arg("rf") = nullptr);
py::arg("tensors"));
module.def(
"_invoke_rpc_torchscript",
[](const std::string& dstWorkerName,
const std::string& qualifiedNameStr,
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf,
const py::tuple& argsTuple,
const py::dict& kwargsDict) {
// No need to catch exception here, if function can not be found,
@ -455,8 +460,8 @@ If the future completes with an error, an exception is thrown.
c10::nullopt);
}
DCHECK(!PyGILState_Check());
c10::intrusive_ptr<c10::ivalue::Future> fut = rpcTorchscript(
dstWorkerName, qualifiedName, functionSchema, stack, rf);
c10::intrusive_ptr<c10::ivalue::Future> fut =
rpcTorchscript(dstWorkerName, qualifiedName, functionSchema, stack);
return torch::jit::PythonFutureWrapper(fut);
},
py::call_guard<py::gil_scoped_release>());
@ -465,11 +470,10 @@ If the future completes with an error, an exception is thrown.
"_invoke_remote_builtin",
[](const WorkerInfo& dst,
const std::string& opName,
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf,
const py::args& args,
const py::kwargs& kwargs) {
DCHECK(PyGILState_Check());
return pyRemoteBuiltin(dst, opName, rf, args, kwargs);
return pyRemoteBuiltin(dst, opName, args, kwargs);
},
py::call_guard<py::gil_scoped_acquire>());
@ -477,7 +481,6 @@ If the future completes with an error, an exception is thrown.
"_invoke_remote_torchscript",
[](const std::string& dstWorkerName,
const std::string& qualifiedNameStr,
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf,
const py::args& args,
const py::kwargs& kwargs) {
DCHECK(!PyGILState_Check());
@ -495,7 +498,7 @@ If the future completes with an error, an exception is thrown.
}
DCHECK(!PyGILState_Check());
auto rrefPtr = remoteTorchscript(
dstWorkerName, qualifiedName, functionSchema, stack, rf);
dstWorkerName, qualifiedName, functionSchema, stack);
return PyRRef(rrefPtr);
},
py::call_guard<py::gil_scoped_release>());
@ -504,16 +507,14 @@ If the future completes with an error, an exception is thrown.
"_invoke_remote_python_udf",
[](const WorkerInfo& dst,
std::string& pickledPythonUDF,
std::vector<torch::Tensor>& tensors,
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf) {
std::vector<torch::Tensor>& tensors) {
DCHECK(!PyGILState_Check());
return pyRemotePythonUdf(dst, pickledPythonUDF, tensors, rf);
return pyRemotePythonUdf(dst, pickledPythonUDF, tensors);
},
py::call_guard<py::gil_scoped_release>(),
py::arg("dst"),
py::arg("pickledPythonUDF"),
py::arg("tensors"),
py::arg("rf") = nullptr);
py::arg("tensors"));
module.def(
"get_rpc_timeout",

View File

@ -112,6 +112,10 @@ PyRRef::PyRRef(const py::object& value, const py::object& type_hint)
return rref;
}()) {}
const std::shared_ptr<FutureMessage> PyRRef::getFuture() const {
return rref_->getOwnerCreationFuture();
}
bool PyRRef::isOwner() const {
return rref_->isOwner();
}

View File

@ -25,6 +25,10 @@ class PyRRef {
py::tuple pickle() const;
static PyRRef unpickle(const py::tuple& t);
c10::IValue toIValue();
// Future that is associated with the creation of this RRef on the remote end.
// This is only used to get the future corresponding to the rref for profiling
// use cases.
const std::shared_ptr<FutureMessage> getFuture() const;
private:
c10::intrusive_ptr<RRef> rref_;

View File

@ -63,8 +63,7 @@ std::shared_ptr<FutureMessage> sendPythonRemoteCall(
const WorkerInfo& dst,
SerializedPyObj serializedPyObj,
const IValue& rrefId,
const IValue& forkId,
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf) {
const IValue& forkId) {
auto pythonRemoteCall = std::make_unique<PythonRemoteCall>(
std::move(serializedPyObj), rrefId, forkId);
@ -75,8 +74,7 @@ std::shared_ptr<FutureMessage> sendPythonRemoteCall(
*agent,
dst,
std::move(*pythonRemoteCall).toMessage(),
true /*forceGradRecording*/,
rf);
true /*forceGradRecording*/);
}
} // namespace
@ -119,7 +117,6 @@ py::object toPyObj(const Message& message) {
std::shared_ptr<FutureMessage> pyRpcBuiltin(
const WorkerInfo& dst,
const std::string& opName,
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf,
const py::args& args,
const py::kwargs& kwargs) {
Stack stack;
@ -129,13 +126,12 @@ std::shared_ptr<FutureMessage> pyRpcBuiltin(
auto scriptCall = std::make_unique<ScriptCall>(op, std::move(stack));
auto agent = RpcAgent::getCurrentRpcAgent();
return sendMessageWithAutograd(
*agent, dst, std::move(*scriptCall).toMessage(), false, rf);
*agent, dst, std::move(*scriptCall).toMessage(), false);
}
PyRRef pyRemoteBuiltin(
const WorkerInfo& dst,
const std::string& opName,
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf,
const py::args& args,
const py::kwargs& kwargs) {
Stack stack;
@ -154,8 +150,9 @@ PyRRef pyRemoteBuiltin(
op, std::move(stack), userRRef->rrefId(), userRRef->forkId());
auto fm = sendMessageWithAutograd(
*agent, dst, std::move(*scriptRemoteCall).toMessage(), false, rf);
*agent, dst, std::move(*scriptRemoteCall).toMessage(), false);
userRRef->registerOwnerCreationFuture(fm);
ctx.addPendingUser(userRRef->forkId(), userRRef);
fm->addCallback([forkId{userRRef->forkId()}](const FutureMessage& fm) {
callback::confirmPendingUser(fm, forkId);
@ -169,7 +166,9 @@ PyRRef pyRemoteBuiltin(
auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>(
op, std::move(stack), ownerRRef->rrefId(), ownerRRef->rrefId());
auto fm = sendMessageWithAutograd(
*agent, dst, std::move(*scriptRemoteCall).toMessage(), false, rf);
*agent, dst, std::move(*scriptRemoteCall).toMessage(), false);
ownerRRef->registerOwnerCreationFuture(fm);
// Builtin operators does not return py::object, and hence does not require
// GIL for destructing the potentially deleted OwerRRef.
@ -182,8 +181,7 @@ PyRRef pyRemoteBuiltin(
std::shared_ptr<FutureMessage> pyRpcPythonUdf(
const WorkerInfo& dst,
std::string& pickledPythonUDF,
std::vector<torch::Tensor>& tensors,
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf) {
std::vector<torch::Tensor>& tensors) {
auto serializedPyObj =
SerializedPyObj(std::move(pickledPythonUDF), std::move(tensors));
auto pythonCall = std::make_unique<PythonCall>(std::move(serializedPyObj));
@ -193,15 +191,13 @@ std::shared_ptr<FutureMessage> pyRpcPythonUdf(
*agent,
dst,
std::move(*pythonCall).toMessage(),
true /*forceGradRecording*/,
rf);
true /*forceGradRecording*/);
}
PyRRef pyRemotePythonUdf(
const WorkerInfo& dst,
std::string& pickledPythonUDF,
std::vector<torch::Tensor>& tensors,
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf) {
std::vector<torch::Tensor>& tensors) {
auto& ctx = RRefContext::getInstance();
auto serializedPyObj =
SerializedPyObj(std::move(pickledPythonUDF), std::move(tensors));
@ -211,8 +207,9 @@ PyRRef pyRemotePythonUdf(
dst,
std::move(serializedPyObj),
userRRef->rrefId().toIValue(),
userRRef->forkId().toIValue(),
rf);
userRRef->forkId().toIValue());
userRRef->registerOwnerCreationFuture(fm);
ctx.addPendingUser(userRRef->forkId(), userRRef);
fm->addCallback([forkId{userRRef->forkId()}](const FutureMessage& fm) {
@ -227,8 +224,9 @@ PyRRef pyRemotePythonUdf(
dst,
std::move(serializedPyObj),
ownerRRef->rrefId().toIValue(),
ownerRRef->rrefId().toIValue(),
rf);
ownerRRef->rrefId().toIValue());
ownerRRef->registerOwnerCreationFuture(fm);
fm->addCallback([](const FutureMessage& fm) {
auto deletedRRef = callback::finishCreatingOwnerRRef(fm);

View File

@ -1,6 +1,5 @@
#pragma once
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/distributed/rpc/py_rref.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/utils/pybind.h>
@ -14,28 +13,24 @@ py::object toPyObj(const Message& message);
std::shared_ptr<FutureMessage> pyRpcBuiltin(
const WorkerInfo& dst,
const std::string& opName,
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf,
const py::args& args,
const py::kwargs& kwargs);
std::shared_ptr<FutureMessage> pyRpcPythonUdf(
const WorkerInfo& dst,
std::string& pickledPythonUDF,
std::vector<torch::Tensor>& tensors,
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf);
std::vector<torch::Tensor>& tensors);
PyRRef pyRemoteBuiltin(
const WorkerInfo& dst,
const std::string& opName,
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf,
const py::args& args,
const py::kwargs& kwargs);
PyRRef pyRemotePythonUdf(
const WorkerInfo& dst,
std::string& pickledPythonUDF,
std::vector<torch::Tensor>& tensors,
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf);
std::vector<torch::Tensor>& tensors);
} // namespace rpc
} // namespace distributed

View File

@ -219,6 +219,19 @@ class TORCH_API RRef : public RRefInterface {
return type_;
}
// Save the future corresponding to the creation of this RRef on a remote
// node. Note that this is only set when processing requests invoked with
// rpc.remote. This is only used to get the future corresponding to the rref
// for profiling use cases.
inline void registerOwnerCreationFuture(std::shared_ptr<FutureMessage> fut) {
ownerCreationFuture_ = std::move(fut);
}
// Get the future corresponding to the creation of this rref.
inline std::shared_ptr<FutureMessage> getOwnerCreationFuture() const {
return ownerCreationFuture_;
}
// Send delete UserRRef request to Owner,
// if the request hasn't been sent yet.
// There are 2 cases to call it,
@ -240,6 +253,8 @@ class TORCH_API RRef : public RRefInterface {
// type field to denote the type of the element that the RRef is holding
// it could be any TypePtr that JIT support, including PyObjectType
const TypePtr type_;
// Future corresponding to request to create RRef on remote node.
std::shared_ptr<FutureMessage> ownerCreationFuture_;
};
// ``UserRRef`` represents a user of an RRef. Besides the ``RRefId``, each user

View File

@ -15,8 +15,7 @@ c10::intrusive_ptr<c10::ivalue::Future> rpcTorchscript(
const std::string& dstWorkerName,
const c10::QualifiedName& qualifiedName,
const c10::FunctionSchema& functionSchema,
std::vector<c10::IValue>& stack,
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf) {
std::vector<c10::IValue>& stack) {
auto scriptCall =
std::make_unique<ScriptCall>(qualifiedName, std::move(stack));
auto rpcAgentPtr = RpcAgent::getCurrentRpcAgent();
@ -24,8 +23,7 @@ c10::intrusive_ptr<c10::ivalue::Future> rpcTorchscript(
*rpcAgentPtr,
rpcAgentPtr->getWorkerInfo(dstWorkerName),
std::move(*scriptCall).toMessage(),
true /*forceGradRecording*/,
rf);
true /*forceGradRecording*/);
// Get function return type to construct c10::ivalue::Future.
auto returns = functionSchema.returns();
@ -55,8 +53,7 @@ c10::intrusive_ptr<RRef> remoteTorchscript(
const std::string& dstWorkerName,
const c10::QualifiedName& qualifiedName,
const c10::FunctionSchema& functionSchema,
std::vector<c10::IValue>& stack,
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf) {
std::vector<c10::IValue>& stack) {
auto rpcAgentPtr = RpcAgent::getCurrentRpcAgent();
auto dstWorkerInfo = rpcAgentPtr->getWorkerInfo(dstWorkerName);
auto& ctx = RRefContext::getInstance();
@ -84,8 +81,9 @@ c10::intrusive_ptr<RRef> remoteTorchscript(
*rpcAgentPtr,
dstWorkerInfo,
std::move(*scriptRemoteCall).toMessage(),
true /*forceGradRecording*/,
rf);
true /*forceGradRecording*/);
userRRefPtr->registerOwnerCreationFuture(fm);
ctx.addPendingUser(userRRefPtr->forkId(), userRRefPtr);
fm->addCallback([forkId{userRRefPtr->forkId()}](const FutureMessage& fm) {
@ -108,8 +106,9 @@ c10::intrusive_ptr<RRef> remoteTorchscript(
*rpcAgentPtr,
dstWorkerInfo,
std::move(*scriptRemoteCall).toMessage(),
true /*forceGradRecording*/,
rf);
true /*forceGradRecording*/);
ownerRRefPtr->registerOwnerCreationFuture(fm);
fm->addCallback(
[](const FutureMessage& fm) { callback::finishCreatingOwnerRRef(fm); });

View File

@ -24,17 +24,13 @@ c10::intrusive_ptr<c10::ivalue::Future> TORCH_API rpcTorchscript(
const std::string& dstWorkerName,
const c10::QualifiedName& qualifiedName,
const c10::FunctionSchema& functionSchema,
std::vector<c10::IValue>& stack,
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf =
nullptr);
std::vector<c10::IValue>& stack);
c10::intrusive_ptr<RRef> TORCH_API remoteTorchscript(
const std::string& dstWorkerName,
const c10::QualifiedName& qualifiedName,
const c10::FunctionSchema& functionSchema,
std::vector<c10::IValue>& stack,
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf =
nullptr);
std::vector<c10::IValue>& stack);
} // namespace rpc
} // namespace distributed

View File

@ -33,7 +33,7 @@ from .internal import (
PythonUDF,
RPCExecMode,
_internal_rpc_pickler,
_start_record_function,
_build_rpc_profiling_key,
)
@ -423,11 +423,10 @@ def remote(to, func, args=None, kwargs=None):
"""
qualified_name = torch.jit._find_builtin(func)
dst_worker_info = _to_worker_info(to)
should_profile = torch.autograd._profiler_enabled()
# If profiling is enabled, kick off the timer and retrieve back a
# RecordFunction instance.
rf = None
if torch.autograd._profiler_enabled():
ctx_manager = contextlib.suppress()
if should_profile:
# Create appropriate string representation based on type of func
# (builtin, script, python)
if qualified_name is None:
@ -438,28 +437,39 @@ def remote(to, func, args=None, kwargs=None):
)
else:
func_name = qualified_name
rf = _start_record_function(
# Build RPC profiling key.
rpc_profiling_key = _build_rpc_profiling_key(
RPCExecMode.REMOTE,
func_name,
get_worker_info().name,
dst_worker_info.name,
)
ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key)
args = args if args else ()
kwargs = kwargs if kwargs else {}
if qualified_name is not None:
return _invoke_remote_builtin(dst_worker_info, qualified_name, rf, *args, **kwargs)
elif isinstance(func, torch.jit.ScriptFunction):
return _invoke_remote_torchscript(
dst_worker_info.name, torch._jit_internal._qualified_name(func), rf, *args, **kwargs
)
else:
(pickled_python_udf, tensors) = _default_pickler.serialize(
PythonUDF(func, args, kwargs)
)
return _invoke_remote_python_udf(dst_worker_info, pickled_python_udf, tensors, rf)
with ctx_manager as rf:
args = args if args else ()
kwargs = kwargs if kwargs else {}
if qualified_name is not None:
rref = _invoke_remote_builtin(dst_worker_info, qualified_name, *args, **kwargs)
elif isinstance(func, torch.jit.ScriptFunction):
rref = _invoke_remote_torchscript(
dst_worker_info.name,
torch._jit_internal._qualified_name(func),
*args,
**kwargs
)
else:
(pickled_python_udf, tensors) = _default_pickler.serialize(
PythonUDF(func, args, kwargs)
)
rref = _invoke_remote_python_udf(dst_worker_info, pickled_python_udf, tensors)
# attach profiling information
if should_profile:
assert torch.autograd._profiler_enabled()
assert rf is not None
rf._call_end_callbacks_on_future(rref._get_future())
return rref
def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None):
if not callable(func):
@ -467,10 +477,13 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None):
qualified_name = torch.jit._find_builtin(func)
dst_worker_info = _to_worker_info(to)
# If profiling is enabled, kick off the timer and retrieve back a
# RecordFunction instance.
rf = None
if torch.autograd._profiler_enabled():
# TODO: profiling logic does not really belong in invoke_rpc, it should be
# added as part of a context manager or helper (https://github.com/pytorch/pytorch/issues/36360)
should_profile = torch.autograd._profiler_enabled()
ctx_manager = contextlib.suppress()
if should_profile:
# Create appropriate string representation based on type of func
# (builtin, script, python)
if qualified_name is None:
@ -481,27 +494,35 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None):
)
else:
func_name = qualified_name
rf = _start_record_function(
# Build RPC profiling key.
rpc_profiling_key = _build_rpc_profiling_key(
rpc_type,
func_name,
get_worker_info().name,
dst_worker_info.name,
)
ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key)
args = args if args else ()
kwargs = kwargs if kwargs else {}
with ctx_manager as rf:
args = args if args else ()
kwargs = kwargs if kwargs else {}
if qualified_name is not None:
fut = _invoke_rpc_builtin(dst_worker_info, qualified_name, rf, *args, **kwargs)
elif isinstance(func, torch.jit.ScriptFunction):
fut = _invoke_rpc_torchscript(
dst_worker_info.name, torch.jit._qualified_name(func), rf, args, kwargs
)
else:
(pickled_python_udf, tensors) = _default_pickler.serialize(
PythonUDF(func, args, kwargs)
)
fut = _invoke_rpc_python_udf(dst_worker_info, pickled_python_udf, tensors, rf)
if qualified_name is not None:
fut = _invoke_rpc_builtin(dst_worker_info, qualified_name, *args, **kwargs)
elif isinstance(func, torch.jit.ScriptFunction):
fut = _invoke_rpc_torchscript(
dst_worker_info.name, torch.jit._qualified_name(func), args, kwargs
)
else:
(pickled_python_udf, tensors) = _default_pickler.serialize(
PythonUDF(func, args, kwargs)
)
fut = _invoke_rpc_python_udf(dst_worker_info, pickled_python_udf, tensors)
if should_profile:
assert torch.autograd._profiler_enabled()
assert rf is not None
# Schedule profiling callbacks to run when the future completes.
rf._call_end_callbacks_on_future(fut)
return fut

View File

@ -14,7 +14,6 @@ import torch.distributed as dist
# objects
_thread_local_tensor_tables = threading.local()
class RPCExecMode(Enum):
SYNC = "sync"
ASYNC = "async"
@ -162,6 +161,28 @@ def _handle_exception(result):
if isinstance(result, RemoteException):
raise result.exception_type(result.msg)
def _build_rpc_profiling_key(exec_type, func_name, current_worker_name, dst_worker_name):
"""
Builds the key that RPC calls are profiled with using the autograd profiler.
This will be the name of the corresponding Event recorded in the profiler.
Arguments:
exec_type (RPCExecMode): Type of RPC/RRef call
func_name (str): Name of function being profiled.
current_worker_name (str): Name of current worker.
dst_worker_name (str): Name of the destination worker.
Returns:
String representing profiling key
"""
profile_key = "rpc_{rpc_type}#{func_name}({current_worker} -> {dst_worker})".format(
rpc_type=exec_type.value,
func_name=func_name,
current_worker=current_worker_name,
dst_worker=dst_worker_name,
)
return profile_key
def _start_record_function(exec_type, func_name, current_worker_name, dest_worker_name):
"""

View File

@ -188,3 +188,16 @@ def initialize_pg(init_method, rank, world_size):
def worker_name(rank):
return "worker{}".format(rank)
def get_function_event(function_events, partial_event_name):
"""
Returns the first event that matches partial_event_name in the provided
function_events. These function_events should be the output of
torch.autograd.profiler.function_events().
Args:
function_events: function_events returned by the profiler.
event_name (str): partial key that the event was profiled with.
"""
event = [event for event in function_events if partial_event_name in event.name][0]
return event

View File

@ -2,15 +2,24 @@ import unittest
from typing import Dict, Tuple
import torch
import time
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch.distributed.rpc.internal import _build_rpc_profiling_key, RPCExecMode
from torch import Tensor
from torch.testing._internal.common_utils import TemporaryFileName
from torch.testing._internal.dist_utils import dist_init, initialize_pg, worker_name
from torch.testing._internal.dist_utils import (
dist_init,
initialize_pg,
worker_name,
get_function_event
)
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
RpcAgentTestFixture,
)
def sleep(t):
time.sleep(t)
def rpc_return_rref(dst):
return rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1))
@ -25,7 +34,6 @@ def return_value(value):
# type: (int) -> int
return value
class RRefAPITest:
@dist_init
def test_rref_is_owner(self):
@ -96,6 +104,28 @@ def script_fork_wait_throw(invalue):
value = torch.jit._wait(fut)
return value
@torch.jit.script
def call_rpc_with_profiling(handle: Tensor, dst_worker_name: str) -> Tensor:
# Call rpc_async from within ScriptFunction and ensure that we can attach
# profiling callbacks. Note that handle here is a Tensor representation of
# RecordFunction.
fut = rpc.rpc_async(dst_worker_name, one_arg, (torch.tensor(1),))
torch.ops.profiler._call_end_callbacks_on_jit_fut(handle, fut)
ret = fut.wait()
return ret
@torch.jit.script
def call_fork_with_profiling(handle: Tensor) -> Tensor:
# Call fork from within ScriptFunction and ensure that we can attach profiling
# callbacks to the resulting future. Note that handle here is a Tensor
# representation of RecordFunction.
fut = torch.jit._fork(one_arg, torch.tensor(1))
torch.ops.profiler._call_end_callbacks_on_jit_fut(handle, fut)
ret = fut.wait()
return ret
class MyScriptModuleWithRRefs(torch.jit.ScriptModule):
def __init__(self, dst_worker):
super().__init__()
@ -884,3 +914,53 @@ class JitRpcTest(RRefAPITest, LocalRRefTest, JitRpcAsyncOpTest, RpcAgentTestFixt
args=(torch.ones(2),))
with self.assertRaisesRegex(Exception, ".*Expected error.*"):
future.wait()
@dist_init
def test_call_rpc_with_profiling(self):
# Ensures that we can call torch.ops.profiler._call_end_callbacks_on_jit_fut on a jit
# future from within a script function that calls rpc_async
if self.rank == 0:
with torch.autograd.profiler.profile() as prof:
prof_key = _build_rpc_profiling_key(
RPCExecMode.ASYNC,
torch.jit._qualified_name(one_arg),
"worker0",
"worker1",
)
with torch.autograd.profiler.record_function(prof_key) as rf:
ret = call_rpc_with_profiling(rf.handle, "worker1")
# TODO: Can't get a reliable time for this profiling event since
# it's hard to estimate the execution time on the remote end for non-UDFs.
# This can be resolved by https://github.com/pytorch/pytorch/issues/36272.
# After that, this test should be modified to validate the function time.
events = prof.function_events
function_event = get_function_event(events, prof_key)
self.assertTrue(torch.jit._qualified_name(one_arg) in function_event.name)
def test_record_function_jit_end_callbacks_with_fork(self):
# Ensures that we can call rf._call_end_callbacks_on_future on a jit
# future in python eager mode with torch.jit.fork
sleep_interval = 1
with torch.autograd.profiler.profile() as prof:
with torch.autograd.profiler.record_function("foo") as rf:
fut = torch.jit._fork(sleep, sleep_interval)
rf._call_end_callbacks_on_future(fut)
fut.wait()
function_events = prof.function_events
sleep_event = get_function_event(function_events, "foo")
self.assertEqual(sleep_event.name, "foo")
# Validate that callbacks were fired at the right time by checking the
# profiling event cpu time
self.assertGreaterEqual(sleep_event.cpu_time * 1e-6, sleep_interval)
def test_call_fork_in_jit_with_profiling(self):
# Ensures that we can call torch.ops.profiler._call_end_callbacks_on_jit_fut on a jit
# future from within a script function with torch.jit.fork
with torch.autograd.profiler.profile() as prof:
with torch.autograd.profiler.record_function("foo") as rf:
ret = call_fork_with_profiling(rf.handle)
events = prof.function_events
function_event = get_function_event(events, "foo")
self.assertEqual(function_event.name, "foo")

View File

@ -12,11 +12,17 @@ import torch.distributed.rpc as rpc
import torch.testing._internal.dist_utils as dist_utils
from torch.distributed.rpc import RRef, _get_debug_info, _rref_context_get_debug_info
from torch.distributed.rpc.api import _delete_all_user_rrefs, _use_rpc_pickler
from torch.distributed.rpc.internal import PythonUDF, RPCExecMode, _internal_rpc_pickler
from torch.distributed.rpc.internal import (
PythonUDF,
RPCExecMode,
_internal_rpc_pickler,
_build_rpc_profiling_key,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import IS_MACOS, load_tests
from torch.testing._internal.dist_utils import (
dist_init,
get_function_event,
get_shutdown_error_regex,
initialize_pg,
wait_until_node_failure,
@ -647,6 +653,18 @@ class RpcTest(RpcAgentTestFixture):
)
self.assertEqual(ret, my_function(n, n + 1, n + 2))
def test_build_rpc_profiling_key(self):
# Tests that the name that shows up as an Event in profiling RPCs has all
# the necessary information.
for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
rpc_profiling_key = _build_rpc_profiling_key(
exec_mode, "foo", "worker0", "worker1"
)
self.assertIn(exec_mode.value, rpc_profiling_key)
self.assertIn("foo", rpc_profiling_key)
self.assertIn("worker0", rpc_profiling_key)
self.assertIn("worker1", rpc_profiling_key)
def _profiler_test_with_rpc(self, rpc_exec_mode, func, args, use_record_function=False):
dst = (self.rank + 1) % self.world_size
# only run profiler on rank 1.
@ -676,11 +694,9 @@ class RpcTest(RpcAgentTestFixture):
record_function.__exit__()
events = prof.function_events
rpc_event = [
event for event in events if rpc_exec_mode.value in event.name
][0]
rpc_event = get_function_event(events, rpc_exec_mode.value)
if use_record_function:
scope_event = [event for event in events if "foo" in event.name][0]
scope_event = get_function_event(events, "foo")
# Since RPC call is within the scope, its CPU interval should be
# contained within foo's interval.
self.assertTrue(scope_event.cpu_interval.start < rpc_event.cpu_interval.start)
@ -788,7 +804,45 @@ class RpcTest(RpcAgentTestFixture):
use_record_function=True,
)
@dist_init
def test_async_record_function_double_end_callbacks(self):
num_sleep_seconds = 1
if self.rank == 1:
# Validate that calling the function twice results in an error.
with torch.autograd.profiler.profile() as pf:
with torch.autograd.profiler.record_function("foo") as rf:
fut = rpc.rpc_async(
worker_name(0), my_sleep_func, args=(num_sleep_seconds,)
)
rf._call_end_callbacks_on_future(fut)
with self.assertRaisesRegex(
RuntimeError, "can only be called once."
):
rf._call_end_callbacks_on_future(fut)
fut.wait()
@dist_init
def test_async_record_function_cbs_jit_call(self):
if self.rank == 1:
with torch.autograd.profiler.profile() as pf:
key = _build_rpc_profiling_key(
RPCExecMode.ASYNC,
torch.jit._qualified_name(my_script_func),
"worker1",
"worker0",
)
with torch.autograd.profiler.record_function(key) as rf:
fut = rpc.rpc_async(
worker_name(0), my_script_func, args=(torch.tensor(1),)
)
# Intentionally calling record_function internals
torch.ops.profiler._call_end_callbacks_on_jit_fut(rf.handle, fut)
fut.wait()
events = pf.function_events
rpc_event = get_function_event(
events, torch.jit._qualified_name(my_script_func)
)
self.assertTrue(torch.jit._qualified_name(my_script_func) in rpc_event.name)
@dist_init
def test_py_class_constructor(self):
@ -1384,6 +1438,30 @@ class RpcTest(RpcAgentTestFixture):
),
)
@dist_init
def test_rref_get_future(self):
# Tests that we can obtain the future corresponding to the creation of
# the RRef on remote end
if self.rank == 0:
# Builtin
rref = rpc.remote(worker_name(1), torch.add, args=(1, 1))
rref.to_here()
fut = rref._get_future()
self.assertIsInstance(fut, torch.distributed.rpc.Future)
# UDF
rref = rpc.remote(worker_name(1), foo_add, args=())
rref.to_here()
fut = rref._get_future()
self.assertIsInstance(fut, torch.distributed.rpc.Future)
# Script
rref = rpc.remote(worker_name(1), my_script_func, args=(torch.tensor(1), ))
rref.to_here()
fut = rref._get_future()
self.assertIsInstance(fut, torch.distributed.rpc.Future)
@dist_init
def test_rref_context_debug_info(self):
# This test checks local states that are modified by remote workers.