mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Consistently use load_torchbind_test_lib in tests (#148082)
The same code is repeated multiple times with slightly different implementations. Use the existing function for brevity and consistency. In the function the code from `test_export` is used which does a single `load_library` with cleaner conditions Pull Request resolved: https://github.com/pytorch/pytorch/pull/148082 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
40c2505f16
commit
302c660298
@ -76,6 +76,7 @@ from torch.testing._internal.custom_tensor import (
|
|||||||
CustomTensorPlainOut,
|
CustomTensorPlainOut,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
||||||
|
from torch.testing._internal.torchbind_impls import load_torchbind_test_lib
|
||||||
from torch.testing._internal.triton_utils import requires_cuda, requires_gpu
|
from torch.testing._internal.triton_utils import requires_cuda, requires_gpu
|
||||||
from torch.testing._internal.two_tensor import TwoTensor
|
from torch.testing._internal.two_tensor import TwoTensor
|
||||||
from torch.utils._pytree import (
|
from torch.utils._pytree import (
|
||||||
@ -12662,15 +12663,7 @@ def forward(self, x):
|
|||||||
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
|
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
|
||||||
class TestExportCustomClass(TorchTestCase):
|
class TestExportCustomClass(TorchTestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
if IS_FBCODE:
|
load_torchbind_test_lib()
|
||||||
lib_file_path = "//caffe2/test/cpp/jit:test_custom_class_registrations"
|
|
||||||
elif IS_SANDCASTLE or IS_MACOS:
|
|
||||||
raise unittest.SkipTest("non-portable load_library call used in test")
|
|
||||||
elif IS_WINDOWS:
|
|
||||||
lib_file_path = find_library_location("torchbind_test.dll")
|
|
||||||
else:
|
|
||||||
lib_file_path = find_library_location("libtorchbind_test.so")
|
|
||||||
torch.ops.load_library(str(lib_file_path))
|
|
||||||
|
|
||||||
def test_lift_custom_obj(self):
|
def test_lift_custom_obj(self):
|
||||||
# TODO: fix this test once custom class tracing is implemented
|
# TODO: fix this test once custom class tracing is implemented
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
# Owner(s): ["oncall: export"]
|
# Owner(s): ["oncall: export"]
|
||||||
import unittest
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
@ -18,15 +17,8 @@ from torch.export.exported_program import (
|
|||||||
TensorArgument,
|
TensorArgument,
|
||||||
)
|
)
|
||||||
from torch.export.graph_signature import CustomObjArgument
|
from torch.export.graph_signature import CustomObjArgument
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||||
find_library_location,
|
from torch.testing._internal.torchbind_impls import load_torchbind_test_lib
|
||||||
IS_FBCODE,
|
|
||||||
IS_MACOS,
|
|
||||||
IS_SANDCASTLE,
|
|
||||||
IS_WINDOWS,
|
|
||||||
run_tests,
|
|
||||||
TestCase,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GraphBuilder:
|
class GraphBuilder:
|
||||||
@ -146,18 +138,7 @@ class GraphBuilder:
|
|||||||
|
|
||||||
class TestLift(TestCase):
|
class TestLift(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
if IS_MACOS:
|
load_torchbind_test_lib()
|
||||||
raise unittest.SkipTest("non-portable load_library call used in test")
|
|
||||||
elif IS_SANDCASTLE or IS_FBCODE:
|
|
||||||
torch.ops.load_library(
|
|
||||||
"//caffe2/test/cpp/jit:test_custom_class_registrations"
|
|
||||||
)
|
|
||||||
elif IS_WINDOWS:
|
|
||||||
lib_file_path = find_library_location("torchbind_test.dll")
|
|
||||||
torch.ops.load_library(str(lib_file_path))
|
|
||||||
else:
|
|
||||||
lib_file_path = find_library_location("libtorchbind_test.so")
|
|
||||||
torch.ops.load_library(str(lib_file_path))
|
|
||||||
|
|
||||||
def test_lift_basic(self):
|
def test_lift_basic(self):
|
||||||
builder = GraphBuilder()
|
builder = GraphBuilder()
|
||||||
@ -379,18 +360,7 @@ class TestLift(TestCase):
|
|||||||
|
|
||||||
class ConstantAttrMapTest(TestCase):
|
class ConstantAttrMapTest(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
if IS_MACOS:
|
load_torchbind_test_lib()
|
||||||
raise unittest.SkipTest("non-portable load_library call used in test")
|
|
||||||
elif IS_SANDCASTLE or IS_FBCODE:
|
|
||||||
torch.ops.load_library(
|
|
||||||
"//caffe2/test/cpp/jit:test_custom_class_registrations"
|
|
||||||
)
|
|
||||||
elif IS_WINDOWS:
|
|
||||||
lib_file_path = find_library_location("torchbind_test.dll")
|
|
||||||
torch.ops.load_library(str(lib_file_path))
|
|
||||||
else:
|
|
||||||
lib_file_path = find_library_location("libtorchbind_test.so")
|
|
||||||
torch.ops.load_library(str(lib_file_path))
|
|
||||||
|
|
||||||
def test_dict_api(self):
|
def test_dict_api(self):
|
||||||
constant_attr_map = ConstantAttrMap()
|
constant_attr_map = ConstantAttrMap()
|
||||||
|
|||||||
@ -5,7 +5,6 @@ import copy
|
|||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -16,14 +15,8 @@ from torch.testing._internal.common_utils import skipIfTorchDynamo
|
|||||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||||
sys.path.append(pytorch_test_dir)
|
sys.path.append(pytorch_test_dir)
|
||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
from torch.testing._internal.common_utils import (
|
|
||||||
find_library_location,
|
|
||||||
IS_FBCODE,
|
|
||||||
IS_MACOS,
|
|
||||||
IS_SANDCASTLE,
|
|
||||||
IS_WINDOWS,
|
|
||||||
)
|
|
||||||
from torch.testing._internal.jit_utils import JitTestCase
|
from torch.testing._internal.jit_utils import JitTestCase
|
||||||
|
from torch.testing._internal.torchbind_impls import load_torchbind_test_lib
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -37,12 +30,7 @@ if __name__ == "__main__":
|
|||||||
@skipIfTorchDynamo("skipping as a precaution")
|
@skipIfTorchDynamo("skipping as a precaution")
|
||||||
class TestTorchbind(JitTestCase):
|
class TestTorchbind(JitTestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
if IS_SANDCASTLE or IS_MACOS or IS_FBCODE:
|
load_torchbind_test_lib()
|
||||||
raise unittest.SkipTest("non-portable load_library call used in test")
|
|
||||||
lib_file_path = find_library_location("libtorchbind_test.so")
|
|
||||||
if IS_WINDOWS:
|
|
||||||
lib_file_path = find_library_location("torchbind_test.dll")
|
|
||||||
torch.ops.load_library(str(lib_file_path))
|
|
||||||
|
|
||||||
def test_torchbind(self):
|
def test_torchbind(self):
|
||||||
def test_equality(f, cmp_key):
|
def test_equality(f, cmp_key):
|
||||||
|
|||||||
@ -7,15 +7,8 @@ import threading
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import IS_MACOS, run_tests, TestCase
|
||||||
find_library_location,
|
from torch.testing._internal.torchbind_impls import load_torchbind_test_lib
|
||||||
IS_FBCODE,
|
|
||||||
IS_MACOS,
|
|
||||||
IS_SANDCASTLE,
|
|
||||||
IS_WINDOWS,
|
|
||||||
run_tests,
|
|
||||||
TestCase,
|
|
||||||
)
|
|
||||||
from torch.utils.weak import _WeakHashRef, WeakIdKeyDictionary
|
from torch.utils.weak import _WeakHashRef, WeakIdKeyDictionary
|
||||||
|
|
||||||
|
|
||||||
@ -594,18 +587,10 @@ class WeakKeyDictionaryScriptObjectTestCase(TestCase):
|
|||||||
|
|
||||||
def __init__(self, *args, **kw):
|
def __init__(self, *args, **kw):
|
||||||
unittest.TestCase.__init__(self, *args, **kw)
|
unittest.TestCase.__init__(self, *args, **kw)
|
||||||
if IS_SANDCASTLE or IS_FBCODE:
|
try:
|
||||||
torch.ops.load_library(
|
load_torchbind_test_lib()
|
||||||
"//caffe2/test/cpp/jit:test_custom_class_registrations"
|
except unittest.SkipTest:
|
||||||
)
|
return # Skip in setup
|
||||||
elif IS_MACOS:
|
|
||||||
# don't load the library, just skip the tests in setUp
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
lib_file_path = find_library_location("libtorchbind_test.so")
|
|
||||||
if IS_WINDOWS:
|
|
||||||
lib_file_path = find_library_location("torchbind_test.dll")
|
|
||||||
torch.ops.load_library(str(lib_file_path))
|
|
||||||
|
|
||||||
self.reference = self._reference().copy()
|
self.reference = self._reference().copy()
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import contextlib
|
import contextlib
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -113,15 +114,15 @@ def load_torchbind_test_lib():
|
|||||||
IS_WINDOWS,
|
IS_WINDOWS,
|
||||||
)
|
)
|
||||||
|
|
||||||
if IS_SANDCASTLE or IS_FBCODE:
|
if IS_MACOS:
|
||||||
torch.ops.load_library("//caffe2/test/cpp/jit:test_custom_class_registrations")
|
|
||||||
elif IS_MACOS:
|
|
||||||
raise unittest.SkipTest("non-portable load_library call used in test")
|
raise unittest.SkipTest("non-portable load_library call used in test")
|
||||||
|
elif IS_SANDCASTLE or IS_FBCODE:
|
||||||
|
lib_file_path = Path("//caffe2/test/cpp/jit:test_custom_class_registrations")
|
||||||
|
elif IS_WINDOWS:
|
||||||
|
lib_file_path = find_library_location("torchbind_test.dll")
|
||||||
else:
|
else:
|
||||||
lib_file_path = find_library_location("libtorchbind_test.so")
|
lib_file_path = find_library_location("libtorchbind_test.so")
|
||||||
if IS_WINDOWS:
|
torch.ops.load_library(str(lib_file_path))
|
||||||
lib_file_path = find_library_location("torchbind_test.dll")
|
|
||||||
torch.ops.load_library(str(lib_file_path))
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
|
|||||||
Reference in New Issue
Block a user