make autocast cache global instead of thread-local (#86492)

Summary:

There is a memory leak because `torch.clear_autocast_cache()` clears
the autocast cache from the main thread, but autograd can write to
this cache from a background thread, so whatever autograd writes
will leak.

With some offline discussion we decided that a global cache is a
practical way to deal with this, and the performance impact of the
lock should be negligible.

Test Plan:

I don't have a local repro of the original issue, need to look into how to get
that.

A toy example
(https://gist.github.com/vkuzo/0d6318fe7f7cb1c505e370cd5c1a643b)
does cache clearing as expected on forward and backward pass.

local testing:
```
python test/test_cuda.py -k autocast
python test/test_autocast.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86492
Approved by: https://github.com/ezyang
This commit is contained in:
vasiliy
2022-10-28 12:15:49 -07:00
committed by PyTorch MergeBot
parent 34f523b221
commit 75dbe37909
2 changed files with 66 additions and 1 deletions

View File

@ -9,6 +9,7 @@
#include <iostream>
#include <exception>
#include <mutex>
namespace at {
namespace autocast {
@ -64,7 +65,8 @@ namespace {
// directly against incoming TensorImpl*s.
using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
using val_type = std::tuple<weakref_type, Tensor>;
thread_local std::unordered_map<TensorImpl*, val_type> cached_casts;
std::unordered_map<TensorImpl*, val_type> cached_casts;
std::mutex cached_casts_mutex;
// nesting tracks the nesting depth of the Python-side context manager.
// When the autocast context manager exits to a nesting level that's outside
@ -89,6 +91,7 @@ thread_local at::ScalarType autocast_gpu_dtype = at::kHalf;
}
void clear_cache() {
const std::lock_guard<std::mutex> lock(cached_casts_mutex);
cached_casts.clear();
}
@ -155,6 +158,7 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_
arg.scalar_type() == at::kFloat && arg.requires_grad() &&
arg.is_leaf() && !arg.is_view() && cache_enabled);
if (can_try_cache) {
const std::lock_guard<std::mutex> lock(cached_casts_mutex);
auto it = cached_casts.find(arg.unsafeGetTensorImpl());
if (it != cached_casts.end()) {
return std::get<1>(it->second);

View File

@ -1,9 +1,12 @@
# Owner(s): ["module: unknown"]
import collections
import unittest
import torch
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.autocast_test_lists import AutocastCPUTestLists
from torch.utils._python_dispatch import TorchDispatchMode
class TestAutocastCPU(TestCase):
def setUp(self):
@ -122,6 +125,64 @@ class TestAutocastCPU(TestCase):
for op, args in self.autocast_lists.torch_need_autocast_promote:
self._run_autocast_outofplace(op, args, torch.float32)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
class TestAutocastGPU(TestCase):
def test_cast_cache_is_global(self):
"""
Verifies that the autocast cache is global. This is done by
mocking out cache clearing at the end of the forward pass,
running forward+backward with an explicit call to autocast in the
backward, and verifying that the weight only get cast to float16 once.
"""
class CustomLinear(torch.autograd.Function):
@staticmethod
def forward(ctx, x, w_t):
ctx.save_for_backward(x, w_t)
return torch.nn.functional.linear(x, w_t)
@staticmethod
def backward(ctx, grad_output):
x, w_t = ctx.saved_tensors
with torch.autocast(device_type='cuda'):
dL_dX = torch.matmul(grad_output, w_t)
dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1)
return dL_dX, dL_dW
data = torch.randn(2, 3).cuda()
weight = torch.nn.Parameter(torch.randn(4, 3).cuda())
weight_dtype_cast_counter = 0
class WeightDTypeCastCounterMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if (
func is torch.ops.aten._to_copy.default and
args[0] is weight and
kwargs['dtype'] is torch.float16
):
nonlocal weight_dtype_cast_counter
weight_dtype_cast_counter += 1
return func(*args, **kwargs)
def __enter__(self):
self.old_clear_cache = torch.clear_autocast_cache
torch.clear_autocast_cache = lambda: None
return super().__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
torch.clear_autocast_cache = self.old_clear_cache
return super().__exit__(exc_type, exc_val, exc_tb)
with WeightDTypeCastCounterMode():
with torch.autocast(device_type='cuda'):
output = CustomLinear.apply(data, weight)
s = output.sum()
s.backward()
self.assertEqual(weight_dtype_cast_counter, 1)
class TestTorchAutocast(TestCase):
def test_autocast_fast_dtype(self):
gpu_fast_dtype = torch.get_autocast_gpu_dtype()