mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[OpenReg] Add OSX/Windows Support for OpenReg (#159441)
As the title stated. **Changes:** - Abstract platform-specific APIs - Add OSX/Windows support - Set default symbol visibility to "hidden" Co-authored-by: @can-gaa-hou Original PR:https://github.com/pytorch/pytorch/pull/159029 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159441 Approved by: https://github.com/albanD Co-authored-by: jiahaochen666 <jiahaochen535@gmail.com>
This commit is contained in:
@ -4,28 +4,29 @@ project(TORCH_OPENREG CXX C)
|
||||
|
||||
include(GNUInstallDirs)
|
||||
include(CheckCXXCompilerFlag)
|
||||
include(CMakeDependentOption)
|
||||
|
||||
set(CMAKE_SKIP_BUILD_RPATH FALSE)
|
||||
set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)
|
||||
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE)
|
||||
set(CMAKE_INSTALL_RPATH "$ORIGIN/lib/:$ORIGIN/")
|
||||
|
||||
set(LINUX TRUE)
|
||||
set(CMAKE_INSTALL_MESSAGE NEVER)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_C_STANDARD 11)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
|
||||
set(CMAKE_INSTALL_LIBDIR lib)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
set(CMAKE_SKIP_BUILD_RPATH FALSE)
|
||||
set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)
|
||||
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE)
|
||||
set(CMAKE_CXX_VISIBILITY_PRESET hidden)
|
||||
|
||||
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=1)
|
||||
if(APPLE)
|
||||
set(CMAKE_INSTALL_RPATH "@loader_path/lib;@loader_path")
|
||||
elseif(UNIX)
|
||||
set(CMAKE_INSTALL_RPATH "$ORIGIN/lib:$ORIGIN")
|
||||
elseif(WIN32)
|
||||
set(CMAKE_INSTALL_RPATH "")
|
||||
endif()
|
||||
set(CMAKE_INSTALL_LIBDIR lib)
|
||||
set(CMAKE_INSTALL_MESSAGE NEVER)
|
||||
|
||||
set(Torch_DIR ${PYTORCH_INSTALL_DIR}/share/cmake/Torch)
|
||||
find_package(Torch REQUIRED)
|
||||
include_directories(${PYTORCH_INSTALL_DIR}/include)
|
||||
|
||||
if(DEFINED PYTHON_INCLUDE_DIR)
|
||||
include_directories(${PYTHON_INCLUDE_DIR})
|
||||
@ -33,6 +34,8 @@ else()
|
||||
message(FATAL_ERROR "Cannot find Python directory")
|
||||
endif()
|
||||
|
||||
include(${PROJECT_SOURCE_DIR}/cmake/TorchPythonTargets.cmake)
|
||||
|
||||
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/openreg)
|
||||
add_subdirectory(${PROJECT_SOURCE_DIR}/csrc)
|
||||
add_subdirectory(${PROJECT_SOURCE_DIR}/torch_openreg/csrc)
|
||||
|
@ -0,0 +1,22 @@
|
||||
if(WIN32)
|
||||
set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/torch_python.lib")
|
||||
elseif(APPLE)
|
||||
set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/libtorch_python.dylib")
|
||||
else()
|
||||
set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/libtorch_python.so")
|
||||
endif()
|
||||
|
||||
add_library(torch_python SHARED IMPORTED)
|
||||
|
||||
set_target_properties(torch_python PROPERTIES
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${PYTORCH_INSTALL_DIR}/include"
|
||||
INTERFACE_LINK_LIBRARIES "c10;torch_cpu"
|
||||
IMPORTED_LOCATION "${TORCH_PYTHON_IMPORTED_LOCATION}"
|
||||
)
|
||||
|
||||
add_library(torch_python_library INTERFACE IMPORTED)
|
||||
|
||||
set_target_properties(torch_python_library PROPERTIES
|
||||
INTERFACE_INCLUDE_DIRECTORIES "\$<TARGET_PROPERTY:torch_python,INTERFACE_INCLUDE_DIRECTORIES>"
|
||||
INTERFACE_LINK_LIBRARIES "\$<TARGET_FILE:torch_python>;\$<TARGET_PROPERTY:torch_python,INTERFACE_LINK_LIBRARIES>"
|
||||
)
|
@ -6,7 +6,11 @@ file(GLOB_RECURSE SOURCE_FILES
|
||||
|
||||
add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES})
|
||||
|
||||
target_link_libraries(${LIBRARY_NAME} PRIVATE openreg torch_cpu)
|
||||
target_link_libraries(${LIBRARY_NAME} PRIVATE torch_cpu_library openreg)
|
||||
target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
||||
install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})
|
||||
install(TARGETS ${LIBRARY_NAME}
|
||||
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
)
|
||||
|
@ -30,7 +30,7 @@ int device_count_impl() {
|
||||
return count;
|
||||
}
|
||||
|
||||
c10::DeviceIndex device_count() noexcept {
|
||||
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept {
|
||||
// initialize number of devices only once
|
||||
static int count = []() {
|
||||
try {
|
||||
@ -49,17 +49,17 @@ c10::DeviceIndex device_count() noexcept {
|
||||
return static_cast<c10::DeviceIndex>(count);
|
||||
}
|
||||
|
||||
c10::DeviceIndex current_device() {
|
||||
OPENREG_EXPORT c10::DeviceIndex current_device() {
|
||||
c10::DeviceIndex cur_device = -1;
|
||||
GetDevice(&cur_device);
|
||||
return cur_device;
|
||||
}
|
||||
|
||||
void set_device(c10::DeviceIndex device) {
|
||||
OPENREG_EXPORT void set_device(c10::DeviceIndex device) {
|
||||
SetDevice(device);
|
||||
}
|
||||
|
||||
DeviceIndex ExchangeDevice(DeviceIndex device) {
|
||||
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
|
||||
int current_device = -1;
|
||||
orGetDevice(¤t_device);
|
||||
|
||||
|
@ -1,5 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef _WIN32
|
||||
#define OPENREG_EXPORT __declspec(dllexport)
|
||||
#else
|
||||
#define OPENREG_EXPORT __attribute__((visibility("default")))
|
||||
#endif
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
@ -7,10 +13,10 @@
|
||||
|
||||
namespace c10::openreg {
|
||||
|
||||
c10::DeviceIndex device_count() noexcept;
|
||||
DeviceIndex current_device();
|
||||
void set_device(c10::DeviceIndex device);
|
||||
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept;
|
||||
OPENREG_EXPORT c10::DeviceIndex current_device();
|
||||
OPENREG_EXPORT void set_device(c10::DeviceIndex device);
|
||||
|
||||
DeviceIndex ExchangeDevice(DeviceIndex device);
|
||||
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device);
|
||||
|
||||
} // namespace c10::openreg
|
||||
|
@ -1,5 +1,6 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
@ -9,10 +10,23 @@ from distutils.command.clean import clean
|
||||
from setuptools import Extension, find_packages, setup
|
||||
|
||||
|
||||
# Env Variables
|
||||
IS_DARWIN = platform.system() == "Darwin"
|
||||
IS_WINDOWS = platform.system() == "Windows"
|
||||
|
||||
BASE_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
RUN_BUILD_DEPS = any(arg in {"clean", "dist_info"} for arg in sys.argv)
|
||||
|
||||
|
||||
def make_relative_rpath_args(path):
|
||||
if IS_DARWIN:
|
||||
return ["-Wl,-rpath,@loader_path/" + path]
|
||||
elif IS_WINDOWS:
|
||||
return []
|
||||
else:
|
||||
return ["-Wl,-rpath,$ORIGIN/" + path]
|
||||
|
||||
|
||||
def get_pytorch_dir():
|
||||
import torch
|
||||
|
||||
@ -39,9 +53,15 @@ def build_deps():
|
||||
".",
|
||||
"--target",
|
||||
"install",
|
||||
"--config",
|
||||
"Release",
|
||||
"--",
|
||||
]
|
||||
build_args += ["-j", str(multiprocessing.cpu_count())]
|
||||
|
||||
if IS_WINDOWS:
|
||||
build_args += ["/m:" + str(multiprocessing.cpu_count())]
|
||||
else:
|
||||
build_args += ["-j", str(multiprocessing.cpu_count())]
|
||||
|
||||
command = ["cmake"] + build_args
|
||||
subprocess.check_call(command, cwd=build_dir, env=os.environ)
|
||||
@ -64,19 +84,47 @@ def main():
|
||||
if not RUN_BUILD_DEPS:
|
||||
build_deps()
|
||||
|
||||
if IS_WINDOWS:
|
||||
# /NODEFAULTLIB makes sure we only link to DLL runtime
|
||||
# and matches the flags set for protobuf and ONNX
|
||||
extra_link_args: list[str] = ["/NODEFAULTLIB:LIBCMT.LIB"] + [
|
||||
*make_relative_rpath_args("lib")
|
||||
]
|
||||
# /MD links against DLL runtime
|
||||
# and matches the flags set for protobuf and ONNX
|
||||
# /EHsc is about standard C++ exception handling
|
||||
extra_compile_args: list[str] = ["/MD", "/FS", "/EHsc"]
|
||||
else:
|
||||
extra_link_args = [*make_relative_rpath_args("lib")]
|
||||
extra_compile_args = [
|
||||
"-Wall",
|
||||
"-Wextra",
|
||||
"-Wno-strict-overflow",
|
||||
"-Wno-unused-parameter",
|
||||
"-Wno-missing-field-initializers",
|
||||
"-Wno-unknown-pragmas",
|
||||
]
|
||||
|
||||
ext_modules = [
|
||||
Extension(
|
||||
name="torch_openreg._C",
|
||||
sources=["torch_openreg/csrc/stub.c"],
|
||||
language="c",
|
||||
extra_compile_args=["-g", "-Wall", "-Werror"],
|
||||
extra_compile_args=extra_compile_args,
|
||||
libraries=["torch_bindings"],
|
||||
library_dirs=[os.path.join(BASE_DIR, "torch_openreg/lib")],
|
||||
extra_link_args=["-Wl,-rpath,$ORIGIN/lib"],
|
||||
extra_link_args=extra_link_args,
|
||||
)
|
||||
]
|
||||
|
||||
package_data = {"torch_openreg": ["lib/*.so*"]}
|
||||
package_data = {
|
||||
"torch_openreg": [
|
||||
"lib/*.so*",
|
||||
"lib/*.dylib*",
|
||||
"lib/*.dll",
|
||||
"lib/*.lib",
|
||||
]
|
||||
}
|
||||
|
||||
setup(
|
||||
packages=find_packages(),
|
||||
|
@ -8,4 +8,8 @@ add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES})
|
||||
|
||||
target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
||||
install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})
|
||||
install(TARGETS ${LIBRARY_NAME}
|
||||
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
)
|
||||
|
@ -1,38 +1,9 @@
|
||||
#include <include/openreg.h>
|
||||
#include "memory.h"
|
||||
|
||||
#include <sys/mman.h>
|
||||
#include <unistd.h>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
|
||||
namespace openreg {
|
||||
namespace internal {
|
||||
|
||||
class ScopedMemoryProtector {
|
||||
public:
|
||||
ScopedMemoryProtector(const orPointerAttributes& info)
|
||||
: m_info(info), m_protected(false) {
|
||||
if (m_info.type == orMemoryType::orMemoryTypeDevice) {
|
||||
if (mprotect(m_info.pointer, m_info.size, PROT_READ | PROT_WRITE) ==
|
||||
0) {
|
||||
m_protected = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
~ScopedMemoryProtector() {
|
||||
if (m_protected) {
|
||||
mprotect(m_info.pointer, m_info.size, PROT_NONE);
|
||||
}
|
||||
}
|
||||
ScopedMemoryProtector(const ScopedMemoryProtector&) = delete;
|
||||
ScopedMemoryProtector& operator=(const ScopedMemoryProtector&) = delete;
|
||||
|
||||
private:
|
||||
orPointerAttributes m_info;
|
||||
bool m_protected;
|
||||
};
|
||||
namespace {
|
||||
|
||||
class MemoryManager {
|
||||
public:
|
||||
@ -46,7 +17,7 @@ class MemoryManager {
|
||||
return orErrorUnknown;
|
||||
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
long page_size = sysconf(_SC_PAGESIZE);
|
||||
long page_size = openreg::get_pagesize();
|
||||
size_t aligned_size = ((size - 1) / page_size + 1) * page_size;
|
||||
void* mem = nullptr;
|
||||
int current_device = -1;
|
||||
@ -54,21 +25,15 @@ class MemoryManager {
|
||||
if (type == orMemoryType::orMemoryTypeDevice) {
|
||||
orGetDevice(¤t_device);
|
||||
|
||||
mem = mmap(
|
||||
nullptr,
|
||||
aligned_size,
|
||||
PROT_READ | PROT_WRITE,
|
||||
MAP_PRIVATE | MAP_ANONYMOUS,
|
||||
-1,
|
||||
0);
|
||||
if (mem == MAP_FAILED)
|
||||
mem = openreg::mmap(aligned_size);
|
||||
if (mem == nullptr)
|
||||
return orErrorUnknown;
|
||||
if (mprotect(mem, aligned_size, PROT_NONE) != 0) {
|
||||
munmap(mem, aligned_size);
|
||||
if (openreg::mprotect(mem, aligned_size, F_PROT_NONE) != 0) {
|
||||
openreg::munmap(mem, aligned_size);
|
||||
return orErrorUnknown;
|
||||
}
|
||||
} else {
|
||||
if (posix_memalign(&mem, page_size, aligned_size) != 0) {
|
||||
if (openreg::alloc(&mem, page_size, aligned_size) != 0) {
|
||||
return orErrorUnknown;
|
||||
}
|
||||
}
|
||||
@ -87,11 +52,12 @@ class MemoryManager {
|
||||
if (it == m_registry.end())
|
||||
return orErrorUnknown;
|
||||
const auto& info = it->second;
|
||||
|
||||
if (info.type == orMemoryType::orMemoryTypeDevice) {
|
||||
mprotect(info.pointer, info.size, PROT_READ | PROT_WRITE);
|
||||
munmap(info.pointer, info.size);
|
||||
openreg::mprotect(info.pointer, info.size, F_PROT_READ | F_PROT_WRITE);
|
||||
openreg::munmap(info.pointer, info.size);
|
||||
} else {
|
||||
::free(info.pointer);
|
||||
openreg::free(info.pointer);
|
||||
}
|
||||
m_registry.erase(it);
|
||||
return orSuccess;
|
||||
@ -167,7 +133,8 @@ class MemoryManager {
|
||||
if (info.type != orMemoryType::orMemoryTypeDevice) {
|
||||
return orErrorUnknown;
|
||||
}
|
||||
if (mprotect(info.pointer, info.size, PROT_READ | PROT_WRITE) != 0) {
|
||||
if (openreg::mprotect(
|
||||
info.pointer, info.size, F_PROT_READ | F_PROT_WRITE) != 0) {
|
||||
return orErrorUnknown;
|
||||
}
|
||||
return orSuccess;
|
||||
@ -179,49 +146,75 @@ class MemoryManager {
|
||||
if (info.type != orMemoryType::orMemoryTypeDevice) {
|
||||
return orErrorUnknown;
|
||||
}
|
||||
if (mprotect(info.pointer, info.size, PROT_NONE) != 0) {
|
||||
if (openreg::mprotect(info.pointer, info.size, F_PROT_NONE) != 0) {
|
||||
return orErrorUnknown;
|
||||
}
|
||||
return orSuccess;
|
||||
}
|
||||
|
||||
private:
|
||||
class ScopedMemoryProtector {
|
||||
public:
|
||||
ScopedMemoryProtector(const orPointerAttributes& info)
|
||||
: m_info(info), m_protected(false) {
|
||||
if (m_info.type == orMemoryType::orMemoryTypeDevice) {
|
||||
if (openreg::mprotect(
|
||||
m_info.pointer, m_info.size, F_PROT_READ | F_PROT_WRITE) == 0) {
|
||||
m_protected = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
~ScopedMemoryProtector() {
|
||||
if (m_protected) {
|
||||
openreg::mprotect(m_info.pointer, m_info.size, F_PROT_NONE);
|
||||
}
|
||||
}
|
||||
ScopedMemoryProtector(const ScopedMemoryProtector&) = delete;
|
||||
ScopedMemoryProtector& operator=(const ScopedMemoryProtector&) = delete;
|
||||
|
||||
private:
|
||||
orPointerAttributes m_info;
|
||||
bool m_protected;
|
||||
};
|
||||
|
||||
MemoryManager() = default;
|
||||
|
||||
orPointerAttributes getPointerInfo(const void* ptr) {
|
||||
auto it = m_registry.upper_bound(const_cast<void*>(ptr));
|
||||
if (it == m_registry.begin())
|
||||
return {};
|
||||
--it;
|
||||
const char* p_char = static_cast<const char*>(ptr);
|
||||
const char* base_char = static_cast<const char*>(it->first);
|
||||
if (p_char >= base_char && p_char < (base_char + it->second.size)) {
|
||||
return it->second;
|
||||
if (it != m_registry.begin()) {
|
||||
--it;
|
||||
const char* p_char = static_cast<const char*>(ptr);
|
||||
const char* base_char = static_cast<const char*>(it->first);
|
||||
if (p_char >= base_char && p_char < (base_char + it->second.size)) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
std::map<void*, orPointerAttributes> m_registry;
|
||||
std::mutex m_mutex;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
} // namespace openreg
|
||||
} // namespace
|
||||
|
||||
orError_t orMalloc(void** devPtr, size_t size) {
|
||||
return openreg::internal::MemoryManager::getInstance().allocate(
|
||||
return MemoryManager::getInstance().allocate(
|
||||
devPtr, size, orMemoryType::orMemoryTypeDevice);
|
||||
}
|
||||
|
||||
orError_t orFree(void* devPtr) {
|
||||
return openreg::internal::MemoryManager::getInstance().free(devPtr);
|
||||
return MemoryManager::getInstance().free(devPtr);
|
||||
}
|
||||
|
||||
orError_t orMallocHost(void** hostPtr, size_t size) {
|
||||
return openreg::internal::MemoryManager::getInstance().allocate(
|
||||
return MemoryManager::getInstance().allocate(
|
||||
hostPtr, size, orMemoryType::orMemoryTypeHost);
|
||||
}
|
||||
|
||||
orError_t orFreeHost(void* hostPtr) {
|
||||
return openreg::internal::MemoryManager::getInstance().free(hostPtr);
|
||||
return MemoryManager::getInstance().free(hostPtr);
|
||||
}
|
||||
|
||||
orError_t orMemcpy(
|
||||
@ -229,21 +222,19 @@ orError_t orMemcpy(
|
||||
const void* src,
|
||||
size_t count,
|
||||
orMemcpyKind kind) {
|
||||
return openreg::internal::MemoryManager::getInstance().memcpy(
|
||||
dst, src, count, kind);
|
||||
return MemoryManager::getInstance().memcpy(dst, src, count, kind);
|
||||
}
|
||||
|
||||
orError_t orPointerGetAttributes(
|
||||
orPointerAttributes* attributes,
|
||||
const void* ptr) {
|
||||
return openreg::internal::MemoryManager::getInstance().getPointerAttributes(
|
||||
attributes, ptr);
|
||||
return MemoryManager::getInstance().getPointerAttributes(attributes, ptr);
|
||||
}
|
||||
|
||||
orError_t orMemoryUnprotect(void* devPtr) {
|
||||
return openreg::internal::MemoryManager::getInstance().unprotect(devPtr);
|
||||
return MemoryManager::getInstance().unprotect(devPtr);
|
||||
}
|
||||
|
||||
orError_t orMemoryProtect(void* devPtr) {
|
||||
return openreg::internal::MemoryManager::getInstance().protect(devPtr);
|
||||
return MemoryManager::getInstance().protect(devPtr);
|
||||
}
|
||||
|
98
test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/memory.h
vendored
Normal file
98
test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/memory.h
vendored
Normal file
@ -0,0 +1,98 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
|
||||
#include <include/openreg.h>
|
||||
|
||||
#if defined(_WIN32)
|
||||
#include <windows.h>
|
||||
#else
|
||||
#include <sys/mman.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#define F_PROT_NONE 0x0
|
||||
#define F_PROT_READ 0x1
|
||||
#define F_PROT_WRITE 0x2
|
||||
|
||||
namespace openreg {
|
||||
|
||||
void* mmap(size_t size) {
|
||||
#if defined(_WIN32)
|
||||
return VirtualAlloc(nullptr, size, MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE);
|
||||
#else
|
||||
void* addr = ::mmap(
|
||||
nullptr,
|
||||
size,
|
||||
PROT_READ | PROT_WRITE,
|
||||
MAP_PRIVATE | MAP_ANONYMOUS,
|
||||
-1,
|
||||
0);
|
||||
return (addr == MAP_FAILED) ? nullptr : addr;
|
||||
#endif
|
||||
}
|
||||
|
||||
void munmap(void* addr, size_t size) {
|
||||
#if defined(_WIN32)
|
||||
VirtualFree(addr, 0, MEM_RELEASE);
|
||||
#else
|
||||
::munmap(addr, size);
|
||||
#endif
|
||||
}
|
||||
|
||||
int mprotect(void* addr, size_t size, int prot) {
|
||||
#if defined(_WIN32)
|
||||
DWORD win_prot = 0;
|
||||
DWORD old;
|
||||
if (prot == F_PROT_NONE) {
|
||||
win_prot = PAGE_NOACCESS;
|
||||
} else {
|
||||
win_prot = PAGE_READWRITE;
|
||||
}
|
||||
|
||||
return VirtualProtect(addr, size, win_prot, &old) ? 0 : -1;
|
||||
#else
|
||||
int native_prot = 0;
|
||||
if (prot == F_PROT_NONE)
|
||||
native_prot = PROT_NONE;
|
||||
else {
|
||||
if (prot & F_PROT_READ)
|
||||
native_prot |= PROT_READ;
|
||||
if (prot & F_PROT_WRITE)
|
||||
native_prot |= PROT_WRITE;
|
||||
}
|
||||
|
||||
return ::mprotect(addr, size, native_prot);
|
||||
#endif
|
||||
}
|
||||
|
||||
int alloc(void** mem, size_t alignment, size_t size) {
|
||||
#ifdef _WIN32
|
||||
*mem = _aligned_malloc(size, alignment);
|
||||
return *mem ? 0 : -1;
|
||||
#else
|
||||
return posix_memalign(mem, alignment, size);
|
||||
#endif
|
||||
}
|
||||
|
||||
void free(void* mem) {
|
||||
#ifdef _WIN32
|
||||
_aligned_free(mem);
|
||||
#else
|
||||
::free(mem);
|
||||
#endif
|
||||
}
|
||||
|
||||
long get_pagesize() {
|
||||
#ifdef _WIN32
|
||||
SYSTEM_INFO si;
|
||||
GetSystemInfo(&si);
|
||||
return static_cast<long>(si.dwPageSize);
|
||||
#else
|
||||
return sysconf(_SC_PAGESIZE);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace openreg
|
@ -2,6 +2,12 @@
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
#ifdef _WIN32
|
||||
#define OPENREG_EXPORT __declspec(dllexport)
|
||||
#else
|
||||
#define OPENREG_EXPORT __attribute__((visibility("default")))
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
@ -28,19 +34,19 @@ struct orPointerAttributes {
|
||||
size_t size;
|
||||
};
|
||||
|
||||
orError_t orMalloc(void** devPtr, size_t size);
|
||||
orError_t orFree(void* devPtr);
|
||||
orError_t orMallocHost(void** hostPtr, size_t size);
|
||||
orError_t orFreeHost(void* hostPtr);
|
||||
orError_t orMemcpy(void* dst, const void* src, size_t count, orMemcpyKind kind);
|
||||
orError_t orMemoryUnprotect(void* devPtr);
|
||||
orError_t orMemoryProtect(void* devPtr);
|
||||
OPENREG_EXPORT orError_t orMalloc(void** devPtr, size_t size);
|
||||
OPENREG_EXPORT orError_t orFree(void* devPtr);
|
||||
OPENREG_EXPORT orError_t orMallocHost(void** hostPtr, size_t size);
|
||||
OPENREG_EXPORT orError_t orFreeHost(void* hostPtr);
|
||||
OPENREG_EXPORT orError_t orMemcpy(void* dst, const void* src, size_t count, orMemcpyKind kind);
|
||||
OPENREG_EXPORT orError_t orMemoryUnprotect(void* devPtr);
|
||||
OPENREG_EXPORT orError_t orMemoryProtect(void* devPtr);
|
||||
|
||||
orError_t orGetDeviceCount(int* count);
|
||||
orError_t orSetDevice(int device);
|
||||
orError_t orGetDevice(int* device);
|
||||
OPENREG_EXPORT orError_t orGetDeviceCount(int* count);
|
||||
OPENREG_EXPORT orError_t orSetDevice(int device);
|
||||
OPENREG_EXPORT orError_t orGetDevice(int* device);
|
||||
|
||||
orError_t orPointerGetAttributes(
|
||||
OPENREG_EXPORT orError_t orPointerGetAttributes(
|
||||
orPointerAttributes* attributes,
|
||||
const void* ptr);
|
||||
|
||||
|
@ -1,5 +1,15 @@
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
if sys.platform == "win32":
|
||||
from ._utils import _load_dll_libraries
|
||||
|
||||
_load_dll_libraries()
|
||||
del _load_dll_libraries
|
||||
|
||||
|
||||
import torch_openreg._C # type: ignore[misc]
|
||||
import torch_openreg.openreg
|
||||
|
||||
|
@ -0,0 +1,42 @@
|
||||
import ctypes
|
||||
import glob
|
||||
import os
|
||||
|
||||
|
||||
def _load_dll_libraries() -> None:
|
||||
openreg_dll_path = os.path.join(os.path.dirname(__file__), "lib")
|
||||
|
||||
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
|
||||
with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
|
||||
prev_error_mode = kernel32.SetErrorMode(0x0001)
|
||||
|
||||
kernel32.LoadLibraryW.restype = ctypes.c_void_p
|
||||
if with_load_library_flags:
|
||||
kernel32.LoadLibraryExW.restype = ctypes.c_void_p
|
||||
|
||||
os.add_dll_directory(openreg_dll_path)
|
||||
|
||||
dlls = glob.glob(os.path.join(openreg_dll_path, "*.dll"))
|
||||
path_patched = False
|
||||
for dll in dlls:
|
||||
is_loaded = False
|
||||
if with_load_library_flags:
|
||||
res = kernel32.LoadLibraryExW(dll, None, 0x00001100)
|
||||
last_error = ctypes.get_last_error()
|
||||
if res is None and last_error != 126:
|
||||
err = ctypes.WinError(last_error)
|
||||
err.strerror += f' Error loading "{dll}" or one of its dependencies.'
|
||||
raise err
|
||||
elif res is not None:
|
||||
is_loaded = True
|
||||
if not is_loaded:
|
||||
if not path_patched:
|
||||
os.environ["PATH"] = ";".join([openreg_dll_path] + [os.environ["PATH"]])
|
||||
path_patched = True
|
||||
res = kernel32.LoadLibraryW(dll)
|
||||
if res is None:
|
||||
err = ctypes.WinError(ctypes.get_last_error())
|
||||
err.strerror += f' Error loading "{dll}" or one of its dependencies.'
|
||||
raise err
|
||||
|
||||
kernel32.SetErrorMode(prev_error_mode)
|
@ -6,7 +6,19 @@ file(GLOB_RECURSE SOURCE_FILES
|
||||
|
||||
add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES})
|
||||
|
||||
target_link_libraries(${LIBRARY_NAME} PRIVATE torch_python torch_openreg)
|
||||
target_link_libraries(${LIBRARY_NAME} PRIVATE torch_python_library torch_openreg)
|
||||
|
||||
if(WIN32)
|
||||
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
|
||||
target_link_libraries(${LIBRARY_NAME} PRIVATE ${Python3_LIBRARIES})
|
||||
elseif(APPLE)
|
||||
set_target_properties(${LIBRARY_NAME} PROPERTIES LINK_FLAGS "-undefined dynamic_lookup")
|
||||
endif()
|
||||
|
||||
target_link_directories(${LIBRARY_NAME} PRIVATE ${PYTORCH_INSTALL_DIR}/lib)
|
||||
|
||||
install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})
|
||||
install(TARGETS ${LIBRARY_NAME}
|
||||
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
)
|
||||
|
@ -90,7 +90,7 @@ static PyMethodDef methods[] = {
|
||||
* Therefore, it cannot be named initModule here, otherwise initModule
|
||||
* in torch/csrc/Module.cpp will be called, resulting in failure.
|
||||
*/
|
||||
extern "C" PyObject* initOpenRegModule(void) {
|
||||
extern "C" OPENREG_EXPORT PyObject* initOpenRegModule(void) {
|
||||
static struct PyModuleDef openreg_C_module = {
|
||||
PyModuleDef_HEAD_INIT, "torch_openreg._C", nullptr, -1, methods};
|
||||
PyObject* mod = PyModule_Create(&openreg_C_module);
|
||||
|
@ -1,13 +1,18 @@
|
||||
#include <Python.h>
|
||||
|
||||
extern PyObject* initOpenRegModule(void);
|
||||
#ifdef _WIN32
|
||||
#define OPENREG_EXPORT __declspec(dllexport)
|
||||
#else
|
||||
#define OPENREG_EXPORT __attribute__((visibility("default")))
|
||||
#endif
|
||||
|
||||
extern OPENREG_EXPORT PyObject* initOpenRegModule(void);
|
||||
|
||||
#ifndef _WIN32
|
||||
#ifdef __cplusplus
|
||||
extern "C"
|
||||
#endif
|
||||
__attribute__((visibility("default"))) PyObject* PyInit__C(void);
|
||||
#endif
|
||||
|
||||
OPENREG_EXPORT PyObject* PyInit__C(void);
|
||||
|
||||
PyMODINIT_FUNC PyInit__C(void)
|
||||
{
|
||||
|
@ -28,7 +28,6 @@ from torch.multiprocessing import current_process, get_context
|
||||
from torch.testing._internal.common_utils import (
|
||||
get_report_path,
|
||||
IS_CI,
|
||||
IS_LINUX,
|
||||
IS_MACOS,
|
||||
retry_shell,
|
||||
set_cwd,
|
||||
@ -909,10 +908,6 @@ def _test_autoload(test_directory, options, enable=True):
|
||||
|
||||
|
||||
def run_test_with_openreg(test_module, test_directory, options):
|
||||
# TODO(FFFrog): Will remove this later when windows/macos are supported.
|
||||
if not IS_LINUX:
|
||||
return 0
|
||||
|
||||
openreg_dir = os.path.join(
|
||||
test_directory, "cpp_extensions", "open_registration_extension", "torch_openreg"
|
||||
)
|
||||
|
@ -16,7 +16,9 @@ import torch
|
||||
from torch.serialization import safe_globals
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skipIfMPS,
|
||||
skipIfTorchDynamo,
|
||||
skipIfWindows,
|
||||
skipIfXpu,
|
||||
TemporaryFileName,
|
||||
TestCase,
|
||||
@ -284,6 +286,8 @@ class TestOpenReg(TestCase):
|
||||
self.assertEqual(torch.openreg.initial_seed(), 2024) # type: ignore[misc]
|
||||
|
||||
# Autograd
|
||||
@skipIfMPS
|
||||
@skipIfWindows()
|
||||
def test_autograd_init(self):
|
||||
# Make sure autograd is initialized
|
||||
torch.ones(2, requires_grad=True, device="openreg").sum().backward()
|
||||
|
Reference in New Issue
Block a user