Compare commits

...

5 Commits

Author SHA1 Message Date
56b43f4fec Perform appropriate CUDA stream synchronization in distributed autograd. (#53929) (#54358)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53929

The local autograd engine performs appropriate stream synchronization
between autograd nodes in the graph to ensure a consumer's stream is
synchronized with the producer's stream before executing the consumer.

However in case of distributed autograd, the SendRpcBackward function receives
gradients over the wire and TensorPipe uses its own pool of streams for this
purpose. As a result, the tensors are received on TensorPipe's stream pool but
SendRpcBackward runs on a different stream during the backward pass and there
is no logic to synchronize these streams.

To fix this, I've enhanced DistEngine to synchronize these streams
appropriately when it receives grads over the wire.
ghstack-source-id: 124055277

(Note: this ignores all push blocking failures!)

Test Plan:
1) Added unit test which reproduced the issue.
2) waitforbuildbot.

Reviewed By: walterddr, wanchaol

Differential Revision: D27025307

fbshipit-source-id: 2944854e688e001cb3989d2741727b30d9278414

Co-authored-by: Pritam Damania <pritam.damania@fb.com>
2021-03-23 19:28:21 -07:00
6c394614f0 [CI] Install compatible cmath for Win builds (#54556)
* [CI]Install older cmath during Windows build (#54431)

Summary:
Based on peterjc123 analysis, `cmath` after 26bbe2ad50 (diff-3fa97ceb95d524432661f01d4b34509c6d261a2f7f45ddcf26f79f55b3eec88a) renders a lot of CUDA fail to compile with:
```
error: calling a __host__ function("__copysignf") from a __host__ __device__ function("c10::guts::detail::apply_impl< ::at::native::AUnaryFunctor< ::>  &,     ::std::tuple<float >  &, (unsigned long long)0ull > ") is not allowed
```
Workaround for https://github.com/pytorch/pytorch/issues/54382

Pull Request resolved: https://github.com/pytorch/pytorch/pull/54431

Reviewed By: anjali411

Differential Revision: D27234299

Pulled By: malfet

fbshipit-source-id: b3f1fef941341222cc10cb27346fcf4a1d522a0c

* [CI] Install compatible cmath for Win binary builds (#54527)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54527

Reviewed By: walterddr

Differential Revision: D27269528

Pulled By: malfet

fbshipit-source-id: 4afdc706598f3a6ad296468dfb77a70433ae7d0f
2021-03-23 19:02:01 -07:00
7c3c293ea7 [1.8] Don't build TensorPipe CMA backend on old glibc versions (#54491)
Some users who are building from source on old glibc versions are hitting the issue of TensorPipe using the process_vm_readv syscall which is not wrapped by glibc. This PR tries to check that condition in CMake and disable that backend in such cases.

This should have no effect on PyTorch's official builds, it should just help people who are building from source.
2021-03-23 15:56:26 -07:00
9d43171746 [1.8.1] Replace thrust with cub in randperm (#54537)
Summary:
Benchmark of
```python
%timeit torch.randperm(100000, device='cuda'); torch.cuda.synchronize()
```
thrust:
```
5.76 ms ± 42.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```
cub:
```
3.02 ms ± 32.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```

sync in thrust sort is removed

Warning:
Thrust supports 64bit indexing, but cub doesn't, so this is a functional regression. However, `torch.randperm(2**31, device='cuda')` fails with OOM on 40GB A100, and `torch.randperm(2**32, device='cuda')` fails with OOM on 80GB A100, so I think this functional regression has low impact and is acceptable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/53841

Reviewed By: albanD

Differential Revision: D26993453

Pulled By: ngimel

fbshipit-source-id: 39dd128559d53dbb01cab1585e5462cb5f3cceca

Co-authored-by: Xiang Gao <qasdfgtyuiop@gmail.com>
2021-03-23 15:45:20 -07:00
f3c950e04e various doc building cleanups (#54141) 2021-03-23 11:23:02 -07:00
17 changed files with 154 additions and 61 deletions

View File

@ -697,6 +697,11 @@ jobs:
executor: <<parameters.executor>>
steps:
- checkout
- run:
name: _HACK_ Install CUDA compatible cmath
no_output_timeout: 1m
command: |
powershell .circleci/scripts/vs_install_cmath.ps1
- run:
name: Install Cuda
no_output_timeout: 30m
@ -1084,6 +1089,11 @@ jobs:
steps:
# See Note [Workspace for CircleCI scripts] in job-specs-setup.yml
- checkout
- run:
name: _HACK_ Install CUDA compatible cmath
no_output_timeout: 1m
command: |
powershell .circleci/scripts/vs_install_cmath.ps1
- run:
<<: *binary_checkout
- run:

View File

@ -111,14 +111,6 @@ popd
git rm -rf "$install_path" || true
mv "$pt_checkout/docs/build/html" "$install_path"
# Add the version handler by search and replace.
# XXX: Consider moving this to the docs Makefile or site build
if [ "$is_master_doc" = true ]; then
find "$install_path" -name "*.html" -print0 | xargs -0 perl -pi -w -e "s@master\s+\((\d\.\d\.[A-Fa-f0-9]+\+[A-Fa-f0-9]+)\s+\)@<a href='http://pytorch.org/docs/versions.html'>\1 \&#x25BC</a>@g"
else
find "$install_path" -name "*.html" -print0 | xargs -0 perl -pi -w -e "s@master\s+\((\d\.\d\.[A-Fa-f0-9]+\+[A-Fa-f0-9]+)\s+\)@<a href='http://pytorch.org/docs/versions.html'>$version \&#x25BC</a>@g"
fi
# Prevent Google from indexing $install_path/_modules. This folder contains
# generated source files.
# NB: the following only works on gnu sed. The sed shipped with mac os is different.

View File

@ -0,0 +1,5 @@
$CMATH_DOWNLOAD_LINK = "https://raw.githubusercontent.com/microsoft/STL/12c684bba78f9b032050526abdebf14f58ca26a3/stl/inc/cmath"
$VC14_28_INSTALL_PATH="C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.28.29910\include"
curl.exe --retry 3 -kL $CMATH_DOWNLOAD_LINK --output "$home\cmath"
Move-Item -Path "$home\cmath" -Destination "$VC14_28_INSTALL_PATH" -Force

View File

@ -293,6 +293,11 @@
steps:
# See Note [Workspace for CircleCI scripts] in job-specs-setup.yml
- checkout
- run:
name: _HACK_ Install CUDA compatible cmath
no_output_timeout: 1m
command: |
powershell .circleci/scripts/vs_install_cmath.ps1
- run:
<<: *binary_checkout
- run:

View File

@ -256,6 +256,11 @@ jobs:
executor: <<parameters.executor>>
steps:
- checkout
- run:
name: _HACK_ Install CUDA compatible cmath
no_output_timeout: 1m
command: |
powershell .circleci/scripts/vs_install_cmath.ps1
- run:
name: Install Cuda
no_output_timeout: 30m

View File

@ -13,10 +13,12 @@
#include <thrust/sort.h>
#include <thrust/execution_policy.h>
#include <thrust/sequence.h>
#include <cub/cub.cuh>
#include <algorithm>
#include <cstddef>
#include <cmath>
#include <limits>
namespace at {
namespace native {
@ -102,28 +104,40 @@ Tensor& randperm_out_cuda(Tensor& result, int64_t n, c10::optional<Generator> ge
// Generate random values for the keys array
AT_DISPATCH_ALL_TYPES(
result.scalar_type(), "randperm_out_cuda", [&] {
TORCH_CHECK(n <= std::numeric_limits<int>::max(),
"randperm of tensors larger than INT_MAX is not supported yet in pytorch");
auto keys = at::empty(result.sizes(), result.options()).random_(generator);
auto keys_data = thrust::device_ptr<scalar_t>(keys.data_ptr<scalar_t>());
auto range = at::arange(n, result.options());
auto keys_tmp = at::empty_like(keys);
// shuffled_data points to the underlying data of the output tensor if the tensor is contiguous; otherwise it
// points to a new tensor.
Tensor shuffled;
thrust::device_ptr<scalar_t> shuffled_data;
scalar_t *shuffled_data;
if (result.is_contiguous()) {
shuffled_data = thrust::device_ptr<scalar_t>(result.data_ptr<scalar_t>());
shuffled_data = result.data_ptr<scalar_t>();
} else {
shuffled = at::empty(n, result.options());
shuffled_data = thrust::device_ptr<scalar_t>(shuffled.data_ptr<scalar_t>());
shuffled_data = shuffled.data_ptr<scalar_t>();
}
auto state = globalContext().getTHCState();
THCThrustAllocator thrustAlloc(state);
auto policy = thrust::cuda::par(thrustAlloc).on(at::cuda::getCurrentCUDAStream());
thrust::sequence(policy, shuffled_data, shuffled_data + n);
// Use the sorted order of keys to rearrange the result array
thrust::sort_by_key(policy, keys_data, keys_data + n, shuffled_data);
void *d_temp_storage = nullptr;
size_t temp_storage_bytes = 0;
cub::DeviceRadixSort::SortPairs(
nullptr, temp_storage_bytes,
keys.data_ptr<scalar_t>(), keys_tmp.data_ptr<scalar_t>(),
range.data_ptr<scalar_t>(), shuffled_data, n,
0, sizeof(scalar_t) * 8, at::cuda::getCurrentCUDAStream());
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto dataPtr = allocator.allocate(temp_storage_bytes);
cub::DeviceRadixSort::SortPairs(
dataPtr.get(), temp_storage_bytes,
keys.data_ptr<scalar_t>(), keys_tmp.data_ptr<scalar_t>(),
range.data_ptr<scalar_t>(), shuffled_data, n,
0, sizeof(scalar_t) * 8, at::cuda::getCurrentCUDAStream());
if (!result.is_contiguous()) {
result.copy_(shuffled);

View File

@ -1,34 +0,0 @@
{% extends "!layout.html" %}
<link rel="canonical" href="{{ theme_canonical_url }}{{ pagename }}.html" />
{% block menu %}
{{ super() }}
{% endblock %}
{% block footer %}
{{ super() }}
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-90545585-1', 'auto');
ga('send', 'pageview');
</script>
<script async src="https://www.googletagmanager.com/gtag/js?id=UA-117752657-2"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag(){dataLayer.push(arguments);}
gtag('js', new Date());
gtag('config', 'UA-117752657-2');
</script>
<img height="1" width="1" style="border-style:none;" alt="" src="https://www.googleadservices.com/pagead/conversion/795629140/?label=txkmCPmdtosBENSssfsC&amp;guid=ON&amp;script=0"/>
{% endblock %}

View File

@ -1,18 +1,31 @@
{% extends "!layout.html" %}
<link rel="canonical" href="{{ theme_canonical_url }}{{ pagename }}.html" />
{% block menu %}
{% if release == "master" %}
<div>
<a style="color:#F05732" href="{{ theme_canonical_url }}{{ pagename }}.html">
You are viewing unstable developer preview docs.
Click here to view docs for latest stable release.
</a>
</div>
{% endif %}
{{ super() }}
{% endblock %}
{% block sidebartitle %}
<div class="version">
<a href='https://pytorch.org/docs/versions.html'>{{ version }} &#x25BC</a>
</div>
{% include "searchbox.html" %}
{% endblock %}
{% block footer %}
{{ super() }}
<script script type="text/javascript">
var collapsedSections = ['Notes', 'Language Bindings', 'Libraries', 'Community'];
</script>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),

View File

@ -4,7 +4,7 @@ Complex Numbers
===============
Complex numbers are numbers that can be expressed in the form :math:`a + bj`, where a and b are real numbers,
and *j* is a solution of the equation :math:`x^2 = 1`. Complex numbers frequently occur in mathematics and
and *j* is a solution of the equation :math:`x^2 = -1`. Complex numbers frequently occur in mathematics and
engineering, especially in signal processing. Traditionally many users and libraries (e.g., TorchAudio) have
handled complex numbers by representing the data in float tensors with shape :math:`(..., 2)` where the last
dimension contains the real and imaginary values.

View File

@ -75,8 +75,6 @@ napoleon_use_ivar = True
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
if RELEASE:
templates_path = ['_templates-stable'] + templates_path
# TODO: document these and remove them from here.
@ -170,6 +168,8 @@ if RELEASE:
html_title = " ".join((project, torch.__version__, "documentation"))
else:
html_title = " ".join((project, torch.__version__[:version_end], "documentation"))
version = torch.__version__
release = version
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.

View File

@ -4823,7 +4823,7 @@ If :math:`m < n`, :func:`lstsq` solves the least-norm problem:
.. math::
\begin{array}{ll}
\begin{array}{llll}
\min_X & \|X\|_2 & \text{subject to} & AX = B.
\end{array}

View File

@ -1,6 +1,7 @@
#include <queue>
#include <ATen/Parallel.h>
#include <c10/core/Event.h>
#include <torch/csrc/autograd/functions/accumulate_grad.h>
#include <torch/csrc/autograd/input_buffer.h>
#include <torch/csrc/distributed/autograd/context/container.h>
@ -423,8 +424,27 @@ std::shared_ptr<c10::ivalue::Future> DistEngine::
std::shared_ptr<c10::ivalue::Future> DistEngine::executeSendFunctionAsync(
const ContextPtr& autogradContext,
const std::shared_ptr<Node>& sendFunction,
const std::shared_ptr<SendRpcBackward>& sendFunction,
bool retainGraph) {
// Typically the local autograd engine ensures stream synchronizations between
// nodes in the graph. However, for distributed autograd the sendFunction
// inputs might have been retrieved over the wire on a separate stream and the
// sendFunction itself runs on a different stream. As a result, we need to
// manually synchronize those two streams here.
const auto& send_backward_stream = sendFunction->stream(c10::DeviceType::CUDA);
if (send_backward_stream) {
for (const auto& grad : sendFunction->getGrads()) {
const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA};
const auto default_stream = guard.getStream(grad.device());
if (send_backward_stream != default_stream) {
auto event = c10::Event{c10::DeviceType::CUDA};
event.record(default_stream);
send_backward_stream->wait(event);
}
}
}
std::unique_lock<std::mutex> lock(initializedContextIdsLock_);
if (initializedContextIds_.find(autogradContext->contextId()) ==
initializedContextIds_.end()) {

View File

@ -46,7 +46,7 @@ class TORCH_API DistEngine {
// The gradients are accumulated in the provided autograd context.
std::shared_ptr<c10::ivalue::Future> executeSendFunctionAsync(
const ContextPtr& autogradContext,
const std::shared_ptr<torch::autograd::Node>& sendFunction,
const std::shared_ptr<SendRpcBackward>& sendFunction,
bool retainGraph);
// Number of backward passes currently running for the Distributed Engine.

View File

@ -23,6 +23,10 @@ void SendRpcBackward::setGrads(const torch::autograd::variable_list& grads) {
grads_ = grads;
}
const torch::autograd::variable_list& SendRpcBackward::getGrads() const {
return grads_;
}
} // namespace autograd
} // namespace distributed
} // namespace torch

View File

@ -25,6 +25,9 @@ struct TORCH_API SendRpcBackward : public torch::autograd::Node {
// computation.
void setGrads(const torch::autograd::variable_list& grads);
// Retrieve the grads for the function.
const torch::autograd::variable_list& getGrads() const;
private:
torch::autograd::variable_list grads_;
};

View File

@ -3,6 +3,7 @@ import threading
import time
import unittest
from enum import Enum
import random
import torch
from datetime import timedelta
import torch.distributed as dist
@ -2266,3 +2267,58 @@ class TensorPipeDistAutogradTest(RpcAgentTestFixture):
self.assertEqual(t2.device, grads[t2].device)
rpc.shutdown()
class MyRemoteCompute(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, input):
input = input * 2.0
return input
class MyLocalCompute(torch.nn.Module):
def __init__(self, next_stage):
super().__init__()
self.next_stage = next_stage
def forward(self, input):
return self.next_stage.rpc_sync().forward(input)
@skip_if_lt_x_gpu(4)
def test_dist_autograd_sync_streams(self):
options = self.rpc_backend_options
dst = worker_name((self.rank + 1) % self.world_size)
# The reverse of this device mapping should be used for the backward pass.
options.set_device_map(dst, {self.rank: (self.rank + 1) % self.world_size})
rpc.init_rpc(
name=worker_name(self.rank),
backend=self.rpc_backend,
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=options,
)
remote_compute = rpc.remote(dst, TensorPipeDistAutogradTest.MyRemoteCompute)
local_compute = TensorPipeDistAutogradTest.MyLocalCompute(remote_compute)
for _ in range(10):
input = torch.rand([1000, 10000], device=self.rank, requires_grad=True)
# Run local autograd
result = input * 2.0
r = random.random()
loss = result.sum() * r
loss.backward()
# Run distributed autograd
with dist_autograd.context() as context_id:
result = local_compute(input)
loss = result.sum() * r
dist_autograd.backward(context_id, [loss])
# Compare grads.
grads = dist_autograd.get_gradients(context_id)
self.assertEqual(input.grad, grads[input])
rpc.shutdown()