mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add __main__ guards to tests (#154716)
This PR is part of a series attempting to re-submit https://github.com/pytorch/pytorch/pull/134592 as smaller PRs. Add missing `if __name__ == "__main__":` guards to some tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154716 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
ca0c2985d3
commit
cf9cad31df
@ -16,7 +16,7 @@ from torch._dynamo.testing import same
|
||||
from torch._inductor import config
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import IS_FBCODE
|
||||
from torch.testing._internal.common_utils import IS_FBCODE, run_tests
|
||||
from torch.testing._internal.inductor_utils import clone_preserve_strides_offset
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
@ -205,11 +205,14 @@ def check_model(
|
||||
atol=None,
|
||||
rtol=None,
|
||||
):
|
||||
with torch.no_grad(), config.patch(
|
||||
{
|
||||
"aot_inductor.allow_stack_allocation": self.allow_stack_allocation,
|
||||
"aot_inductor.use_minimal_arrayref_interface": self.use_minimal_arrayref_interface,
|
||||
}
|
||||
with (
|
||||
torch.no_grad(),
|
||||
config.patch(
|
||||
{
|
||||
"aot_inductor.allow_stack_allocation": self.allow_stack_allocation,
|
||||
"aot_inductor.use_minimal_arrayref_interface": self.use_minimal_arrayref_interface,
|
||||
}
|
||||
),
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
if not isinstance(model, types.FunctionType):
|
||||
@ -248,11 +251,14 @@ def check_model_with_multiple_inputs(
|
||||
options=None,
|
||||
dynamic_shapes=None,
|
||||
):
|
||||
with torch.no_grad(), config.patch(
|
||||
{
|
||||
"aot_inductor.allow_stack_allocation": self.allow_stack_allocation,
|
||||
"aot_inductor.use_minimal_arrayref_interface": self.use_minimal_arrayref_interface,
|
||||
}
|
||||
with (
|
||||
torch.no_grad(),
|
||||
config.patch(
|
||||
{
|
||||
"aot_inductor.allow_stack_allocation": self.allow_stack_allocation,
|
||||
"aot_inductor.use_minimal_arrayref_interface": self.use_minimal_arrayref_interface,
|
||||
}
|
||||
),
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
model = model.to(self.device)
|
||||
@ -275,11 +281,14 @@ def code_check_count(
|
||||
target_str: str,
|
||||
target_count: int,
|
||||
):
|
||||
with torch.no_grad(), config.patch(
|
||||
{
|
||||
"aot_inductor.allow_stack_allocation": self.allow_stack_allocation,
|
||||
"aot_inductor.use_minimal_arrayref_interface": self.use_minimal_arrayref_interface,
|
||||
}
|
||||
with (
|
||||
torch.no_grad(),
|
||||
config.patch(
|
||||
{
|
||||
"aot_inductor.allow_stack_allocation": self.allow_stack_allocation,
|
||||
"aot_inductor.use_minimal_arrayref_interface": self.use_minimal_arrayref_interface,
|
||||
}
|
||||
),
|
||||
):
|
||||
package_path = torch._export.aot_compile(model, example_inputs)
|
||||
|
||||
@ -290,3 +299,7 @@ def code_check_count(
|
||||
target_count,
|
||||
exactly=True,
|
||||
).run(src_code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -1,8 +1,13 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch._lazy.metrics
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
|
||||
def test_metrics():
|
||||
names = torch._lazy.metrics.counter_names()
|
||||
assert len(names) == 0, f"Expected no counter names, but got {names}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -206,3 +206,10 @@ class OptimizeTest(unittest.TestCase):
|
||||
test_return_multi = maketest(ModuleReturnMulti)
|
||||
test_return_dup_tensor = maketest(ModuleReturnDupTensor)
|
||||
test_inplace_update = maketest(ModuleInplaceUpdate)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test is not currently used and should be "
|
||||
"enabled in discover_tests.py if required."
|
||||
)
|
||||
|
@ -37,3 +37,10 @@ class TestMetaKernel(TestCase):
|
||||
def test_add_invalid_device(self):
|
||||
with self.assertRaisesRegex(RuntimeError, ".*not a lazy tensor.*"):
|
||||
_ = torch.tensor([1], device="cpu") + torch.tensor([1], device="lazy")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test is not currently used and should be "
|
||||
"enabled in discover_tests.py if required."
|
||||
)
|
||||
|
@ -10,6 +10,7 @@ import torchvision
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
def _get_test_image_tensor():
|
||||
@ -95,3 +96,7 @@ class TestQuantizedModelsONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
pretrained=True, quantize=True
|
||||
)
|
||||
self.run_test(model, _get_test_image_tensor())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
||||
|
@ -160,3 +160,10 @@ class TestONNXScriptExport(common_utils.TestCase):
|
||||
)
|
||||
loop_selu_proto = onnx.load(io.BytesIO(saved_model.getvalue()))
|
||||
self.assertEqual(len(loop_selu_proto.functions), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test is not currently used and should be "
|
||||
"enabled in discover_tests.py if required."
|
||||
)
|
||||
|
@ -92,3 +92,10 @@ class TestBundledImages(TestCase):
|
||||
im2_tensor = torch.ops.fb.image_decode_to_NCHW(byte_tensor, weight, bias)
|
||||
self.assertEqual(raw_data.shape, im2_tensor.shape)
|
||||
self.assertEqual(raw_data, im2_tensor, atol=0.1, rtol=1e-01)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test is not currently used and should be "
|
||||
"enabled in discover_tests.py if required."
|
||||
)
|
||||
|
@ -8,7 +8,12 @@ from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.hub as hub
|
||||
from torch.testing._internal.common_utils import IS_SANDCASTLE, retry, TestCase
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_SANDCASTLE,
|
||||
retry,
|
||||
run_tests,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
def sum_of_state_dict(state_dict):
|
||||
@ -307,3 +312,7 @@ class TestHub(TestCase):
|
||||
torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check")
|
||||
|
||||
self._assert_trusted_list_is_empty()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -104,7 +104,10 @@ TESTS = discover_tests(
|
||||
"distributed/test_c10d_spawn",
|
||||
"distributions/test_transforms",
|
||||
"distributions/test_utils",
|
||||
"lazy/test_meta_kernel",
|
||||
"lazy/test_extract_compiled_graph",
|
||||
"test/inductor/test_aot_inductor_utils",
|
||||
"onnx/test_onnxscript_no_runtime",
|
||||
"onnx/test_pytorch_onnx_onnxruntime_cuda",
|
||||
"onnx/test_models",
|
||||
# These are not C++ tests
|
||||
|
Reference in New Issue
Block a user