[dynamo] Clear GenerationTracker on dynamo reset (#125855)

Fixes https://github.com/pytorch/pytorch/issues/125567

Not doing this causes modules to be unspecialized when tests run in sequence, and specialized when run alone.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125855
Approved by: https://github.com/jansel
This commit is contained in:
Animesh Jain
2024-05-09 11:00:44 -07:00
committed by PyTorch MergeBot
parent 52fad83335
commit 477612c0f6
2 changed files with 8 additions and 0 deletions

View File

@ -29,6 +29,7 @@ from .eval_frame import (
reset_code,
)
from .external_utils import is_compiling
from .mutation_guard import GenerationTracker
from .utils import graph_break_reasons, guard_failures, orig_code_map, reset_frame_count
__all__ = [
@ -82,6 +83,7 @@ def reset() -> None:
convert_frame.FRAME_COUNTER = 0
convert_frame.FRAME_COMPILE_COUNTER.clear()
callback_handler.clear()
GenerationTracker.clear()
def reset_code_caches() -> None:

View File

@ -83,6 +83,12 @@ class GenerationTracker:
and cls.generation_values[obj] == cls.generation
)
@classmethod
def clear(cls):
cls.generation = 0
cls.dynamic_classes = ExactWeakKeyDictionary()
cls.generation_values = ExactWeakKeyDictionary()
def is_dynamic_nn_module(obj, is_export):
"""Check for nn.Modules() created dynamically or mutated"""