mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
158232 Fix autocast cache incorrectly retaining no_grad state (#165068)
Fixes #158232 The autocast caching heuristic in `aten/src/ATen/autocast_mode.cpp:139` did not account for gradient mode state when deciding whether to cache. FSDP2 is not directly related. ~~This PR adds `GradMode::is_enabled()` check to caching condition. Caching is now disabled in `no_grad()` contexts to prevent storing tensors with incorrect gradient state. Ensures correctness at the cost of using cache.~~ This PR proposes separate caches for gradient-enabled and gradient-disabled modes. Adds tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165068 Approved by: https://github.com/ngimel, https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
6dedd34c31
commit
5daef30b26
@ -2,6 +2,7 @@
|
||||
|
||||
#include <mutex>
|
||||
#include <ATen/CachedTensorUtils.h>
|
||||
#include <c10/core/GradMode.h>
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
|
||||
namespace at::autocast {
|
||||
@ -36,10 +37,29 @@ namespace {
|
||||
using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
|
||||
using val_type = std::tuple<weakref_type, Tensor>;
|
||||
|
||||
ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
|
||||
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts;
|
||||
return cached_casts;
|
||||
// We maintain separate caches for gradient-enabled and gradient-disabled modes.
|
||||
// This ensures that tensors cached in torch.no_grad() (with requires_grad=False)
|
||||
// are not incorrectly reused in gradient-enabled contexts.
|
||||
// This fixes issue #158232 while maintaining optimal performance for both modes.
|
||||
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts_grad_enabled() {
|
||||
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts_grad_enabled;
|
||||
return cached_casts_grad_enabled;
|
||||
}
|
||||
|
||||
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts_grad_disabled() {
|
||||
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts_grad_disabled;
|
||||
return cached_casts_grad_disabled;
|
||||
}
|
||||
|
||||
// Helper function to get the appropriate cache based on current gradient mode.
|
||||
// This allows us to cache tensors separately for grad-enabled and grad-disabled contexts,
|
||||
// preventing incorrect cache hits when gradient mode changes.
|
||||
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
|
||||
return at::GradMode::is_enabled() ?
|
||||
get_cached_casts_grad_enabled() :
|
||||
get_cached_casts_grad_disabled();
|
||||
}
|
||||
|
||||
std::mutex cached_casts_mutex;
|
||||
|
||||
|
||||
@ -86,7 +106,9 @@ thread_local bool cache_enabled = true;
|
||||
|
||||
void clear_cache() {
|
||||
const std::lock_guard<std::mutex> lock(cached_casts_mutex);
|
||||
get_cached_casts().clear();
|
||||
// Clear both caches to ensure consistent behavior regardless of current gradient mode
|
||||
get_cached_casts_grad_enabled().clear();
|
||||
get_cached_casts_grad_disabled().clear();
|
||||
}
|
||||
|
||||
int increment_nesting() {
|
||||
@ -121,6 +143,11 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_
|
||||
if (is_eligible(arg, device_type) && (arg.scalar_type() != to_type)) {
|
||||
// Heuristic: Do what Apex does, and cache lower_precision_fp casts of fp32 model weights (leaves).
|
||||
// See cached_casts declaration above for detailed strategy.
|
||||
//
|
||||
// We maintain separate caches for gradient-enabled and gradient-disabled modes
|
||||
// (see get_cached_casts() above). This ensures correctness when mixing torch.no_grad()
|
||||
// with torch.autocast(), while maintaining optimal performance for both training and inference.
|
||||
// This fixes issue #158232 without any performance regression.
|
||||
bool can_try_cache = (to_type == get_lower_precision_fp_from_device_type(device_type) &&
|
||||
arg.scalar_type() == at::kFloat && arg.requires_grad() &&
|
||||
arg.is_leaf() && !arg.is_view() && cache_enabled &&
|
||||
|
@ -384,6 +384,143 @@ class TestTorchAutocast(TestCase):
|
||||
with self.assertRaisesRegex(expected_exception=ValueError, expected_regex=msg):
|
||||
torch.autocast(device_type=dev)
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
def test_autocast_nograd_caching_issue_158232(self):
|
||||
"""
|
||||
Regression test for issue #158232: autocast + no_grad incompatibility
|
||||
|
||||
When torch.no_grad() is nested inside torch.autocast(), the autocast cache
|
||||
must not cache tensors created in the no_grad context, because they lack
|
||||
gradient tracking. If cached, subsequent operations in gradient-enabled mode
|
||||
would incorrectly use the no-gradient cached version.
|
||||
|
||||
Before fix: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
|
||||
After fix: Should work correctly
|
||||
"""
|
||||
model = torch.nn.Linear(2, 2)
|
||||
inp = torch.randn(8, 2)
|
||||
|
||||
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
|
||||
# First forward pass in no_grad context (e.g., shape inference)
|
||||
with torch.no_grad():
|
||||
out1 = model(inp)
|
||||
self.assertFalse(
|
||||
out1.requires_grad, "Output in no_grad should not require grad"
|
||||
)
|
||||
|
||||
# Second forward pass with gradients enabled (e.g., training)
|
||||
out2 = model(inp)
|
||||
self.assertTrue(
|
||||
out2.requires_grad,
|
||||
"Output should require gradients after exiting no_grad",
|
||||
)
|
||||
self.assertIsNotNone(
|
||||
out2.grad_fn, "Output should have grad_fn after exiting no_grad"
|
||||
)
|
||||
|
||||
# Backward pass should work
|
||||
loss = out2.mean()
|
||||
loss.backward()
|
||||
|
||||
# Verify gradients were computed
|
||||
self.assertIsNotNone(model.weight.grad)
|
||||
self.assertIsNotNone(model.bias.grad)
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
def test_autocast_inference_mode_interaction(self):
|
||||
"""
|
||||
Test that autocast works correctly with torch.inference_mode()
|
||||
|
||||
InferenceMode is a stricter version of no_grad that provides additional
|
||||
performance optimizations. Verify it doesn't break with autocast.
|
||||
"""
|
||||
model = torch.nn.Linear(2, 2)
|
||||
inp = torch.randn(8, 2)
|
||||
|
||||
# Test 1: inference_mode inside autocast
|
||||
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
|
||||
with torch.inference_mode():
|
||||
out1 = model(inp)
|
||||
self.assertFalse(out1.requires_grad)
|
||||
self.assertEqual(out1.dtype, torch.bfloat16)
|
||||
|
||||
# After exiting inference_mode, gradients should work
|
||||
out2 = model(inp)
|
||||
self.assertTrue(out2.requires_grad)
|
||||
out2.mean().backward()
|
||||
|
||||
# Test 2: autocast inside inference_mode
|
||||
with torch.inference_mode():
|
||||
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
|
||||
out = model(inp)
|
||||
self.assertFalse(out.requires_grad)
|
||||
self.assertEqual(out.dtype, torch.bfloat16)
|
||||
|
||||
def test_autocast_caching_still_works_with_gradients(self):
|
||||
"""
|
||||
Verify that autocast caching still functions correctly when gradients ARE enabled.
|
||||
|
||||
This test ensures the fix for #158232 didn't break normal caching behavior.
|
||||
We can't directly observe cache hits, but we verify that repeated operations
|
||||
with gradients enabled work correctly.
|
||||
"""
|
||||
model = torch.nn.Linear(2, 2)
|
||||
inp = torch.randn(8, 2)
|
||||
|
||||
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
|
||||
# Multiple forward passes with gradients enabled
|
||||
out1 = model(inp)
|
||||
out2 = model(inp)
|
||||
out3 = model(inp)
|
||||
|
||||
# All should have gradients
|
||||
self.assertTrue(out1.requires_grad)
|
||||
self.assertTrue(out2.requires_grad)
|
||||
self.assertTrue(out3.requires_grad)
|
||||
|
||||
# All should have grad_fn
|
||||
self.assertIsNotNone(out1.grad_fn)
|
||||
self.assertIsNotNone(out2.grad_fn)
|
||||
self.assertIsNotNone(out3.grad_fn)
|
||||
|
||||
# Backward should work on all
|
||||
out1.mean().backward(retain_graph=True)
|
||||
out2.mean().backward(retain_graph=True)
|
||||
out3.mean().backward()
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
def test_autocast_mixed_grad_contexts(self):
|
||||
"""
|
||||
Test complex nesting of gradient contexts within autocast.
|
||||
|
||||
This ensures the gradient mode check works correctly across
|
||||
multiple transitions between gradient-enabled and disabled states.
|
||||
"""
|
||||
model = torch.nn.Linear(2, 2)
|
||||
inp = torch.randn(8, 2)
|
||||
|
||||
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
|
||||
# Pass 1: no_grad
|
||||
with torch.no_grad():
|
||||
out1 = model(inp)
|
||||
self.assertFalse(out1.requires_grad)
|
||||
|
||||
# Pass 2: gradients enabled
|
||||
out2 = model(inp)
|
||||
self.assertTrue(out2.requires_grad)
|
||||
|
||||
# Pass 3: no_grad again
|
||||
with torch.no_grad():
|
||||
out3 = model(inp)
|
||||
self.assertFalse(out3.requires_grad)
|
||||
|
||||
# Pass 4: gradients enabled again
|
||||
out4 = model(inp)
|
||||
self.assertTrue(out4.requires_grad)
|
||||
|
||||
# Backward on gradient-enabled outputs
|
||||
(out2.mean() + out4.mean()).backward()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user