mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[profiler] Allow record_function ctx manager to profile futures (#35055)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35055 This is the first step to improving the way RPCs are profiled as suggested by Ilia. For now, since RPC can return two different types of futures, we have to implement two different code paths, one for the python eager mode future and one for the jit future. This diff implements the python eager part. We have defined a method `_call_end_callbacks_on_future` that takes in a future and schedules a `RecordFunction` to be completed as a callback on the future. Once https://github.com/pytorch/pytorch/pull/35039 lands, we can implement the JIT codepath by registering an operator that takes a `Future(t)` as well. These code paths will be merged once the futures are merged. ghstack-source-id: 102478180 Test Plan: Added unit tests Differential Revision: D20452003 fbshipit-source-id: 1acdcb073bd1f63d6fb2e78277ac0be00fd6671d
This commit is contained in:
committed by
Facebook GitHub Bot
parent
1e054bfbdc
commit
752d3c281a
@ -366,14 +366,47 @@ class record_function(ContextDecorator):
|
||||
"""
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
# Whether or not we should run record function's end callbacks when exiting.
|
||||
self.run_callbacks_on_exit = True
|
||||
|
||||
def __enter__(self):
|
||||
self.handle = torch.ops.profiler._record_function_enter(self.name)
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
torch.ops.profiler._record_function_exit(self.handle)
|
||||
if self.run_callbacks_on_exit:
|
||||
torch.ops.profiler._record_function_exit(self.handle)
|
||||
return False
|
||||
|
||||
def _call_end_callbacks_on_future(self, fut):
|
||||
"""
|
||||
_call_end_callbacks_on_future is meant to be used for profiling async
|
||||
calls that return a future. Calling this function will extend recording
|
||||
beyond this scope, until the future is satisfied. It is useful for profiling
|
||||
the end to end time of asynchronous calls. This function should only be called
|
||||
once to attach the callback onto the future, and will throw if called multiple
|
||||
times.
|
||||
|
||||
Arguments:
|
||||
fut: (torch.distributed.rpc.Future or torch._C.Future): future for which to schedule
|
||||
callback for.
|
||||
"""
|
||||
# Throw if we have already attached a callback onto the future.
|
||||
if not self.run_callbacks_on_exit:
|
||||
raise RuntimeError("_call_end_callbacks_on_future can only be called once.")
|
||||
|
||||
# We are scheduling to run this RecordFunction's end callbacks when the
|
||||
# passed in future completes, so don't run end callbacks on exit.
|
||||
self.run_callbacks_on_exit = False
|
||||
# TODO: Currently, we have two different futures that can be returned,
|
||||
# thus, two different code paths. We should clean this up when the
|
||||
# futures are merged and rpc_async returns a consistent type (https://github.com/pytorch/pytorch/issues/34999).
|
||||
if isinstance(fut, torch.distributed.rpc.Future):
|
||||
torch.autograd._call_end_callbacks_on_fut(self.handle, fut)
|
||||
else:
|
||||
# jit Future, call jit operator
|
||||
torch.ops.profiler._call_end_callbacks_on_jit_fut(self.handle, fut)
|
||||
|
||||
|
||||
class emit_nvtx(object):
|
||||
"""Context manager that makes every autograd operation emit an NVTX range.
|
||||
|
@ -5,8 +5,12 @@
|
||||
#include <torch/csrc/autograd/grad_mode.h>
|
||||
#include <ATen/autocast_mode.h>
|
||||
#include <torch/csrc/autograd/profiler.h>
|
||||
#include <torch/csrc/autograd/record_function_ops.h>
|
||||
#include <torch/csrc/autograd/python_function.h>
|
||||
#include <torch/csrc/autograd/function.h>
|
||||
#ifdef USE_DISTRIBUTED
|
||||
#include <torch/csrc/distributed/rpc/message.h>
|
||||
#endif
|
||||
|
||||
PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
|
||||
using namespace torch::autograd::profiler;
|
||||
@ -52,10 +56,16 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
|
||||
m.def("_enable_profiler", enableProfiler);
|
||||
m.def("_disable_profiler", disableProfiler);
|
||||
m.def("_profiler_enabled", profilerEnabled);
|
||||
m.def("_run_before_callbacks", _runBeforeCallbacks);
|
||||
|
||||
py::class_<RecordFunction, std::shared_ptr<RecordFunction>>(m, "_RecordFunction")
|
||||
.def(py::init<>());
|
||||
// TODO: remove when jit future can hold PyObject (https://github.com/pytorch/pytorch/issues/34999)
|
||||
#ifdef USE_DISTRIBUTED
|
||||
m.def(
|
||||
"_call_end_callbacks_on_fut",
|
||||
[](const at::Tensor& handle,
|
||||
const std::shared_ptr<torch::distributed::rpc::FutureMessage>& fut) {
|
||||
torch::autograd::profiler::_call_end_callbacks_on_fut(handle, fut);
|
||||
});
|
||||
#endif
|
||||
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
|
@ -1,3 +1,4 @@
|
||||
#include <torch/csrc/autograd/record_function_ops.h>
|
||||
#include <ATen/cpp_custom_type_hack.h>
|
||||
#include <torch/csrc/autograd/record_function.h>
|
||||
#include <torch/csrc/jit/runtime/custom_operator.h>
|
||||
@ -16,7 +17,7 @@ at::Tensor record_function_enter(const std::string& name) {
|
||||
auto rec = std::make_unique<RecordFunction>(RecordScope::USER_SCOPE);
|
||||
// Only add new scope if profiling is enabled.
|
||||
if (auto* current = rec->current()) {
|
||||
AT_ASSERT(
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
current->name() == StringView("profiler::_record_function_enter"));
|
||||
// RecordFunction requires parent_ to be alive for it's entire lifetime.
|
||||
// Since the currently active RecordFunction will only live for the lifetime
|
||||
@ -28,23 +29,68 @@ at::Tensor record_function_enter(const std::string& name) {
|
||||
return at::cpp_custom_type_hack::create(std::move(rec), at::TensorOptions());
|
||||
}
|
||||
|
||||
RecordFunction& getRecordFunctionFromTensor(const at::Tensor& handle) {
|
||||
auto& rec = at::cpp_custom_type_hack::cast<RecordFunction>(handle);
|
||||
return rec;
|
||||
}
|
||||
|
||||
void record_function_exit(const at::Tensor& handle) {
|
||||
// We don't actually need to do anything with handle just need to persist the
|
||||
// lifetime until now.
|
||||
auto& rec = at::cpp_custom_type_hack::cast<RecordFunction>(handle);
|
||||
auto& rec = getRecordFunctionFromTensor(handle);
|
||||
if (auto* current = rec.current()) {
|
||||
AT_ASSERT(
|
||||
current->name() == StringView("profiler::_record_function_exit"));
|
||||
TORCH_INTERNAL_ASSERT(current->name() == StringView("profiler::_record_function_exit"));
|
||||
current->_end();
|
||||
}
|
||||
rec._end();
|
||||
}
|
||||
|
||||
// Same as _call_end_callbacks_on_fut but takes an ivalue future.
|
||||
// TODO: once python and JIT futures are merged, consolidate this with
|
||||
// call_end_callbacks_on_fut (https://github.com/pytorch/pytorch/issues/34999).
|
||||
void _call_end_callbacks_on_jit_fut(
|
||||
const at::Tensor& handle,
|
||||
const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
|
||||
// Add a callback onto the future to mark run RecordFunction's end callbacks
|
||||
// when the future is completed.
|
||||
fut->addCallback(
|
||||
// Copy handle by value to persist after the python context manager is
|
||||
// exited.
|
||||
[handle]() {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
handle.defined(),
|
||||
"Undefined RecordFunction handle. This can happen if the handle is "
|
||||
"not correctly persisted and is destroyed before the future is "
|
||||
"realized.");
|
||||
auto& rec = getRecordFunctionFromTensor(handle);
|
||||
rec._end();
|
||||
});
|
||||
}
|
||||
|
||||
// Internal only, do not use directly, use Python's record_function()
|
||||
static auto registry =
|
||||
RegisterOperators()
|
||||
.op("profiler::_record_function_enter", &record_function_enter)
|
||||
.op("profiler::_record_function_exit", &record_function_exit);
|
||||
|
||||
// Needed to register JIT operator in operator registry below
|
||||
c10::AliasAnalysisKind aliasAnalysisFromSchema() {
|
||||
return c10::AliasAnalysisKind::FROM_SCHEMA;
|
||||
}
|
||||
|
||||
jit::RegisterOperators reg_fut_ops({
|
||||
jit::Operator(
|
||||
"profiler::_call_end_callbacks_on_jit_fut(Tensor x, Future(t) y) -> ()",
|
||||
[](jit::Stack& stack) {
|
||||
// Pop inputs, which should be a future and a tensor
|
||||
auto fut = jit::pop(stack).toFuture();
|
||||
auto tensor = jit::pop(stack).toTensor();
|
||||
_call_end_callbacks_on_jit_fut(tensor, fut);
|
||||
return 0;
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
});
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
|
44
torch/csrc/autograd/record_function_ops.h
Normal file
44
torch/csrc/autograd/record_function_ops.h
Normal file
@ -0,0 +1,44 @@
|
||||
#pragma once
|
||||
#include <torch/csrc/autograd/record_function.h>
|
||||
#include <torch/csrc/utils/future.h>
|
||||
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
namespace profiler {
|
||||
// Creates a new profiling scope using RecordFunction and invokes its starting
|
||||
// callbacks.
|
||||
at::Tensor record_function_enter(const std::string& name);
|
||||
|
||||
// Cast Tensor that was created with at::cpp_custom_type_hack back to
|
||||
// RecordFunction. This is a temporary workaround until RecordFunction is
|
||||
// registered as a custom C++ class
|
||||
// (https://github.com/pytorch/pytorch/issues/35026).
|
||||
TORCH_API RecordFunction& getRecordFunctionFromTensor(const at::Tensor& handle);
|
||||
|
||||
// Schedules RecordFunction's end callbacks to be run on completion of a future.
|
||||
template <typename T>
|
||||
void _call_end_callbacks_on_fut(
|
||||
const at::Tensor& handle,
|
||||
const std::shared_ptr<torch::utils::Future<T>> fut) {
|
||||
// Add a callback onto the future to mark run RecordFunction's end callbacks
|
||||
// when the future is completed.
|
||||
fut->addCallback(
|
||||
// Copy handle by value to persist after the python context manager is
|
||||
// exited.
|
||||
[handle]() {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
handle.defined(),
|
||||
"Undefined RecordFunction handle. This can happen if the handle is "
|
||||
"not correctly persisted and is destroyed before the future is "
|
||||
"realized.");
|
||||
auto& rec = getRecordFunctionFromTensor(handle);
|
||||
rec._end();
|
||||
});
|
||||
}
|
||||
|
||||
// Ends the profiling scope created with record_function_enter.
|
||||
void record_function_exit(const at::Tensor& handle);
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
@ -115,24 +115,14 @@ std::shared_ptr<FutureMessage> sendMessageWithAutograd(
|
||||
RpcAgent& agent,
|
||||
const WorkerInfo& dst,
|
||||
torch::distributed::rpc::Message&& wrappedRpcMsg,
|
||||
bool forceGradRecording,
|
||||
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf) {
|
||||
bool forceGradRecording) {
|
||||
auto msg = getMessageWithAutograd(
|
||||
dst.id_,
|
||||
std::move(wrappedRpcMsg),
|
||||
MessageType::FORWARD_AUTOGRAD_REQ,
|
||||
forceGradRecording);
|
||||
|
||||
auto fut = agent.send(dst, std::move(msg));
|
||||
if (rf != nullptr) {
|
||||
// Add a callback to
|
||||
// the future that captures the RecordFunction to persist it for the
|
||||
// lifetime of the future. When the future is completed, this will run the
|
||||
// end() callbacks associated with the RecordFunction, so that async RPCs
|
||||
// can be profiled correctly.
|
||||
fut->addCallback([rf]() { rf->_end(); });
|
||||
}
|
||||
return fut;
|
||||
return agent.send(dst, std::move(msg));
|
||||
}
|
||||
|
||||
} // namespace autograd
|
||||
|
@ -48,9 +48,7 @@ sendMessageWithAutograd(
|
||||
rpc::RpcAgent& agent,
|
||||
const rpc::WorkerInfo& dst,
|
||||
rpc::Message&& wrappedRpcMsg,
|
||||
bool forceGradRecording = false,
|
||||
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf =
|
||||
nullptr);
|
||||
bool forceGradRecording = false);
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace distributed
|
||||
|
@ -239,6 +239,15 @@ PyObject* rpc_init(PyObject* /* unused */) {
|
||||
"_deserialize",
|
||||
&PyRRef::unpickle,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"_get_future",
|
||||
&PyRRef::getFuture,
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
R"(
|
||||
Returns the future that corresponds to the creation of this RRef
|
||||
on the remote node. This is for internal use cases such as profiling
|
||||
only.
|
||||
)")
|
||||
// not releasing GIL to avoid context switch
|
||||
.def("__str__", &PyRRef::str);
|
||||
|
||||
@ -404,11 +413,10 @@ If the future completes with an error, an exception is thrown.
|
||||
"_invoke_rpc_builtin",
|
||||
[](const WorkerInfo& dst,
|
||||
const std::string& opName,
|
||||
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf,
|
||||
const py::args& args,
|
||||
const py::kwargs& kwargs) {
|
||||
DCHECK(PyGILState_Check());
|
||||
return pyRpcBuiltin(dst, opName, rf, args, kwargs);
|
||||
return pyRpcBuiltin(dst, opName, args, kwargs);
|
||||
},
|
||||
py::call_guard<py::gil_scoped_acquire>());
|
||||
|
||||
@ -416,22 +424,19 @@ If the future completes with an error, an exception is thrown.
|
||||
"_invoke_rpc_python_udf",
|
||||
[](const WorkerInfo& dst,
|
||||
std::string& pickledPythonUDF,
|
||||
std::vector<torch::Tensor>& tensors,
|
||||
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf) {
|
||||
std::vector<torch::Tensor>& tensors) {
|
||||
DCHECK(!PyGILState_Check());
|
||||
return pyRpcPythonUdf(dst, pickledPythonUDF, tensors, rf);
|
||||
return pyRpcPythonUdf(dst, pickledPythonUDF, tensors);
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
py::arg("dst"),
|
||||
py::arg("pickledPythonUDF"),
|
||||
py::arg("tensors"),
|
||||
py::arg("rf") = nullptr);
|
||||
py::arg("tensors"));
|
||||
|
||||
module.def(
|
||||
"_invoke_rpc_torchscript",
|
||||
[](const std::string& dstWorkerName,
|
||||
const std::string& qualifiedNameStr,
|
||||
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf,
|
||||
const py::tuple& argsTuple,
|
||||
const py::dict& kwargsDict) {
|
||||
// No need to catch exception here, if function can not be found,
|
||||
@ -455,8 +460,8 @@ If the future completes with an error, an exception is thrown.
|
||||
c10::nullopt);
|
||||
}
|
||||
DCHECK(!PyGILState_Check());
|
||||
c10::intrusive_ptr<c10::ivalue::Future> fut = rpcTorchscript(
|
||||
dstWorkerName, qualifiedName, functionSchema, stack, rf);
|
||||
c10::intrusive_ptr<c10::ivalue::Future> fut =
|
||||
rpcTorchscript(dstWorkerName, qualifiedName, functionSchema, stack);
|
||||
return torch::jit::PythonFutureWrapper(fut);
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
@ -465,11 +470,10 @@ If the future completes with an error, an exception is thrown.
|
||||
"_invoke_remote_builtin",
|
||||
[](const WorkerInfo& dst,
|
||||
const std::string& opName,
|
||||
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf,
|
||||
const py::args& args,
|
||||
const py::kwargs& kwargs) {
|
||||
DCHECK(PyGILState_Check());
|
||||
return pyRemoteBuiltin(dst, opName, rf, args, kwargs);
|
||||
return pyRemoteBuiltin(dst, opName, args, kwargs);
|
||||
},
|
||||
py::call_guard<py::gil_scoped_acquire>());
|
||||
|
||||
@ -477,7 +481,6 @@ If the future completes with an error, an exception is thrown.
|
||||
"_invoke_remote_torchscript",
|
||||
[](const std::string& dstWorkerName,
|
||||
const std::string& qualifiedNameStr,
|
||||
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf,
|
||||
const py::args& args,
|
||||
const py::kwargs& kwargs) {
|
||||
DCHECK(!PyGILState_Check());
|
||||
@ -495,7 +498,7 @@ If the future completes with an error, an exception is thrown.
|
||||
}
|
||||
DCHECK(!PyGILState_Check());
|
||||
auto rrefPtr = remoteTorchscript(
|
||||
dstWorkerName, qualifiedName, functionSchema, stack, rf);
|
||||
dstWorkerName, qualifiedName, functionSchema, stack);
|
||||
return PyRRef(rrefPtr);
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
@ -504,16 +507,14 @@ If the future completes with an error, an exception is thrown.
|
||||
"_invoke_remote_python_udf",
|
||||
[](const WorkerInfo& dst,
|
||||
std::string& pickledPythonUDF,
|
||||
std::vector<torch::Tensor>& tensors,
|
||||
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf) {
|
||||
std::vector<torch::Tensor>& tensors) {
|
||||
DCHECK(!PyGILState_Check());
|
||||
return pyRemotePythonUdf(dst, pickledPythonUDF, tensors, rf);
|
||||
return pyRemotePythonUdf(dst, pickledPythonUDF, tensors);
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
py::arg("dst"),
|
||||
py::arg("pickledPythonUDF"),
|
||||
py::arg("tensors"),
|
||||
py::arg("rf") = nullptr);
|
||||
py::arg("tensors"));
|
||||
|
||||
module.def(
|
||||
"get_rpc_timeout",
|
||||
|
@ -112,6 +112,10 @@ PyRRef::PyRRef(const py::object& value, const py::object& type_hint)
|
||||
return rref;
|
||||
}()) {}
|
||||
|
||||
const std::shared_ptr<FutureMessage> PyRRef::getFuture() const {
|
||||
return rref_->getOwnerCreationFuture();
|
||||
}
|
||||
|
||||
bool PyRRef::isOwner() const {
|
||||
return rref_->isOwner();
|
||||
}
|
||||
|
@ -25,6 +25,10 @@ class PyRRef {
|
||||
py::tuple pickle() const;
|
||||
static PyRRef unpickle(const py::tuple& t);
|
||||
c10::IValue toIValue();
|
||||
// Future that is associated with the creation of this RRef on the remote end.
|
||||
// This is only used to get the future corresponding to the rref for profiling
|
||||
// use cases.
|
||||
const std::shared_ptr<FutureMessage> getFuture() const;
|
||||
|
||||
private:
|
||||
c10::intrusive_ptr<RRef> rref_;
|
||||
|
@ -63,8 +63,7 @@ std::shared_ptr<FutureMessage> sendPythonRemoteCall(
|
||||
const WorkerInfo& dst,
|
||||
SerializedPyObj serializedPyObj,
|
||||
const IValue& rrefId,
|
||||
const IValue& forkId,
|
||||
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf) {
|
||||
const IValue& forkId) {
|
||||
auto pythonRemoteCall = std::make_unique<PythonRemoteCall>(
|
||||
std::move(serializedPyObj), rrefId, forkId);
|
||||
|
||||
@ -75,8 +74,7 @@ std::shared_ptr<FutureMessage> sendPythonRemoteCall(
|
||||
*agent,
|
||||
dst,
|
||||
std::move(*pythonRemoteCall).toMessage(),
|
||||
true /*forceGradRecording*/,
|
||||
rf);
|
||||
true /*forceGradRecording*/);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -119,7 +117,6 @@ py::object toPyObj(const Message& message) {
|
||||
std::shared_ptr<FutureMessage> pyRpcBuiltin(
|
||||
const WorkerInfo& dst,
|
||||
const std::string& opName,
|
||||
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf,
|
||||
const py::args& args,
|
||||
const py::kwargs& kwargs) {
|
||||
Stack stack;
|
||||
@ -129,13 +126,12 @@ std::shared_ptr<FutureMessage> pyRpcBuiltin(
|
||||
auto scriptCall = std::make_unique<ScriptCall>(op, std::move(stack));
|
||||
auto agent = RpcAgent::getCurrentRpcAgent();
|
||||
return sendMessageWithAutograd(
|
||||
*agent, dst, std::move(*scriptCall).toMessage(), false, rf);
|
||||
*agent, dst, std::move(*scriptCall).toMessage(), false);
|
||||
}
|
||||
|
||||
PyRRef pyRemoteBuiltin(
|
||||
const WorkerInfo& dst,
|
||||
const std::string& opName,
|
||||
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf,
|
||||
const py::args& args,
|
||||
const py::kwargs& kwargs) {
|
||||
Stack stack;
|
||||
@ -154,8 +150,9 @@ PyRRef pyRemoteBuiltin(
|
||||
op, std::move(stack), userRRef->rrefId(), userRRef->forkId());
|
||||
|
||||
auto fm = sendMessageWithAutograd(
|
||||
*agent, dst, std::move(*scriptRemoteCall).toMessage(), false, rf);
|
||||
*agent, dst, std::move(*scriptRemoteCall).toMessage(), false);
|
||||
|
||||
userRRef->registerOwnerCreationFuture(fm);
|
||||
ctx.addPendingUser(userRRef->forkId(), userRRef);
|
||||
fm->addCallback([forkId{userRRef->forkId()}](const FutureMessage& fm) {
|
||||
callback::confirmPendingUser(fm, forkId);
|
||||
@ -169,7 +166,9 @@ PyRRef pyRemoteBuiltin(
|
||||
auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>(
|
||||
op, std::move(stack), ownerRRef->rrefId(), ownerRRef->rrefId());
|
||||
auto fm = sendMessageWithAutograd(
|
||||
*agent, dst, std::move(*scriptRemoteCall).toMessage(), false, rf);
|
||||
*agent, dst, std::move(*scriptRemoteCall).toMessage(), false);
|
||||
|
||||
ownerRRef->registerOwnerCreationFuture(fm);
|
||||
|
||||
// Builtin operators does not return py::object, and hence does not require
|
||||
// GIL for destructing the potentially deleted OwerRRef.
|
||||
@ -182,8 +181,7 @@ PyRRef pyRemoteBuiltin(
|
||||
std::shared_ptr<FutureMessage> pyRpcPythonUdf(
|
||||
const WorkerInfo& dst,
|
||||
std::string& pickledPythonUDF,
|
||||
std::vector<torch::Tensor>& tensors,
|
||||
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf) {
|
||||
std::vector<torch::Tensor>& tensors) {
|
||||
auto serializedPyObj =
|
||||
SerializedPyObj(std::move(pickledPythonUDF), std::move(tensors));
|
||||
auto pythonCall = std::make_unique<PythonCall>(std::move(serializedPyObj));
|
||||
@ -193,15 +191,13 @@ std::shared_ptr<FutureMessage> pyRpcPythonUdf(
|
||||
*agent,
|
||||
dst,
|
||||
std::move(*pythonCall).toMessage(),
|
||||
true /*forceGradRecording*/,
|
||||
rf);
|
||||
true /*forceGradRecording*/);
|
||||
}
|
||||
|
||||
PyRRef pyRemotePythonUdf(
|
||||
const WorkerInfo& dst,
|
||||
std::string& pickledPythonUDF,
|
||||
std::vector<torch::Tensor>& tensors,
|
||||
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf) {
|
||||
std::vector<torch::Tensor>& tensors) {
|
||||
auto& ctx = RRefContext::getInstance();
|
||||
auto serializedPyObj =
|
||||
SerializedPyObj(std::move(pickledPythonUDF), std::move(tensors));
|
||||
@ -211,8 +207,9 @@ PyRRef pyRemotePythonUdf(
|
||||
dst,
|
||||
std::move(serializedPyObj),
|
||||
userRRef->rrefId().toIValue(),
|
||||
userRRef->forkId().toIValue(),
|
||||
rf);
|
||||
userRRef->forkId().toIValue());
|
||||
|
||||
userRRef->registerOwnerCreationFuture(fm);
|
||||
|
||||
ctx.addPendingUser(userRRef->forkId(), userRRef);
|
||||
fm->addCallback([forkId{userRRef->forkId()}](const FutureMessage& fm) {
|
||||
@ -227,8 +224,9 @@ PyRRef pyRemotePythonUdf(
|
||||
dst,
|
||||
std::move(serializedPyObj),
|
||||
ownerRRef->rrefId().toIValue(),
|
||||
ownerRRef->rrefId().toIValue(),
|
||||
rf);
|
||||
ownerRRef->rrefId().toIValue());
|
||||
|
||||
ownerRRef->registerOwnerCreationFuture(fm);
|
||||
|
||||
fm->addCallback([](const FutureMessage& fm) {
|
||||
auto deletedRRef = callback::finishCreatingOwnerRRef(fm);
|
||||
|
@ -1,6 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/autograd/profiler.h>
|
||||
#include <torch/csrc/distributed/rpc/py_rref.h>
|
||||
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
@ -14,28 +13,24 @@ py::object toPyObj(const Message& message);
|
||||
std::shared_ptr<FutureMessage> pyRpcBuiltin(
|
||||
const WorkerInfo& dst,
|
||||
const std::string& opName,
|
||||
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf,
|
||||
const py::args& args,
|
||||
const py::kwargs& kwargs);
|
||||
|
||||
std::shared_ptr<FutureMessage> pyRpcPythonUdf(
|
||||
const WorkerInfo& dst,
|
||||
std::string& pickledPythonUDF,
|
||||
std::vector<torch::Tensor>& tensors,
|
||||
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf);
|
||||
std::vector<torch::Tensor>& tensors);
|
||||
|
||||
PyRRef pyRemoteBuiltin(
|
||||
const WorkerInfo& dst,
|
||||
const std::string& opName,
|
||||
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf,
|
||||
const py::args& args,
|
||||
const py::kwargs& kwargs);
|
||||
|
||||
PyRRef pyRemotePythonUdf(
|
||||
const WorkerInfo& dst,
|
||||
std::string& pickledPythonUDF,
|
||||
std::vector<torch::Tensor>& tensors,
|
||||
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf);
|
||||
std::vector<torch::Tensor>& tensors);
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
|
@ -219,6 +219,19 @@ class TORCH_API RRef : public RRefInterface {
|
||||
return type_;
|
||||
}
|
||||
|
||||
// Save the future corresponding to the creation of this RRef on a remote
|
||||
// node. Note that this is only set when processing requests invoked with
|
||||
// rpc.remote. This is only used to get the future corresponding to the rref
|
||||
// for profiling use cases.
|
||||
inline void registerOwnerCreationFuture(std::shared_ptr<FutureMessage> fut) {
|
||||
ownerCreationFuture_ = std::move(fut);
|
||||
}
|
||||
|
||||
// Get the future corresponding to the creation of this rref.
|
||||
inline std::shared_ptr<FutureMessage> getOwnerCreationFuture() const {
|
||||
return ownerCreationFuture_;
|
||||
}
|
||||
|
||||
// Send delete UserRRef request to Owner,
|
||||
// if the request hasn't been sent yet.
|
||||
// There are 2 cases to call it,
|
||||
@ -240,6 +253,8 @@ class TORCH_API RRef : public RRefInterface {
|
||||
// type field to denote the type of the element that the RRef is holding
|
||||
// it could be any TypePtr that JIT support, including PyObjectType
|
||||
const TypePtr type_;
|
||||
// Future corresponding to request to create RRef on remote node.
|
||||
std::shared_ptr<FutureMessage> ownerCreationFuture_;
|
||||
};
|
||||
|
||||
// ``UserRRef`` represents a user of an RRef. Besides the ``RRefId``, each user
|
||||
|
@ -15,8 +15,7 @@ c10::intrusive_ptr<c10::ivalue::Future> rpcTorchscript(
|
||||
const std::string& dstWorkerName,
|
||||
const c10::QualifiedName& qualifiedName,
|
||||
const c10::FunctionSchema& functionSchema,
|
||||
std::vector<c10::IValue>& stack,
|
||||
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf) {
|
||||
std::vector<c10::IValue>& stack) {
|
||||
auto scriptCall =
|
||||
std::make_unique<ScriptCall>(qualifiedName, std::move(stack));
|
||||
auto rpcAgentPtr = RpcAgent::getCurrentRpcAgent();
|
||||
@ -24,8 +23,7 @@ c10::intrusive_ptr<c10::ivalue::Future> rpcTorchscript(
|
||||
*rpcAgentPtr,
|
||||
rpcAgentPtr->getWorkerInfo(dstWorkerName),
|
||||
std::move(*scriptCall).toMessage(),
|
||||
true /*forceGradRecording*/,
|
||||
rf);
|
||||
true /*forceGradRecording*/);
|
||||
|
||||
// Get function return type to construct c10::ivalue::Future.
|
||||
auto returns = functionSchema.returns();
|
||||
@ -55,8 +53,7 @@ c10::intrusive_ptr<RRef> remoteTorchscript(
|
||||
const std::string& dstWorkerName,
|
||||
const c10::QualifiedName& qualifiedName,
|
||||
const c10::FunctionSchema& functionSchema,
|
||||
std::vector<c10::IValue>& stack,
|
||||
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf) {
|
||||
std::vector<c10::IValue>& stack) {
|
||||
auto rpcAgentPtr = RpcAgent::getCurrentRpcAgent();
|
||||
auto dstWorkerInfo = rpcAgentPtr->getWorkerInfo(dstWorkerName);
|
||||
auto& ctx = RRefContext::getInstance();
|
||||
@ -84,8 +81,9 @@ c10::intrusive_ptr<RRef> remoteTorchscript(
|
||||
*rpcAgentPtr,
|
||||
dstWorkerInfo,
|
||||
std::move(*scriptRemoteCall).toMessage(),
|
||||
true /*forceGradRecording*/,
|
||||
rf);
|
||||
true /*forceGradRecording*/);
|
||||
|
||||
userRRefPtr->registerOwnerCreationFuture(fm);
|
||||
|
||||
ctx.addPendingUser(userRRefPtr->forkId(), userRRefPtr);
|
||||
fm->addCallback([forkId{userRRefPtr->forkId()}](const FutureMessage& fm) {
|
||||
@ -108,8 +106,9 @@ c10::intrusive_ptr<RRef> remoteTorchscript(
|
||||
*rpcAgentPtr,
|
||||
dstWorkerInfo,
|
||||
std::move(*scriptRemoteCall).toMessage(),
|
||||
true /*forceGradRecording*/,
|
||||
rf);
|
||||
true /*forceGradRecording*/);
|
||||
|
||||
ownerRRefPtr->registerOwnerCreationFuture(fm);
|
||||
|
||||
fm->addCallback(
|
||||
[](const FutureMessage& fm) { callback::finishCreatingOwnerRRef(fm); });
|
||||
|
@ -24,17 +24,13 @@ c10::intrusive_ptr<c10::ivalue::Future> TORCH_API rpcTorchscript(
|
||||
const std::string& dstWorkerName,
|
||||
const c10::QualifiedName& qualifiedName,
|
||||
const c10::FunctionSchema& functionSchema,
|
||||
std::vector<c10::IValue>& stack,
|
||||
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf =
|
||||
nullptr);
|
||||
std::vector<c10::IValue>& stack);
|
||||
|
||||
c10::intrusive_ptr<RRef> TORCH_API remoteTorchscript(
|
||||
const std::string& dstWorkerName,
|
||||
const c10::QualifiedName& qualifiedName,
|
||||
const c10::FunctionSchema& functionSchema,
|
||||
std::vector<c10::IValue>& stack,
|
||||
const std::shared_ptr<torch::autograd::profiler::RecordFunction>& rf =
|
||||
nullptr);
|
||||
std::vector<c10::IValue>& stack);
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
|
@ -33,7 +33,7 @@ from .internal import (
|
||||
PythonUDF,
|
||||
RPCExecMode,
|
||||
_internal_rpc_pickler,
|
||||
_start_record_function,
|
||||
_build_rpc_profiling_key,
|
||||
)
|
||||
|
||||
|
||||
@ -423,11 +423,10 @@ def remote(to, func, args=None, kwargs=None):
|
||||
"""
|
||||
qualified_name = torch.jit._find_builtin(func)
|
||||
dst_worker_info = _to_worker_info(to)
|
||||
should_profile = torch.autograd._profiler_enabled()
|
||||
|
||||
# If profiling is enabled, kick off the timer and retrieve back a
|
||||
# RecordFunction instance.
|
||||
rf = None
|
||||
if torch.autograd._profiler_enabled():
|
||||
ctx_manager = contextlib.suppress()
|
||||
if should_profile:
|
||||
# Create appropriate string representation based on type of func
|
||||
# (builtin, script, python)
|
||||
if qualified_name is None:
|
||||
@ -438,28 +437,39 @@ def remote(to, func, args=None, kwargs=None):
|
||||
)
|
||||
else:
|
||||
func_name = qualified_name
|
||||
rf = _start_record_function(
|
||||
# Build RPC profiling key.
|
||||
rpc_profiling_key = _build_rpc_profiling_key(
|
||||
RPCExecMode.REMOTE,
|
||||
func_name,
|
||||
get_worker_info().name,
|
||||
dst_worker_info.name,
|
||||
)
|
||||
ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key)
|
||||
|
||||
args = args if args else ()
|
||||
kwargs = kwargs if kwargs else {}
|
||||
|
||||
if qualified_name is not None:
|
||||
return _invoke_remote_builtin(dst_worker_info, qualified_name, rf, *args, **kwargs)
|
||||
elif isinstance(func, torch.jit.ScriptFunction):
|
||||
return _invoke_remote_torchscript(
|
||||
dst_worker_info.name, torch._jit_internal._qualified_name(func), rf, *args, **kwargs
|
||||
)
|
||||
else:
|
||||
(pickled_python_udf, tensors) = _default_pickler.serialize(
|
||||
PythonUDF(func, args, kwargs)
|
||||
)
|
||||
return _invoke_remote_python_udf(dst_worker_info, pickled_python_udf, tensors, rf)
|
||||
with ctx_manager as rf:
|
||||
args = args if args else ()
|
||||
kwargs = kwargs if kwargs else {}
|
||||
if qualified_name is not None:
|
||||
rref = _invoke_remote_builtin(dst_worker_info, qualified_name, *args, **kwargs)
|
||||
elif isinstance(func, torch.jit.ScriptFunction):
|
||||
rref = _invoke_remote_torchscript(
|
||||
dst_worker_info.name,
|
||||
torch._jit_internal._qualified_name(func),
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
(pickled_python_udf, tensors) = _default_pickler.serialize(
|
||||
PythonUDF(func, args, kwargs)
|
||||
)
|
||||
rref = _invoke_remote_python_udf(dst_worker_info, pickled_python_udf, tensors)
|
||||
# attach profiling information
|
||||
if should_profile:
|
||||
assert torch.autograd._profiler_enabled()
|
||||
assert rf is not None
|
||||
rf._call_end_callbacks_on_future(rref._get_future())
|
||||
|
||||
return rref
|
||||
|
||||
def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None):
|
||||
if not callable(func):
|
||||
@ -467,10 +477,13 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None):
|
||||
|
||||
qualified_name = torch.jit._find_builtin(func)
|
||||
dst_worker_info = _to_worker_info(to)
|
||||
# If profiling is enabled, kick off the timer and retrieve back a
|
||||
# RecordFunction instance.
|
||||
rf = None
|
||||
if torch.autograd._profiler_enabled():
|
||||
|
||||
# TODO: profiling logic does not really belong in invoke_rpc, it should be
|
||||
# added as part of a context manager or helper (https://github.com/pytorch/pytorch/issues/36360)
|
||||
should_profile = torch.autograd._profiler_enabled()
|
||||
|
||||
ctx_manager = contextlib.suppress()
|
||||
if should_profile:
|
||||
# Create appropriate string representation based on type of func
|
||||
# (builtin, script, python)
|
||||
if qualified_name is None:
|
||||
@ -481,27 +494,35 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None):
|
||||
)
|
||||
else:
|
||||
func_name = qualified_name
|
||||
rf = _start_record_function(
|
||||
# Build RPC profiling key.
|
||||
rpc_profiling_key = _build_rpc_profiling_key(
|
||||
rpc_type,
|
||||
func_name,
|
||||
get_worker_info().name,
|
||||
dst_worker_info.name,
|
||||
)
|
||||
ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key)
|
||||
|
||||
args = args if args else ()
|
||||
kwargs = kwargs if kwargs else {}
|
||||
with ctx_manager as rf:
|
||||
args = args if args else ()
|
||||
kwargs = kwargs if kwargs else {}
|
||||
|
||||
if qualified_name is not None:
|
||||
fut = _invoke_rpc_builtin(dst_worker_info, qualified_name, rf, *args, **kwargs)
|
||||
elif isinstance(func, torch.jit.ScriptFunction):
|
||||
fut = _invoke_rpc_torchscript(
|
||||
dst_worker_info.name, torch.jit._qualified_name(func), rf, args, kwargs
|
||||
)
|
||||
else:
|
||||
(pickled_python_udf, tensors) = _default_pickler.serialize(
|
||||
PythonUDF(func, args, kwargs)
|
||||
)
|
||||
fut = _invoke_rpc_python_udf(dst_worker_info, pickled_python_udf, tensors, rf)
|
||||
if qualified_name is not None:
|
||||
fut = _invoke_rpc_builtin(dst_worker_info, qualified_name, *args, **kwargs)
|
||||
elif isinstance(func, torch.jit.ScriptFunction):
|
||||
fut = _invoke_rpc_torchscript(
|
||||
dst_worker_info.name, torch.jit._qualified_name(func), args, kwargs
|
||||
)
|
||||
else:
|
||||
(pickled_python_udf, tensors) = _default_pickler.serialize(
|
||||
PythonUDF(func, args, kwargs)
|
||||
)
|
||||
fut = _invoke_rpc_python_udf(dst_worker_info, pickled_python_udf, tensors)
|
||||
if should_profile:
|
||||
assert torch.autograd._profiler_enabled()
|
||||
assert rf is not None
|
||||
# Schedule profiling callbacks to run when the future completes.
|
||||
rf._call_end_callbacks_on_future(fut)
|
||||
return fut
|
||||
|
||||
|
||||
|
@ -14,7 +14,6 @@ import torch.distributed as dist
|
||||
# objects
|
||||
_thread_local_tensor_tables = threading.local()
|
||||
|
||||
|
||||
class RPCExecMode(Enum):
|
||||
SYNC = "sync"
|
||||
ASYNC = "async"
|
||||
@ -162,6 +161,28 @@ def _handle_exception(result):
|
||||
if isinstance(result, RemoteException):
|
||||
raise result.exception_type(result.msg)
|
||||
|
||||
def _build_rpc_profiling_key(exec_type, func_name, current_worker_name, dst_worker_name):
|
||||
"""
|
||||
Builds the key that RPC calls are profiled with using the autograd profiler.
|
||||
This will be the name of the corresponding Event recorded in the profiler.
|
||||
|
||||
Arguments:
|
||||
exec_type (RPCExecMode): Type of RPC/RRef call
|
||||
func_name (str): Name of function being profiled.
|
||||
current_worker_name (str): Name of current worker.
|
||||
dst_worker_name (str): Name of the destination worker.
|
||||
|
||||
Returns:
|
||||
String representing profiling key
|
||||
"""
|
||||
profile_key = "rpc_{rpc_type}#{func_name}({current_worker} -> {dst_worker})".format(
|
||||
rpc_type=exec_type.value,
|
||||
func_name=func_name,
|
||||
current_worker=current_worker_name,
|
||||
dst_worker=dst_worker_name,
|
||||
)
|
||||
return profile_key
|
||||
|
||||
|
||||
def _start_record_function(exec_type, func_name, current_worker_name, dest_worker_name):
|
||||
"""
|
||||
|
@ -188,3 +188,16 @@ def initialize_pg(init_method, rank, world_size):
|
||||
|
||||
def worker_name(rank):
|
||||
return "worker{}".format(rank)
|
||||
|
||||
def get_function_event(function_events, partial_event_name):
|
||||
"""
|
||||
Returns the first event that matches partial_event_name in the provided
|
||||
function_events. These function_events should be the output of
|
||||
torch.autograd.profiler.function_events().
|
||||
|
||||
Args:
|
||||
function_events: function_events returned by the profiler.
|
||||
event_name (str): partial key that the event was profiled with.
|
||||
"""
|
||||
event = [event for event in function_events if partial_event_name in event.name][0]
|
||||
return event
|
||||
|
@ -2,15 +2,24 @@ import unittest
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
import time
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.rpc as rpc
|
||||
from torch.distributed.rpc.internal import _build_rpc_profiling_key, RPCExecMode
|
||||
from torch import Tensor
|
||||
from torch.testing._internal.common_utils import TemporaryFileName
|
||||
from torch.testing._internal.dist_utils import dist_init, initialize_pg, worker_name
|
||||
from torch.testing._internal.dist_utils import (
|
||||
dist_init,
|
||||
initialize_pg,
|
||||
worker_name,
|
||||
get_function_event
|
||||
)
|
||||
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
|
||||
RpcAgentTestFixture,
|
||||
)
|
||||
|
||||
def sleep(t):
|
||||
time.sleep(t)
|
||||
|
||||
def rpc_return_rref(dst):
|
||||
return rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1))
|
||||
@ -25,7 +34,6 @@ def return_value(value):
|
||||
# type: (int) -> int
|
||||
return value
|
||||
|
||||
|
||||
class RRefAPITest:
|
||||
@dist_init
|
||||
def test_rref_is_owner(self):
|
||||
@ -96,6 +104,28 @@ def script_fork_wait_throw(invalue):
|
||||
value = torch.jit._wait(fut)
|
||||
return value
|
||||
|
||||
@torch.jit.script
|
||||
def call_rpc_with_profiling(handle: Tensor, dst_worker_name: str) -> Tensor:
|
||||
# Call rpc_async from within ScriptFunction and ensure that we can attach
|
||||
# profiling callbacks. Note that handle here is a Tensor representation of
|
||||
# RecordFunction.
|
||||
fut = rpc.rpc_async(dst_worker_name, one_arg, (torch.tensor(1),))
|
||||
torch.ops.profiler._call_end_callbacks_on_jit_fut(handle, fut)
|
||||
ret = fut.wait()
|
||||
return ret
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def call_fork_with_profiling(handle: Tensor) -> Tensor:
|
||||
# Call fork from within ScriptFunction and ensure that we can attach profiling
|
||||
# callbacks to the resulting future. Note that handle here is a Tensor
|
||||
# representation of RecordFunction.
|
||||
fut = torch.jit._fork(one_arg, torch.tensor(1))
|
||||
torch.ops.profiler._call_end_callbacks_on_jit_fut(handle, fut)
|
||||
ret = fut.wait()
|
||||
return ret
|
||||
|
||||
|
||||
class MyScriptModuleWithRRefs(torch.jit.ScriptModule):
|
||||
def __init__(self, dst_worker):
|
||||
super().__init__()
|
||||
@ -884,3 +914,53 @@ class JitRpcTest(RRefAPITest, LocalRRefTest, JitRpcAsyncOpTest, RpcAgentTestFixt
|
||||
args=(torch.ones(2),))
|
||||
with self.assertRaisesRegex(Exception, ".*Expected error.*"):
|
||||
future.wait()
|
||||
|
||||
@dist_init
|
||||
def test_call_rpc_with_profiling(self):
|
||||
# Ensures that we can call torch.ops.profiler._call_end_callbacks_on_jit_fut on a jit
|
||||
# future from within a script function that calls rpc_async
|
||||
if self.rank == 0:
|
||||
with torch.autograd.profiler.profile() as prof:
|
||||
prof_key = _build_rpc_profiling_key(
|
||||
RPCExecMode.ASYNC,
|
||||
torch.jit._qualified_name(one_arg),
|
||||
"worker0",
|
||||
"worker1",
|
||||
)
|
||||
with torch.autograd.profiler.record_function(prof_key) as rf:
|
||||
ret = call_rpc_with_profiling(rf.handle, "worker1")
|
||||
# TODO: Can't get a reliable time for this profiling event since
|
||||
# it's hard to estimate the execution time on the remote end for non-UDFs.
|
||||
# This can be resolved by https://github.com/pytorch/pytorch/issues/36272.
|
||||
# After that, this test should be modified to validate the function time.
|
||||
events = prof.function_events
|
||||
function_event = get_function_event(events, prof_key)
|
||||
self.assertTrue(torch.jit._qualified_name(one_arg) in function_event.name)
|
||||
|
||||
def test_record_function_jit_end_callbacks_with_fork(self):
|
||||
# Ensures that we can call rf._call_end_callbacks_on_future on a jit
|
||||
# future in python eager mode with torch.jit.fork
|
||||
sleep_interval = 1
|
||||
with torch.autograd.profiler.profile() as prof:
|
||||
with torch.autograd.profiler.record_function("foo") as rf:
|
||||
fut = torch.jit._fork(sleep, sleep_interval)
|
||||
rf._call_end_callbacks_on_future(fut)
|
||||
fut.wait()
|
||||
|
||||
function_events = prof.function_events
|
||||
sleep_event = get_function_event(function_events, "foo")
|
||||
self.assertEqual(sleep_event.name, "foo")
|
||||
# Validate that callbacks were fired at the right time by checking the
|
||||
# profiling event cpu time
|
||||
self.assertGreaterEqual(sleep_event.cpu_time * 1e-6, sleep_interval)
|
||||
|
||||
def test_call_fork_in_jit_with_profiling(self):
|
||||
# Ensures that we can call torch.ops.profiler._call_end_callbacks_on_jit_fut on a jit
|
||||
# future from within a script function with torch.jit.fork
|
||||
with torch.autograd.profiler.profile() as prof:
|
||||
with torch.autograd.profiler.record_function("foo") as rf:
|
||||
ret = call_fork_with_profiling(rf.handle)
|
||||
|
||||
events = prof.function_events
|
||||
function_event = get_function_event(events, "foo")
|
||||
self.assertEqual(function_event.name, "foo")
|
||||
|
@ -12,11 +12,17 @@ import torch.distributed.rpc as rpc
|
||||
import torch.testing._internal.dist_utils as dist_utils
|
||||
from torch.distributed.rpc import RRef, _get_debug_info, _rref_context_get_debug_info
|
||||
from torch.distributed.rpc.api import _delete_all_user_rrefs, _use_rpc_pickler
|
||||
from torch.distributed.rpc.internal import PythonUDF, RPCExecMode, _internal_rpc_pickler
|
||||
from torch.distributed.rpc.internal import (
|
||||
PythonUDF,
|
||||
RPCExecMode,
|
||||
_internal_rpc_pickler,
|
||||
_build_rpc_profiling_key,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import IS_MACOS, load_tests
|
||||
from torch.testing._internal.dist_utils import (
|
||||
dist_init,
|
||||
get_function_event,
|
||||
get_shutdown_error_regex,
|
||||
initialize_pg,
|
||||
wait_until_node_failure,
|
||||
@ -647,6 +653,18 @@ class RpcTest(RpcAgentTestFixture):
|
||||
)
|
||||
self.assertEqual(ret, my_function(n, n + 1, n + 2))
|
||||
|
||||
def test_build_rpc_profiling_key(self):
|
||||
# Tests that the name that shows up as an Event in profiling RPCs has all
|
||||
# the necessary information.
|
||||
for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
|
||||
rpc_profiling_key = _build_rpc_profiling_key(
|
||||
exec_mode, "foo", "worker0", "worker1"
|
||||
)
|
||||
self.assertIn(exec_mode.value, rpc_profiling_key)
|
||||
self.assertIn("foo", rpc_profiling_key)
|
||||
self.assertIn("worker0", rpc_profiling_key)
|
||||
self.assertIn("worker1", rpc_profiling_key)
|
||||
|
||||
def _profiler_test_with_rpc(self, rpc_exec_mode, func, args, use_record_function=False):
|
||||
dst = (self.rank + 1) % self.world_size
|
||||
# only run profiler on rank 1.
|
||||
@ -676,11 +694,9 @@ class RpcTest(RpcAgentTestFixture):
|
||||
record_function.__exit__()
|
||||
|
||||
events = prof.function_events
|
||||
rpc_event = [
|
||||
event for event in events if rpc_exec_mode.value in event.name
|
||||
][0]
|
||||
rpc_event = get_function_event(events, rpc_exec_mode.value)
|
||||
if use_record_function:
|
||||
scope_event = [event for event in events if "foo" in event.name][0]
|
||||
scope_event = get_function_event(events, "foo")
|
||||
# Since RPC call is within the scope, its CPU interval should be
|
||||
# contained within foo's interval.
|
||||
self.assertTrue(scope_event.cpu_interval.start < rpc_event.cpu_interval.start)
|
||||
@ -788,7 +804,45 @@ class RpcTest(RpcAgentTestFixture):
|
||||
use_record_function=True,
|
||||
)
|
||||
|
||||
@dist_init
|
||||
def test_async_record_function_double_end_callbacks(self):
|
||||
num_sleep_seconds = 1
|
||||
if self.rank == 1:
|
||||
# Validate that calling the function twice results in an error.
|
||||
with torch.autograd.profiler.profile() as pf:
|
||||
with torch.autograd.profiler.record_function("foo") as rf:
|
||||
fut = rpc.rpc_async(
|
||||
worker_name(0), my_sleep_func, args=(num_sleep_seconds,)
|
||||
)
|
||||
rf._call_end_callbacks_on_future(fut)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "can only be called once."
|
||||
):
|
||||
rf._call_end_callbacks_on_future(fut)
|
||||
fut.wait()
|
||||
|
||||
@dist_init
|
||||
def test_async_record_function_cbs_jit_call(self):
|
||||
if self.rank == 1:
|
||||
with torch.autograd.profiler.profile() as pf:
|
||||
key = _build_rpc_profiling_key(
|
||||
RPCExecMode.ASYNC,
|
||||
torch.jit._qualified_name(my_script_func),
|
||||
"worker1",
|
||||
"worker0",
|
||||
)
|
||||
with torch.autograd.profiler.record_function(key) as rf:
|
||||
fut = rpc.rpc_async(
|
||||
worker_name(0), my_script_func, args=(torch.tensor(1),)
|
||||
)
|
||||
# Intentionally calling record_function internals
|
||||
torch.ops.profiler._call_end_callbacks_on_jit_fut(rf.handle, fut)
|
||||
fut.wait()
|
||||
events = pf.function_events
|
||||
rpc_event = get_function_event(
|
||||
events, torch.jit._qualified_name(my_script_func)
|
||||
)
|
||||
self.assertTrue(torch.jit._qualified_name(my_script_func) in rpc_event.name)
|
||||
|
||||
@dist_init
|
||||
def test_py_class_constructor(self):
|
||||
@ -1384,6 +1438,30 @@ class RpcTest(RpcAgentTestFixture):
|
||||
),
|
||||
)
|
||||
|
||||
@dist_init
|
||||
def test_rref_get_future(self):
|
||||
# Tests that we can obtain the future corresponding to the creation of
|
||||
# the RRef on remote end
|
||||
if self.rank == 0:
|
||||
# Builtin
|
||||
rref = rpc.remote(worker_name(1), torch.add, args=(1, 1))
|
||||
rref.to_here()
|
||||
fut = rref._get_future()
|
||||
self.assertIsInstance(fut, torch.distributed.rpc.Future)
|
||||
|
||||
# UDF
|
||||
rref = rpc.remote(worker_name(1), foo_add, args=())
|
||||
rref.to_here()
|
||||
fut = rref._get_future()
|
||||
self.assertIsInstance(fut, torch.distributed.rpc.Future)
|
||||
|
||||
# Script
|
||||
rref = rpc.remote(worker_name(1), my_script_func, args=(torch.tensor(1), ))
|
||||
rref.to_here()
|
||||
fut = rref._get_future()
|
||||
self.assertIsInstance(fut, torch.distributed.rpc.Future)
|
||||
|
||||
|
||||
@dist_init
|
||||
def test_rref_context_debug_info(self):
|
||||
# This test checks local states that are modified by remote workers.
|
||||
|
Reference in New Issue
Block a user