[OpenReg] Add OSX/Windows Support for OpenReg (#159441)

As the title stated.

**Changes:**

- Abstract platform-specific APIs
- Add OSX/Windows support
- Set default symbol visibility to "hidden"

Co-authored-by: @can-gaa-hou

Original PR:https://github.com/pytorch/pytorch/pull/159029
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159441
Approved by: https://github.com/albanD

Co-authored-by: jiahaochen666 <jiahaochen535@gmail.com>
This commit is contained in:
FFFrog
2025-08-25 08:03:27 +00:00
committed by PyTorch MergeBot
parent 80df27a612
commit 56ebed627a
17 changed files with 368 additions and 118 deletions

View File

@ -4,28 +4,29 @@ project(TORCH_OPENREG CXX C)
include(GNUInstallDirs) include(GNUInstallDirs)
include(CheckCXXCompilerFlag) include(CheckCXXCompilerFlag)
include(CMakeDependentOption)
set(CMAKE_SKIP_BUILD_RPATH FALSE)
set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE)
set(CMAKE_INSTALL_RPATH "$ORIGIN/lib/:$ORIGIN/")
set(LINUX TRUE)
set(CMAKE_INSTALL_MESSAGE NEVER)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
set(CMAKE_C_STANDARD 11) set(CMAKE_C_STANDARD 11)
set(CMAKE_CXX_EXTENSIONS OFF) set(CMAKE_CXX_EXTENSIONS OFF)
set(CMAKE_INSTALL_LIBDIR lib) set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_SKIP_BUILD_RPATH FALSE)
set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE)
set(CMAKE_CXX_VISIBILITY_PRESET hidden)
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=1) if(APPLE)
set(CMAKE_INSTALL_RPATH "@loader_path/lib;@loader_path")
elseif(UNIX)
set(CMAKE_INSTALL_RPATH "$ORIGIN/lib:$ORIGIN")
elseif(WIN32)
set(CMAKE_INSTALL_RPATH "")
endif()
set(CMAKE_INSTALL_LIBDIR lib)
set(CMAKE_INSTALL_MESSAGE NEVER)
set(Torch_DIR ${PYTORCH_INSTALL_DIR}/share/cmake/Torch) set(Torch_DIR ${PYTORCH_INSTALL_DIR}/share/cmake/Torch)
find_package(Torch REQUIRED) find_package(Torch REQUIRED)
include_directories(${PYTORCH_INSTALL_DIR}/include)
if(DEFINED PYTHON_INCLUDE_DIR) if(DEFINED PYTHON_INCLUDE_DIR)
include_directories(${PYTHON_INCLUDE_DIR}) include_directories(${PYTHON_INCLUDE_DIR})
@ -33,6 +34,8 @@ else()
message(FATAL_ERROR "Cannot find Python directory") message(FATAL_ERROR "Cannot find Python directory")
endif() endif()
include(${PROJECT_SOURCE_DIR}/cmake/TorchPythonTargets.cmake)
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/openreg) add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/openreg)
add_subdirectory(${PROJECT_SOURCE_DIR}/csrc) add_subdirectory(${PROJECT_SOURCE_DIR}/csrc)
add_subdirectory(${PROJECT_SOURCE_DIR}/torch_openreg/csrc) add_subdirectory(${PROJECT_SOURCE_DIR}/torch_openreg/csrc)

View File

@ -0,0 +1,22 @@
if(WIN32)
set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/torch_python.lib")
elseif(APPLE)
set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/libtorch_python.dylib")
else()
set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/libtorch_python.so")
endif()
add_library(torch_python SHARED IMPORTED)
set_target_properties(torch_python PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${PYTORCH_INSTALL_DIR}/include"
INTERFACE_LINK_LIBRARIES "c10;torch_cpu"
IMPORTED_LOCATION "${TORCH_PYTHON_IMPORTED_LOCATION}"
)
add_library(torch_python_library INTERFACE IMPORTED)
set_target_properties(torch_python_library PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "\$<TARGET_PROPERTY:torch_python,INTERFACE_INCLUDE_DIRECTORIES>"
INTERFACE_LINK_LIBRARIES "\$<TARGET_FILE:torch_python>;\$<TARGET_PROPERTY:torch_python,INTERFACE_LINK_LIBRARIES>"
)

