mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Attach 'send' autograd function to the autograd graph as part of RPC. (#24876)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/24876 This contains very basic functionality of adding 'send' autograd function to our autograd graph. The purpose of this change is to validate the basic structure proposed here makes sense. Once this makes sense, we can build upon this to address more complicated scenarios. At a high level we've added the following functionality: 1) Define a very simple 'SendRpcBackwards' autograd function. 2) Attach this function to appropriate tensors when we call an RPC. 3) Store the send function in our distributed autograd context. ghstack-source-id: 89359708 Test Plan: unit tests. Differential Revision: D16903255 fbshipit-source-id: 6c04794a8e58b199795404225fd9da0c1440460e
This commit is contained in:
committed by
Facebook Github Bot
parent
a024e1e091
commit
40cb5182e9
@ -451,6 +451,8 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
if (NOT INTERN_BUILD_MOBILE)
|
||||
list(APPEND TORCH_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/autograd/functions/sendrpc_backward.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/autograd/utils.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/future_message.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/message.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/script_call.cpp
|
||||
@ -720,6 +722,7 @@ ENDIF()
|
||||
|
||||
if (BUILD_TEST AND NOT NO_API)
|
||||
add_subdirectory(${TORCH_ROOT}/test/cpp/api ${CMAKE_BINARY_DIR}/test_api)
|
||||
add_subdirectory(${TORCH_ROOT}/test/cpp/dist_autograd ${CMAKE_BINARY_DIR}/dist_autograd)
|
||||
endif()
|
||||
|
||||
# XXX This ABI check cannot be run with arm-linux-androideabi-g++
|
||||
|
27
test/cpp/dist_autograd/CMakeLists.txt
Normal file
27
test/cpp/dist_autograd/CMakeLists.txt
Normal file
@ -0,0 +1,27 @@
|
||||
set(DIST_AUTOGRAD_TEST_DIR "${TORCH_ROOT}/test/cpp/dist_autograd")
|
||||
set(DIST_AUTOGRAD_TEST_SOURCES
|
||||
${TORCH_ROOT}/test/cpp/common/main.cpp
|
||||
${DIST_AUTOGRAD_TEST_DIR}/test_dist_autograd.cpp
|
||||
)
|
||||
|
||||
add_executable(test_dist_autograd ${DIST_AUTOGRAD_TEST_SOURCES})
|
||||
target_include_directories(test_dist_autograd PRIVATE ${ATen_CPU_INCLUDE})
|
||||
target_link_libraries(test_dist_autograd PRIVATE torch gtest)
|
||||
|
||||
if (USE_CUDA)
|
||||
target_link_libraries(test_dist_autograd PRIVATE
|
||||
${CUDA_LIBRARIES}
|
||||
${CUDA_NVRTC_LIB}
|
||||
${CUDA_CUDA_LIB}
|
||||
${TORCH_CUDA_LIBRARIES})
|
||||
|
||||
target_compile_definitions(test_dist_autograd PRIVATE "USE_CUDA")
|
||||
endif()
|
||||
|
||||
if (INSTALL_TEST)
|
||||
install(TARGETS test_dist_autograd DESTINATION bin)
|
||||
# Install PDB files for MSVC builds
|
||||
if (MSVC AND BUILD_SHARED_LIBS)
|
||||
install(FILES $<TARGET_PDB_FILE:test_dist_autograd> DESTINATION bin OPTIONAL)
|
||||
endif()
|
||||
endif()
|
53
test/cpp/dist_autograd/test_dist_autograd.cpp
Normal file
53
test/cpp/dist_autograd/test_dist_autograd.cpp
Normal file
@ -0,0 +1,53 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/csrc/distributed/autograd/utils.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
TEST(DistAutogradTest, TestSendFunction) {
|
||||
// Initialize input tensors requiring grad.
|
||||
auto options = at::TensorOptions().requires_grad(true);
|
||||
auto in1 = torch::ones({3, 3}, options);
|
||||
auto in2 = torch::ones({3, 3}, options);
|
||||
ASSERT_FALSE(in1.grad().defined());
|
||||
ASSERT_FALSE(in2.grad().defined());
|
||||
|
||||
// Attach the send autograd function to tensors.
|
||||
auto send_function =
|
||||
torch::distributed::autograd::addSendRpcBackward({in1, in2});
|
||||
ASSERT_NE(send_function, nullptr);
|
||||
|
||||
// Build loss and attach it as input to send autograd function.
|
||||
auto o1 = torch::autograd::Variable(torch::ones({3, 3}));
|
||||
auto edge = torch::autograd::Edge(send_function, 0);
|
||||
o1.set_gradient_edge(edge);
|
||||
auto o2 = torch::autograd::Variable(torch::ones({3, 3}));
|
||||
edge = torch::autograd::Edge(send_function, 1);
|
||||
o2.set_gradient_edge(edge);
|
||||
auto loss = torch::add(o1, o2);
|
||||
|
||||
// Run backwards pass and verify gradients accumulated.
|
||||
auto gradient = torch::autograd::Variable(torch::rand({3, 3}));
|
||||
loss.backward(gradient, false, false);
|
||||
ASSERT_TRUE(in1.grad().defined());
|
||||
ASSERT_TRUE(in2.grad().defined());
|
||||
}
|
||||
|
||||
TEST(DistAutogradTest, TestSendFunctionInvalidInputs) {
|
||||
auto options = at::TensorOptions().requires_grad(true);
|
||||
auto in1 = torch::ones({3, 3}, options);
|
||||
auto in2 = torch::ones({3, 3}, options);
|
||||
|
||||
// Attach the send autograd function to tensors.
|
||||
auto send_function =
|
||||
torch::distributed::autograd::addSendRpcBackward({in1, in2});
|
||||
|
||||
// Build loss and attach it as input to send autograd function.
|
||||
auto loss = torch::autograd::Variable(torch::ones({3, 3}));
|
||||
loss.set_gradient_edge(torch::autograd::Edge(send_function, 1));
|
||||
|
||||
// This should fail since the SendRpcBackward function is looking for two
|
||||
// inputs and as a result encounters an undefined grad.
|
||||
EXPECT_THROW(
|
||||
loss.backward(torch::autograd::Variable(), false, false), c10::Error);
|
||||
}
|
@ -5,7 +5,9 @@ import torch.distributed as dist
|
||||
import torch.distributed.autograd as dist_autograd
|
||||
from common_distributed import MultiProcessTestCase
|
||||
from functools import wraps
|
||||
import six
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
if not dist.is_available():
|
||||
print("c10d not available, skipping tests")
|
||||
@ -30,10 +32,14 @@ def dist_init(func):
|
||||
|
||||
return wrapper
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 0), "Pytorch distributed autograd package "
|
||||
@unittest.skipIf(not six.PY3, "Pytorch distributed autograd package "
|
||||
"does not support python2")
|
||||
class TestDistAutograd(MultiProcessTestCase):
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return 4
|
||||
|
||||
@dist_init
|
||||
def test_autograd_context(self):
|
||||
context_ids = []
|
||||
@ -48,5 +54,63 @@ class TestDistAutograd(MultiProcessTestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, 'Could not find autograd context with id: {}'.format(context_id)):
|
||||
dist_autograd._retrieve_context(context_id)
|
||||
|
||||
@dist_init
|
||||
def test_autograd_send_function(self):
|
||||
dst_rank = (self.rank + 1) % self.world_size
|
||||
with dist_autograd.context() as context_id:
|
||||
t1 = torch.ones(3, 3, requires_grad=True)
|
||||
t2 = torch.zeros(3, 3, requires_grad=True)
|
||||
ret = dist.rpc('worker{}'.format(dst_rank), torch.add,
|
||||
args=(t1, t2))
|
||||
|
||||
# Get send function.
|
||||
ctx = dist_autograd._current_context()
|
||||
self.assertEqual(context_id, ctx._context_id())
|
||||
send_functions = ctx._send_functions()
|
||||
self.assertEqual(1, len(send_functions))
|
||||
|
||||
# Retrieve the next functions in the graph.
|
||||
next_funcs = send_functions[0].next_functions
|
||||
self.assertEqual(2, len(next_funcs))
|
||||
|
||||
# We should now hit t1 and t2 in the autograd graph.
|
||||
self.assertEqual('torch::autograd::AccumulateGrad', next_funcs[0][0].name())
|
||||
self.assertEqual(t1, next_funcs[0][0].variable)
|
||||
self.assertEqual(0, next_funcs[0][1])
|
||||
self.assertEqual('torch::autograd::AccumulateGrad', next_funcs[1][0].name())
|
||||
self.assertEqual(t2, next_funcs[1][0].variable)
|
||||
self.assertEqual(0, next_funcs[1][1])
|
||||
|
||||
# autograd context should be cleaned up by now.
|
||||
with self.assertRaises(RuntimeError):
|
||||
ctx = dist_autograd._retrieve_context(context_id)
|
||||
|
||||
# No autograd context available.
|
||||
with self.assertRaises(RuntimeError):
|
||||
ctx = dist_autograd._current_context()
|
||||
|
||||
@dist_init
|
||||
def test_rpc_complex_args(self):
|
||||
dst_rank = (self.rank + 1) % self.world_size
|
||||
with dist_autograd.context() as context_id:
|
||||
num_tensors = 10
|
||||
tensors = []
|
||||
for i in range(num_tensors):
|
||||
tensors.append(torch.ones(3, 3, requires_grad=(i % 2 == 0)))
|
||||
ret = dist.rpc('worker{}'.format(dst_rank), torch.stack,
|
||||
args=(tensors,))
|
||||
self.assertEqual(torch.stack(tensors), ret)
|
||||
|
||||
# Verify appropriate tensors have been attached the autograd graph.
|
||||
next_funcs = dist_autograd._current_context()._send_functions()[0].next_functions
|
||||
idx = 0
|
||||
for i in range(num_tensors):
|
||||
if i % 2 == 0:
|
||||
self.assertEqual('torch::autograd::AccumulateGrad', next_funcs[i][0].name())
|
||||
self.assertEqual(tensors[i], next_funcs[i][0].variable)
|
||||
else:
|
||||
self.assertIsNone(next_funcs[i][0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -6,6 +6,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from common_distributed import MultiProcessTestCase
|
||||
from common_utils import load_tests, run_tests
|
||||
|
||||
@ -68,6 +69,12 @@ if not dist.is_available():
|
||||
|
||||
|
||||
def _wrap_with_rpc(func):
|
||||
'''
|
||||
We use this decorator for setting up and tearing down state since
|
||||
MultiProcessTestCase runs each `test*` method in a separate process and
|
||||
each process just runs the `test*` method without actually calling
|
||||
'setUp' and 'tearDown' methods of unittest.
|
||||
'''
|
||||
def wrapper(self):
|
||||
store = dist.FileStore(self.file.name, self.world_size)
|
||||
dist.init_process_group(backend='gloo', rank=self.rank,
|
||||
@ -378,6 +385,5 @@ class RpcTest(MultiProcessTestCase):
|
||||
def test_stress_heavy_rpc(self):
|
||||
self._stress_test_rpc(heavy_rpc, repeat=20, args=(torch.ones(100, 100),))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -49,8 +49,10 @@ libtorch_sources = [
|
||||
"torch/csrc/autograd/record_function.cpp",
|
||||
"torch/csrc/autograd/saved_variable.cpp",
|
||||
"torch/csrc/autograd/variable.cpp",
|
||||
"torch/csrc/distributed/autograd/utils.cpp",
|
||||
"torch/csrc/distributed/autograd/context/dist_autograd_container.cpp",
|
||||
"torch/csrc/distributed/autograd/context/dist_autograd_context.cpp",
|
||||
"torch/csrc/distributed/autograd/functions/sendrpc_backward.cpp",
|
||||
"torch/csrc/distributed/rpc/future_message.cpp",
|
||||
"torch/csrc/distributed/rpc/message.cpp",
|
||||
"torch/csrc/distributed/rpc/script_call.cpp",
|
||||
@ -235,6 +237,7 @@ def add_torch_libs():
|
||||
"torch/csrc/distributed/c10d/comm.cpp",
|
||||
"torch/csrc/distributed/c10d/init.cpp",
|
||||
"torch/csrc/distributed/c10d/reducer.cpp",
|
||||
"torch/csrc/distributed/autograd/init.cpp",
|
||||
"torch/csrc/distributed/rpc/functions.cpp",
|
||||
"torch/csrc/distributed/rpc/init.cpp",
|
||||
"torch/csrc/distributed/rpc/process_group_agent.cpp",
|
||||
|
@ -1,10 +1,11 @@
|
||||
#include <Python.h>
|
||||
#include <torch/csrc/autograd/functions/accumulate_grad.h>
|
||||
#include <torch/csrc/autograd/functions/basic_ops.h>
|
||||
#include <torch/csrc/autograd/functions/tensor.h>
|
||||
#include <torch/csrc/autograd/functions/pybind.h>
|
||||
#include <torch/csrc/autograd/python_cpp_function.h>
|
||||
#include <torch/csrc/autograd/functions/tensor.h>
|
||||
#include <torch/csrc/autograd/generated/python_functions.h>
|
||||
#include <torch/csrc/autograd/python_cpp_function.h>
|
||||
#include <torch/csrc/distributed/autograd/functions/sendrpc_backward.h>
|
||||
#include <torch/csrc/jit/python_tracer.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/utils/tuple_parser.h>
|
||||
@ -102,6 +103,10 @@ void THPAutograd_initFunctions()
|
||||
static PyTypeObject CopyBackwardsClass;
|
||||
addClass<CopyBackwards, NoCtor>(module, CopyBackwardsClass, "CopyBackwards");
|
||||
|
||||
static PyTypeObject SendRpcBackwardClass;
|
||||
addClass<torch::distributed::autograd::SendRpcBackward, NoCtor>(
|
||||
module, SendRpcBackwardClass, "SendRpcBackward");
|
||||
|
||||
static PyTypeObject CopySlicesClass;
|
||||
addClass<CopySlices, NoCtor>(module, CopySlicesClass, "CopySlices");
|
||||
|
||||
|
@ -10,8 +10,10 @@ constexpr int64_t kContextIdMask = (1LL << kContextIdBits) - 1;
|
||||
constexpr int kMaxWorkerId = 65535;
|
||||
constexpr int64_t kMaxContextId = kContextIdMask;
|
||||
|
||||
thread_local int64_t DistAutogradContainer::current_context_id_ = -1;
|
||||
|
||||
DistAutogradContainer::DistAutogradContainer()
|
||||
: current_context_id_(0), worker_id_(0), initialized_(false) {}
|
||||
: next_context_id_(0), worker_id_(0), initialized_(false) {}
|
||||
|
||||
DistAutogradContainer& DistAutogradContainer::init(int64_t worker_id) {
|
||||
TORCH_CHECK(
|
||||
@ -20,7 +22,7 @@ DistAutogradContainer& DistAutogradContainer::init(int64_t worker_id) {
|
||||
|
||||
auto& container = getInstance();
|
||||
container.worker_id_ = worker_id;
|
||||
container.current_context_id_ = static_cast<int64_t>(worker_id)
|
||||
container.next_context_id_ = static_cast<int64_t>(worker_id)
|
||||
<< kContextIdBits;
|
||||
container.initialized_ = true;
|
||||
return container;
|
||||
@ -39,16 +41,39 @@ const DistAutogradContext& DistAutogradContainer::newContext() {
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> guard(autograd_context_lock_);
|
||||
if (current_context_id_ == std::numeric_limits<int64_t>::max() ||
|
||||
current_context_id_ >
|
||||
(kMaxContextId |
|
||||
(static_cast<int64_t>(worker_id_) << kContextIdBits))) {
|
||||
throw std::runtime_error("We have run out of autograd context ids!!!");
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
next_context_id_ < std::numeric_limits<int64_t>::max() &&
|
||||
next_context_id_ <
|
||||
(kMaxContextId |
|
||||
(static_cast<int64_t>(worker_id_) << kContextIdBits)),
|
||||
"We have run out of autograd context ids!!!");
|
||||
|
||||
autograd_context_.emplace(
|
||||
current_context_id_, DistAutogradContext(current_context_id_));
|
||||
return autograd_context_.at(current_context_id_++);
|
||||
std::piecewise_construct,
|
||||
std::forward_as_tuple(next_context_id_),
|
||||
std::forward_as_tuple(next_context_id_));
|
||||
|
||||
current_context_id_ = next_context_id_;
|
||||
return autograd_context_.at(next_context_id_++);
|
||||
}
|
||||
|
||||
bool DistAutogradContainer::hasValidContext() const {
|
||||
return current_context_id_ != -1;
|
||||
}
|
||||
|
||||
DistAutogradContext& DistAutogradContainer::currentContext() {
|
||||
TORCH_CHECK(
|
||||
hasValidContext(),
|
||||
"Current thread doesn't have a valid autograd context. Please wrap your "
|
||||
"code using: `with torch.distributed.autograd.context() as context_id` "
|
||||
"to generate a valid context");
|
||||
std::lock_guard<std::mutex> guard(autograd_context_lock_);
|
||||
auto it = autograd_context_.find(current_context_id_);
|
||||
TORCH_CHECK(
|
||||
it != autograd_context_.end(),
|
||||
"Couldn't find autograd context "
|
||||
"data for current autograd context id");
|
||||
return it->second;
|
||||
}
|
||||
|
||||
void DistAutogradContainer::releaseContext(int64_t context_id) {
|
||||
@ -58,10 +83,15 @@ void DistAutogradContainer::releaseContext(int64_t context_id) {
|
||||
"Could not find autograd context with id: ",
|
||||
context_id);
|
||||
autograd_context_.erase(context_id);
|
||||
|
||||
if (current_context_id_ == context_id) {
|
||||
// Reset the thread_local current context id, since it is no longer valid.
|
||||
current_context_id_ = -1;
|
||||
}
|
||||
}
|
||||
|
||||
const DistAutogradContext& DistAutogradContainer::retrieveContext(
|
||||
int64_t context_id) {
|
||||
int64_t context_id) const {
|
||||
std::lock_guard<std::mutex> guard(autograd_context_lock_);
|
||||
TORCH_CHECK(
|
||||
autograd_context_.find(context_id) != autograd_context_.end(),
|
||||
|
@ -20,11 +20,26 @@ namespace autograd {
|
||||
// auto-incrementing id for each worker.
|
||||
class DistAutogradContainer {
|
||||
public:
|
||||
// One time initialization of the container.
|
||||
static DistAutogradContainer& init(int64_t worker_id);
|
||||
|
||||
// Retrieve the singleton instance of the container.
|
||||
static DistAutogradContainer& getInstance();
|
||||
|
||||
// Create a new context for a distributed autograd pass.
|
||||
const DistAutogradContext& newContext();
|
||||
|
||||
// Clean up resources for a given context_id once the autograd pass is done.
|
||||
void releaseContext(int64_t context_id);
|
||||
const DistAutogradContext& retrieveContext(int64_t context_id);
|
||||
|
||||
// Retrieve the autograd context for a given context_id.
|
||||
const DistAutogradContext& retrieveContext(int64_t context_id) const;
|
||||
|
||||
// Retrieves the currently active autograd context for the current thread.
|
||||
DistAutogradContext& currentContext();
|
||||
|
||||
// Checks whether or not the current thread has a valid autograd context.
|
||||
bool hasValidContext() const;
|
||||
|
||||
private:
|
||||
DistAutogradContainer();
|
||||
@ -37,7 +52,7 @@ class DistAutogradContainer {
|
||||
|
||||
// Auto incrementing context id used to identify unique autograd passes.
|
||||
// Initialized with the first 16 bits being the worker_id.
|
||||
int64_t current_context_id_;
|
||||
int64_t next_context_id_;
|
||||
|
||||
// Unique id to identify a worker in the distributed setting.
|
||||
int16_t worker_id_;
|
||||
@ -48,9 +63,12 @@ class DistAutogradContainer {
|
||||
// Whether or not the container has been initialized appropriately.
|
||||
bool initialized_;
|
||||
|
||||
// Lock to protect current_context_id_ and autograd_context map. initialized_
|
||||
// Lock to protect next_context_id_ and autograd_context map. initialized_
|
||||
// and worker_id_ are immutable.
|
||||
mutable std::mutex autograd_context_lock_;
|
||||
|
||||
// Each thread has a single autograd_context_id valid at any point in time.
|
||||
static thread_local int64_t current_context_id_;
|
||||
};
|
||||
|
||||
} // namespace autograd
|
||||
|
@ -12,6 +12,17 @@ int64_t DistAutogradContext::context_id() const {
|
||||
return context_id_;
|
||||
}
|
||||
|
||||
void DistAutogradContext::addSendFunction(
|
||||
const std::shared_ptr<SendRpcBackward>& func) {
|
||||
std::lock_guard<std::mutex> guard(lock_);
|
||||
sendAutogradFunctions_.push_back(func);
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<SendRpcBackward>> DistAutogradContext::
|
||||
sendFunctions() const {
|
||||
return sendAutogradFunctions_;
|
||||
}
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
||||
|
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/distributed/autograd/functions/sendrpc_backward.h>
|
||||
#include <cstdint>
|
||||
|
||||
namespace torch {
|
||||
@ -11,10 +12,27 @@ namespace autograd {
|
||||
class DistAutogradContext {
|
||||
public:
|
||||
explicit DistAutogradContext(int64_t context_id);
|
||||
|
||||
// Retrieves the autograd context id for this context.
|
||||
int64_t context_id() const;
|
||||
|
||||
// Records a 'send' autograd function for this context.
|
||||
void addSendFunction(const std::shared_ptr<SendRpcBackward>& func);
|
||||
|
||||
std::vector<std::shared_ptr<SendRpcBackward>> sendFunctions() const;
|
||||
|
||||
DistAutogradContext(const DistAutogradContext&) = delete;
|
||||
DistAutogradContext& operator=(const DistAutogradContext&) = delete;
|
||||
DistAutogradContext(DistAutogradContext&&) = delete;
|
||||
DistAutogradContext& operator=(DistAutogradContext&&) = delete;
|
||||
|
||||
private:
|
||||
const int64_t context_id_;
|
||||
|
||||
std::vector<std::shared_ptr<SendRpcBackward>> sendAutogradFunctions_;
|
||||
|
||||
// Lock to protect concurrent modification of the context.
|
||||
mutable std::mutex lock_;
|
||||
};
|
||||
|
||||
} // namespace autograd
|
||||
|
@ -0,0 +1,22 @@
|
||||
#include <torch/csrc/distributed/autograd/functions/sendrpc_backward.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace autograd {
|
||||
|
||||
torch::autograd::variable_list SendRpcBackward::apply(
|
||||
torch::autograd::variable_list&& grads) {
|
||||
// Each grad variable should be valid!
|
||||
for (const auto& grad : grads) {
|
||||
TORCH_CHECK(
|
||||
grad.defined(), "BUG!: SendRpcBackward didn't receive valid gradients");
|
||||
}
|
||||
|
||||
// Simply forwards the gradients over.
|
||||
// TODO: Improve this as we build out more parts of distributed autograd.
|
||||
return std::move(grads);
|
||||
}
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
24
torch/csrc/distributed/autograd/functions/sendrpc_backward.h
Normal file
24
torch/csrc/distributed/autograd/functions/sendrpc_backward.h
Normal file
@ -0,0 +1,24 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/autograd/function.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace autograd {
|
||||
|
||||
// As part of our distributed autograd implementation, whenever we send an RPC
|
||||
// from one node to another, we add a 'SendRpcBackward' autograd function to the
|
||||
// autograd graph. This is more or less a placeholder function that is used to
|
||||
// kickoff the autograd engine on the current worker on the backward pass. The
|
||||
// edges for this autograd function are the inputs to the RPC method.
|
||||
//
|
||||
// During the backward pass, this function is queued for execution in the
|
||||
// autograd engine which eventually runs the rest of the autograd graph.
|
||||
struct TORCH_API SendRpcBackward : public torch::autograd::Node {
|
||||
torch::autograd::variable_list apply(
|
||||
torch::autograd::variable_list&& grads) override;
|
||||
};
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
@ -1,3 +1,4 @@
|
||||
#include <torch/csrc/autograd/python_cpp_function.h>
|
||||
#include <torch/csrc/distributed/autograd/context/dist_autograd_container.h>
|
||||
#include <torch/csrc/jit/pybind_utils.h>
|
||||
#include <torch/csrc/python_headers.h>
|
||||
@ -28,19 +29,40 @@ PyObject* dist_autograd_init(PyObject* /* unused */) {
|
||||
.def(
|
||||
"_context_id",
|
||||
&DistAutogradContext::context_id,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("_send_functions", [](const DistAutogradContext& ctx) {
|
||||
std::vector<py::object> funcs;
|
||||
for (const auto& sendFunction : ctx.sendFunctions()) {
|
||||
funcs.push_back(py::reinterpret_steal<py::object>(
|
||||
torch::autograd::functionToPyObject(sendFunction)));
|
||||
}
|
||||
return funcs;
|
||||
});
|
||||
|
||||
module.def("_new_context", []() {
|
||||
return DistAutogradContainer::getInstance().newContext();
|
||||
});
|
||||
module.def(
|
||||
"_new_context",
|
||||
[]() -> const DistAutogradContext& {
|
||||
return DistAutogradContainer::getInstance().newContext();
|
||||
},
|
||||
py::return_value_policy::reference);
|
||||
|
||||
module.def("_release_context", [](int64_t context_id) {
|
||||
return DistAutogradContainer::getInstance().releaseContext(context_id);
|
||||
});
|
||||
|
||||
module.def("_retrieve_context", [](int64_t context_id) {
|
||||
return DistAutogradContainer::getInstance().retrieveContext(context_id);
|
||||
});
|
||||
module.def(
|
||||
"_retrieve_context",
|
||||
[](int64_t context_id) -> const DistAutogradContext& {
|
||||
return DistAutogradContainer::getInstance().retrieveContext(context_id);
|
||||
},
|
||||
py::return_value_policy::reference);
|
||||
|
||||
module.def(
|
||||
"_current_context",
|
||||
[]() -> const DistAutogradContext& {
|
||||
return DistAutogradContainer::getInstance().currentContext();
|
||||
},
|
||||
py::return_value_policy::reference);
|
||||
|
||||
module.def("_init", [](int64_t worker_id) {
|
||||
DistAutogradContainer::init(worker_id);
|
||||
|
26
torch/csrc/distributed/autograd/utils.cpp
Normal file
26
torch/csrc/distributed/autograd/utils.cpp
Normal file
@ -0,0 +1,26 @@
|
||||
#include <torch/csrc/autograd/functions/utils.h>
|
||||
#include <torch/csrc/distributed/autograd/utils.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace autograd {
|
||||
|
||||
std::shared_ptr<SendRpcBackward> addSendRpcBackward(
|
||||
const std::vector<torch::Tensor>& tensors) {
|
||||
// Attach the appropriate autograd edges.
|
||||
std::shared_ptr<SendRpcBackward> grad_fn;
|
||||
if (torch::autograd::compute_requires_grad(tensors)) {
|
||||
grad_fn = std::make_shared<SendRpcBackward>();
|
||||
grad_fn->set_next_edges(torch::autograd::collect_next_edges(tensors));
|
||||
|
||||
// Add the appropriate input metadata for the grad_fn.
|
||||
for (const auto& tensor : tensors) {
|
||||
grad_fn->add_input_metadata(tensor);
|
||||
}
|
||||
}
|
||||
return grad_fn;
|
||||
}
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
21
torch/csrc/distributed/autograd/utils.h
Normal file
21
torch/csrc/distributed/autograd/utils.h
Normal file
@ -0,0 +1,21 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/distributed/autograd/functions/sendrpc_backward.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace autograd {
|
||||
|
||||
// This method is used to attach the 'send' autograd function to the autograd
|
||||
// graph when we use RPC. This method creates a new 'send' autograd function
|
||||
// and attaches the provided tensors as next_edges to the 'send' function.
|
||||
//
|
||||
// Returns a shared_ptr to the autograd function, so that we can hold a
|
||||
// reference to it.
|
||||
TORCH_API std::shared_ptr<SendRpcBackward> addSendRpcBackward(
|
||||
const std::vector<torch::Tensor>& tensors);
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
@ -164,8 +164,9 @@ void ProcessGroupAgent::sync() {
|
||||
pg_->barrier()->wait();
|
||||
}
|
||||
|
||||
std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
|
||||
const WorkerId& to, Message&& message) {
|
||||
std::shared_ptr<FutureMessage> ProcessGroupAgent::sendImpl(
|
||||
const WorkerId& to,
|
||||
Message&& message) {
|
||||
TORCH_CHECK(to.id_ != (worker_id_t)pg_->getRank(),
|
||||
"ProcessGroupAgent does not support making RPC calls to self.")
|
||||
TORCH_CHECK(to.id_ < (worker_id_t)pg_->getSize(),
|
||||
|
@ -42,12 +42,6 @@ class ProcessGroupAgent : public RpcAgent {
|
||||
std::shared_ptr<c10d::ProcessGroup> pg,
|
||||
int numSendRecvThreads = 4);
|
||||
|
||||
// This method wraps the destination information and the message into a
|
||||
// SendWork object, and put the SendWork into a queue. Another thread will
|
||||
// consume SendWork from the queue and send it out.
|
||||
std::shared_ptr<FutureMessage> send(
|
||||
const WorkerId& to, Message&& message) override;
|
||||
|
||||
const WorkerId& getWorkerId(const std::string& workerName) const override;
|
||||
|
||||
void join() override;
|
||||
@ -56,6 +50,13 @@ class ProcessGroupAgent : public RpcAgent {
|
||||
|
||||
int16_t getWorkerId() override;
|
||||
|
||||
protected:
|
||||
// This method wraps the destination information and the message into a
|
||||
// SendWork object, and put the SendWork into a queue. Another thread will
|
||||
// consume SendWork from the queue and send it out.
|
||||
std::shared_ptr<FutureMessage> sendImpl(const WorkerId& to, Message&& message)
|
||||
override;
|
||||
|
||||
private:
|
||||
void collectNames();
|
||||
// put SendWork into a queue and notify the worker thread
|
||||
|
@ -31,20 +31,28 @@ std::shared_ptr<FutureMessage> py_rpc_builtin(
|
||||
const std::string& opName,
|
||||
const py::args& args,
|
||||
const py::kwargs& kwargs) {
|
||||
if (opName.rfind("aten", 0) == 0) {
|
||||
// builtin operators.
|
||||
Symbol symbol = Symbol::fromQualString(opName);
|
||||
for (const auto& op : torch::jit::getAllOperatorsFor(symbol)) {
|
||||
try {
|
||||
// FIXME: This is temporary solution. We should at least refactor
|
||||
// ``createStackForSchema`` to avoid throwing an error.
|
||||
Stack stack = torch::jit::createStackForSchema(
|
||||
op->schema(), args, kwargs, c10::nullopt);
|
||||
if (symbol.is_aten()) {
|
||||
Stack stack;
|
||||
for (const auto& op : torch::jit::getAllOperatorsFor(symbol)) {
|
||||
try {
|
||||
// FIXME: This is temporary solution. We should at least refactor
|
||||
// ``createStackForSchema`` to avoid throwing an error.
|
||||
stack = torch::jit::createStackForSchema(
|
||||
op->schema(), args, kwargs, c10::nullopt);
|
||||
|
||||
} catch (std::runtime_error& e) {
|
||||
VLOG(1) << "Couldn't match schema: " << op->schema()
|
||||
<< " to args: " << args << " and kwargs: " << kwargs
|
||||
<< ", reason: " << e.what();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Found the right op! Send it along...
|
||||
return agent.send(dst, ScriptCall(op, std::move(stack)).toMessage());
|
||||
} catch (std::runtime_error) {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AT_ERROR(
|
||||
"Failed to match operator name ",
|
||||
|
@ -1,10 +1,13 @@
|
||||
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
||||
#include <torch/csrc/distributed/autograd/context/dist_autograd_container.h>
|
||||
#include <torch/csrc/distributed/autograd/utils.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
constexpr size_t WorkerId::MAX_NAME_LEN;
|
||||
using namespace torch::distributed::autograd;
|
||||
|
||||
RpcAgent::RpcAgent(WorkerId workerId, RequestCallback cb)
|
||||
: workerId_(std::move(workerId)), cb_(std::move(cb)) {}
|
||||
@ -15,6 +18,24 @@ const WorkerId& RpcAgent::getWorkerId() const {
|
||||
return workerId_;
|
||||
}
|
||||
|
||||
std::shared_ptr<FutureMessage> RpcAgent::send(
|
||||
const WorkerId& to,
|
||||
Message&& message) {
|
||||
// Record appropriate autograd information before sending the message over the
|
||||
// wire.
|
||||
auto& autogradContainer = DistAutogradContainer::getInstance();
|
||||
if (autogradContainer.hasValidContext()) {
|
||||
// Attach the appropriate autograd edges to the tensors found in the
|
||||
// message.
|
||||
auto grad_fn = addSendRpcBackward(message.tensors());
|
||||
|
||||
// Record the send function in our current context.
|
||||
auto& currentContext = autogradContainer.currentContext();
|
||||
currentContext.addSendFunction(grad_fn);
|
||||
}
|
||||
|
||||
return sendImpl(to, std::forward<Message>(message));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -81,8 +81,7 @@ class RpcAgent {
|
||||
// If ``message.isRequest()`` is true, the ``FutureMessage`` will be completed
|
||||
// when the response arrives. For other message types, the Future should be
|
||||
// ignored by the caller.
|
||||
virtual std::shared_ptr<FutureMessage> send(
|
||||
const WorkerId& to, Message&& message) = 0;
|
||||
std::shared_ptr<FutureMessage> send(const WorkerId& to, Message&& message);
|
||||
|
||||
// Return a reference to the ``WorkerId`` of this RpcAgent.
|
||||
// NB: not using ``c10::optional<const std::string&>`` here because we might
|
||||
@ -106,6 +105,14 @@ class RpcAgent {
|
||||
|
||||
protected:
|
||||
const WorkerId workerId_;
|
||||
|
||||
// Method that needs to be overridden by all implementations of this
|
||||
// interface. The public send() method is responsible for common
|
||||
// pre-processing shared across all implementations.
|
||||
virtual std::shared_ptr<FutureMessage> sendImpl(
|
||||
const WorkerId& to,
|
||||
Message&& message) = 0;
|
||||
const std::string workerName_;
|
||||
const RequestCallback cb_;
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user