Reland: Remove remaining global set_default_dtype calls from tests (#108088)

Fixes #68972

Relands #107246

To avoid causing Meta-internal CI failures, this PR avoids always asserting that the default dtype is float in the `TestCase.setUp/tearDown` methods. Instead, the assert is only done if `TestCase._default_dtype_check_enabled == True`. `_default_dtype_check_enabled` is set to True in the `if __name__ == "__main__":` blocks of all the relevant test files that have required changes for this issue

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108088
Approved by: https://github.com/ezyang
This commit is contained in:
Kurt Mohler
2023-09-07 03:04:34 +00:00
committed by PyTorch MergeBot
parent 54e73271c7
commit 3f88e3105f
20 changed files with 921 additions and 878 deletions

View File

@ -1,8 +1,5 @@
# Owner(s): ["module: cpp"]
import torch
# NN tests use double as the default dtype
torch.set_default_dtype(torch.double)
import os
@ -59,4 +56,5 @@ if not common.IS_ARM64:
functional_impl_check.build_cpp_tests(TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE)
if __name__ == "__main__":
common.TestCase._default_dtype_check_enabled = True
common.run_tests()