mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Setup operator registration for distributed package (#31214)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/31214 This set up the basic infrastructure for distributed autograd and rpc to bind their operators to TorchScript, since the whole distributed package is builtin behind the `USE_DISTRIBUTED` flag, we separate the registration and build it only when the flag is on. Test Plan: Imported from OSS Differential Revision: D19137160 fbshipit-source-id: ff47dc4c380ebe273fe0eea9e5e3fccfbd6466d7
This commit is contained in:
committed by
Facebook Github Bot
parent
e33dea6e4e
commit
e3fecabdcb
@ -516,6 +516,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/script_resp.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/types.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/utils.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/register_distributed_ops.cpp
|
||||
)
|
||||
endif()
|
||||
endif()
|
||||
|
@ -11,6 +11,7 @@ import torch.distributed.rpc as rpc
|
||||
import dist_utils
|
||||
from dist_utils import dist_init
|
||||
from rpc_agent_test_fixture import RpcAgentTestFixture
|
||||
from torch.testing import FileCheck
|
||||
|
||||
import threading
|
||||
|
||||
@ -1568,5 +1569,34 @@ class DistAutogradTest(RpcAgentTestFixture):
|
||||
debug_info = dist_autograd._get_debug_info()
|
||||
self.assertEqual(0, int(debug_info['num_autograd_contexts']))
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch._six.PY3, "Pytorch distributed autograd package " "does not support python2"
|
||||
)
|
||||
class DistAutogradJitTest(RpcAgentTestFixture):
|
||||
@dist_init
|
||||
def test_get_gradients(self):
|
||||
dst_rank = self.rank
|
||||
@torch.jit.script
|
||||
def dist_get_gradients(context_id):
|
||||
# type: (int) -> (Dict[Tensor, Tensor])
|
||||
return dist_autograd.get_gradients(context_id)
|
||||
|
||||
FileCheck().check("get_gradients").run(str(dist_get_gradients.graph))
|
||||
with dist_autograd.context() as context_id:
|
||||
t1 = torch.rand((3, 3), requires_grad=True)
|
||||
t2 = torch.rand((3, 3), requires_grad=True)
|
||||
t3 = torch.add(t1, t2)
|
||||
|
||||
dist_autograd.backward([t3.sum()])
|
||||
grads = dist_get_gradients(context_id)
|
||||
|
||||
self.assertEqual(2, len(grads))
|
||||
self.assertIn(t1, grads)
|
||||
self.assertIn(t2, grads)
|
||||
self.assertEqual(torch.ones(3, 3), grads[t1])
|
||||
self.assertEqual(torch.ones(3, 3), grads[t2])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
from dist_autograd_test import DistAutogradTest
|
||||
from dist_autograd_test import DistAutogradTest, DistAutogradJitTest
|
||||
from common_distributed import MultiProcessTestCase
|
||||
from common_utils import TEST_WITH_ASAN, run_tests
|
||||
|
||||
@ -14,5 +14,12 @@ class DistAutogradTestWithSpawn(MultiProcessTestCase, DistAutogradTest):
|
||||
super(DistAutogradTestWithSpawn, self).setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ASAN, "Skip ASAN as torch + multiprocessing spawn have known issues")
|
||||
class DistAutogradJitTestWithSpawn(MultiProcessTestCase, DistAutogradJitTest):
|
||||
|
||||
def setUp(self):
|
||||
super(DistAutogradJitTestWithSpawn, self).setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
@ -143,6 +143,7 @@ libtorch_sources = [
|
||||
"torch/csrc/jit/register_prim_ops.cpp",
|
||||
"torch/csrc/jit/register_string_ops.cpp",
|
||||
"torch/csrc/jit/register_special_ops.cpp",
|
||||
"torch/csrc/jit/register_distributed_ops.cpp",
|
||||
"torch/csrc/jit/scope.cpp",
|
||||
"torch/csrc/jit/script/compiler.cpp",
|
||||
"torch/csrc/jit/script/edit_distance.cpp",
|
||||
|
40
torch/csrc/jit/register_distributed_ops.cpp
Normal file
40
torch/csrc/jit/register_distributed_ops.cpp
Normal file
@ -0,0 +1,40 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
|
||||
#include <torch/csrc/distributed/autograd/context/container.h>
|
||||
#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
|
||||
|
||||
using at::Scalar;
|
||||
using at::Tensor;
|
||||
namespace dist_autograd = torch::distributed::autograd;
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace {
|
||||
at::Tensor toOptionalTensor(const c10::IValue& v) {
|
||||
if (v.isNone()) {
|
||||
return at::Tensor();
|
||||
}
|
||||
return v.toTensor();
|
||||
}
|
||||
|
||||
at::Tensor optional_to_tensor(c10::optional<at::Tensor> v) {
|
||||
return v.has_value() ? *v : at::Tensor();
|
||||
}
|
||||
|
||||
auto reg_distributed_ops =
|
||||
torch::RegisterOperators()
|
||||
.op("aten::get_gradients(int context_id) -> Dict(Tensor, Tensor)",
|
||||
torch::RegisterOperators::options()
|
||||
.aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA)
|
||||
.catchAllKernel([](int64_t context_id) {
|
||||
const auto& autogradContext =
|
||||
dist_autograd::DistAutogradContainer::getInstance()
|
||||
.retrieveContext(context_id);
|
||||
return autogradContext->getGradients();
|
||||
}));
|
||||
|
||||
} // namespace
|
||||
} // namespace jit
|
||||
} // namespace torch
|
@ -2023,17 +2023,23 @@ def _get_builtin_table():
|
||||
for name in dir(mod):
|
||||
v = getattr(mod, name)
|
||||
if callable(v):
|
||||
_builtin_table[id(v)] = "aten::" + name
|
||||
_builtin_ops.append((v, "aten::" + name))
|
||||
for mod in _modules_containing_builtins:
|
||||
register_all(mod)
|
||||
|
||||
if not PY2:
|
||||
_builtin_ops.append((math.gcd, "aten::gcd"))
|
||||
_builtin_ops.append((math.isfinite, "aten::isfinite"))
|
||||
if PY37:
|
||||
_builtin_ops.append((math.remainder, "aten::mathremainder"))
|
||||
|
||||
import torch.distributed.autograd as dist_autograd
|
||||
if dist_autograd.is_available():
|
||||
_builtin_ops.append((dist_autograd.get_gradients, "aten::get_gradients"))
|
||||
|
||||
# populate the _builtin_table from _builtin_ops
|
||||
for builtin, aten_op in _builtin_ops:
|
||||
_builtin_table[id(builtin)] = aten_op
|
||||
if not PY2:
|
||||
_builtin_table[id(math.gcd)] = "aten::gcd"
|
||||
_builtin_table[id(math.isfinite)] = "aten::isfinite"
|
||||
if PY37:
|
||||
_builtin_table[id(math.remainder)] = "aten::mathremainder"
|
||||
|
||||
return _builtin_table
|
||||
|
||||
|
Reference in New Issue
Block a user