make autocast cache global instead of thread-local (#86492)

Summary:

There is a memory leak because `torch.clear_autocast_cache()` clears
the autocast cache from the main thread, but autograd can write to
this cache from a background thread, so whatever autograd writes
will leak.

With some offline discussion we decided that a global cache is a
practical way to deal with this, and the performance impact of the
lock should be negligible.

Test Plan:

I don't have a local repro of the original issue, need to look into how to get
that.

A toy example
(https://gist.github.com/vkuzo/0d6318fe7f7cb1c505e370cd5c1a643b)
does cache clearing as expected on forward and backward pass.

local testing:
```
python test/test_cuda.py -k autocast
python test/test_autocast.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86492
Approved by: https://github.com/ezyang
This commit is contained in:
vasiliy
2022-10-28 12:15:49 -07:00
committed by PyTorch MergeBot
parent 34f523b221
commit 75dbe37909
2 changed files with 66 additions and 1 deletions

View File

@ -1,9 +1,12 @@
# Owner(s): ["module: unknown"]
import collections
import unittest
import torch
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.autocast_test_lists import AutocastCPUTestLists
from torch.utils._python_dispatch import TorchDispatchMode
class TestAutocastCPU(TestCase):
def setUp(self):
@ -122,6 +125,64 @@ class TestAutocastCPU(TestCase):
for op, args in self.autocast_lists.torch_need_autocast_promote:
self._run_autocast_outofplace(op, args, torch.float32)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
class TestAutocastGPU(TestCase):
def test_cast_cache_is_global(self):
"""
Verifies that the autocast cache is global. This is done by
mocking out cache clearing at the end of the forward pass,
running forward+backward with an explicit call to autocast in the
backward, and verifying that the weight only get cast to float16 once.
"""
class CustomLinear(torch.autograd.Function):
@staticmethod
def forward(ctx, x, w_t):
ctx.save_for_backward(x, w_t)
return torch.nn.functional.linear(x, w_t)
@staticmethod
def backward(ctx, grad_output):
x, w_t = ctx.saved_tensors
with torch.autocast(device_type='cuda'):
dL_dX = torch.matmul(grad_output, w_t)
dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1)
return dL_dX, dL_dW
data = torch.randn(2, 3).cuda()
weight = torch.nn.Parameter(torch.randn(4, 3).cuda())
weight_dtype_cast_counter = 0
class WeightDTypeCastCounterMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if (
func is torch.ops.aten._to_copy.default and
args[0] is weight and
kwargs['dtype'] is torch.float16
):
nonlocal weight_dtype_cast_counter
weight_dtype_cast_counter += 1
return func(*args, **kwargs)
def __enter__(self):
self.old_clear_cache = torch.clear_autocast_cache
torch.clear_autocast_cache = lambda: None
return super().__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
torch.clear_autocast_cache = self.old_clear_cache
return super().__exit__(exc_type, exc_val, exc_tb)
with WeightDTypeCastCounterMode():
with torch.autocast(device_type='cuda'):
output = CustomLinear.apply(data, weight)
s = output.sum()
s.backward()
self.assertEqual(weight_dtype_cast_counter, 1)
class TestTorchAutocast(TestCase):
def test_autocast_fast_dtype(self):
gpu_fast_dtype = torch.get_autocast_gpu_dtype()