mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[OpenReg] Add OSX/Windows Support for OpenReg (#159441)
As the title stated. **Changes:** - Abstract platform-specific APIs - Add OSX/Windows support - Set default symbol visibility to "hidden" Co-authored-by: @can-gaa-hou Original PR:https://github.com/pytorch/pytorch/pull/159029 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159441 Approved by: https://github.com/albanD Co-authored-by: jiahaochen666 <jiahaochen535@gmail.com>
This commit is contained in:
@ -4,28 +4,29 @@ project(TORCH_OPENREG CXX C)
|
|||||||
|
|
||||||
include(GNUInstallDirs)
|
include(GNUInstallDirs)
|
||||||
include(CheckCXXCompilerFlag)
|
include(CheckCXXCompilerFlag)
|
||||||
include(CMakeDependentOption)
|
|
||||||
|
|
||||||
set(CMAKE_SKIP_BUILD_RPATH FALSE)
|
|
||||||
set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)
|
|
||||||
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE)
|
|
||||||
set(CMAKE_INSTALL_RPATH "$ORIGIN/lib/:$ORIGIN/")
|
|
||||||
|
|
||||||
set(LINUX TRUE)
|
|
||||||
set(CMAKE_INSTALL_MESSAGE NEVER)
|
|
||||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
|
||||||
|
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
set(CMAKE_C_STANDARD 11)
|
set(CMAKE_C_STANDARD 11)
|
||||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||||
|
|
||||||
set(CMAKE_INSTALL_LIBDIR lib)
|
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||||
|
set(CMAKE_SKIP_BUILD_RPATH FALSE)
|
||||||
|
set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)
|
||||||
|
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE)
|
||||||
|
set(CMAKE_CXX_VISIBILITY_PRESET hidden)
|
||||||
|
|
||||||
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=1)
|
if(APPLE)
|
||||||
|
set(CMAKE_INSTALL_RPATH "@loader_path/lib;@loader_path")
|
||||||
|
elseif(UNIX)
|
||||||
|
set(CMAKE_INSTALL_RPATH "$ORIGIN/lib:$ORIGIN")
|
||||||
|
elseif(WIN32)
|
||||||
|
set(CMAKE_INSTALL_RPATH "")
|
||||||
|
endif()
|
||||||
|
set(CMAKE_INSTALL_LIBDIR lib)
|
||||||
|
set(CMAKE_INSTALL_MESSAGE NEVER)
|
||||||
|
|
||||||
set(Torch_DIR ${PYTORCH_INSTALL_DIR}/share/cmake/Torch)
|
set(Torch_DIR ${PYTORCH_INSTALL_DIR}/share/cmake/Torch)
|
||||||
find_package(Torch REQUIRED)
|
find_package(Torch REQUIRED)
|
||||||
include_directories(${PYTORCH_INSTALL_DIR}/include)
|
|
||||||
|
|
||||||
if(DEFINED PYTHON_INCLUDE_DIR)
|
if(DEFINED PYTHON_INCLUDE_DIR)
|
||||||
include_directories(${PYTHON_INCLUDE_DIR})
|
include_directories(${PYTHON_INCLUDE_DIR})
|
||||||
@ -33,6 +34,8 @@ else()
|
|||||||
message(FATAL_ERROR "Cannot find Python directory")
|
message(FATAL_ERROR "Cannot find Python directory")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
include(${PROJECT_SOURCE_DIR}/cmake/TorchPythonTargets.cmake)
|
||||||
|
|
||||||
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/openreg)
|
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/openreg)
|
||||||
add_subdirectory(${PROJECT_SOURCE_DIR}/csrc)
|
add_subdirectory(${PROJECT_SOURCE_DIR}/csrc)
|
||||||
add_subdirectory(${PROJECT_SOURCE_DIR}/torch_openreg/csrc)
|
add_subdirectory(${PROJECT_SOURCE_DIR}/torch_openreg/csrc)
|
||||||
|
@ -0,0 +1,22 @@
|
|||||||
|
if(WIN32)
|
||||||
|
set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/torch_python.lib")
|
||||||
|
elseif(APPLE)
|
||||||
|
set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/libtorch_python.dylib")
|
||||||
|
else()
|
||||||
|
set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/libtorch_python.so")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
add_library(torch_python SHARED IMPORTED)
|
||||||
|
|
||||||
|
set_target_properties(torch_python PROPERTIES
|
||||||
|
INTERFACE_INCLUDE_DIRECTORIES "${PYTORCH_INSTALL_DIR}/include"
|
||||||
|
INTERFACE_LINK_LIBRARIES "c10;torch_cpu"
|
||||||
|
IMPORTED_LOCATION "${TORCH_PYTHON_IMPORTED_LOCATION}"
|
||||||
|
)
|
||||||
|
|
||||||
|
add_library(torch_python_library INTERFACE IMPORTED)
|
||||||
|
|
||||||
|
set_target_properties(torch_python_library PROPERTIES
|
||||||
|
INTERFACE_INCLUDE_DIRECTORIES "\$<TARGET_PROPERTY:torch_python,INTERFACE_INCLUDE_DIRECTORIES>"
|
||||||
|
INTERFACE_LINK_LIBRARIES "\$<TARGET_FILE:torch_python>;\$<TARGET_PROPERTY:torch_python,INTERFACE_LINK_LIBRARIES>"
|
||||||
|
)
|
@ -6,7 +6,11 @@ file(GLOB_RECURSE SOURCE_FILES
|
|||||||
|
|
||||||
add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES})
|
add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES})
|
||||||
|
|
||||||
target_link_libraries(${LIBRARY_NAME} PRIVATE openreg torch_cpu)
|
target_link_libraries(${LIBRARY_NAME} PRIVATE torch_cpu_library openreg)
|
||||||
target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
|
target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
|
||||||
|
|
||||||
install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})
|
install(TARGETS ${LIBRARY_NAME}
|
||||||
|
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||||
|
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||||
|
RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||||
|
)
|
||||||
|
@ -30,7 +30,7 @@ int device_count_impl() {
|
|||||||
return count;
|
return count;
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::DeviceIndex device_count() noexcept {
|
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept {
|
||||||
// initialize number of devices only once
|
// initialize number of devices only once
|
||||||
static int count = []() {
|
static int count = []() {
|
||||||
try {
|
try {
|
||||||
@ -49,17 +49,17 @@ c10::DeviceIndex device_count() noexcept {
|
|||||||
return static_cast<c10::DeviceIndex>(count);
|
return static_cast<c10::DeviceIndex>(count);
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::DeviceIndex current_device() {
|
OPENREG_EXPORT c10::DeviceIndex current_device() {
|
||||||
c10::DeviceIndex cur_device = -1;
|
c10::DeviceIndex cur_device = -1;
|
||||||
GetDevice(&cur_device);
|
GetDevice(&cur_device);
|
||||||
return cur_device;
|
return cur_device;
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_device(c10::DeviceIndex device) {
|
OPENREG_EXPORT void set_device(c10::DeviceIndex device) {
|
||||||
SetDevice(device);
|
SetDevice(device);
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceIndex ExchangeDevice(DeviceIndex device) {
|
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
|
||||||
int current_device = -1;
|
int current_device = -1;
|
||||||
orGetDevice(¤t_device);
|
orGetDevice(¤t_device);
|
||||||
|
|
||||||
|
@ -1,5 +1,11 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#define OPENREG_EXPORT __declspec(dllexport)
|
||||||
|
#else
|
||||||
|
#define OPENREG_EXPORT __attribute__((visibility("default")))
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <c10/core/Device.h>
|
#include <c10/core/Device.h>
|
||||||
#include <c10/macros/Macros.h>
|
#include <c10/macros/Macros.h>
|
||||||
|
|
||||||
@ -7,10 +13,10 @@
|
|||||||
|
|
||||||
namespace c10::openreg {
|
namespace c10::openreg {
|
||||||
|
|
||||||
c10::DeviceIndex device_count() noexcept;
|
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept;
|
||||||
DeviceIndex current_device();
|
OPENREG_EXPORT c10::DeviceIndex current_device();
|
||||||
void set_device(c10::DeviceIndex device);
|
OPENREG_EXPORT void set_device(c10::DeviceIndex device);
|
||||||
|
|
||||||
DeviceIndex ExchangeDevice(DeviceIndex device);
|
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device);
|
||||||
|
|
||||||
} // namespace c10::openreg
|
} // namespace c10::openreg
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
|
import platform
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
@ -9,10 +10,23 @@ from distutils.command.clean import clean
|
|||||||
from setuptools import Extension, find_packages, setup
|
from setuptools import Extension, find_packages, setup
|
||||||
|
|
||||||
|
|
||||||
|
# Env Variables
|
||||||
|
IS_DARWIN = platform.system() == "Darwin"
|
||||||
|
IS_WINDOWS = platform.system() == "Windows"
|
||||||
|
|
||||||
BASE_DIR = os.path.dirname(os.path.realpath(__file__))
|
BASE_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||||
RUN_BUILD_DEPS = any(arg in {"clean", "dist_info"} for arg in sys.argv)
|
RUN_BUILD_DEPS = any(arg in {"clean", "dist_info"} for arg in sys.argv)
|
||||||
|
|
||||||
|
|
||||||
|
def make_relative_rpath_args(path):
|
||||||
|
if IS_DARWIN:
|
||||||
|
return ["-Wl,-rpath,@loader_path/" + path]
|
||||||
|
elif IS_WINDOWS:
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
return ["-Wl,-rpath,$ORIGIN/" + path]
|
||||||
|
|
||||||
|
|
||||||
def get_pytorch_dir():
|
def get_pytorch_dir():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -39,9 +53,15 @@ def build_deps():
|
|||||||
".",
|
".",
|
||||||
"--target",
|
"--target",
|
||||||
"install",
|
"install",
|
||||||
|
"--config",
|
||||||
|
"Release",
|
||||||
"--",
|
"--",
|
||||||
]
|
]
|
||||||
build_args += ["-j", str(multiprocessing.cpu_count())]
|
|
||||||
|
if IS_WINDOWS:
|
||||||
|
build_args += ["/m:" + str(multiprocessing.cpu_count())]
|
||||||
|
else:
|
||||||
|
build_args += ["-j", str(multiprocessing.cpu_count())]
|
||||||
|
|
||||||
command = ["cmake"] + build_args
|
command = ["cmake"] + build_args
|
||||||
subprocess.check_call(command, cwd=build_dir, env=os.environ)
|
subprocess.check_call(command, cwd=build_dir, env=os.environ)
|
||||||
@ -64,19 +84,47 @@ def main():
|
|||||||
if not RUN_BUILD_DEPS:
|
if not RUN_BUILD_DEPS:
|
||||||
build_deps()
|
build_deps()
|
||||||
|
|
||||||
|
if IS_WINDOWS:
|
||||||
|
# /NODEFAULTLIB makes sure we only link to DLL runtime
|
||||||
|
# and matches the flags set for protobuf and ONNX
|
||||||
|
extra_link_args: list[str] = ["/NODEFAULTLIB:LIBCMT.LIB"] + [
|
||||||
|
*make_relative_rpath_args("lib")
|
||||||
|
]
|
||||||
|
# /MD links against DLL runtime
|
||||||
|
# and matches the flags set for protobuf and ONNX
|
||||||
|
# /EHsc is about standard C++ exception handling
|
||||||
|
extra_compile_args: list[str] = ["/MD", "/FS", "/EHsc"]
|
||||||
|
else:
|
||||||
|
extra_link_args = [*make_relative_rpath_args("lib")]
|
||||||
|
extra_compile_args = [
|
||||||
|
"-Wall",
|
||||||
|
"-Wextra",
|
||||||
|
"-Wno-strict-overflow",
|
||||||
|
"-Wno-unused-parameter",
|
||||||
|
"-Wno-missing-field-initializers",
|
||||||
|
"-Wno-unknown-pragmas",
|
||||||
|
]
|
||||||
|
|
||||||
ext_modules = [
|
ext_modules = [
|
||||||
Extension(
|
Extension(
|
||||||
name="torch_openreg._C",
|
name="torch_openreg._C",
|
||||||
sources=["torch_openreg/csrc/stub.c"],
|
sources=["torch_openreg/csrc/stub.c"],
|
||||||
language="c",
|
language="c",
|
||||||
extra_compile_args=["-g", "-Wall", "-Werror"],
|
extra_compile_args=extra_compile_args,
|
||||||
libraries=["torch_bindings"],
|
libraries=["torch_bindings"],
|
||||||
library_dirs=[os.path.join(BASE_DIR, "torch_openreg/lib")],
|
library_dirs=[os.path.join(BASE_DIR, "torch_openreg/lib")],
|
||||||
extra_link_args=["-Wl,-rpath,$ORIGIN/lib"],
|
extra_link_args=extra_link_args,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
package_data = {"torch_openreg": ["lib/*.so*"]}
|
package_data = {
|
||||||
|
"torch_openreg": [
|
||||||
|
"lib/*.so*",
|
||||||
|
"lib/*.dylib*",
|
||||||
|
"lib/*.dll",
|
||||||
|
"lib/*.lib",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
|
@ -8,4 +8,8 @@ add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES})
|
|||||||
|
|
||||||
target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
|
target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
|
||||||
|
|
||||||
install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})
|
install(TARGETS ${LIBRARY_NAME}
|
||||||
|
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||||
|
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||||
|
RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||||
|
)
|
||||||
|
@ -1,38 +1,9 @@
|
|||||||
#include <include/openreg.h>
|
#include "memory.h"
|
||||||
|
|
||||||
#include <sys/mman.h>
|
|
||||||
#include <unistd.h>
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <cstring>
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
|
||||||
namespace openreg {
|
namespace {
|
||||||
namespace internal {
|
|
||||||
|
|
||||||
class ScopedMemoryProtector {
|
|
||||||
public:
|
|
||||||
ScopedMemoryProtector(const orPointerAttributes& info)
|
|
||||||
: m_info(info), m_protected(false) {
|
|
||||||
if (m_info.type == orMemoryType::orMemoryTypeDevice) {
|
|
||||||
if (mprotect(m_info.pointer, m_info.size, PROT_READ | PROT_WRITE) ==
|
|
||||||
0) {
|
|
||||||
m_protected = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
~ScopedMemoryProtector() {
|
|
||||||
if (m_protected) {
|
|
||||||
mprotect(m_info.pointer, m_info.size, PROT_NONE);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ScopedMemoryProtector(const ScopedMemoryProtector&) = delete;
|
|
||||||
ScopedMemoryProtector& operator=(const ScopedMemoryProtector&) = delete;
|
|
||||||
|
|
||||||
private:
|
|
||||||
orPointerAttributes m_info;
|
|
||||||
bool m_protected;
|
|
||||||
};
|
|
||||||
|
|
||||||
class MemoryManager {
|
class MemoryManager {
|
||||||
public:
|
public:
|
||||||
@ -46,7 +17,7 @@ class MemoryManager {
|
|||||||
return orErrorUnknown;
|
return orErrorUnknown;
|
||||||
|
|
||||||
std::lock_guard<std::mutex> lock(m_mutex);
|
std::lock_guard<std::mutex> lock(m_mutex);
|
||||||
long page_size = sysconf(_SC_PAGESIZE);
|
long page_size = openreg::get_pagesize();
|
||||||
size_t aligned_size = ((size - 1) / page_size + 1) * page_size;
|
size_t aligned_size = ((size - 1) / page_size + 1) * page_size;
|
||||||
void* mem = nullptr;
|
void* mem = nullptr;
|
||||||
int current_device = -1;
|
int current_device = -1;
|
||||||
@ -54,21 +25,15 @@ class MemoryManager {
|
|||||||
if (type == orMemoryType::orMemoryTypeDevice) {
|
if (type == orMemoryType::orMemoryTypeDevice) {
|
||||||
orGetDevice(¤t_device);
|
orGetDevice(¤t_device);
|
||||||
|
|
||||||
mem = mmap(
|
mem = openreg::mmap(aligned_size);
|
||||||
nullptr,
|
if (mem == nullptr)
|
||||||
aligned_size,
|
|
||||||
PROT_READ | PROT_WRITE,
|
|
||||||
MAP_PRIVATE | MAP_ANONYMOUS,
|
|
||||||
-1,
|
|
||||||
0);
|
|
||||||
if (mem == MAP_FAILED)
|
|
||||||
return orErrorUnknown;
|
return orErrorUnknown;
|
||||||
if (mprotect(mem, aligned_size, PROT_NONE) != 0) {
|
if (openreg::mprotect(mem, aligned_size, F_PROT_NONE) != 0) {
|
||||||
munmap(mem, aligned_size);
|
openreg::munmap(mem, aligned_size);
|
||||||
return orErrorUnknown;
|
return orErrorUnknown;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (posix_memalign(&mem, page_size, aligned_size) != 0) {
|
if (openreg::alloc(&mem, page_size, aligned_size) != 0) {
|
||||||
return orErrorUnknown;
|
return orErrorUnknown;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -87,11 +52,12 @@ class MemoryManager {
|
|||||||
if (it == m_registry.end())
|
if (it == m_registry.end())
|
||||||
return orErrorUnknown;
|
return orErrorUnknown;
|
||||||
const auto& info = it->second;
|
const auto& info = it->second;
|
||||||
|
|
||||||
if (info.type == orMemoryType::orMemoryTypeDevice) {
|
if (info.type == orMemoryType::orMemoryTypeDevice) {
|
||||||
mprotect(info.pointer, info.size, PROT_READ | PROT_WRITE);
|
openreg::mprotect(info.pointer, info.size, F_PROT_READ | F_PROT_WRITE);
|
||||||
munmap(info.pointer, info.size);
|
openreg::munmap(info.pointer, info.size);
|
||||||
} else {
|
} else {
|
||||||
::free(info.pointer);
|
openreg::free(info.pointer);
|
||||||
}
|
}
|
||||||
m_registry.erase(it);
|
m_registry.erase(it);
|
||||||
return orSuccess;
|
return orSuccess;
|
||||||
@ -167,7 +133,8 @@ class MemoryManager {
|
|||||||
if (info.type != orMemoryType::orMemoryTypeDevice) {
|
if (info.type != orMemoryType::orMemoryTypeDevice) {
|
||||||
return orErrorUnknown;
|
return orErrorUnknown;
|
||||||
}
|
}
|
||||||
if (mprotect(info.pointer, info.size, PROT_READ | PROT_WRITE) != 0) {
|
if (openreg::mprotect(
|
||||||
|
info.pointer, info.size, F_PROT_READ | F_PROT_WRITE) != 0) {
|
||||||
return orErrorUnknown;
|
return orErrorUnknown;
|
||||||
}
|
}
|
||||||
return orSuccess;
|
return orSuccess;
|
||||||
@ -179,49 +146,75 @@ class MemoryManager {
|
|||||||
if (info.type != orMemoryType::orMemoryTypeDevice) {
|
if (info.type != orMemoryType::orMemoryTypeDevice) {
|
||||||
return orErrorUnknown;
|
return orErrorUnknown;
|
||||||
}
|
}
|
||||||
if (mprotect(info.pointer, info.size, PROT_NONE) != 0) {
|
if (openreg::mprotect(info.pointer, info.size, F_PROT_NONE) != 0) {
|
||||||
return orErrorUnknown;
|
return orErrorUnknown;
|
||||||
}
|
}
|
||||||
return orSuccess;
|
return orSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
class ScopedMemoryProtector {
|
||||||
|
public:
|
||||||
|
ScopedMemoryProtector(const orPointerAttributes& info)
|
||||||
|
: m_info(info), m_protected(false) {
|
||||||
|
if (m_info.type == orMemoryType::orMemoryTypeDevice) {
|
||||||
|
if (openreg::mprotect(
|
||||||
|
m_info.pointer, m_info.size, F_PROT_READ | F_PROT_WRITE) == 0) {
|
||||||
|
m_protected = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
~ScopedMemoryProtector() {
|
||||||
|
if (m_protected) {
|
||||||
|
openreg::mprotect(m_info.pointer, m_info.size, F_PROT_NONE);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ScopedMemoryProtector(const ScopedMemoryProtector&) = delete;
|
||||||
|
ScopedMemoryProtector& operator=(const ScopedMemoryProtector&) = delete;
|
||||||
|
|
||||||
|
private:
|
||||||
|
orPointerAttributes m_info;
|
||||||
|
bool m_protected;
|
||||||
|
};
|
||||||
|
|
||||||
MemoryManager() = default;
|
MemoryManager() = default;
|
||||||
|
|
||||||
orPointerAttributes getPointerInfo(const void* ptr) {
|
orPointerAttributes getPointerInfo(const void* ptr) {
|
||||||
auto it = m_registry.upper_bound(const_cast<void*>(ptr));
|
auto it = m_registry.upper_bound(const_cast<void*>(ptr));
|
||||||
if (it == m_registry.begin())
|
if (it != m_registry.begin()) {
|
||||||
return {};
|
--it;
|
||||||
--it;
|
const char* p_char = static_cast<const char*>(ptr);
|
||||||
const char* p_char = static_cast<const char*>(ptr);
|
const char* base_char = static_cast<const char*>(it->first);
|
||||||
const char* base_char = static_cast<const char*>(it->first);
|
if (p_char >= base_char && p_char < (base_char + it->second.size)) {
|
||||||
if (p_char >= base_char && p_char < (base_char + it->second.size)) {
|
return it->second;
|
||||||
return it->second;
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::map<void*, orPointerAttributes> m_registry;
|
std::map<void*, orPointerAttributes> m_registry;
|
||||||
std::mutex m_mutex;
|
std::mutex m_mutex;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace internal
|
} // namespace
|
||||||
} // namespace openreg
|
|
||||||
|
|
||||||
orError_t orMalloc(void** devPtr, size_t size) {
|
orError_t orMalloc(void** devPtr, size_t size) {
|
||||||
return openreg::internal::MemoryManager::getInstance().allocate(
|
return MemoryManager::getInstance().allocate(
|
||||||
devPtr, size, orMemoryType::orMemoryTypeDevice);
|
devPtr, size, orMemoryType::orMemoryTypeDevice);
|
||||||
}
|
}
|
||||||
|
|
||||||
orError_t orFree(void* devPtr) {
|
orError_t orFree(void* devPtr) {
|
||||||
return openreg::internal::MemoryManager::getInstance().free(devPtr);
|
return MemoryManager::getInstance().free(devPtr);
|
||||||
}
|
}
|
||||||
|
|
||||||
orError_t orMallocHost(void** hostPtr, size_t size) {
|
orError_t orMallocHost(void** hostPtr, size_t size) {
|
||||||
return openreg::internal::MemoryManager::getInstance().allocate(
|
return MemoryManager::getInstance().allocate(
|
||||||
hostPtr, size, orMemoryType::orMemoryTypeHost);
|
hostPtr, size, orMemoryType::orMemoryTypeHost);
|
||||||
}
|
}
|
||||||
|
|
||||||
orError_t orFreeHost(void* hostPtr) {
|
orError_t orFreeHost(void* hostPtr) {
|
||||||
return openreg::internal::MemoryManager::getInstance().free(hostPtr);
|
return MemoryManager::getInstance().free(hostPtr);
|
||||||
}
|
}
|
||||||
|
|
||||||
orError_t orMemcpy(
|
orError_t orMemcpy(
|
||||||
@ -229,21 +222,19 @@ orError_t orMemcpy(
|
|||||||
const void* src,
|
const void* src,
|
||||||
size_t count,
|
size_t count,
|
||||||
orMemcpyKind kind) {
|
orMemcpyKind kind) {
|
||||||
return openreg::internal::MemoryManager::getInstance().memcpy(
|
return MemoryManager::getInstance().memcpy(dst, src, count, kind);
|
||||||
dst, src, count, kind);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
orError_t orPointerGetAttributes(
|
orError_t orPointerGetAttributes(
|
||||||
orPointerAttributes* attributes,
|
orPointerAttributes* attributes,
|
||||||
const void* ptr) {
|
const void* ptr) {
|
||||||
return openreg::internal::MemoryManager::getInstance().getPointerAttributes(
|
return MemoryManager::getInstance().getPointerAttributes(attributes, ptr);
|
||||||
attributes, ptr);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
orError_t orMemoryUnprotect(void* devPtr) {
|
orError_t orMemoryUnprotect(void* devPtr) {
|
||||||
return openreg::internal::MemoryManager::getInstance().unprotect(devPtr);
|
return MemoryManager::getInstance().unprotect(devPtr);
|
||||||
}
|
}
|
||||||
|
|
||||||
orError_t orMemoryProtect(void* devPtr) {
|
orError_t orMemoryProtect(void* devPtr) {
|
||||||
return openreg::internal::MemoryManager::getInstance().protect(devPtr);
|
return MemoryManager::getInstance().protect(devPtr);
|
||||||
}
|
}
|
||||||
|
98
test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/memory.h
vendored
Normal file
98
test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/memory.h
vendored
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
|
#include <include/openreg.h>
|
||||||
|
|
||||||
|
#if defined(_WIN32)
|
||||||
|
#include <windows.h>
|
||||||
|
#else
|
||||||
|
#include <sys/mman.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define F_PROT_NONE 0x0
|
||||||
|
#define F_PROT_READ 0x1
|
||||||
|
#define F_PROT_WRITE 0x2
|
||||||
|
|
||||||
|
namespace openreg {
|
||||||
|
|
||||||
|
void* mmap(size_t size) {
|
||||||
|
#if defined(_WIN32)
|
||||||
|
return VirtualAlloc(nullptr, size, MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE);
|
||||||
|
#else
|
||||||
|
void* addr = ::mmap(
|
||||||
|
nullptr,
|
||||||
|
size,
|
||||||
|
PROT_READ | PROT_WRITE,
|
||||||
|
MAP_PRIVATE | MAP_ANONYMOUS,
|
||||||
|
-1,
|
||||||
|
0);
|
||||||
|
return (addr == MAP_FAILED) ? nullptr : addr;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void munmap(void* addr, size_t size) {
|
||||||
|
#if defined(_WIN32)
|
||||||
|
VirtualFree(addr, 0, MEM_RELEASE);
|
||||||
|
#else
|
||||||
|
::munmap(addr, size);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
int mprotect(void* addr, size_t size, int prot) {
|
||||||
|
#if defined(_WIN32)
|
||||||
|
DWORD win_prot = 0;
|
||||||
|
DWORD old;
|
||||||
|
if (prot == F_PROT_NONE) {
|
||||||
|
win_prot = PAGE_NOACCESS;
|
||||||
|
} else {
|
||||||
|
win_prot = PAGE_READWRITE;
|
||||||
|
}
|
||||||
|
|
||||||
|
return VirtualProtect(addr, size, win_prot, &old) ? 0 : -1;
|
||||||
|
#else
|
||||||
|
int native_prot = 0;
|
||||||
|
if (prot == F_PROT_NONE)
|
||||||
|
native_prot = PROT_NONE;
|
||||||
|
else {
|
||||||
|
if (prot & F_PROT_READ)
|
||||||
|
native_prot |= PROT_READ;
|
||||||
|
if (prot & F_PROT_WRITE)
|
||||||
|
native_prot |= PROT_WRITE;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ::mprotect(addr, size, native_prot);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
int alloc(void** mem, size_t alignment, size_t size) {
|
||||||
|
#ifdef _WIN32
|
||||||
|
*mem = _aligned_malloc(size, alignment);
|
||||||
|
return *mem ? 0 : -1;
|
||||||
|
#else
|
||||||
|
return posix_memalign(mem, alignment, size);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void free(void* mem) {
|
||||||
|
#ifdef _WIN32
|
||||||
|
_aligned_free(mem);
|
||||||
|
#else
|
||||||
|
::free(mem);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
long get_pagesize() {
|
||||||
|
#ifdef _WIN32
|
||||||
|
SYSTEM_INFO si;
|
||||||
|
GetSystemInfo(&si);
|
||||||
|
return static_cast<long>(si.dwPageSize);
|
||||||
|
#else
|
||||||
|
return sysconf(_SC_PAGESIZE);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace openreg
|
@ -2,6 +2,12 @@
|
|||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#define OPENREG_EXPORT __declspec(dllexport)
|
||||||
|
#else
|
||||||
|
#define OPENREG_EXPORT __attribute__((visibility("default")))
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
@ -28,19 +34,19 @@ struct orPointerAttributes {
|
|||||||
size_t size;
|
size_t size;
|
||||||
};
|
};
|
||||||
|
|
||||||
orError_t orMalloc(void** devPtr, size_t size);
|
OPENREG_EXPORT orError_t orMalloc(void** devPtr, size_t size);
|
||||||
orError_t orFree(void* devPtr);
|
OPENREG_EXPORT orError_t orFree(void* devPtr);
|
||||||
orError_t orMallocHost(void** hostPtr, size_t size);
|
OPENREG_EXPORT orError_t orMallocHost(void** hostPtr, size_t size);
|
||||||
orError_t orFreeHost(void* hostPtr);
|
OPENREG_EXPORT orError_t orFreeHost(void* hostPtr);
|
||||||
orError_t orMemcpy(void* dst, const void* src, size_t count, orMemcpyKind kind);
|
OPENREG_EXPORT orError_t orMemcpy(void* dst, const void* src, size_t count, orMemcpyKind kind);
|
||||||
orError_t orMemoryUnprotect(void* devPtr);
|
OPENREG_EXPORT orError_t orMemoryUnprotect(void* devPtr);
|
||||||
orError_t orMemoryProtect(void* devPtr);
|
OPENREG_EXPORT orError_t orMemoryProtect(void* devPtr);
|
||||||
|
|
||||||
orError_t orGetDeviceCount(int* count);
|
OPENREG_EXPORT orError_t orGetDeviceCount(int* count);
|
||||||
orError_t orSetDevice(int device);
|
OPENREG_EXPORT orError_t orSetDevice(int device);
|
||||||
orError_t orGetDevice(int* device);
|
OPENREG_EXPORT orError_t orGetDevice(int* device);
|
||||||
|
|
||||||
orError_t orPointerGetAttributes(
|
OPENREG_EXPORT orError_t orPointerGetAttributes(
|
||||||
orPointerAttributes* attributes,
|
orPointerAttributes* attributes,
|
||||||
const void* ptr);
|
const void* ptr);
|
||||||
|
|
||||||
|
@ -1,5 +1,15 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
if sys.platform == "win32":
|
||||||
|
from ._utils import _load_dll_libraries
|
||||||
|
|
||||||
|
_load_dll_libraries()
|
||||||
|
del _load_dll_libraries
|
||||||
|
|
||||||
|
|
||||||
import torch_openreg._C # type: ignore[misc]
|
import torch_openreg._C # type: ignore[misc]
|
||||||
import torch_openreg.openreg
|
import torch_openreg.openreg
|
||||||
|
|
||||||
|
@ -0,0 +1,42 @@
|
|||||||
|
import ctypes
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def _load_dll_libraries() -> None:
|
||||||
|
openreg_dll_path = os.path.join(os.path.dirname(__file__), "lib")
|
||||||
|
|
||||||
|
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
|
||||||
|
with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
|
||||||
|
prev_error_mode = kernel32.SetErrorMode(0x0001)
|
||||||
|
|
||||||
|
kernel32.LoadLibraryW.restype = ctypes.c_void_p
|
||||||
|
if with_load_library_flags:
|
||||||
|
kernel32.LoadLibraryExW.restype = ctypes.c_void_p
|
||||||
|
|
||||||
|
os.add_dll_directory(openreg_dll_path)
|
||||||
|
|
||||||
|
dlls = glob.glob(os.path.join(openreg_dll_path, "*.dll"))
|
||||||
|
path_patched = False
|
||||||
|
for dll in dlls:
|
||||||
|
is_loaded = False
|
||||||
|
if with_load_library_flags:
|
||||||
|
res = kernel32.LoadLibraryExW(dll, None, 0x00001100)
|
||||||
|
last_error = ctypes.get_last_error()
|
||||||
|
if res is None and last_error != 126:
|
||||||
|
err = ctypes.WinError(last_error)
|
||||||
|
err.strerror += f' Error loading "{dll}" or one of its dependencies.'
|
||||||
|
raise err
|
||||||
|
elif res is not None:
|
||||||
|
is_loaded = True
|
||||||
|
if not is_loaded:
|
||||||
|
if not path_patched:
|
||||||
|
os.environ["PATH"] = ";".join([openreg_dll_path] + [os.environ["PATH"]])
|
||||||
|
path_patched = True
|
||||||
|
res = kernel32.LoadLibraryW(dll)
|
||||||
|
if res is None:
|
||||||
|
err = ctypes.WinError(ctypes.get_last_error())
|
||||||
|
err.strerror += f' Error loading "{dll}" or one of its dependencies.'
|
||||||
|
raise err
|
||||||
|
|
||||||
|
kernel32.SetErrorMode(prev_error_mode)
|
@ -6,7 +6,19 @@ file(GLOB_RECURSE SOURCE_FILES
|
|||||||
|
|
||||||
add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES})
|
add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES})
|
||||||
|
|
||||||
target_link_libraries(${LIBRARY_NAME} PRIVATE torch_python torch_openreg)
|
target_link_libraries(${LIBRARY_NAME} PRIVATE torch_python_library torch_openreg)
|
||||||
|
|
||||||
|
if(WIN32)
|
||||||
|
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
|
||||||
|
target_link_libraries(${LIBRARY_NAME} PRIVATE ${Python3_LIBRARIES})
|
||||||
|
elseif(APPLE)
|
||||||
|
set_target_properties(${LIBRARY_NAME} PROPERTIES LINK_FLAGS "-undefined dynamic_lookup")
|
||||||
|
endif()
|
||||||
|
|
||||||
target_link_directories(${LIBRARY_NAME} PRIVATE ${PYTORCH_INSTALL_DIR}/lib)
|
target_link_directories(${LIBRARY_NAME} PRIVATE ${PYTORCH_INSTALL_DIR}/lib)
|
||||||
|
|
||||||
install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})
|
install(TARGETS ${LIBRARY_NAME}
|
||||||
|
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||||
|
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||||
|
RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||||
|
)
|
||||||
|
@ -90,7 +90,7 @@ static PyMethodDef methods[] = {
|
|||||||
* Therefore, it cannot be named initModule here, otherwise initModule
|
* Therefore, it cannot be named initModule here, otherwise initModule
|
||||||
* in torch/csrc/Module.cpp will be called, resulting in failure.
|
* in torch/csrc/Module.cpp will be called, resulting in failure.
|
||||||
*/
|
*/
|
||||||
extern "C" PyObject* initOpenRegModule(void) {
|
extern "C" OPENREG_EXPORT PyObject* initOpenRegModule(void) {
|
||||||
static struct PyModuleDef openreg_C_module = {
|
static struct PyModuleDef openreg_C_module = {
|
||||||
PyModuleDef_HEAD_INIT, "torch_openreg._C", nullptr, -1, methods};
|
PyModuleDef_HEAD_INIT, "torch_openreg._C", nullptr, -1, methods};
|
||||||
PyObject* mod = PyModule_Create(&openreg_C_module);
|
PyObject* mod = PyModule_Create(&openreg_C_module);
|
||||||
|
@ -1,13 +1,18 @@
|
|||||||
#include <Python.h>
|
#include <Python.h>
|
||||||
|
|
||||||
extern PyObject* initOpenRegModule(void);
|
#ifdef _WIN32
|
||||||
|
#define OPENREG_EXPORT __declspec(dllexport)
|
||||||
|
#else
|
||||||
|
#define OPENREG_EXPORT __attribute__((visibility("default")))
|
||||||
|
#endif
|
||||||
|
|
||||||
|
extern OPENREG_EXPORT PyObject* initOpenRegModule(void);
|
||||||
|
|
||||||
#ifndef _WIN32
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C"
|
extern "C"
|
||||||
#endif
|
#endif
|
||||||
__attribute__((visibility("default"))) PyObject* PyInit__C(void);
|
|
||||||
#endif
|
OPENREG_EXPORT PyObject* PyInit__C(void);
|
||||||
|
|
||||||
PyMODINIT_FUNC PyInit__C(void)
|
PyMODINIT_FUNC PyInit__C(void)
|
||||||
{
|
{
|
||||||
|
@ -28,7 +28,6 @@ from torch.multiprocessing import current_process, get_context
|
|||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
get_report_path,
|
get_report_path,
|
||||||
IS_CI,
|
IS_CI,
|
||||||
IS_LINUX,
|
|
||||||
IS_MACOS,
|
IS_MACOS,
|
||||||
retry_shell,
|
retry_shell,
|
||||||
set_cwd,
|
set_cwd,
|
||||||
@ -909,10 +908,6 @@ def _test_autoload(test_directory, options, enable=True):
|
|||||||
|
|
||||||
|
|
||||||
def run_test_with_openreg(test_module, test_directory, options):
|
def run_test_with_openreg(test_module, test_directory, options):
|
||||||
# TODO(FFFrog): Will remove this later when windows/macos are supported.
|
|
||||||
if not IS_LINUX:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
openreg_dir = os.path.join(
|
openreg_dir = os.path.join(
|
||||||
test_directory, "cpp_extensions", "open_registration_extension", "torch_openreg"
|
test_directory, "cpp_extensions", "open_registration_extension", "torch_openreg"
|
||||||
)
|
)
|
||||||
|
@ -16,7 +16,9 @@ import torch
|
|||||||
from torch.serialization import safe_globals
|
from torch.serialization import safe_globals
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
run_tests,
|
run_tests,
|
||||||
|
skipIfMPS,
|
||||||
skipIfTorchDynamo,
|
skipIfTorchDynamo,
|
||||||
|
skipIfWindows,
|
||||||
skipIfXpu,
|
skipIfXpu,
|
||||||
TemporaryFileName,
|
TemporaryFileName,
|
||||||
TestCase,
|
TestCase,
|
||||||
@ -284,6 +286,8 @@ class TestOpenReg(TestCase):
|
|||||||
self.assertEqual(torch.openreg.initial_seed(), 2024) # type: ignore[misc]
|
self.assertEqual(torch.openreg.initial_seed(), 2024) # type: ignore[misc]
|
||||||
|
|
||||||
# Autograd
|
# Autograd
|
||||||
|
@skipIfMPS
|
||||||
|
@skipIfWindows()
|
||||||
def test_autograd_init(self):
|
def test_autograd_init(self):
|
||||||
# Make sure autograd is initialized
|
# Make sure autograd is initialized
|
||||||
torch.ones(2, requires_grad=True, device="openreg").sum().backward()
|
torch.ones(2, requires_grad=True, device="openreg").sum().backward()
|
||||||
|
Reference in New Issue
Block a user