mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Consistently use load_torchbind_test_lib in tests (#148082)
The same code is repeated multiple times with slightly different implementations. Use the existing function for brevity and consistency. In the function the code from `test_export` is used which does a single `load_library` with cleaner conditions Pull Request resolved: https://github.com/pytorch/pytorch/pull/148082 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
40c2505f16
commit
302c660298
@ -76,6 +76,7 @@ from torch.testing._internal.custom_tensor import (
|
||||
CustomTensorPlainOut,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
||||
from torch.testing._internal.torchbind_impls import load_torchbind_test_lib
|
||||
from torch.testing._internal.triton_utils import requires_cuda, requires_gpu
|
||||
from torch.testing._internal.two_tensor import TwoTensor
|
||||
from torch.utils._pytree import (
|
||||
@ -12662,15 +12663,7 @@ def forward(self, x):
|
||||
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
|
||||
class TestExportCustomClass(TorchTestCase):
|
||||
def setUp(self):
|
||||
if IS_FBCODE:
|
||||
lib_file_path = "//caffe2/test/cpp/jit:test_custom_class_registrations"
|
||||
elif IS_SANDCASTLE or IS_MACOS:
|
||||
raise unittest.SkipTest("non-portable load_library call used in test")
|
||||
elif IS_WINDOWS:
|
||||
lib_file_path = find_library_location("torchbind_test.dll")
|
||||
else:
|
||||
lib_file_path = find_library_location("libtorchbind_test.so")
|
||||
torch.ops.load_library(str(lib_file_path))
|
||||
load_torchbind_test_lib()
|
||||
|
||||
def test_lift_custom_obj(self):
|
||||
# TODO: fix this test once custom class tracing is implemented
|
||||
|
@ -1,5 +1,4 @@
|
||||
# Owner(s): ["oncall: export"]
|
||||
import unittest
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -18,15 +17,8 @@ from torch.export.exported_program import (
|
||||
TensorArgument,
|
||||
)
|
||||
from torch.export.graph_signature import CustomObjArgument
|
||||
from torch.testing._internal.common_utils import (
|
||||
find_library_location,
|
||||
IS_FBCODE,
|
||||
IS_MACOS,
|
||||
IS_SANDCASTLE,
|
||||
IS_WINDOWS,
|
||||
run_tests,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.testing._internal.torchbind_impls import load_torchbind_test_lib
|
||||
|
||||
|
||||
class GraphBuilder:
|
||||
@ -146,18 +138,7 @@ class GraphBuilder:
|
||||
|
||||
class TestLift(TestCase):
|
||||
def setUp(self):
|
||||
if IS_MACOS:
|
||||
raise unittest.SkipTest("non-portable load_library call used in test")
|
||||
elif IS_SANDCASTLE or IS_FBCODE:
|
||||
torch.ops.load_library(
|
||||
"//caffe2/test/cpp/jit:test_custom_class_registrations"
|
||||
)
|
||||
elif IS_WINDOWS:
|
||||
lib_file_path = find_library_location("torchbind_test.dll")
|
||||
torch.ops.load_library(str(lib_file_path))
|
||||
else:
|
||||
lib_file_path = find_library_location("libtorchbind_test.so")
|
||||
torch.ops.load_library(str(lib_file_path))
|
||||
load_torchbind_test_lib()
|
||||
|
||||
def test_lift_basic(self):
|
||||
builder = GraphBuilder()
|
||||
@ -379,18 +360,7 @@ class TestLift(TestCase):
|
||||
|
||||
class ConstantAttrMapTest(TestCase):
|
||||
def setUp(self):
|
||||
if IS_MACOS:
|
||||
raise unittest.SkipTest("non-portable load_library call used in test")
|
||||
elif IS_SANDCASTLE or IS_FBCODE:
|
||||
torch.ops.load_library(
|
||||
"//caffe2/test/cpp/jit:test_custom_class_registrations"
|
||||
)
|
||||
elif IS_WINDOWS:
|
||||
lib_file_path = find_library_location("torchbind_test.dll")
|
||||
torch.ops.load_library(str(lib_file_path))
|
||||
else:
|
||||
lib_file_path = find_library_location("libtorchbind_test.so")
|
||||
torch.ops.load_library(str(lib_file_path))
|
||||
load_torchbind_test_lib()
|
||||
|
||||
def test_dict_api(self):
|
||||
constant_attr_map = ConstantAttrMap()
|
||||
|
@ -5,7 +5,6 @@ import copy
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -16,14 +15,8 @@ from torch.testing._internal.common_utils import skipIfTorchDynamo
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import (
|
||||
find_library_location,
|
||||
IS_FBCODE,
|
||||
IS_MACOS,
|
||||
IS_SANDCASTLE,
|
||||
IS_WINDOWS,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.torchbind_impls import load_torchbind_test_lib
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -37,12 +30,7 @@ if __name__ == "__main__":
|
||||
@skipIfTorchDynamo("skipping as a precaution")
|
||||
class TestTorchbind(JitTestCase):
|
||||
def setUp(self):
|
||||
if IS_SANDCASTLE or IS_MACOS or IS_FBCODE:
|
||||
raise unittest.SkipTest("non-portable load_library call used in test")
|
||||
lib_file_path = find_library_location("libtorchbind_test.so")
|
||||
if IS_WINDOWS:
|
||||
lib_file_path = find_library_location("torchbind_test.dll")
|
||||
torch.ops.load_library(str(lib_file_path))
|
||||
load_torchbind_test_lib()
|
||||
|
||||
def test_torchbind(self):
|
||||
def test_equality(f, cmp_key):
|
||||
|
@ -7,15 +7,8 @@ import threading
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import (
|
||||
find_library_location,
|
||||
IS_FBCODE,
|
||||
IS_MACOS,
|
||||
IS_SANDCASTLE,
|
||||
IS_WINDOWS,
|
||||
run_tests,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.common_utils import IS_MACOS, run_tests, TestCase
|
||||
from torch.testing._internal.torchbind_impls import load_torchbind_test_lib
|
||||
from torch.utils.weak import _WeakHashRef, WeakIdKeyDictionary
|
||||
|
||||
|
||||
@ -594,18 +587,10 @@ class WeakKeyDictionaryScriptObjectTestCase(TestCase):
|
||||
|
||||
def __init__(self, *args, **kw):
|
||||
unittest.TestCase.__init__(self, *args, **kw)
|
||||
if IS_SANDCASTLE or IS_FBCODE:
|
||||
torch.ops.load_library(
|
||||
"//caffe2/test/cpp/jit:test_custom_class_registrations"
|
||||
)
|
||||
elif IS_MACOS:
|
||||
# don't load the library, just skip the tests in setUp
|
||||
return
|
||||
else:
|
||||
lib_file_path = find_library_location("libtorchbind_test.so")
|
||||
if IS_WINDOWS:
|
||||
lib_file_path = find_library_location("torchbind_test.dll")
|
||||
torch.ops.load_library(str(lib_file_path))
|
||||
try:
|
||||
load_torchbind_test_lib()
|
||||
except unittest.SkipTest:
|
||||
return # Skip in setup
|
||||
|
||||
self.reference = self._reference().copy()
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -113,15 +114,15 @@ def load_torchbind_test_lib():
|
||||
IS_WINDOWS,
|
||||
)
|
||||
|
||||
if IS_SANDCASTLE or IS_FBCODE:
|
||||
torch.ops.load_library("//caffe2/test/cpp/jit:test_custom_class_registrations")
|
||||
elif IS_MACOS:
|
||||
if IS_MACOS:
|
||||
raise unittest.SkipTest("non-portable load_library call used in test")
|
||||
elif IS_SANDCASTLE or IS_FBCODE:
|
||||
lib_file_path = Path("//caffe2/test/cpp/jit:test_custom_class_registrations")
|
||||
elif IS_WINDOWS:
|
||||
lib_file_path = find_library_location("torchbind_test.dll")
|
||||
else:
|
||||
lib_file_path = find_library_location("libtorchbind_test.so")
|
||||
if IS_WINDOWS:
|
||||
lib_file_path = find_library_location("torchbind_test.dll")
|
||||
torch.ops.load_library(str(lib_file_path))
|
||||
torch.ops.load_library(str(lib_file_path))
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
Reference in New Issue
Block a user