Consistently use load_torchbind_test_lib in tests (#148082)

The same code is repeated multiple times with slightly different implementations.
Use the existing function for brevity and consistency.

In the function the code from `test_export` is used which does a single `load_library` with cleaner conditions

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148082
Approved by: https://github.com/angelayi
This commit is contained in:
Alexander Grund
2025-03-03 19:37:24 +00:00
committed by PyTorch MergeBot
parent 40c2505f16
commit 302c660298
5 changed files with 21 additions and 84 deletions

View File

@ -76,6 +76,7 @@ from torch.testing._internal.custom_tensor import (
CustomTensorPlainOut, CustomTensorPlainOut,
) )
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
from torch.testing._internal.torchbind_impls import load_torchbind_test_lib
from torch.testing._internal.triton_utils import requires_cuda, requires_gpu from torch.testing._internal.triton_utils import requires_cuda, requires_gpu
from torch.testing._internal.two_tensor import TwoTensor from torch.testing._internal.two_tensor import TwoTensor
from torch.utils._pytree import ( from torch.utils._pytree import (
@ -12662,15 +12663,7 @@ def forward(self, x):
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
class TestExportCustomClass(TorchTestCase): class TestExportCustomClass(TorchTestCase):
def setUp(self): def setUp(self):
if IS_FBCODE: load_torchbind_test_lib()
lib_file_path = "//caffe2/test/cpp/jit:test_custom_class_registrations"
elif IS_SANDCASTLE or IS_MACOS:
raise unittest.SkipTest("non-portable load_library call used in test")
elif IS_WINDOWS:
lib_file_path = find_library_location("torchbind_test.dll")
else:
lib_file_path = find_library_location("libtorchbind_test.so")
torch.ops.load_library(str(lib_file_path))
def test_lift_custom_obj(self): def test_lift_custom_obj(self):
# TODO: fix this test once custom class tracing is implemented # TODO: fix this test once custom class tracing is implemented

View File

@ -1,5 +1,4 @@
# Owner(s): ["oncall: export"] # Owner(s): ["oncall: export"]
import unittest
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Optional from typing import Any, Optional
@ -18,15 +17,8 @@ from torch.export.exported_program import (
TensorArgument, TensorArgument,
) )
from torch.export.graph_signature import CustomObjArgument from torch.export.graph_signature import CustomObjArgument
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import run_tests, TestCase
find_library_location, from torch.testing._internal.torchbind_impls import load_torchbind_test_lib
IS_FBCODE,
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
run_tests,
TestCase,
)
class GraphBuilder: class GraphBuilder:
@ -146,18 +138,7 @@ class GraphBuilder:
class TestLift(TestCase): class TestLift(TestCase):
def setUp(self): def setUp(self):
if IS_MACOS: load_torchbind_test_lib()
raise unittest.SkipTest("non-portable load_library call used in test")
elif IS_SANDCASTLE or IS_FBCODE:
torch.ops.load_library(
"//caffe2/test/cpp/jit:test_custom_class_registrations"
)
elif IS_WINDOWS:
lib_file_path = find_library_location("torchbind_test.dll")
torch.ops.load_library(str(lib_file_path))
else:
lib_file_path = find_library_location("libtorchbind_test.so")
torch.ops.load_library(str(lib_file_path))
def test_lift_basic(self): def test_lift_basic(self):
builder = GraphBuilder() builder = GraphBuilder()
@ -379,18 +360,7 @@ class TestLift(TestCase):
class ConstantAttrMapTest(TestCase): class ConstantAttrMapTest(TestCase):
def setUp(self): def setUp(self):
if IS_MACOS: load_torchbind_test_lib()
raise unittest.SkipTest("non-portable load_library call used in test")
elif IS_SANDCASTLE or IS_FBCODE:
torch.ops.load_library(
"//caffe2/test/cpp/jit:test_custom_class_registrations"
)
elif IS_WINDOWS:
lib_file_path = find_library_location("torchbind_test.dll")
torch.ops.load_library(str(lib_file_path))
else:
lib_file_path = find_library_location("libtorchbind_test.so")
torch.ops.load_library(str(lib_file_path))
def test_dict_api(self): def test_dict_api(self):
constant_attr_map = ConstantAttrMap() constant_attr_map = ConstantAttrMap()

