Consistently use load_torchbind_test_lib in tests (#148082)

The same code is repeated multiple times with slightly different implementations.
Use the existing function for brevity and consistency.

In the function the code from `test_export` is used which does a single `load_library` with cleaner conditions

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148082
Approved by: https://github.com/angelayi
This commit is contained in:
Alexander Grund
2025-03-03 19:37:24 +00:00
committed by PyTorch MergeBot
parent 40c2505f16
commit 302c660298
5 changed files with 21 additions and 84 deletions

View File

@ -76,6 +76,7 @@ from torch.testing._internal.custom_tensor import (
CustomTensorPlainOut,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
from torch.testing._internal.torchbind_impls import load_torchbind_test_lib
from torch.testing._internal.triton_utils import requires_cuda, requires_gpu
from torch.testing._internal.two_tensor import TwoTensor
from torch.utils._pytree import (
@ -12662,15 +12663,7 @@ def forward(self, x):
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
class TestExportCustomClass(TorchTestCase):
def setUp(self):
if IS_FBCODE:
lib_file_path = "//caffe2/test/cpp/jit:test_custom_class_registrations"
elif IS_SANDCASTLE or IS_MACOS:
raise unittest.SkipTest("non-portable load_library call used in test")
elif IS_WINDOWS:
lib_file_path = find_library_location("torchbind_test.dll")
else:
lib_file_path = find_library_location("libtorchbind_test.so")
torch.ops.load_library(str(lib_file_path))
load_torchbind_test_lib()
def test_lift_custom_obj(self):
# TODO: fix this test once custom class tracing is implemented

View File

@ -1,5 +1,4 @@
# Owner(s): ["oncall: export"]
import unittest
from collections import OrderedDict
from typing import Any, Optional
@ -18,15 +17,8 @@ from torch.export.exported_program import (
TensorArgument,
)
from torch.export.graph_signature import CustomObjArgument
from torch.testing._internal.common_utils import (
find_library_location,
IS_FBCODE,
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
run_tests,
TestCase,
)
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.torchbind_impls import load_torchbind_test_lib
class GraphBuilder:
@ -146,18 +138,7 @@ class GraphBuilder:
class TestLift(TestCase):
def setUp(self):
if IS_MACOS:
raise unittest.SkipTest("non-portable load_library call used in test")
elif IS_SANDCASTLE or IS_FBCODE:
torch.ops.load_library(
"//caffe2/test/cpp/jit:test_custom_class_registrations"
)
elif IS_WINDOWS:
lib_file_path = find_library_location("torchbind_test.dll")
torch.ops.load_library(str(lib_file_path))
else:
lib_file_path = find_library_location("libtorchbind_test.so")
torch.ops.load_library(str(lib_file_path))
load_torchbind_test_lib()
def test_lift_basic(self):
builder = GraphBuilder()
@ -379,18 +360,7 @@ class TestLift(TestCase):
class ConstantAttrMapTest(TestCase):
def setUp(self):
if IS_MACOS:
raise unittest.SkipTest("non-portable load_library call used in test")
elif IS_SANDCASTLE or IS_FBCODE:
torch.ops.load_library(
"//caffe2/test/cpp/jit:test_custom_class_registrations"
)
elif IS_WINDOWS:
lib_file_path = find_library_location("torchbind_test.dll")
torch.ops.load_library(str(lib_file_path))
else:
lib_file_path = find_library_location("libtorchbind_test.so")
torch.ops.load_library(str(lib_file_path))
load_torchbind_test_lib()
def test_dict_api(self):
constant_attr_map = ConstantAttrMap()

View File

@ -5,7 +5,6 @@ import copy
import io
import os
import sys
import unittest
from typing import Optional
import torch
@ -16,14 +15,8 @@ from torch.testing._internal.common_utils import skipIfTorchDynamo
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing import FileCheck
from torch.testing._internal.common_utils import (
find_library_location,
IS_FBCODE,
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
)
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.torchbind_impls import load_torchbind_test_lib
if __name__ == "__main__":
@ -37,12 +30,7 @@ if __name__ == "__main__":
@skipIfTorchDynamo("skipping as a precaution")
class TestTorchbind(JitTestCase):
def setUp(self):
if IS_SANDCASTLE or IS_MACOS or IS_FBCODE:
raise unittest.SkipTest("non-portable load_library call used in test")
lib_file_path = find_library_location("libtorchbind_test.so")
if IS_WINDOWS:
lib_file_path = find_library_location("torchbind_test.dll")
torch.ops.load_library(str(lib_file_path))
load_torchbind_test_lib()
def test_torchbind(self):
def test_equality(f, cmp_key):

View File

@ -7,15 +7,8 @@ import threading
import unittest
import torch
from torch.testing._internal.common_utils import (
find_library_location,
IS_FBCODE,
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
run_tests,
TestCase,
)
from torch.testing._internal.common_utils import IS_MACOS, run_tests, TestCase
from torch.testing._internal.torchbind_impls import load_torchbind_test_lib
from torch.utils.weak import _WeakHashRef, WeakIdKeyDictionary
@ -594,18 +587,10 @@ class WeakKeyDictionaryScriptObjectTestCase(TestCase):
def __init__(self, *args, **kw):
unittest.TestCase.__init__(self, *args, **kw)
if IS_SANDCASTLE or IS_FBCODE:
torch.ops.load_library(
"//caffe2/test/cpp/jit:test_custom_class_registrations"
)
elif IS_MACOS:
# don't load the library, just skip the tests in setUp
return
else:
lib_file_path = find_library_location("libtorchbind_test.so")
if IS_WINDOWS:
lib_file_path = find_library_location("torchbind_test.dll")
torch.ops.load_library(str(lib_file_path))
try:
load_torchbind_test_lib()
except unittest.SkipTest:
return # Skip in setup
self.reference = self._reference().copy()

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
import contextlib
from pathlib import Path
from typing import Optional
import torch
@ -113,15 +114,15 @@ def load_torchbind_test_lib():
IS_WINDOWS,
)
if IS_SANDCASTLE or IS_FBCODE:
torch.ops.load_library("//caffe2/test/cpp/jit:test_custom_class_registrations")
elif IS_MACOS:
if IS_MACOS:
raise unittest.SkipTest("non-portable load_library call used in test")
elif IS_SANDCASTLE or IS_FBCODE:
lib_file_path = Path("//caffe2/test/cpp/jit:test_custom_class_registrations")
elif IS_WINDOWS:
lib_file_path = find_library_location("torchbind_test.dll")
else:
lib_file_path = find_library_location("libtorchbind_test.so")
if IS_WINDOWS:
lib_file_path = find_library_location("torchbind_test.dll")
torch.ops.load_library(str(lib_file_path))
torch.ops.load_library(str(lib_file_path))
@contextlib.contextmanager