mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable UFMT format on test/test_package.py test/test_per_overload_api.py (#125834)
Fixes some files in https://github.com/pytorch/pytorch/issues/123062 Run lintrunner on files: test/test_package.py test/test_per_overload_api.py ```bash $ lintrunner -a --take UFMT --all-files ok No lint issues. Successfully applied all patches. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/125834 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
ed8a560845
commit
74a0ef8f8c
@ -1114,8 +1114,6 @@ exclude_patterns = [
|
||||
'test/test_optim.py',
|
||||
'test/test_out_dtype_op.py',
|
||||
'test/test_overrides.py',
|
||||
'test/test_package.py',
|
||||
'test/test_per_overload_api.py',
|
||||
'test/test_prims.py',
|
||||
'test/test_proxy_tensor.py',
|
||||
'test/test_pruning_op.py',
|
||||
|
@ -1,25 +1,28 @@
|
||||
# Owner(s): ["oncall: package/deploy"]
|
||||
|
||||
from package.test_resources import TestResources # noqa: F401
|
||||
from package.test_model import ModelTest # noqa: F401
|
||||
from package.package_a.test_all_leaf_modules_tracer import ( # noqa: F401
|
||||
TestAllLeafModulesTracer,
|
||||
)
|
||||
from package.package_a.test_nn_module import TestNnModule # noqa: F401
|
||||
from package.test_analyze import TestAnalyze # noqa: F401
|
||||
from package.test_dependency_api import TestDependencyAPI # noqa: F401
|
||||
from package.test_dependency_hooks import TestDependencyHooks # noqa: F401
|
||||
from package.test_digraph import TestDiGraph # noqa: F401
|
||||
from package.test_directory_reader import DirectoryReaderTest # noqa: F401
|
||||
from package.test_glob_group import TestGlobGroup # noqa: F401
|
||||
from package.test_importer import TestImporter # noqa: F401
|
||||
from package.test_load_bc_packages import TestLoadBCPackages # noqa: F401
|
||||
from package.test_mangling import TestMangling # noqa: F401
|
||||
from package.test_misc import TestMisc # noqa: F401
|
||||
from package.test_directory_reader import DirectoryReaderTest # noqa: F401
|
||||
from package.test_importer import TestImporter # noqa: F401
|
||||
from package.test_glob_group import TestGlobGroup # noqa: F401
|
||||
from package.test_package_script import TestPackageScript # noqa: F401
|
||||
from package.test_save_load import TestSaveLoad # noqa: F401
|
||||
from package.test_repackage import TestRepackage # noqa: F401
|
||||
from package.test_model import ModelTest # noqa: F401
|
||||
from package.test_package_fx import TestPackageFX # noqa: F401
|
||||
from package.test_dependency_hooks import TestDependencyHooks # noqa: F401
|
||||
from package.test_load_bc_packages import TestLoadBCPackages # noqa: F401
|
||||
from package.test_analyze import TestAnalyze # noqa: F401
|
||||
from package.test_digraph import TestDiGraph # noqa: F401
|
||||
from package.package_a.test_all_leaf_modules_tracer import TestAllLeafModulesTracer # noqa: F401
|
||||
from package.package_a.test_nn_module import TestNnModule # noqa: F401
|
||||
from package.test_package_script import TestPackageScript # noqa: F401
|
||||
from package.test_repackage import TestRepackage # noqa: F401
|
||||
from package.test_resources import TestResources # noqa: F401
|
||||
from package.test_save_load import TestSaveLoad # noqa: F401
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
run_tests()
|
||||
|
@ -1,7 +1,9 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
import torch
|
||||
import copy
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class TestPerOverloadAPI(TestCase):
|
||||
def test_basics_opoverloadpacket(self):
|
||||
@ -10,8 +12,8 @@ class TestPerOverloadAPI(TestCase):
|
||||
add_packet = torch.ops.aten.add
|
||||
|
||||
# class attributes
|
||||
self.assertEqual(add_packet.__name__, 'add')
|
||||
self.assertEqual(str(add_packet), 'aten.add')
|
||||
self.assertEqual(add_packet.__name__, "add")
|
||||
self.assertEqual(str(add_packet), "aten.add")
|
||||
|
||||
# callable
|
||||
self.assertEqual(add_packet(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
|
||||
@ -36,8 +38,8 @@ class TestPerOverloadAPI(TestCase):
|
||||
add_tensoroverload = add_packet.Tensor
|
||||
|
||||
# class attributes
|
||||
self.assertEqual(str(add_tensoroverload), 'aten.add.Tensor')
|
||||
self.assertEqual(add_tensoroverload.__name__, 'add.Tensor')
|
||||
self.assertEqual(str(add_tensoroverload), "aten.add.Tensor")
|
||||
self.assertEqual(add_tensoroverload.__name__, "add.Tensor")
|
||||
self.assertEqual(add_tensoroverload.overloadpacket, add_packet)
|
||||
|
||||
# deepcopy is a no-op
|
||||
@ -48,10 +50,14 @@ class TestPerOverloadAPI(TestCase):
|
||||
self.assertEqual(id(add_tensoroverload), id(another_add_tensoroverload))
|
||||
|
||||
# pretty print
|
||||
self.assertEqual(repr(add_tensoroverload), "<OpOverload(op='aten.add', overload='Tensor')>")
|
||||
self.assertEqual(
|
||||
repr(add_tensoroverload), "<OpOverload(op='aten.add', overload='Tensor')>"
|
||||
)
|
||||
|
||||
# callable
|
||||
self.assertEqual(add_tensoroverload(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
|
||||
self.assertEqual(
|
||||
add_tensoroverload(torch.tensor(2), torch.tensor(3)), torch.tensor(5)
|
||||
)
|
||||
|
||||
a = torch.tensor(2)
|
||||
b = torch.tensor(0)
|
||||
@ -65,8 +71,9 @@ class TestPerOverloadAPI(TestCase):
|
||||
y = torch.randn(5, 3)
|
||||
self.assertEqual(
|
||||
torch.ops.aten.linear.default.decompose(x, y),
|
||||
torch.ops.aten.linear.default(x, y)
|
||||
torch.ops.aten.linear.default(x, y),
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user