View File

@ -5,7 +5,6 @@ import copy
import io import io
import os import os
import sys import sys
import unittest
from typing import Optional from typing import Optional
import torch import torch
@ -16,14 +15,8 @@ from torch.testing._internal.common_utils import skipIfTorchDynamo
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing import FileCheck from torch.testing import FileCheck
from torch.testing._internal.common_utils import (
find_library_location,
IS_FBCODE,
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
)
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.torchbind_impls import load_torchbind_test_lib
if __name__ == "__main__": if __name__ == "__main__":
@ -37,12 +30,7 @@ if __name__ == "__main__":
@skipIfTorchDynamo("skipping as a precaution") @skipIfTorchDynamo("skipping as a precaution")
class TestTorchbind(JitTestCase): class TestTorchbind(JitTestCase):
def setUp(self): def setUp(self):
if IS_SANDCASTLE or IS_MACOS or IS_FBCODE: load_torchbind_test_lib()
raise unittest.SkipTest("non-portable load_library call used in test")
lib_file_path = find_library_location("libtorchbind_test.so")
if IS_WINDOWS:
lib_file_path = find_library_location("torchbind_test.dll")
torch.ops.load_library(str(lib_file_path))
def test_torchbind(self): def test_torchbind(self):
def test_equality(f, cmp_key): def test_equality(f, cmp_key):

View File

@ -7,15 +7,8 @@ import threading
import unittest import unittest
import torch import torch
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import IS_MACOS, run_tests, TestCase
find_library_location, from torch.testing._internal.torchbind_impls import load_torchbind_test_lib
IS_FBCODE,
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
run_tests,
TestCase,
)
from torch.utils.weak import _WeakHashRef, WeakIdKeyDictionary from torch.utils.weak import _WeakHashRef, WeakIdKeyDictionary
@ -594,18 +587,10 @@ class WeakKeyDictionaryScriptObjectTestCase(TestCase):
def __init__(self, *args, **kw): def __init__(self, *args, **kw):
unittest.TestCase.__init__(self, *args, **kw) unittest.TestCase.__init__(self, *args, **kw)
if IS_SANDCASTLE or IS_FBCODE: try:
torch.ops.load_library( load_torchbind_test_lib()
"//caffe2/test/cpp/jit:test_custom_class_registrations" except unittest.SkipTest:
) return # Skip in setup
elif IS_MACOS:
# don't load the library, just skip the tests in setUp
return
else:
lib_file_path = find_library_location("libtorchbind_test.so")
if IS_WINDOWS:
lib_file_path = find_library_location("torchbind_test.dll")
torch.ops.load_library(str(lib_file_path))
self.reference = self._reference().copy() self.reference = self._reference().copy()

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import contextlib import contextlib
from pathlib import Path
from typing import Optional from typing import Optional
import torch import torch
@ -113,15 +114,15 @@ def load_torchbind_test_lib():
IS_WINDOWS, IS_WINDOWS,
) )
if IS_SANDCASTLE or IS_FBCODE: if IS_MACOS:
torch.ops.load_library("//caffe2/test/cpp/jit:test_custom_class_registrations")
elif IS_MACOS:
raise unittest.SkipTest("non-portable load_library call used in test") raise unittest.SkipTest("non-portable load_library call used in test")
elif IS_SANDCASTLE or IS_FBCODE:
lib_file_path = Path("//caffe2/test/cpp/jit:test_custom_class_registrations")
elif IS_WINDOWS:
lib_file_path = find_library_location("torchbind_test.dll")
else: else:
lib_file_path = find_library_location("libtorchbind_test.so") lib_file_path = find_library_location("libtorchbind_test.so")
if IS_WINDOWS: torch.ops.load_library(str(lib_file_path))
lib_file_path = find_library_location("torchbind_test.dll")
torch.ops.load_library(str(lib_file_path))
@contextlib.contextmanager @contextlib.contextmanager