View File

@ -6,7 +6,11 @@ file(GLOB_RECURSE SOURCE_FILES
add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES}) add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES})
target_link_libraries(${LIBRARY_NAME} PRIVATE openreg torch_cpu) target_link_libraries(${LIBRARY_NAME} PRIVATE torch_cpu_library openreg)
target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) install(TARGETS ${LIBRARY_NAME}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR}
)

View File

@ -30,7 +30,7 @@ int device_count_impl() {
return count; return count;
} }
c10::DeviceIndex device_count() noexcept { OPENREG_EXPORT c10::DeviceIndex device_count() noexcept {
// initialize number of devices only once // initialize number of devices only once
static int count = []() { static int count = []() {
try { try {
@ -49,17 +49,17 @@ c10::DeviceIndex device_count() noexcept {
return static_cast<c10::DeviceIndex>(count); return static_cast<c10::DeviceIndex>(count);
} }
c10::DeviceIndex current_device() { OPENREG_EXPORT c10::DeviceIndex current_device() {
c10::DeviceIndex cur_device = -1; c10::DeviceIndex cur_device = -1;
GetDevice(&cur_device); GetDevice(&cur_device);
return cur_device; return cur_device;
} }
void set_device(c10::DeviceIndex device) { OPENREG_EXPORT void set_device(c10::DeviceIndex device) {
SetDevice(device); SetDevice(device);
} }
DeviceIndex ExchangeDevice(DeviceIndex device) { OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
int current_device = -1; int current_device = -1;
orGetDevice(&current_device); orGetDevice(&current_device);

View File

@ -1,5 +1,11 @@
#pragma once #pragma once
#ifdef _WIN32
#define OPENREG_EXPORT __declspec(dllexport)
#else
#define OPENREG_EXPORT __attribute__((visibility("default")))
#endif
#include <c10/core/Device.h> #include <c10/core/Device.h>
#include <c10/macros/Macros.h> #include <c10/macros/Macros.h>
@ -7,10 +13,10 @@
namespace c10::openreg { namespace c10::openreg {
c10::DeviceIndex device_count() noexcept; OPENREG_EXPORT c10::DeviceIndex device_count() noexcept;
DeviceIndex current_device(); OPENREG_EXPORT c10::DeviceIndex current_device();
void set_device(c10::DeviceIndex device); OPENREG_EXPORT void set_device(c10::DeviceIndex device);
DeviceIndex ExchangeDevice(DeviceIndex device); OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device);
} // namespace c10::openreg } // namespace c10::openreg

View File

@ -1,5 +1,6 @@
import multiprocessing import multiprocessing
import os import os
import platform
import shutil import shutil
import subprocess import subprocess
import sys import sys
@ -9,10 +10,23 @@ from distutils.command.clean import clean
from setuptools import Extension, find_packages, setup from setuptools import Extension, find_packages, setup
# Env Variables
IS_DARWIN = platform.system() == "Darwin"
IS_WINDOWS = platform.system() == "Windows"
BASE_DIR = os.path.dirname(os.path.realpath(__file__)) BASE_DIR = os.path.dirname(os.path.realpath(__file__))
RUN_BUILD_DEPS = any(arg in {"clean", "dist_info"} for arg in sys.argv) RUN_BUILD_DEPS = any(arg in {"clean", "dist_info"} for arg in sys.argv)
def make_relative_rpath_args(path):
if IS_DARWIN:
return ["-Wl,-rpath,@loader_path/" + path]
elif IS_WINDOWS:
return []
else:
return ["-Wl,-rpath,$ORIGIN/" + path]
def get_pytorch_dir(): def get_pytorch_dir():
import torch import torch
@ -39,9 +53,15 @@ def build_deps():
".", ".",
"--target", "--target",
"install", "install",
"--config",
"Release",
"--", "--",
] ]
build_args += ["-j", str(multiprocessing.cpu_count())]
if IS_WINDOWS:
build_args += ["/m:" + str(multiprocessing.cpu_count())]
else:
build_args += ["-j", str(multiprocessing.cpu_count())]
command = ["cmake"] + build_args command = ["cmake"] + build_args
subprocess.check_call(command, cwd=build_dir, env=os.environ) subprocess.check_call(command, cwd=build_dir, env=os.environ)
@ -64,19 +84,47 @@ def main():
if not RUN_BUILD_DEPS: if not RUN_BUILD_DEPS:
build_deps() build_deps()
if IS_WINDOWS:
# /NODEFAULTLIB makes sure we only link to DLL runtime
# and matches the flags set for protobuf and ONNX
extra_link_args: list[str] = ["/NODEFAULTLIB:LIBCMT.LIB"] + [
*make_relative_rpath_args("lib")
]
# /MD links against DLL runtime
# and matches the flags set for protobuf and ONNX
# /EHsc is about standard C++ exception handling
extra_compile_args: list[str] = ["/MD", "/FS", "/EHsc"]
else:
extra_link_args = [*make_relative_rpath_args("lib")]
extra_compile_args = [
"-Wall",
"-Wextra",
"-Wno-strict-overflow",
"-Wno-unused-parameter",
"-Wno-missing-field-initializers",
"-Wno-unknown-pragmas",
]
ext_modules = [ ext_modules = [
Extension( Extension(
name="torch_openreg._C", name="torch_openreg._C",
sources=["torch_openreg/csrc/stub.c"], sources=["torch_openreg/csrc/stub.c"],
language="c", language="c",
extra_compile_args=["-g", "-Wall", "-Werror"], extra_compile_args=extra_compile_args,
libraries=["torch_bindings"], libraries=["torch_bindings"],
library_dirs=[os.path.join(BASE_DIR, "torch_openreg/lib")], library_dirs=[os.path.join(BASE_DIR, "torch_openreg/lib")],
extra_link_args=["-Wl,-rpath,$ORIGIN/lib"], extra_link_args=extra_link_args,
) )
] ]
package_data = {"torch_openreg": ["lib/*.so*"]} package_data = {
"torch_openreg": [
"lib/*.so*",
"lib/*.dylib*",
"lib/*.dll",
"lib/*.lib",
]
}
setup( setup(
packages=find_packages(), packages=find_packages(),

View File

@ -8,4 +8,8 @@ add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES})
target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) install(TARGETS ${LIBRARY_NAME}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR}
)

View File

@ -1,38 +1,9 @@
#include <include/openreg.h> #include "memory.h"
#include <sys/mman.h>
#include <unistd.h>
#include <cstdlib>
#include <cstring>
#include <map> #include <map>
#include <mutex> #include <mutex>
namespace openreg { namespace {
namespace internal {
class ScopedMemoryProtector {
public:
ScopedMemoryProtector(const orPointerAttributes& info)
: m_info(info), m_protected(false) {
if (m_info.type == orMemoryType::orMemoryTypeDevice) {
if (mprotect(m_info.pointer, m_info.size, PROT_READ | PROT_WRITE) ==
0) {
m_protected = true;
}
}
}
~ScopedMemoryProtector() {
if (m_protected) {
mprotect(m_info.pointer, m_info.size, PROT_NONE);
}
}
ScopedMemoryProtector(const ScopedMemoryProtector&) = delete;
ScopedMemoryProtector& operator=(const ScopedMemoryProtector&) = delete;
private:
orPointerAttributes m_info;
bool m_protected;
};
class MemoryManager { class MemoryManager {
public: public:
@ -46,7 +17,7 @@ class MemoryManager {
return orErrorUnknown; return orErrorUnknown;
std::lock_guard<std::mutex> lock(m_mutex); std::lock_guard<std::mutex> lock(m_mutex);
long page_size = sysconf(_SC_PAGESIZE); long page_size = openreg::get_pagesize();
size_t aligned_size = ((size - 1) / page_size + 1) * page_size; size_t aligned_size = ((size - 1) / page_size + 1) * page_size;
void* mem = nullptr; void* mem = nullptr;
int current_device = -1; int current_device = -1;
@ -54,21 +25,15 @@ class MemoryManager {
if (type == orMemoryType::orMemoryTypeDevice) { if (type == orMemoryType::orMemoryTypeDevice) {
orGetDevice(&current_device); orGetDevice(&current_device);
mem = mmap( mem = openreg::mmap(aligned_size);
nullptr, if (mem == nullptr)
aligned_size,
PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS,
-1,
0);
if (mem == MAP_FAILED)
return orErrorUnknown; return orErrorUnknown;
if (mprotect(mem, aligned_size, PROT_NONE) != 0) { if (openreg::mprotect(mem, aligned_size, F_PROT_NONE) != 0) {
munmap(mem, aligned_size); openreg::munmap(mem, aligned_size);
return orErrorUnknown; return orErrorUnknown;
} }
} else { } else {
if (posix_memalign(&mem, page_size, aligned_size) != 0) { if (openreg::alloc(&mem, page_size, aligned_size) != 0) {
return orErrorUnknown; return orErrorUnknown;
} }
} }
@ -87,11 +52,12 @@ class MemoryManager {
if (it == m_registry.end()) if (it == m_registry.end())
return orErrorUnknown; return orErrorUnknown;
const auto& info = it->second; const auto& info = it->second;
if (info.type == orMemoryType::orMemoryTypeDevice) { if (info.type == orMemoryType::orMemoryTypeDevice) {
mprotect(info.pointer, info.size, PROT_READ | PROT_WRITE); openreg::mprotect(info.pointer, info.size, F_PROT_READ | F_PROT_WRITE);
munmap(info.pointer, info.size); openreg::munmap(info.pointer, info.size);
} else { } else {
::free(info.pointer); openreg::free(info.pointer);
} }
m_registry.erase(it); m_registry.erase(it);
return orSuccess; return orSuccess;
@ -167,7 +133,8 @@ class MemoryManager {
if (info.type != orMemoryType::orMemoryTypeDevice) { if (info.type != orMemoryType::orMemoryTypeDevice) {
return orErrorUnknown; return orErrorUnknown;
} }
if (mprotect(info.pointer, info.size, PROT_READ | PROT_WRITE) != 0) { if (openreg::mprotect(
info.pointer, info.size, F_PROT_READ | F_PROT_WRITE) != 0) {
return orErrorUnknown; return orErrorUnknown;
} }
return orSuccess; return orSuccess;
@ -179,49 +146,75 @@ class MemoryManager {
if (info.type != orMemoryType::orMemoryTypeDevice) { if (info.type != orMemoryType::orMemoryTypeDevice) {
return orErrorUnknown; return orErrorUnknown;
} }
if (mprotect(info.pointer, info.size, PROT_NONE) != 0) { if (openreg::mprotect(info.pointer, info.size, F_PROT_NONE) != 0) {
return orErrorUnknown; return orErrorUnknown;
} }
return orSuccess; return orSuccess;
} }
private: private:
class ScopedMemoryProtector {
public:
ScopedMemoryProtector(const orPointerAttributes& info)
: m_info(info), m_protected(false) {
if (m_info.type == orMemoryType::orMemoryTypeDevice) {
if (openreg::mprotect(
m_info.pointer, m_info.size, F_PROT_READ | F_PROT_WRITE) == 0) {
m_protected = true;
}
}
}
~ScopedMemoryProtector() {
if (m_protected) {
openreg::mprotect(m_info.pointer, m_info.size, F_PROT_NONE);
}
}
ScopedMemoryProtector(const ScopedMemoryProtector&) = delete;
ScopedMemoryProtector& operator=(const ScopedMemoryProtector&) = delete;
private:
orPointerAttributes m_info;
bool m_protected;
};
MemoryManager() = default; MemoryManager() = default;
orPointerAttributes getPointerInfo(const void* ptr) { orPointerAttributes getPointerInfo(const void* ptr) {
auto it = m_registry.upper_bound(const_cast<void*>(ptr)); auto it = m_registry.upper_bound(const_cast<void*>(ptr));
if (it == m_registry.begin()) if (it != m_registry.begin()) {
return {}; --it;
--it; const char* p_char = static_cast<const char*>(ptr);
const char* p_char = static_cast<const char*>(ptr); const char* base_char = static_cast<const char*>(it->first);
const char* base_char = static_cast<const char*>(it->first); if (p_char >= base_char && p_char < (base_char + it->second.size)) {
if (p_char >= base_char && p_char < (base_char + it->second.size)) { return it->second;
return it->second; }
} }
return {}; return {};
} }
std::map<void*, orPointerAttributes> m_registry; std::map<void*, orPointerAttributes> m_registry;
std::mutex m_mutex; std::mutex m_mutex;
}; };
} // namespace internal } // namespace
} // namespace openreg
orError_t orMalloc(void** devPtr, size_t size) { orError_t orMalloc(void** devPtr, size_t size) {
return openreg::internal::MemoryManager::getInstance().allocate( return MemoryManager::getInstance().allocate(
devPtr, size, orMemoryType::orMemoryTypeDevice); devPtr, size, orMemoryType::orMemoryTypeDevice);
} }
orError_t orFree(void* devPtr) { orError_t orFree(void* devPtr) {
return openreg::internal::MemoryManager::getInstance().free(devPtr); return MemoryManager::getInstance().free(devPtr);
} }
orError_t orMallocHost(void** hostPtr, size_t size) { orError_t orMallocHost(void** hostPtr, size_t size) {
return openreg::internal::MemoryManager::getInstance().allocate( return MemoryManager::getInstance().allocate(
hostPtr, size, orMemoryType::orMemoryTypeHost); hostPtr, size, orMemoryType::orMemoryTypeHost);
} }
orError_t orFreeHost(void* hostPtr) { orError_t orFreeHost(void* hostPtr) {
return openreg::internal::MemoryManager::getInstance().free(hostPtr); return MemoryManager::getInstance().free(hostPtr);
} }
orError_t orMemcpy( orError_t orMemcpy(
@ -229,21 +222,19 @@ orError_t orMemcpy(
const void* src, const void* src,
size_t count, size_t count,
orMemcpyKind kind) { orMemcpyKind kind) {
return openreg::internal::MemoryManager::getInstance().memcpy( return MemoryManager::getInstance().memcpy(dst, src, count, kind);
dst, src, count, kind);
} }
orError_t orPointerGetAttributes( orError_t orPointerGetAttributes(
orPointerAttributes* attributes, orPointerAttributes* attributes,
const void* ptr) { const void* ptr) {
return openreg::internal::MemoryManager::getInstance().getPointerAttributes( return MemoryManager::getInstance().getPointerAttributes(attributes, ptr);
attributes, ptr);
} }
orError_t orMemoryUnprotect(void* devPtr) { orError_t orMemoryUnprotect(void* devPtr) {
return openreg::internal::MemoryManager::getInstance().unprotect(devPtr); return MemoryManager::getInstance().unprotect(devPtr);
} }
orError_t orMemoryProtect(void* devPtr) { orError_t orMemoryProtect(void* devPtr) {
return openreg::internal::MemoryManager::getInstance().protect(devPtr); return MemoryManager::getInstance().protect(devPtr);
} }

View File

@ -0,0 +1,98 @@
#pragma once
#include <cstddef>
#include <cstdlib>
#include <cstring>
#include <include/openreg.h>
#if defined(_WIN32)
#include <windows.h>
#else
#include <sys/mman.h>
#include <unistd.h>
#endif
#define F_PROT_NONE 0x0
#define F_PROT_READ 0x1
#define F_PROT_WRITE 0x2
namespace openreg {
void* mmap(size_t size) {
#if defined(_WIN32)
return VirtualAlloc(nullptr, size, MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE);
#else
void* addr = ::mmap(
nullptr,
size,
PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS,
-1,
0);
return (addr == MAP_FAILED) ? nullptr : addr;
#endif
}
void munmap(void* addr, size_t size) {
#if defined(_WIN32)
VirtualFree(addr, 0, MEM_RELEASE);
#else
::munmap(addr, size);
#endif
}
int mprotect(void* addr, size_t size, int prot) {
#if defined(_WIN32)
DWORD win_prot = 0;
DWORD old;
if (prot == F_PROT_NONE) {
win_prot = PAGE_NOACCESS;
} else {
win_prot = PAGE_READWRITE;
}
return VirtualProtect(addr, size, win_prot, &old) ? 0 : -1;
#else
int native_prot = 0;
if (prot == F_PROT_NONE)
native_prot = PROT_NONE;
else {
if (prot & F_PROT_READ)
native_prot |= PROT_READ;
if (prot & F_PROT_WRITE)
native_prot |= PROT_WRITE;
}
return ::mprotect(addr, size, native_prot);
#endif
}
int alloc(void** mem, size_t alignment, size_t size) {
#ifdef _WIN32
*mem = _aligned_malloc(size, alignment);
return *mem ? 0 : -1;
#else
return posix_memalign(mem, alignment, size);
#endif
}
void free(void* mem) {
#ifdef _WIN32
_aligned_free(mem);
#else
::free(mem);
#endif
}
long get_pagesize() {
#ifdef _WIN32
SYSTEM_INFO si;
GetSystemInfo(&si);
return static_cast<long>(si.dwPageSize);
#else
return sysconf(_SC_PAGESIZE);
#endif
}
} // namespace openreg

View File

@ -2,6 +2,12 @@
#include <cstddef> #include <cstddef>
#ifdef _WIN32
#define OPENREG_EXPORT __declspec(dllexport)
#else
#define OPENREG_EXPORT __attribute__((visibility("default")))
#endif
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
@ -28,19 +34,19 @@ struct orPointerAttributes {
size_t size; size_t size;
}; };
orError_t orMalloc(void** devPtr, size_t size); OPENREG_EXPORT orError_t orMalloc(void** devPtr, size_t size);
orError_t orFree(void* devPtr); OPENREG_EXPORT orError_t orFree(void* devPtr);
orError_t orMallocHost(void** hostPtr, size_t size); OPENREG_EXPORT orError_t orMallocHost(void** hostPtr, size_t size);
orError_t orFreeHost(void* hostPtr); OPENREG_EXPORT orError_t orFreeHost(void* hostPtr);
orError_t orMemcpy(void* dst, const void* src, size_t count, orMemcpyKind kind); OPENREG_EXPORT orError_t orMemcpy(void* dst, const void* src, size_t count, orMemcpyKind kind);
orError_t orMemoryUnprotect(void* devPtr); OPENREG_EXPORT orError_t orMemoryUnprotect(void* devPtr);
orError_t orMemoryProtect(void* devPtr); OPENREG_EXPORT orError_t orMemoryProtect(void* devPtr);
orError_t orGetDeviceCount(int* count); OPENREG_EXPORT orError_t orGetDeviceCount(int* count);
orError_t orSetDevice(int device); OPENREG_EXPORT orError_t orSetDevice(int device);
orError_t orGetDevice(int* device); OPENREG_EXPORT orError_t orGetDevice(int* device);
orError_t orPointerGetAttributes( OPENREG_EXPORT orError_t orPointerGetAttributes(
orPointerAttributes* attributes, orPointerAttributes* attributes,
const void* ptr); const void* ptr);

View File

@ -1,5 +1,15 @@
import sys
import torch import torch
if sys.platform == "win32":
from ._utils import _load_dll_libraries
_load_dll_libraries()
del _load_dll_libraries
import torch_openreg._C # type: ignore[misc] import torch_openreg._C # type: ignore[misc]
import torch_openreg.openreg import torch_openreg.openreg

View File

@ -0,0 +1,42 @@
import ctypes
import glob
import os
def _load_dll_libraries() -> None:
openreg_dll_path = os.path.join(os.path.dirname(__file__), "lib")
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
prev_error_mode = kernel32.SetErrorMode(0x0001)
kernel32.LoadLibraryW.restype = ctypes.c_void_p
if with_load_library_flags:
kernel32.LoadLibraryExW.restype = ctypes.c_void_p
os.add_dll_directory(openreg_dll_path)
dlls = glob.glob(os.path.join(openreg_dll_path, "*.dll"))
path_patched = False
for dll in dlls:
is_loaded = False
if with_load_library_flags:
res = kernel32.LoadLibraryExW(dll, None, 0x00001100)
last_error = ctypes.get_last_error()
if res is None and last_error != 126:
err = ctypes.WinError(last_error)
err.strerror += f' Error loading "{dll}" or one of its dependencies.'
raise err
elif res is not None:
is_loaded = True
if not is_loaded:
if not path_patched:
os.environ["PATH"] = ";".join([openreg_dll_path] + [os.environ["PATH"]])
path_patched = True
res = kernel32.LoadLibraryW(dll)
if res is None:
err = ctypes.WinError(ctypes.get_last_error())
err.strerror += f' Error loading "{dll}" or one of its dependencies.'
raise err
kernel32.SetErrorMode(prev_error_mode)

View File

@ -6,7 +6,19 @@ file(GLOB_RECURSE SOURCE_FILES
add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES}) add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES})
target_link_libraries(${LIBRARY_NAME} PRIVATE torch_python torch_openreg) target_link_libraries(${LIBRARY_NAME} PRIVATE torch_python_library torch_openreg)
if(WIN32)
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
target_link_libraries(${LIBRARY_NAME} PRIVATE ${Python3_LIBRARIES})
elseif(APPLE)
set_target_properties(${LIBRARY_NAME} PROPERTIES LINK_FLAGS "-undefined dynamic_lookup")
endif()
target_link_directories(${LIBRARY_NAME} PRIVATE ${PYTORCH_INSTALL_DIR}/lib) target_link_directories(${LIBRARY_NAME} PRIVATE ${PYTORCH_INSTALL_DIR}/lib)
install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) install(TARGETS ${LIBRARY_NAME}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR}
)

View File

@ -90,7 +90,7 @@ static PyMethodDef methods[] = {
* Therefore, it cannot be named initModule here, otherwise initModule * Therefore, it cannot be named initModule here, otherwise initModule
* in torch/csrc/Module.cpp will be called, resulting in failure. * in torch/csrc/Module.cpp will be called, resulting in failure.
*/ */
extern "C" PyObject* initOpenRegModule(void) { extern "C" OPENREG_EXPORT PyObject* initOpenRegModule(void) {
static struct PyModuleDef openreg_C_module = { static struct PyModuleDef openreg_C_module = {
PyModuleDef_HEAD_INIT, "torch_openreg._C", nullptr, -1, methods}; PyModuleDef_HEAD_INIT, "torch_openreg._C", nullptr, -1, methods};
PyObject* mod = PyModule_Create(&openreg_C_module); PyObject* mod = PyModule_Create(&openreg_C_module);

View File

@ -1,13 +1,18 @@
#include <Python.h> #include <Python.h>
extern PyObject* initOpenRegModule(void); #ifdef _WIN32
#define OPENREG_EXPORT __declspec(dllexport)
#else
#define OPENREG_EXPORT __attribute__((visibility("default")))
#endif
extern OPENREG_EXPORT PyObject* initOpenRegModule(void);
#ifndef _WIN32
#ifdef __cplusplus #ifdef __cplusplus
extern "C" extern "C"
#endif #endif
__attribute__((visibility("default"))) PyObject* PyInit__C(void);
#endif OPENREG_EXPORT PyObject* PyInit__C(void);
PyMODINIT_FUNC PyInit__C(void) PyMODINIT_FUNC PyInit__C(void)
{ {

View File

@ -28,7 +28,6 @@ from torch.multiprocessing import current_process, get_context
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
get_report_path, get_report_path,
IS_CI, IS_CI,
IS_LINUX,
IS_MACOS, IS_MACOS,
retry_shell, retry_shell,
set_cwd, set_cwd,
@ -909,10 +908,6 @@ def _test_autoload(test_directory, options, enable=True):
def run_test_with_openreg(test_module, test_directory, options): def run_test_with_openreg(test_module, test_directory, options):
# TODO(FFFrog): Will remove this later when windows/macos are supported.
if not IS_LINUX:
return 0
openreg_dir = os.path.join( openreg_dir = os.path.join(
test_directory, "cpp_extensions", "open_registration_extension", "torch_openreg" test_directory, "cpp_extensions", "open_registration_extension", "torch_openreg"
) )

View File

@ -16,7 +16,9 @@ import torch
from torch.serialization import safe_globals from torch.serialization import safe_globals
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
run_tests, run_tests,
skipIfMPS,
skipIfTorchDynamo, skipIfTorchDynamo,
skipIfWindows,
skipIfXpu, skipIfXpu,
TemporaryFileName, TemporaryFileName,
TestCase, TestCase,
@ -284,6 +286,8 @@ class TestOpenReg(TestCase):
self.assertEqual(torch.openreg.initial_seed(), 2024) # type: ignore[misc] self.assertEqual(torch.openreg.initial_seed(), 2024) # type: ignore[misc]
# Autograd # Autograd
@skipIfMPS
@skipIfWindows()
def test_autograd_init(self): def test_autograd_init(self):
# Make sure autograd is initialized # Make sure autograd is initialized
torch.ones(2, requires_grad=True, device="openreg").sum().backward() torch.ones(2, requires_grad=True, device="openreg").sum().backward()