mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Improve FakeTensor cache to handle SymNode and tracing properly. For now, when we're proxy tracing just don't bother caching operations that contain SymNodes in the output. The problem is that the proxy tracer relies on SymNode identity and our cache doesn't preserve that. It can be fixed (and I left some notes in _validate_symbolic_output_for_caching() how) but it's not worth it for now. If we aren't proxy tracing then caching is fine. Thus these changes: 1. Our cache key needs to include whether we were actively tracing or not - this way if we create a cache entry when we weren't tracing and then we try to use it when we ARE tracing it gets rerun. 2. If there's a SymNode in the output then bypass tracing. 3. Some general cleanup of the output validation - we were unnecessarily doing it as a two-step process when it could just be a single step (it's still two parts internally but only a single outer try/except). Pull Request resolved: https://github.com/pytorch/pytorch/pull/164718 Approved by: https://github.com/bobrenjc93 ghstack dependencies: #165266, #164717