[OpenReg] Add OSX/Windows Support for OpenReg (#159441)

As the title stated.

**Changes:**

- Abstract platform-specific APIs
- Add OSX/Windows support
- Set default symbol visibility to "hidden"

Co-authored-by: @can-gaa-hou

Original PR:https://github.com/pytorch/pytorch/pull/159029
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159441
Approved by: https://github.com/albanD

Co-authored-by: jiahaochen666 <jiahaochen535@gmail.com>
This commit is contained in:
FFFrog
2025-08-25 08:03:27 +00:00
committed by PyTorch MergeBot
parent 80df27a612
commit 56ebed627a
17 changed files with 368 additions and 118 deletions

View File

@ -4,28 +4,29 @@ project(TORCH_OPENREG CXX C)
include(GNUInstallDirs)
include(CheckCXXCompilerFlag)
include(CMakeDependentOption)
set(CMAKE_SKIP_BUILD_RPATH FALSE)
set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE)
set(CMAKE_INSTALL_RPATH "$ORIGIN/lib/:$ORIGIN/")
set(LINUX TRUE)
set(CMAKE_INSTALL_MESSAGE NEVER)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_C_STANDARD 11)
set(CMAKE_CXX_EXTENSIONS OFF)
set(CMAKE_INSTALL_LIBDIR lib)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_SKIP_BUILD_RPATH FALSE)
set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE)
set(CMAKE_CXX_VISIBILITY_PRESET hidden)
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=1)
if(APPLE)
set(CMAKE_INSTALL_RPATH "@loader_path/lib;@loader_path")
elseif(UNIX)
set(CMAKE_INSTALL_RPATH "$ORIGIN/lib:$ORIGIN")
elseif(WIN32)
set(CMAKE_INSTALL_RPATH "")
endif()
set(CMAKE_INSTALL_LIBDIR lib)
set(CMAKE_INSTALL_MESSAGE NEVER)
set(Torch_DIR ${PYTORCH_INSTALL_DIR}/share/cmake/Torch)
find_package(Torch REQUIRED)
include_directories(${PYTORCH_INSTALL_DIR}/include)
if(DEFINED PYTHON_INCLUDE_DIR)
include_directories(${PYTHON_INCLUDE_DIR})
@ -33,6 +34,8 @@ else()
message(FATAL_ERROR "Cannot find Python directory")
endif()
include(${PROJECT_SOURCE_DIR}/cmake/TorchPythonTargets.cmake)
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/openreg)
add_subdirectory(${PROJECT_SOURCE_DIR}/csrc)
add_subdirectory(${PROJECT_SOURCE_DIR}/torch_openreg/csrc)

View File

@ -0,0 +1,22 @@
if(WIN32)
set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/torch_python.lib")
elseif(APPLE)
set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/libtorch_python.dylib")
else()
set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/libtorch_python.so")
endif()
add_library(torch_python SHARED IMPORTED)
set_target_properties(torch_python PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${PYTORCH_INSTALL_DIR}/include"
INTERFACE_LINK_LIBRARIES "c10;torch_cpu"
IMPORTED_LOCATION "${TORCH_PYTHON_IMPORTED_LOCATION}"
)
add_library(torch_python_library INTERFACE IMPORTED)
set_target_properties(torch_python_library PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "\$<TARGET_PROPERTY:torch_python,INTERFACE_INCLUDE_DIRECTORIES>"
INTERFACE_LINK_LIBRARIES "\$<TARGET_FILE:torch_python>;\$<TARGET_PROPERTY:torch_python,INTERFACE_LINK_LIBRARIES>"
)

View File

@ -6,7 +6,11 @@ file(GLOB_RECURSE SOURCE_FILES
add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES})
target_link_libraries(${LIBRARY_NAME} PRIVATE openreg torch_cpu)
target_link_libraries(${LIBRARY_NAME} PRIVATE torch_cpu_library openreg)
target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})
install(TARGETS ${LIBRARY_NAME}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR}
)

View File

