Add torch.utils.deterministic.fill_uninitialized_memory flag (#111377)

Part of #109802

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111377
Approved by: https://github.com/albanD, https://github.com/aaronenyeshi
This commit is contained in:
Kurt Mohler
2023-11-01 16:10:09 +00:00
committed by PyTorch MergeBot
parent cce5016653
commit fd209543d5
21 changed files with 193 additions and 46 deletions

View File

@ -1458,21 +1458,25 @@ def setLinalgBackendsToDefaultFinally(fn):
# Context manager for setting deterministic flag and automatically
# resetting it to its original value
class DeterministicGuard:
def __init__(self, deterministic, *, warn_only=False):
def __init__(self, deterministic, *, warn_only=False, fill_uninitialized_memory=True):
self.deterministic = deterministic
self.warn_only = warn_only
self.fill_uninitialized_memory = fill_uninitialized_memory
def __enter__(self):
self.deterministic_restore = torch.are_deterministic_algorithms_enabled()
self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled()
self.fill_uninitialized_memory_restore = torch.utils.deterministic.fill_uninitialized_memory
torch.use_deterministic_algorithms(
self.deterministic,
warn_only=self.warn_only)
torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory
def __exit__(self, exception_type, exception_value, traceback):
torch.use_deterministic_algorithms(
self.deterministic_restore,
warn_only=self.warn_only_restore)
torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory_restore
class AlwaysWarnTypedStorageRemoval:
def __init__(self, always_warn):