Revert "158232 Fix autocast cache incorrectly retaining no_grad state (#165068)"

This reverts commit 5daef30b26b794d237fbbc399c1d47ec0380200a.

Reverted https://github.com/pytorch/pytorch/pull/165068 on behalf of https://github.com/jeffdaily due to This broke ROCm CI. test/test_transformers.py::TestTransformersCUDA::test_transformerencoder_fastpath_use_torchscript_False_enable_nested_tensor_True_use_autocast_True_d_model_256_cuda [GH job link](https://github.com/pytorch/pytorch/actions/runs/18572589089/job/52952074008) [HUD commit link](5daef30b26) ([comment](https://github.com/pytorch/pytorch/pull/165068#issuecomment-3413184445))
This commit is contained in:
PyTorch MergeBot
2025-10-16 23:08:26 +00:00
parent 98a488c9aa
commit d2c82bafb7
2 changed files with 4 additions and 168 deletions

View File

@ -2,7 +2,6 @@
#include <mutex>
#include <ATen/CachedTensorUtils.h>
#include <c10/core/GradMode.h>
#include <c10/util/flat_hash_map.h>
namespace at::autocast {
@ -37,29 +36,10 @@ namespace {
using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
using val_type = std::tuple<weakref_type, Tensor>;
// We maintain separate caches for gradient-enabled and gradient-disabled modes.
// This ensures that tensors cached in torch.no_grad() (with requires_grad=False)
// are not incorrectly reused in gradient-enabled contexts.
// This fixes issue #158232 while maintaining optimal performance for both modes.
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts_grad_enabled() {
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts_grad_enabled;
return cached_casts_grad_enabled;
ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts;
return cached_casts;
}
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts_grad_disabled() {
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts_grad_disabled;
return cached_casts_grad_disabled;
}
// Helper function to get the appropriate cache based on current gradient mode.
// This allows us to cache tensors separately for grad-enabled and grad-disabled contexts,
// preventing incorrect cache hits when gradient mode changes.
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
return at::GradMode::is_enabled() ?
get_cached_casts_grad_enabled() :
get_cached_casts_grad_disabled();
}
std::mutex cached_casts_mutex;
@ -106,9 +86,7 @@ thread_local bool cache_enabled = true;
void clear_cache() {
const std::lock_guard<std::mutex> lock(cached_casts_mutex);
// Clear both caches to ensure consistent behavior regardless of current gradient mode
get_cached_casts_grad_enabled().clear();
get_cached_casts_grad_disabled().clear();
get_cached_casts().clear();
}
int increment_nesting() {
@ -143,11 +121,6 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_
if (is_eligible(arg, device_type) && (arg.scalar_type() != to_type)) {
// Heuristic: Do what Apex does, and cache lower_precision_fp casts of fp32 model weights (leaves).
// See cached_casts declaration above for detailed strategy.
//
// We maintain separate caches for gradient-enabled and gradient-disabled modes
// (see get_cached_casts() above). This ensures correctness when mixing torch.no_grad()
// with torch.autocast(), while maintaining optimal performance for both training and inference.
// This fixes issue #158232 without any performance regression.
bool can_try_cache = (to_type == get_lower_precision_fp_from_device_type(device_type) &&
arg.scalar_type() == at::kFloat && arg.requires_grad() &&
arg.is_leaf() && !arg.is_view() && cache_enabled &&

View File

@ -384,143 +384,6 @@ class TestTorchAutocast(TestCase):
with self.assertRaisesRegex(expected_exception=ValueError, expected_regex=msg):
torch.autocast(device_type=dev)
@skipIfTorchDynamo()
def test_autocast_nograd_caching_issue_158232(self):
"""
Regression test for issue #158232: autocast + no_grad incompatibility
When torch.no_grad() is nested inside torch.autocast(), the autocast cache
must not cache tensors created in the no_grad context, because they lack
gradient tracking. If cached, subsequent operations in gradient-enabled mode
would incorrectly use the no-gradient cached version.
Before fix: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
After fix: Should work correctly
"""
model = torch.nn.Linear(2, 2)
inp = torch.randn(8, 2)
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
# First forward pass in no_grad context (e.g., shape inference)
with torch.no_grad():
out1 = model(inp)
self.assertFalse(
out1.requires_grad, "Output in no_grad should not require grad"
)
# Second forward pass with gradients enabled (e.g., training)
out2 = model(inp)
self.assertTrue(
out2.requires_grad,
"Output should require gradients after exiting no_grad",
)
self.assertIsNotNone(
out2.grad_fn, "Output should have grad_fn after exiting no_grad"
)
# Backward pass should work
loss = out2.mean()
loss.backward()
# Verify gradients were computed
self.assertIsNotNone(model.weight.grad)
self.assertIsNotNone(model.bias.grad)
@skipIfTorchDynamo()
def test_autocast_inference_mode_interaction(self):
"""
Test that autocast works correctly with torch.inference_mode()
InferenceMode is a stricter version of no_grad that provides additional
performance optimizations. Verify it doesn't break with autocast.
"""
model = torch.nn.Linear(2, 2)
inp = torch.randn(8, 2)
# Test 1: inference_mode inside autocast
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
with torch.inference_mode():
out1 = model(inp)
self.assertFalse(out1.requires_grad)
self.assertEqual(out1.dtype, torch.bfloat16)
# After exiting inference_mode, gradients should work
out2 = model(inp)
self.assertTrue(out2.requires_grad)
out2.mean().backward()
# Test 2: autocast inside inference_mode
with torch.inference_mode():
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
out = model(inp)
self.assertFalse(out.requires_grad)
self.assertEqual(out.dtype, torch.bfloat16)
def test_autocast_caching_still_works_with_gradients(self):
"""
Verify that autocast caching still functions correctly when gradients ARE enabled.
This test ensures the fix for #158232 didn't break normal caching behavior.
We can't directly observe cache hits, but we verify that repeated operations
with gradients enabled work correctly.
"""
model = torch.nn.Linear(2, 2)
inp = torch.randn(8, 2)
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
# Multiple forward passes with gradients enabled
out1 = model(inp)
out2 = model(inp)
out3 = model(inp)
# All should have gradients
self.assertTrue(out1.requires_grad)
self.assertTrue(out2.requires_grad)
self.assertTrue(out3.requires_grad)
# All should have grad_fn
self.assertIsNotNone(out1.grad_fn)
self.assertIsNotNone(out2.grad_fn)
self.assertIsNotNone(out3.grad_fn)
# Backward should work on all
out1.mean().backward(retain_graph=True)
out2.mean().backward(retain_graph=True)
out3.mean().backward()
@skipIfTorchDynamo()
def test_autocast_mixed_grad_contexts(self):
"""
Test complex nesting of gradient contexts within autocast.
This ensures the gradient mode check works correctly across
multiple transitions between gradient-enabled and disabled states.
"""
model = torch.nn.Linear(2, 2)
inp = torch.randn(8, 2)
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
# Pass 1: no_grad
with torch.no_grad():
out1 = model(inp)
self.assertFalse(out1.requires_grad)
# Pass 2: gradients enabled
out2 = model(inp)
self.assertTrue(out2.requires_grad)
# Pass 3: no_grad again
with torch.no_grad():
out3 = model(inp)
self.assertFalse(out3.requires_grad)
# Pass 4: gradients enabled again
out4 = model(inp)
self.assertTrue(out4.requires_grad)
# Backward on gradient-enabled outputs
(out2.mean() + out4.mean()).backward()
if __name__ == "__main__":
run_tests()