mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fixes #158232 The autocast caching heuristic in `aten/src/ATen/autocast_mode.cpp:139` did not account for gradient mode state when deciding whether to cache. FSDP2 is not directly related. ~~This PR adds `GradMode::is_enabled()` check to caching condition. Caching is now disabled in `no_grad()` contexts to prevent storing tensors with incorrect gradient state. Ensures correctness at the cost of using cache.~~ This PR proposes separate caches for gradient-enabled and gradient-disabled modes. Adds tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165068 Approved by: https://github.com/ngimel, https://github.com/janeyx99