From 726fc366a2880b8d4085d02f9a84390195e3b60e Mon Sep 17 00:00:00 2001 From: Yuxin Wu Date: Wed, 22 Mar 2023 19:09:04 +0000 Subject: [PATCH] Add missing __main__ in two unittests (#97302) Pull Request resolved: https://github.com/pytorch/pytorch/pull/97302 Approved by: https://github.com/zou3519 --- test/test_comparison_utils.py | 6 +++++- test/test_pruning_op.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/test/test_comparison_utils.py b/test/test_comparison_utils.py index fccc217bb7b2..172e2c409293 100644 --- a/test/test_comparison_utils.py +++ b/test/test_comparison_utils.py @@ -2,7 +2,7 @@ # Owner(s): ["module: internals"] import torch -from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_utils import TestCase, run_tests class TestComparisonUtils(TestCase): def test_all_equal_no_assert(self): @@ -30,3 +30,7 @@ class TestComparisonUtils(TestCase): with self.assertRaises(RuntimeError): torch._assert_tensor_metadata(t, [3], [1], torch.float) + + +if __name__ == '__main__': + run_tests() diff --git a/test/test_pruning_op.py b/test/test_pruning_op.py index 88e5a4e57be5..ef28381c1904 100644 --- a/test/test_pruning_op.py +++ b/test/test_pruning_op.py @@ -4,7 +4,7 @@ import hypothesis.strategies as st from hypothesis import given import numpy as np import torch -from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_utils import TestCase, run_tests import torch.testing._internal.hypothesis_utils as hu hu.assert_deadline_disabled() @@ -76,3 +76,7 @@ class PruningOpTest(TestCase): ) def test_rowwise_prune_op_64bit_indices(self, embedding_rows, embedding_dims, weights_dtype): self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int64, weights_dtype) + + +if __name__ == '__main__': + run_tests()