mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
make autocast cache global instead of thread-local (#86492)
Summary: There is a memory leak because `torch.clear_autocast_cache()` clears the autocast cache from the main thread, but autograd can write to this cache from a background thread, so whatever autograd writes will leak. With some offline discussion we decided that a global cache is a practical way to deal with this, and the performance impact of the lock should be negligible. Test Plan: I don't have a local repro of the original issue, need to look into how to get that. A toy example (https://gist.github.com/vkuzo/0d6318fe7f7cb1c505e370cd5c1a643b) does cache clearing as expected on forward and backward pass. local testing: ``` python test/test_cuda.py -k autocast python test/test_autocast.py ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/86492 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
34f523b221
commit
75dbe37909
@ -9,6 +9,7 @@
|
||||
|
||||
#include <iostream>
|
||||
#include <exception>
|
||||
#include <mutex>
|
||||
|
||||
namespace at {
|
||||
namespace autocast {
|
||||
@ -64,7 +65,8 @@ namespace {
|
||||
// directly against incoming TensorImpl*s.
|
||||
using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
|
||||
using val_type = std::tuple<weakref_type, Tensor>;
|
||||
thread_local std::unordered_map<TensorImpl*, val_type> cached_casts;
|
||||
std::unordered_map<TensorImpl*, val_type> cached_casts;
|
||||
std::mutex cached_casts_mutex;
|
||||
|
||||
// nesting tracks the nesting depth of the Python-side context manager.
|
||||
// When the autocast context manager exits to a nesting level that's outside
|
||||
@ -89,6 +91,7 @@ thread_local at::ScalarType autocast_gpu_dtype = at::kHalf;
|
||||
}
|
||||
|
||||
void clear_cache() {
|
||||
const std::lock_guard<std::mutex> lock(cached_casts_mutex);
|
||||
cached_casts.clear();
|
||||
}
|
||||
|
||||
@ -155,6 +158,7 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_
|
||||
arg.scalar_type() == at::kFloat && arg.requires_grad() &&
|
||||
arg.is_leaf() && !arg.is_view() && cache_enabled);
|
||||
if (can_try_cache) {
|
||||
const std::lock_guard<std::mutex> lock(cached_casts_mutex);
|
||||
auto it = cached_casts.find(arg.unsafeGetTensorImpl());
|
||||
if (it != cached_casts.end()) {
|
||||
return std::get<1>(it->second);
|
||||
|
@ -1,9 +1,12 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
import collections
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.testing._internal.autocast_test_lists import AutocastCPUTestLists
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
class TestAutocastCPU(TestCase):
|
||||
def setUp(self):
|
||||
@ -122,6 +125,64 @@ class TestAutocastCPU(TestCase):
|
||||
for op, args in self.autocast_lists.torch_need_autocast_promote:
|
||||
self._run_autocast_outofplace(op, args, torch.float32)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
||||
class TestAutocastGPU(TestCase):
|
||||
def test_cast_cache_is_global(self):
|
||||
"""
|
||||
Verifies that the autocast cache is global. This is done by
|
||||
mocking out cache clearing at the end of the forward pass,
|
||||
running forward+backward with an explicit call to autocast in the
|
||||
backward, and verifying that the weight only get cast to float16 once.
|
||||
"""
|
||||
|
||||
class CustomLinear(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, w_t):
|
||||
ctx.save_for_backward(x, w_t)
|
||||
return torch.nn.functional.linear(x, w_t)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x, w_t = ctx.saved_tensors
|
||||
with torch.autocast(device_type='cuda'):
|
||||
dL_dX = torch.matmul(grad_output, w_t)
|
||||
dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1)
|
||||
return dL_dX, dL_dW
|
||||
|
||||
data = torch.randn(2, 3).cuda()
|
||||
weight = torch.nn.Parameter(torch.randn(4, 3).cuda())
|
||||
weight_dtype_cast_counter = 0
|
||||
|
||||
class WeightDTypeCastCounterMode(TorchDispatchMode):
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
if (
|
||||
func is torch.ops.aten._to_copy.default and
|
||||
args[0] is weight and
|
||||
kwargs['dtype'] is torch.float16
|
||||
):
|
||||
nonlocal weight_dtype_cast_counter
|
||||
weight_dtype_cast_counter += 1
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def __enter__(self):
|
||||
self.old_clear_cache = torch.clear_autocast_cache
|
||||
torch.clear_autocast_cache = lambda: None
|
||||
return super().__enter__()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
torch.clear_autocast_cache = self.old_clear_cache
|
||||
return super().__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
with WeightDTypeCastCounterMode():
|
||||
with torch.autocast(device_type='cuda'):
|
||||
output = CustomLinear.apply(data, weight)
|
||||
s = output.sum()
|
||||
s.backward()
|
||||
|
||||
self.assertEqual(weight_dtype_cast_counter, 1)
|
||||
|
||||
|
||||
class TestTorchAutocast(TestCase):
|
||||
def test_autocast_fast_dtype(self):
|
||||
gpu_fast_dtype = torch.get_autocast_gpu_dtype()
|
||||
|
Reference in New Issue
Block a user