Setup operator registration for distributed package (#31214)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31214

This set up the basic infrastructure for distributed autograd and rpc to
bind their operators to TorchScript, since the whole distributed package
is builtin behind the `USE_DISTRIBUTED` flag, we separate the
registration and build it only when the flag is on.

Test Plan: Imported from OSS

Differential Revision: D19137160

fbshipit-source-id: ff47dc4c380ebe273fe0eea9e5e3fccfbd6466d7
This commit is contained in:
Wanchao Liang
2019-12-17 17:24:47 -08:00
committed by Facebook Github Bot
parent e33dea6e4e
commit e3fecabdcb
6 changed files with 92 additions and 7 deletions

View File

@ -516,6 +516,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/distributed/rpc/script_resp.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/types.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/utils.cpp
${TORCH_SRC_DIR}/csrc/jit/register_distributed_ops.cpp
)
endif()
endif()

View File

@ -11,6 +11,7 @@ import torch.distributed.rpc as rpc
import dist_utils
from dist_utils import dist_init
from rpc_agent_test_fixture import RpcAgentTestFixture
from torch.testing import FileCheck
import threading
@ -1568,5 +1569,34 @@ class DistAutogradTest(RpcAgentTestFixture):
debug_info = dist_autograd._get_debug_info()
self.assertEqual(0, int(debug_info['num_autograd_contexts']))
@unittest.skipIf(
not torch._six.PY3, "Pytorch distributed autograd package " "does not support python2"
)
class DistAutogradJitTest(RpcAgentTestFixture):
@dist_init
def test_get_gradients(self):
dst_rank = self.rank
@torch.jit.script
def dist_get_gradients(context_id):
# type: (int) -> (Dict[Tensor, Tensor])
return dist_autograd.get_gradients(context_id)
FileCheck().check("get_gradients").run(str(dist_get_gradients.graph))
with dist_autograd.context() as context_id:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
t3 = torch.add(t1, t2)
dist_autograd.backward([t3.sum()])
grads = dist_get_gradients(context_id)
self.assertEqual(2, len(grads))
self.assertIn(t1, grads)
self.assertIn(t2, grads)
self.assertEqual(torch.ones(3, 3), grads[t1])
self.assertEqual(torch.ones(3, 3), grads[t2])
if __name__ == '__main__':
unittest.main()

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python3
from __future__ import absolute_import, division, print_function, unicode_literals
from dist_autograd_test import DistAutogradTest
from dist_autograd_test import DistAutogradTest, DistAutogradJitTest
from common_distributed import MultiProcessTestCase
from common_utils import TEST_WITH_ASAN, run_tests
@ -14,5 +14,12 @@ class DistAutogradTestWithSpawn(MultiProcessTestCase, DistAutogradTest):
super(DistAutogradTestWithSpawn, self).setUp()
self._spawn_processes()
@unittest.skipIf(TEST_WITH_ASAN, "Skip ASAN as torch + multiprocessing spawn have known issues")
class DistAutogradJitTestWithSpawn(MultiProcessTestCase, DistAutogradJitTest):
def setUp(self):
super(DistAutogradJitTestWithSpawn, self).setUp()
self._spawn_processes()
if __name__ == '__main__':
run_tests()

View File

@ -143,6 +143,7 @@ libtorch_sources = [
"torch/csrc/jit/register_prim_ops.cpp",
"torch/csrc/jit/register_string_ops.cpp",
"torch/csrc/jit/register_special_ops.cpp",
"torch/csrc/jit/register_distributed_ops.cpp",
"torch/csrc/jit/scope.cpp",
"torch/csrc/jit/script/compiler.cpp",
"torch/csrc/jit/script/edit_distance.cpp",

View File

@ -0,0 +1,40 @@
#include <ATen/ATen.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/csrc/distributed/autograd/context/container.h>
#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
using at::Scalar;
using at::Tensor;
namespace dist_autograd = torch::distributed::autograd;
namespace torch {
namespace jit {
namespace {
at::Tensor toOptionalTensor(const c10::IValue& v) {
if (v.isNone()) {
return at::Tensor();
}
return v.toTensor();
}
at::Tensor optional_to_tensor(c10::optional<at::Tensor> v) {
return v.has_value() ? *v : at::Tensor();
}
auto reg_distributed_ops =
torch::RegisterOperators()
.op("aten::get_gradients(int context_id) -> Dict(Tensor, Tensor)",
torch::RegisterOperators::options()
.aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA)
.catchAllKernel([](int64_t context_id) {
const auto& autogradContext =
dist_autograd::DistAutogradContainer::getInstance()
.retrieveContext(context_id);
return autogradContext->getGradients();
}));
} // namespace
} // namespace jit
} // namespace torch

View File

@ -2023,17 +2023,23 @@ def _get_builtin_table():
for name in dir(mod):
v = getattr(mod, name)
if callable(v):
_builtin_table[id(v)] = "aten::" + name
_builtin_ops.append((v, "aten::" + name))
for mod in _modules_containing_builtins:
register_all(mod)
if not PY2:
_builtin_ops.append((math.gcd, "aten::gcd"))
_builtin_ops.append((math.isfinite, "aten::isfinite"))
if PY37:
_builtin_ops.append((math.remainder, "aten::mathremainder"))
import torch.distributed.autograd as dist_autograd
if dist_autograd.is_available():
_builtin_ops.append((dist_autograd.get_gradients, "aten::get_gradients"))
# populate the _builtin_table from _builtin_ops
for builtin, aten_op in _builtin_ops:
_builtin_table[id(builtin)] = aten_op
if not PY2:
_builtin_table[id(math.gcd)] = "aten::gcd"
_builtin_table[id(math.isfinite)] = "aten::isfinite"
if PY37:
_builtin_table[id(math.remainder)] = "aten::mathremainder"
return _builtin_table