@ -30,7 +30,7 @@ int device_count_impl() {
return count;
}
c10::DeviceIndex device_count() noexcept {
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept {
// initialize number of devices only once
static int count = []() {
try {
@ -49,17 +49,17 @@ c10::DeviceIndex device_count() noexcept {
return static_cast<c10::DeviceIndex>(count);
}
c10::DeviceIndex current_device() {
OPENREG_EXPORT c10::DeviceIndex current_device() {
c10::DeviceIndex cur_device = -1;
GetDevice(&cur_device);
return cur_device;
}
void set_device(c10::DeviceIndex device) {
OPENREG_EXPORT void set_device(c10::DeviceIndex device) {
SetDevice(device);
}
DeviceIndex ExchangeDevice(DeviceIndex device) {
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
int current_device = -1;
orGetDevice(&current_device);

View File

@ -1,5 +1,11 @@
#pragma once
#ifdef _WIN32
#define OPENREG_EXPORT __declspec(dllexport)
#else
#define OPENREG_EXPORT __attribute__((visibility("default")))
#endif
#include <c10/core/Device.h>
#include <c10/macros/Macros.h>
@ -7,10 +13,10 @@
namespace c10::openreg {
c10::DeviceIndex device_count() noexcept;
DeviceIndex current_device();
void set_device(c10::DeviceIndex device);
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept;
OPENREG_EXPORT c10::DeviceIndex current_device();
OPENREG_EXPORT void set_device(c10::DeviceIndex device);
DeviceIndex ExchangeDevice(DeviceIndex device);
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device);
} // namespace c10::openreg

View File

@ -1,5 +1,6 @@
import multiprocessing
import os
import platform
import shutil
import subprocess
import sys
@ -9,10 +10,23 @@ from distutils.command.clean import clean
from setuptools import Extension, find_packages, setup
# Env Variables
IS_DARWIN = platform.system() == "Darwin"
IS_WINDOWS = platform.system() == "Windows"
BASE_DIR = os.path.dirname(os.path.realpath(__file__))
RUN_BUILD_DEPS = any(arg in {"clean", "dist_info"} for arg in sys.argv)
def make_relative_rpath_args(path):
if IS_DARWIN:
return ["-Wl,-rpath,@loader_path/" + path]
elif IS_WINDOWS:
return []
else:
return ["-Wl,-rpath,$ORIGIN/" + path]
def get_pytorch_dir():
import torch
@ -39,9 +53,15 @@ def build_deps():
".",
"--target",
"install",
"--config",
"Release",
"--",
]
build_args += ["-j", str(multiprocessing.cpu_count())]
if IS_WINDOWS:
build_args += ["/m:" + str(multiprocessing.cpu_count())]
else:
build_args += ["-j", str(multiprocessing.cpu_count())]
command = ["cmake"] + build_args
subprocess.check_call(command, cwd=build_dir, env=os.environ)
@ -64,19 +84,47 @@ def main():
if not RUN_BUILD_DEPS:
build_deps()
if IS_WINDOWS:
# /NODEFAULTLIB makes sure we only link to DLL runtime
# and matches the flags set for protobuf and ONNX
extra_link_args: list[str] = ["/NODEFAULTLIB:LIBCMT.LIB"] + [
*make_relative_rpath_args("lib")
]
# /MD links against DLL runtime
# and matches the flags set for protobuf and ONNX
# /EHsc is about standard C++ exception handling
extra_compile_args: list[str] = ["/MD", "/FS", "/EHsc"]
else:
extra_link_args = [*make_relative_rpath_args("lib")]
extra_compile_args = [
"-Wall",
"-Wextra",
"-Wno-strict-overflow",
"-Wno-unused-parameter",
"-Wno-missing-field-initializers",
"-Wno-unknown-pragmas",
]
ext_modules = [
Extension(
name="torch_openreg._C",
sources=["torch_openreg/csrc/stub.c"],
language="c",
extra_compile_args=["-g", "-Wall", "-Werror"],
extra_compile_args=extra_compile_args,
libraries=["torch_bindings"],
library_dirs=[os.path.join(BASE_DIR, "torch_openreg/lib")],
extra_link_args=["-Wl,-rpath,$ORIGIN/lib"],
extra_link_args=extra_link_args,
)
]
package_data = {"torch_openreg": ["lib/*.so*"]}
package_data = {
"torch_openreg": [
"lib/*.so*",
"lib/*.dylib*",
"lib/*.dll",
"lib/*.lib",
]
}
setup(
packages=find_packages(),

View File

@ -8,4 +8,8 @@ add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES})
target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})
install(TARGETS ${LIBRARY_NAME}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR}
)

View File

@ -1,38 +1,9 @@
#include <include/openreg.h>
#include "memory.h"
#include <sys/mman.h>
#include <unistd.h>
#include <cstdlib>
#include <cstring>
#include <map>
#include <mutex>
namespace openreg {
namespace internal {
class ScopedMemoryProtector {
public:
ScopedMemoryProtector(const orPointerAttributes& info)
: m_info(info), m_protected(false) {
if (m_info.type == orMemoryType::orMemoryTypeDevice) {
if (mprotect(m_info.pointer, m_info.size, PROT_READ | PROT_WRITE) ==
0) {
m_protected = true;
}
}
}
~ScopedMemoryProtector() {
if (m_protected) {
mprotect(m_info.pointer, m_info.size, PROT_NONE);
}
}
ScopedMemoryProtector(const ScopedMemoryProtector&) = delete;
ScopedMemoryProtector& operator=(const ScopedMemoryProtector&) = delete;
private:
orPointerAttributes m_info;
bool m_protected;
};
namespace {
class MemoryManager {
public:
@ -46,7 +17,7 @@ class MemoryManager {
return orErrorUnknown;
std::lock_guard<std::mutex> lock(m_mutex);
long page_size = sysconf(_SC_PAGESIZE);
long page_size = openreg::get_pagesize();
size_t aligned_size = ((size - 1) / page_size + 1) * page_size;
void* mem = nullptr;
int current_device = -1;
@ -54,21 +25,15 @@ class MemoryManager {
if (type == orMemoryType::orMemoryTypeDevice) {
orGetDevice(&current_device);
mem = mmap(
nullptr,
aligned_size,
PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS,
-1,
0);
if (mem == MAP_FAILED)
mem = openreg::mmap(aligned_size);
if (mem == nullptr)
return orErrorUnknown;
if (mprotect(mem, aligned_size, PROT_NONE) != 0) {
munmap(mem, aligned_size);
if (openreg::mprotect(mem, aligned_size, F_PROT_NONE) != 0) {
openreg::munmap(mem, aligned_size);
return orErrorUnknown;
}
} else {
if (posix_memalign(&mem, page_size, aligned_size) != 0) {
if (openreg::alloc(&mem, page_size, aligned_size) != 0) {
return orErrorUnknown;
}
}
@ -87,11 +52,12 @@ class MemoryManager {
if (it == m_registry.end())
return orErrorUnknown;
const auto& info = it->second;
if (info.type == orMemoryType::orMemoryTypeDevice) {
mprotect(info.pointer, info.size, PROT_READ | PROT_WRITE);
munmap(info.pointer, info.size);
openreg::mprotect(info.pointer, info.size, F_PROT_READ | F_PROT_WRITE);
openreg::munmap(info.pointer, info.size);
} else {
::free(info.pointer);
openreg::free(info.pointer);
}
m_registry.erase(it);
return orSuccess;
@ -167,7 +133,8 @@ class MemoryManager {
if (info.type != orMemoryType::orMemoryTypeDevice) {
return orErrorUnknown;
}
if (mprotect(info.pointer, info.size, PROT_READ | PROT_WRITE) != 0) {
if (openreg::mprotect(
info.pointer, info.size, F_PROT_READ | F_PROT_WRITE) != 0) {
return orErrorUnknown;
}
return orSuccess;
@ -179,49 +146,75 @@ class MemoryManager {
if (info.type != orMemoryType::orMemoryTypeDevice) {
return orErrorUnknown;
}
if (mprotect(info.pointer, info.size, PROT_NONE) != 0) {
if (openreg::mprotect(info.pointer, info.size, F_PROT_NONE) != 0) {
return orErrorUnknown;
}
return orSuccess;
}
private:
class ScopedMemoryProtector {
public:
ScopedMemoryProtector(const orPointerAttributes& info)
: m_info(info), m_protected(false) {
if (m_info.type == orMemoryType::orMemoryTypeDevice) {
if (openreg::mprotect(
m_info.pointer, m_info.size, F_PROT_READ | F_PROT_WRITE) == 0) {
m_protected = true;
}
}
}
~ScopedMemoryProtector() {
if (m_protected) {
openreg::mprotect(m_info.pointer, m_info.size, F_PROT_NONE);
}
}
ScopedMemoryProtector(const ScopedMemoryProtector&) = delete;
ScopedMemoryProtector& operator=(const ScopedMemoryProtector&) = delete;
private:
orPointerAttributes m_info;
bool m_protected;
};
MemoryManager() = default;
orPointerAttributes getPointerInfo(const void* ptr) {
auto it = m_registry.upper_bound(const_cast<void*>(ptr));
if (it == m_registry.begin())
return {};
--it;
const char* p_char = static_cast<const char*>(ptr);
const char* base_char = static_cast<const char*>(it->first);
if (p_char >= base_char && p_char < (base_char + it->second.size)) {
return it->second;
if (it != m_registry.begin()) {
--it;
const char* p_char = static_cast<const char*>(ptr);
const char* base_char = static_cast<const char*>(it->first);
if (p_char >= base_char && p_char < (base_char + it->second.size)) {
return it->second;
}
}
return {};
}
std::map<void*, orPointerAttributes> m_registry;
std::mutex m_mutex;
};
} // namespace internal
} // namespace openreg
} // namespace
orError_t orMalloc(void** devPtr, size_t size) {
return openreg::internal::MemoryManager::getInstance().allocate(
return MemoryManager::getInstance().allocate(
devPtr, size, orMemoryType::orMemoryTypeDevice);
}
orError_t orFree(void* devPtr) {
return openreg::internal::MemoryManager::getInstance().free(devPtr);
return MemoryManager::getInstance().free(devPtr);
}
orError_t orMallocHost(void** hostPtr, size_t size) {
return openreg::internal::MemoryManager::getInstance().allocate(
return MemoryManager::getInstance().allocate(
hostPtr, size, orMemoryType::orMemoryTypeHost);
}
orError_t orFreeHost(void* hostPtr) {
return openreg::internal::MemoryManager::getInstance().free(hostPtr);
return MemoryManager::getInstance().free(hostPtr);
}
orError_t orMemcpy(
@ -229,21 +222,19 @@ orError_t orMemcpy(
const void* src,
size_t count,
orMemcpyKind kind) {
return openreg::internal::MemoryManager::getInstance().memcpy(
dst, src, count, kind);
return MemoryManager::getInstance().memcpy(dst, src, count, kind);
}
orError_t orPointerGetAttributes(
orPointerAttributes* attributes,
const void* ptr) {
return openreg::internal::MemoryManager::getInstance().getPointerAttributes(
attributes, ptr);
return MemoryManager::getInstance().getPointerAttributes(attributes, ptr);
}
orError_t orMemoryUnprotect(void* devPtr) {
return openreg::internal::MemoryManager::getInstance().unprotect(devPtr);
return MemoryManager::getInstance().unprotect(devPtr);
}
orError_t orMemoryProtect(void* devPtr) {
return openreg::internal::MemoryManager::getInstance().protect(devPtr);
return MemoryManager::getInstance().protect(devPtr);
}

View File

@ -0,0 +1,98 @@
#pragma once
#include <cstddef>
#include <cstdlib>
#include <cstring>
#include <include/openreg.h>
#if defined(_WIN32)
#include <windows.h>
#else
#include <sys/mman.h>
#include <unistd.h>
#endif
#define F_PROT_NONE 0x0
#define F_PROT_READ 0x1
#define F_PROT_WRITE 0x2
namespace openreg {
void* mmap(size_t size) {
#if defined(_WIN32)
return VirtualAlloc(nullptr, size, MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE);
#else
void* addr = ::mmap(
nullptr,
size,
PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS,
-1,
0);
return (addr == MAP_FAILED) ? nullptr : addr;
#endif
}
void munmap(void* addr, size_t size) {
#if defined(_WIN32)
VirtualFree(addr, 0, MEM_RELEASE);
#else
::munmap(addr, size);
#endif
}
int mprotect(void* addr, size_t size, int prot) {
#if defined(_WIN32)
DWORD win_prot = 0;
DWORD old;
if (prot == F_PROT_NONE) {
win_prot = PAGE_NOACCESS;
} else {
win_prot = PAGE_READWRITE;
}
return VirtualProtect(addr, size, win_prot, &old) ? 0 : -1;
#else
int native_prot = 0;
if (prot == F_PROT_NONE)
native_prot = PROT_NONE;
else {
if (prot & F_PROT_READ)
native_prot |= PROT_READ;
if (prot & F_PROT_WRITE)
native_prot |= PROT_WRITE;
}
return ::mprotect(addr, size, native_prot);
#endif
}
int alloc(void** mem, size_t alignment, size_t size) {
#ifdef _WIN32
*mem = _aligned_malloc(size, alignment);
return *mem ? 0 : -1;
#else
return posix_memalign(mem, alignment, size);
#endif
}
void free(void* mem) {
#ifdef _WIN32
_aligned_free(mem);
#else
::free(mem);
#endif
}
long get_pagesize() {
#ifdef _WIN32
SYSTEM_INFO si;
GetSystemInfo(&si);
return static_cast<long>(si.dwPageSize);
#else
return sysconf(_SC_PAGESIZE);
#endif
}
} // namespace openreg

View File

@ -2,6 +2,12 @@
#include <cstddef>
#ifdef _WIN32
#define OPENREG_EXPORT __declspec(dllexport)
#else
#define OPENREG_EXPORT __attribute__((visibility("default")))
#endif
#ifdef __cplusplus
extern "C" {
#endif
@ -28,19 +34,19 @@ struct orPointerAttributes {
size_t size;
};
orError_t orMalloc(void** devPtr, size_t size);
orError_t orFree(void* devPtr);
orError_t orMallocHost(void** hostPtr, size_t size);
orError_t orFreeHost(void* hostPtr);
orError_t orMemcpy(void* dst, const void* src, size_t count, orMemcpyKind kind);
orError_t orMemoryUnprotect(void* devPtr);
orError_t orMemoryProtect(void* devPtr);
OPENREG_EXPORT orError_t orMalloc(void** devPtr, size_t size);
OPENREG_EXPORT orError_t orFree(void* devPtr);
OPENREG_EXPORT orError_t orMallocHost(void** hostPtr, size_t size);
OPENREG_EXPORT orError_t orFreeHost(void* hostPtr);
OPENREG_EXPORT orError_t orMemcpy(void* dst, const void* src, size_t count, orMemcpyKind kind);
OPENREG_EXPORT orError_t orMemoryUnprotect(void* devPtr);
OPENREG_EXPORT orError_t orMemoryProtect(void* devPtr);
orError_t orGetDeviceCount(int* count);
orError_t orSetDevice(int device);
orError_t orGetDevice(int* device);
OPENREG_EXPORT orError_t orGetDeviceCount(int* count);
OPENREG_EXPORT orError_t orSetDevice(int device);
OPENREG_EXPORT orError_t orGetDevice(int* device);
orError_t orPointerGetAttributes(
OPENREG_EXPORT orError_t orPointerGetAttributes(
orPointerAttributes* attributes,
const void* ptr);

View File

@ -1,5 +1,15 @@
import sys
import torch
if sys.platform == "win32":
from ._utils import _load_dll_libraries
_load_dll_libraries()
del _load_dll_libraries
import torch_openreg._C # type: ignore[misc]
import torch_openreg.openreg

View File

@ -0,0 +1,42 @@
import ctypes
import glob
import os
def _load_dll_libraries() -> None:
openreg_dll_path = os.path.join(os.path.dirname(__file__), "lib")
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
prev_error_mode = kernel32.SetErrorMode(0x0001)
kernel32.LoadLibraryW.restype = ctypes.c_void_p
if with_load_library_flags:
kernel32.LoadLibraryExW.restype = ctypes.c_void_p
os.add_dll_directory(openreg_dll_path)
dlls = glob.glob(os.path.join(openreg_dll_path, "*.dll"))
path_patched = False
for dll in dlls:
is_loaded = False
if with_load_library_flags:
res = kernel32.LoadLibraryExW(dll, None, 0x00001100)
last_error = ctypes.get_last_error()
if res is None and last_error != 126:
err = ctypes.WinError(last_error)
err.strerror += f' Error loading "{dll}" or one of its dependencies.'
raise err
elif res is not None:
is_loaded = True
if not is_loaded:
if not path_patched:
os.environ["PATH"] = ";".join([openreg_dll_path] + [os.environ["PATH"]])
path_patched = True
res = kernel32.LoadLibraryW(dll)
if res is None:
err = ctypes.WinError(ctypes.get_last_error())
err.strerror += f' Error loading "{dll}" or one of its dependencies.'
raise err
kernel32.SetErrorMode(prev_error_mode)

View File

@ -6,7 +6,19 @@ file(GLOB_RECURSE SOURCE_FILES
add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES})
target_link_libraries(${LIBRARY_NAME} PRIVATE torch_python torch_openreg)
target_link_libraries(${LIBRARY_NAME} PRIVATE torch_python_library torch_openreg)
if(WIN32)
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
target_link_libraries(${LIBRARY_NAME} PRIVATE ${Python3_LIBRARIES})
elseif(APPLE)
set_target_properties(${LIBRARY_NAME} PROPERTIES LINK_FLAGS "-undefined dynamic_lookup")
endif()
target_link_directories(${LIBRARY_NAME} PRIVATE ${PYTORCH_INSTALL_DIR}/lib)
install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})
install(TARGETS ${LIBRARY_NAME}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR}
)

View File

@ -90,7 +90,7 @@ static PyMethodDef methods[] = {
* Therefore, it cannot be named initModule here, otherwise initModule
* in torch/csrc/Module.cpp will be called, resulting in failure.
*/
extern "C" PyObject* initOpenRegModule(void) {
extern "C" OPENREG_EXPORT PyObject* initOpenRegModule(void) {
static struct PyModuleDef openreg_C_module = {
PyModuleDef_HEAD_INIT, "torch_openreg._C", nullptr, -1, methods};
PyObject* mod = PyModule_Create(&openreg_C_module);

View File

@ -1,13 +1,18 @@
#include <Python.h>
extern PyObject* initOpenRegModule(void);
#ifdef _WIN32
#define OPENREG_EXPORT __declspec(dllexport)
#else
#define OPENREG_EXPORT __attribute__((visibility("default")))
#endif
extern OPENREG_EXPORT PyObject* initOpenRegModule(void);
#ifndef _WIN32
#ifdef __cplusplus
extern "C"
#endif
__attribute__((visibility("default"))) PyObject* PyInit__C(void);
#endif
OPENREG_EXPORT PyObject* PyInit__C(void);
PyMODINIT_FUNC PyInit__C(void)
{

View File

@ -28,7 +28,6 @@ from torch.multiprocessing import current_process, get_context
from torch.testing._internal.common_utils import (
get_report_path,
IS_CI,
IS_LINUX,
IS_MACOS,
retry_shell,
set_cwd,
@ -909,10 +908,6 @@ def _test_autoload(test_directory, options, enable=True):
def run_test_with_openreg(test_module, test_directory, options):
# TODO(FFFrog): Will remove this later when windows/macos are supported.
if not IS_LINUX:
return 0
openreg_dir = os.path.join(
test_directory, "cpp_extensions", "open_registration_extension", "torch_openreg"
)

View File

@ -16,7 +16,9 @@ import torch
from torch.serialization import safe_globals
from torch.testing._internal.common_utils import (
run_tests,
skipIfMPS,
skipIfTorchDynamo,
skipIfWindows,
skipIfXpu,
TemporaryFileName,
TestCase,
@ -284,6 +286,8 @@ class TestOpenReg(TestCase):
self.assertEqual(torch.openreg.initial_seed(), 2024) # type: ignore[misc]
# Autograd
@skipIfMPS
@skipIfWindows()
def test_autograd_init(self):
# Make sure autograd is initialized
torch.ones(2, requires_grad=True, device="openreg").sum().backward()