[1/N] OpenReg: Replace open_registration_extension.cpp with openreg (#141815)

As described in OpenReg [next-steps](https://github.com/pytorch/pytorch/blob/main/test/cpp_extensions/open_registration_extension/README.md#next-steps), here we replace the current `open_registration_extension.cpp` test in PyTorch CI with openreg.

The current `open_registration_extension.cpp` contains two parts:
1. Implentations to support `PrivateUse1` backend.
2. Helper functions used for UTs in `test_cpp_extensions_open_device_registration.py` and `test_transformers.py`.

For the first part, we'll replace it with openreg. For the second part, we'll migrate them to ut files step by step.

@albanD

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141815
Approved by: https://github.com/albanD
This commit is contained in:
Zhenbin Lin
2025-01-14 15:59:00 +00:00
committed by PyTorch MergeBot
parent 347a74b8f5
commit cbb1ed2966
10 changed files with 290 additions and 501 deletions

1
.gitignore vendored
View File

@ -63,6 +63,7 @@ dropout_model.pt
test/generated_type_hints_smoketest.py
test/htmlcov
test/cpp_extensions/install/
test/cpp_extensions/open_registration_extension/install
third_party/build/
tools/coverage_plugins_package/pip-wheel-metadata/
tools/shared/_utils_internal.py

View File

@ -33,15 +33,6 @@ static uint64_t last_abs_saved_value = 0;
static uint64_t storageImpl_counter = 0;
static uint64_t last_storageImpl_saved_value = 0;
// register guard
namespace at {
namespace detail {
C10_REGISTER_GUARD_IMPL(
PrivateUse1,
c10::impl::NoOpDeviceGuardImpl<DeviceType::PrivateUse1>);
}} // namespace at::detail
namespace {
@ -249,147 +240,8 @@ at::Tensor custom_add_Tensor(const at::Tensor& self, const at::Tensor& other, co
return at::empty(self.sizes(), self.options());
}
// basic abs function
at::Tensor& custom_abs_out(const at::Tensor& self, at::Tensor& out) {
return at::native::abs_out(self, out);
}
// A dummy allocator for our custom device, that secretly uses the CPU
struct DummyCustomAllocator final : at::Allocator {
DummyCustomAllocator() = default;
at::DataPtr allocate(size_t nbytes) override {
void* data = c10::alloc_cpu(nbytes);
return {data, data, &ReportAndDelete, at::Device(at::DeviceType::PrivateUse1, custom_device_index)};
}
static void ReportAndDelete(void* ptr) {
if (!ptr) {
return;
}
c10::free_cpu(ptr);
}
at::DeleterFnPtr raw_deleter() const override {
return &ReportAndDelete;
}
void copy_data(void* dest, const void* src, std::size_t count) const final {
default_copy_data(dest, src, count);
}
};
// Register our dummy allocator
static DummyCustomAllocator global_custom_alloc;
REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_custom_alloc);
// basic dummy empty function, so we can directly construct tensors on the custom device
// This dummy test device will just use the CPU allocator, and ignores pinned memory.
at::Tensor custom_empty_memory_format(at::IntArrayRef size,
std::optional<at::ScalarType> dtype,
std::optional<at::Layout> layout,
std::optional<at::Device> device,
std::optional<bool> pin_memory,
std::optional<at::MemoryFormat> memory_format) {
constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
return at::detail::empty_generic(size,
&global_custom_alloc,
private_use_ks,
c10::dtype_or_default(dtype),
memory_format);
}
at::Tensor custom_empty_symint(c10::IntArrayRef size,
std::optional<at::ScalarType> dtype,
std::optional<at::Layout> layout,
std::optional<at::Device> device,
std::optional<bool> pin_memory,
std::optional<at::MemoryFormat> memory_format) {
constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
return at::detail::empty_generic(size,
&global_custom_alloc, private_use_ks, c10::dtype_or_default(dtype), memory_format);
}
at::Tensor & custom_fill__scalar(at::Tensor & self, const at::Scalar & value) {
// Not bothering to implement.
return self;
}
// Unsafe using dummy device data_ptr to creat a cpu tensor, and shared data_ptr.
at::Tensor unsafe_create_cpu_tensor_from_dummy_tensor(const at::Tensor& src) {
TORCH_CHECK(src.device().type() == c10::DeviceType::PrivateUse1,
"Only support dummy device.");
const auto& sizes_ = src.sizes();
const auto& strides_ = src.strides();
auto storage_offset_ = src.storage_offset();
at::detail::check_size_nonnegative(sizes_);
size_t size_bytes = at::detail::computeStorageNbytes(sizes_, strides_,
src.element_size(),
storage_offset_);
at::DataPtr data_ptr =
c10::InefficientStdFunctionContext::makeDataPtr(src.storage().mutable_data_ptr().get(),
[](void*){}, at::kCPU);
c10::Storage storage{c10::Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr),
/*allocator=*/&global_custom_alloc, /*resizeable=*/false};
constexpr c10::DispatchKeySet cpu_ks(c10::DispatchKey::CPU);
at::Tensor tensor = at::detail::make_tensor<c10::TensorImpl>(
std::move(storage), cpu_ks, src.dtype());
c10::TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
tensor_impl->set_sizes_and_strides(sizes_, strides_);
tensor_impl->set_storage_offset(storage_offset_);
return tensor;
}
// basic dummy copy_() function, so we can copy from the custom device to/from CPU
at::Tensor custom__copy_from(const at::Tensor& self, const at::Tensor& dst, bool non_blocking) {
TORCH_CHECK(
self.is_cpu() || self.device().type() == c10::DeviceType::PrivateUse1,
"Dummy test only allows copy from cpu -> dummy device.");
TORCH_CHECK(
dst.is_cpu() || dst.device().type() == c10::DeviceType::PrivateUse1,
"Dummy test only allows copy from cpu -> dummy device.");
// Some dummy asserts for the basic use case: inputs are the same size / dtype, all contiguous.
TORCH_CHECK(self.sizes() == dst.sizes());
TORCH_CHECK(self.scalar_type() == dst.scalar_type());
if (self.is_contiguous() && dst.is_contiguous()) {
std::memcpy(dst.storage().data_ptr().get(),
self.storage().data_ptr().get(),
self.storage().nbytes());
} else {
// Using cpu tensor to accomplishment stride copy.
auto convert_to_cpu_tensor = [](const at::Tensor& src) -> at::Tensor {
if (src.device().type() == c10::DeviceType::PrivateUse1) {
return unsafe_create_cpu_tensor_from_dummy_tensor(src);
} else {
return src;
}
};
at::Tensor cpu_self = convert_to_cpu_tensor(self);
at::Tensor cpu_dst = convert_to_cpu_tensor(dst);
cpu_dst.copy_(cpu_self);
}
return dst;
}
at::Tensor custom__copy_from_and_resize(const at::Tensor& self, const at::Tensor& dst) {
return custom__copy_from(self, dst, false);
}
at::Tensor custom_empty_strided(c10::IntArrayRef size,
c10::IntArrayRef stride,
std::optional<at::ScalarType> dtype_opt,
std::optional<at::Layout> layout_opt,
std::optional<at::Device> device_opt,
std::optional<bool> pin_memory_opt) {
constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
auto dtype = c10::dtype_or_default(dtype_opt);
return at::detail::empty_strided_generic(size, stride, &global_custom_alloc, private_use_ks, dtype);
return dst.copy_(self, false);
}
// Some set operations for the basic use case
@ -404,43 +256,6 @@ at::Tensor& custom_set_source_Storage(at::Tensor& result, c10::Storage src) {
return result;
}
// Some set operations for the basic use case
at::Tensor& custom_set_source_Storage_storage_offset(at::Tensor& result,
c10::Storage storage,
int64_t storage_offset,
c10::IntArrayRef size,
c10::IntArrayRef stride) {
result.unsafeGetTensorImpl()->set_storage_offset(storage_offset);
at::OptionalIntArrayRef stride_opt = stride.data() != nullptr ? at::OptionalIntArrayRef(stride) : std::nullopt;
at::native::resize_impl_cpu_(result.unsafeGetTensorImpl(),
size, stride_opt,
/*resize_storage=*/!result.is_meta());
return result;
}
const at::Tensor& custom_resize_(const at::Tensor& self, at::IntArrayRef size,
std::optional<at::MemoryFormat> optional_memory_format) {
at::TensorImpl* tensor_impl = self.unsafeGetTensorImpl();
tensor_impl->set_sizes_contiguous(size);
const auto itemsize = tensor_impl->dtype().itemsize();
const auto offset = tensor_impl->storage_offset();
const auto storage_size = at::detail::computeStorageNbytesContiguous(size, itemsize, offset);
// Dummy device is using cpu allocator, so here just call cpu
// function maybe_resize_storage_cpu in aten/src/ATen/native/Resize.h
// to get a sufficient memory space.
at::native::maybe_resize_storage_cpu(tensor_impl, storage_size);
if (optional_memory_format.has_value()) {
auto memory_format =
optional_memory_format.value();
TORCH_CHECK(
memory_format != at::MemoryFormat::Preserve,
"Unsupported memory format",
memory_format);
tensor_impl->empty_tensor_restride(memory_format);
}
return self;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, c10::SymInt, c10::SymInt, at::Tensor, at::Tensor, at::Tensor>
custom_scaled_dot_product_fused_attention_overrideable(
const at::Tensor & query,
@ -504,17 +319,9 @@ custom_scaled_dot_product_fused_attention_overrideable_backward(
// This macro registers your kernels to the PyTorch Dispatcher.
// More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/.
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("abs.out", &custom_abs_out);
m.impl("add.Tensor", &custom_add_Tensor);
m.impl("empty.memory_format", &custom_empty_symint);
m.impl("fill_.Scalar", &custom_fill__scalar);
m.impl("_copy_from", &custom__copy_from);
m.impl("_copy_from_and_resize", &custom__copy_from_and_resize);
m.impl("empty_strided", &custom_empty_strided);
m.impl("set_.source_Storage", &custom_set_source_Storage);
m.impl("set_.source_Storage_storage_offset",&custom_set_source_Storage_storage_offset);
m.impl("resize_", &custom_resize_);
m.impl("as_strided", at::native::as_strided_tensorimpl);
m.impl("quantize_per_tensor", at::native::quantize_per_tensor);
m.impl("_fused_sdp_choice", &_fused_sdp_choice_privateuse1);
m.impl("_scaled_dot_product_fused_attention_overrideable", &custom_scaled_dot_product_fused_attention_overrideable);
@ -526,10 +333,8 @@ void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack
}
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("sub.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
m.impl("_foreach_add.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
m.impl("_fused_adamw_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
m.impl("index.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
m.impl("triu_indices", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
}
@ -555,73 +360,6 @@ void set_custom_device_index(c10::DeviceIndex device_index) {
custom_device_index = device_index;
}
// a global flag used for dummy pin_memory of custom device
bool custom_pinned_flag = false;
class PrivateGeneratorImpl : public at::CPUGeneratorImpl {
public:
// Constructors
PrivateGeneratorImpl(c10::DeviceIndex device_index) {
device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index);
key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1);
}
~PrivateGeneratorImpl() override = default;
};
struct FooHooksArgs : public at::PrivateUse1HooksArgs {};
struct FooHooksInterface : public at::PrivateUse1HooksInterface {
FooHooksInterface(FooHooksArgs) {}
~FooHooksInterface() override = default;
const at::Generator& getDefaultGenerator(
c10::DeviceIndex device_index) const override {
static auto device_gen = at::make_generator<PrivateGeneratorImpl>(device_index);
return device_gen;
}
at::Generator getNewGenerator(c10::DeviceIndex device_index) const {
return at::make_generator<PrivateGeneratorImpl>(device_index);
}
// this is a simple implementation, custom_pinned_flag will be set as true
// once tensor.pin_memory() is called. And then tensor.is_pinned()
// always return true no matter what tensor it's called on.
bool isPinnedPtr(const void* data) const override {
return custom_pinned_flag;
}
c10::Allocator* getPinnedMemoryAllocator() const override {
custom_pinned_flag = true;
return c10::GetCPUAllocator();
}
};
TORCH_DECLARE_REGISTRY(PrivateUse1HooksRegistry, FooHooksInterface, FooHooksArgs);
C10_DEFINE_REGISTRY(PrivateUse1HooksRegistry, FooHooksInterface, FooHooksArgs)
// Using Create function to get PrivateUse1HooksInterface point from PrivateUse1HooksRegistry class.
C10_REGISTER_TYPED_CLASS(PrivateUse1HooksRegistry, "FooHooks", FooHooksInterface)
static at::PrivateUse1HooksInterface* privateuse1_hooks_local = nullptr;
static at::PrivateUse1HooksInterface* get_private_hooks() {
static c10::once_flag once;
c10::call_once(once, [] {
privateuse1_hooks_local = PrivateUse1HooksRegistry()->Create("FooHooks", {}).release();
if (!privateuse1_hooks_local) {
privateuse1_hooks_local = new FooHooksInterface(FooHooksArgs{});
}
});
return privateuse1_hooks_local;
}
void register_hook() {
at::RegisterPrivateUse1HooksInterface(get_private_hooks());
}
bool is_register_hook() {
return privateuse1_hooks_local != nullptr;
}
const at::Generator& default_generator(c10::DeviceIndex device_index) {
return at::globalContext().defaultGenerator(at::Device(c10::DeviceType::PrivateUse1, device_index));;
}
@ -683,8 +421,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("custom_set_backend_meta", &custom_set_backend_meta, "a fake set tensor BackendMeta function");
m.def("check_backend_meta", &check_backend_meta, "check if BackendMeta serialization correctly");
m.def("custom_serialization_registry", &custom_serialization_registry, "register custom serialization function");
m.def("register_hook", &register_hook, "register_hook for privateuse1");
m.def("is_register_hook", &is_register_hook, "is_register_hook for privateuse1");
m.def("default_generator", &default_generator, "default_generator for privateuse1");
m.def("fallback_with_undefined_tensor", &fallback_with_undefined_tensor, "fallback_with_undefined_tensor for privateuse1");

View File

@ -1,5 +1,6 @@
import ctypes
import logging
import threading
import time
import torch
@ -25,9 +26,9 @@ class Allocator:
self.allocated = {}
def malloc(self, size):
new_data = torch.empty(size, dtype=torch.uint8)
ptr = new_data.data_ptr()
self.allocated[ptr] = new_data
mem = ctypes.create_string_buffer(size)
ptr = ctypes.addressof(mem)
self.allocated[ptr] = (size, mem)
return ptr
def free(self, ptr):
@ -41,23 +42,28 @@ class Allocator:
return ptr in self.allocated
def tensor_from_meta(self, meta):
def create_tensor_from_data_ptr(ptr, size):
storage = torch._C._construct_storage_from_data_pointer(
ptr, torch.device("cpu"), size
)
return torch.Tensor(storage)
found_base = None
# Usual case, we're receiving a known Tensor
found_base = self.allocated.get(meta.data_ptr, None)
if meta.data_ptr in self.allocated:
found_base = create_tensor_from_data_ptr(
meta.data_ptr, self.allocated[meta.data_ptr][0]
)
# Might be a rewrap of another storage at a different offset
# Slow path to try and find the corresponding storage
if found_base is None:
for tag, t in self.allocated.items():
for tag, (size, _) in self.allocated.items():
# t is always a 1D uint8 storage!
if meta.data_ptr > tag and meta.data_ptr < tag + t.nelement():
if meta.data_ptr > tag and meta.data_ptr < tag + size:
# Blame @ngimel for this
slice_size = t.nelement() - (meta.data_ptr - tag)
found_base = torch.tensor((), dtype=torch.uint8).set_(
t.untyped_storage()[meta.data_ptr - tag :],
size=(slice_size,),
stride=(1,),
storage_offset=0,
)
slice_size = size - (meta.data_ptr - tag)
found_base = create_tensor_from_data_ptr(meta.data_ptr, slice_size)
# Might be an empty tensor
if found_base is None and meta.nelem_in_bytes == 0:
@ -91,6 +97,7 @@ class Driver:
super().__init__()
self.num_devices = num_devices
self.is_initialized = False
self.rlock = threading.RLock()
def _lazy_init(self):
if self.is_initialized:
@ -121,19 +128,20 @@ class Driver:
self.is_initialized = True
def exec(self, cmd, *args):
self._lazy_init()
log.info("Main process launched: %s(*%s)", cmd, safe_str(args))
with self.rlock:
self._lazy_init()
log.info("Main process launched: %s(*%s)", cmd, safe_str(args))
if cmd in Driver.registry:
res = Driver.registry[cmd](self, *args)
else:
res = self.run_on_executor(self.curr_device_idx, cmd, *args)
if cmd in Driver.registry:
res = Driver.registry[cmd](self, *args)
else:
res = self.run_on_executor(self.curr_device_idx, cmd, *args)
log.info("Main process result for %s received: %s", cmd, safe_str(res))
if res == "ERROR":
raise RuntimeError(f"Error in daemon while executing {cmd}, see logs")
else:
return res
log.info("Main process result for %s received: %s", cmd, safe_str(res))
if res == "ERROR":
raise RuntimeError(f"Error in daemon while executing {cmd}, see logs")
else:
return res
def run_on_executor(self, device_idx, cmd, *args):
req_queue, ans_queue, _ = self.devices[device_idx]

View File

@ -3,7 +3,14 @@
// Make this a proper CPython module
static struct PyModuleDef openreg_C_module = {
PyModuleDef_HEAD_INIT,
.m_name = "pytorch_openreg._C",
"pytorch_openreg._C",
nullptr,
-1,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr
};
PyMODINIT_FUNC PyInit__C(void) {

View File

@ -35,11 +35,25 @@ struct OpenRegAllocator final : at::Allocator {
if (!ptr || !Py_IsInitialized()) {
return;
}
py::gil_scoped_acquire acquire;
PyObject *type = nullptr, *value = nullptr, *traceback = nullptr;
// Always stash, this will be a no-op if there is no error
PyErr_Fetch(&type, &value, &traceback);
TORCH_CHECK(
get_method("free")(reinterpret_cast<openreg_ptr_t>(ptr)).cast<bool>(),
"Failed to free memory pointer at ", ptr
);
"Failed to free memory pointer at ",
ptr);
// If that user code raised an error, just print it without raising it
if (PyErr_Occurred()) {
PyErr_Print();
}
// Restore the original error
PyErr_Restore(type, value, traceback);
}
at::DeleterFnPtr raw_deleter() const override {

View File

@ -1,5 +1,7 @@
import distutils.command.clean
import os
import shutil
import sys
from pathlib import Path
from setuptools import find_packages, setup
@ -32,6 +34,15 @@ class clean(distutils.command.clean.clean):
if __name__ == "__main__":
if sys.platform == "win32":
vc_version = os.getenv("VCToolsVersion", "")
if vc_version.startswith("14.16."):
CXX_FLAGS = ["/sdl"]
else:
CXX_FLAGS = ["/sdl", "/permissive-"]
else:
CXX_FLAGS = {"cxx": ["-g", "-Wall", "-Werror"]}
sources = list(CSRS_DIR.glob("*.cpp"))
# Note that we always compile with debug info
@ -40,7 +51,7 @@ if __name__ == "__main__":
name="pytorch_openreg._C",
sources=sorted(str(s) for s in sources),
include_dirs=[CSRS_DIR],
extra_compile_args={"cxx": ["-g", "-Wall", "-Werror"]},
extra_compile_args=CXX_FLAGS,
)
]

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
import argparse
import contextlib
import copy
import glob
import json
@ -905,6 +906,41 @@ def run_test(
return ret_code
def install_cpp_extensions(cpp_extensions_test_dir, env=os.environ):
# Wipe the build folder, if it exists already
cpp_extensions_test_build_dir = os.path.join(cpp_extensions_test_dir, "build")
if os.path.exists(cpp_extensions_test_build_dir):
shutil.rmtree(cpp_extensions_test_build_dir)
# Build the test cpp extensions modules
cmd = [sys.executable, "setup.py", "install", "--root", "./install"]
return_code = shell(cmd, cwd=cpp_extensions_test_dir, env=env)
if return_code != 0:
return None, return_code
install_directory = ""
# install directory is the one that is named site-packages
for root, directories, _ in os.walk(
os.path.join(cpp_extensions_test_dir, "install")
):
for directory in directories:
if "-packages" in directory:
install_directory = os.path.join(root, directory)
assert install_directory, "install_directory must not be empty"
return install_directory, 0
@contextlib.contextmanager
def extend_python_path(install_directory):
python_path = os.environ.get("PYTHONPATH", "")
try:
os.environ["PYTHONPATH"] = os.pathsep.join([install_directory, python_path])
yield
finally:
os.environ["PYTHONPATH"] = python_path
def try_set_cpp_stack_traces(env, command, set=True):
# Print full c++ stack traces during retries
env = env or {}
@ -1051,8 +1087,6 @@ def _test_cpp_extensions_aot(test_directory, options, use_ninja):
if return_code != 0:
return return_code
# "install" the test modules and run tests
python_path = os.environ.get("PYTHONPATH", "")
from shutil import copyfile
os.environ["USE_NINJA"] = shell_env["USE_NINJA"]
@ -1071,10 +1105,9 @@ def _test_cpp_extensions_aot(test_directory, options, use_ninja):
install_directory = os.path.join(root, directory)
assert install_directory, "install_directory must not be empty"
os.environ["PYTHONPATH"] = os.pathsep.join([install_directory, python_path])
return run_test(ShardedTest(test_module, 1, 1), test_directory, options)
with extend_python_path(install_directory):
return run_test(ShardedTest(test_module, 1, 1), test_directory, options)
finally:
os.environ["PYTHONPATH"] = python_path
if os.path.exists(test_directory + "/" + test_module + ".py"):
os.remove(test_directory + "/" + test_module + ".py")
os.environ.pop("USE_NINJA")
@ -1097,42 +1130,33 @@ def test_autoload_disable(test_module, test_directory, options):
def _test_autoload(test_directory, options, enable=True):
# Wipe the build folder, if it exists already
cpp_extensions_test_dir = os.path.join(test_directory, "cpp_extensions")
cpp_extensions_test_build_dir = os.path.join(cpp_extensions_test_dir, "build")
if os.path.exists(cpp_extensions_test_build_dir):
shutil.rmtree(cpp_extensions_test_build_dir)
# Build the test cpp extensions modules
cmd = [sys.executable, "setup.py", "install", "--root", "./install"]
return_code = shell(cmd, cwd=cpp_extensions_test_dir, env=os.environ)
install_directory, return_code = install_cpp_extensions(cpp_extensions_test_dir)
if return_code != 0:
return return_code
# "install" the test modules and run tests
python_path = os.environ.get("PYTHONPATH", "")
try:
cpp_extensions = os.path.join(test_directory, "cpp_extensions")
install_directory = ""
# install directory is the one that is named site-packages
for root, directories, _ in os.walk(os.path.join(cpp_extensions, "install")):
for directory in directories:
if "-packages" in directory:
install_directory = os.path.join(root, directory)
assert install_directory, "install_directory must not be empty"
os.environ["PYTHONPATH"] = os.pathsep.join([install_directory, python_path])
os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = str(int(enable))
cmd = [sys.executable, "test_autoload.py"]
return_code = shell(cmd, cwd=test_directory, env=os.environ)
return return_code
with extend_python_path(install_directory):
cmd = [sys.executable, "test_autoload.py"]
return_code = shell(cmd, cwd=test_directory, env=os.environ)
return return_code
finally:
os.environ["PYTHONPATH"] = python_path
os.environ.pop("TORCH_DEVICE_BACKEND_AUTOLOAD")
def run_test_with_openreg(test_module, test_directory, options):
openreg_dir = os.path.join(
test_directory, "cpp_extensions", "open_registration_extension"
)
install_dir, return_code = install_cpp_extensions(openreg_dir)
if return_code != 0:
return return_code
with extend_python_path(install_dir):
return run_test(test_module, test_directory, options)
def test_distributed(test_module, test_directory, options):
# MPI tests are broken with Python-3.9
mpi_available = subprocess.call(
@ -1456,6 +1480,8 @@ CUSTOM_HANDLERS = {
"test_ci_sanity_check_fail": run_ci_sanity_check,
"test_autoload_enable": test_autoload_enable,
"test_autoload_disable": test_autoload_disable,
"test_cpp_extensions_open_device_registration": run_test_with_openreg,
"test_transformers": run_test_with_openreg,
}

View File

@ -5,12 +5,12 @@ import io
import os
import sys
import tempfile
import types
import unittest
from typing import Union
from unittest.mock import patch
import numpy as np
import pytorch_openreg # noqa: F401
import torch
import torch.testing._internal.common_utils as common
@ -31,15 +31,24 @@ TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not Non
def generate_faked_module():
class _OpenRegMod:
pass
return _OpenRegMod()
def generate_faked_module_methods():
def device_count() -> int:
return 1
def get_rng_state(device: Union[int, str, torch.device] = "foo") -> torch.Tensor:
def get_rng_state(
device: Union[int, str, torch.device] = "openreg",
) -> torch.Tensor:
# create a tensor using our custom device object.
return torch.empty(4, 4, device="foo")
return torch.empty(4, 4, device="openreg")
def set_rng_state(
new_state: torch.Tensor, device: Union[int, str, torch.device] = "foo"
new_state: torch.Tensor, device: Union[int, str, torch.device] = "openreg"
) -> None:
pass
@ -49,18 +58,13 @@ def generate_faked_module():
def current_device():
return 0
# create a new module to fake torch.foo dynamicaly
foo = types.ModuleType("foo")
foo.device_count = device_count
foo.get_rng_state = get_rng_state
foo.set_rng_state = set_rng_state
foo.is_available = is_available
foo.current_device = current_device
foo._lazy_init = lambda: None
foo.is_initialized = lambda: True
return foo
torch.openreg.device_count = device_count
torch.openreg.get_rng_state = get_rng_state
torch.openreg.set_rng_state = set_rng_state
torch.openreg.is_available = is_available
torch.openreg.current_device = current_device
torch.openreg._lazy_init = lambda: None
torch.openreg.is_initialized = lambda: True
@unittest.skipIf(IS_ARM64, "Does not work on arm")
@ -101,10 +105,8 @@ class TestCppExtensionOpenRgistration(common.TestCase):
verbose=True,
)
# register torch.foo module and foo device to torch
torch.utils.rename_privateuse1_backend("foo")
torch.utils.generate_methods_for_privateuse1_backend(for_storage=True)
torch._register_device_module("foo", generate_faked_module())
generate_faked_module_methods()
def test_base_device_registration(self):
self.assertFalse(self.module.custom_add_called())
@ -132,10 +134,10 @@ class TestCppExtensionOpenRgistration(common.TestCase):
with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"):
torch._register_device_module("dev", generate_faked_module())
with self.assertRaisesRegex(RuntimeError, "The runtime module of"):
torch._register_device_module("foo", generate_faked_module())
torch._register_device_module("openreg", generate_faked_module())
# backend name can be renamed to the same name multiple times
torch.utils.rename_privateuse1_backend("foo")
torch.utils.rename_privateuse1_backend("openreg")
# backend name can't be renamed multiple times to different names.
with self.assertRaisesRegex(
@ -147,39 +149,29 @@ class TestCppExtensionOpenRgistration(common.TestCase):
with self.assertRaisesRegex(RuntimeError, "The custom device module of"):
torch.utils.generate_methods_for_privateuse1_backend()
# check whether torch.foo have been registered correctly
# check whether torch.openreg have been registered correctly
self.assertTrue(
torch.utils.backend_registration._get_custom_mod_func("device_count")() == 1
)
with self.assertRaisesRegex(RuntimeError, "Try to call torch.foo"):
with self.assertRaisesRegex(RuntimeError, "Try to call torch.openreg"):
torch.utils.backend_registration._get_custom_mod_func("func_name_")
# check attributes after registered
self.assertTrue(hasattr(torch.Tensor, "is_foo"))
self.assertTrue(hasattr(torch.Tensor, "foo"))
self.assertTrue(hasattr(torch.TypedStorage, "is_foo"))
self.assertTrue(hasattr(torch.TypedStorage, "foo"))
self.assertTrue(hasattr(torch.UntypedStorage, "is_foo"))
self.assertTrue(hasattr(torch.UntypedStorage, "foo"))
self.assertTrue(hasattr(torch.nn.Module, "foo"))
self.assertTrue(hasattr(torch.nn.utils.rnn.PackedSequence, "is_foo"))
self.assertTrue(hasattr(torch.nn.utils.rnn.PackedSequence, "foo"))
self.assertTrue(hasattr(torch.Tensor, "is_openreg"))
self.assertTrue(hasattr(torch.Tensor, "openreg"))
self.assertTrue(hasattr(torch.TypedStorage, "is_openreg"))
self.assertTrue(hasattr(torch.TypedStorage, "openreg"))
self.assertTrue(hasattr(torch.UntypedStorage, "is_openreg"))
self.assertTrue(hasattr(torch.UntypedStorage, "openreg"))
self.assertTrue(hasattr(torch.nn.Module, "openreg"))
self.assertTrue(hasattr(torch.nn.utils.rnn.PackedSequence, "is_openreg"))
self.assertTrue(hasattr(torch.nn.utils.rnn.PackedSequence, "openreg"))
def test_open_device_generator_registration_and_hooks(self):
device = self.module.custom_device()
# None of our CPU operations should call the custom add function.
self.assertFalse(self.module.custom_add_called())
# check generator registered before using
with self.assertRaisesRegex(
RuntimeError,
"Please register PrivateUse1HooksInterface by `RegisterPrivateUse1HooksInterface` first",
):
torch.Generator(device=device)
if self.module.is_register_hook() is False:
self.module.register_hook()
gen = torch.Generator(device=device)
self.assertTrue(gen.device == device)
@ -191,38 +183,40 @@ class TestCppExtensionOpenRgistration(common.TestCase):
def test_open_device_dispatchstub(self):
# test kernels could be reused by privateuse1 backend through dispatchstub
input_data = torch.randn(2, 2, 3, dtype=torch.float32, device="cpu")
foo_input_data = input_data.to("foo")
openreg_input_data = input_data.to("openreg")
output_data = torch.abs(input_data)
foo_output_data = torch.abs(foo_input_data)
self.assertEqual(output_data, foo_output_data.cpu())
openreg_output_data = torch.abs(openreg_input_data)
self.assertEqual(output_data, openreg_output_data.cpu())
output_data = torch.randn(2, 2, 6, dtype=torch.float32, device="cpu")
# output operand will resize flag is True in TensorIterator.
foo_input_data = input_data.to("foo")
foo_output_data = output_data.to("foo")
openreg_input_data = input_data.to("openreg")
openreg_output_data = output_data.to("openreg")
# output operand will resize flag is False in TensorIterator.
torch.abs(input_data, out=output_data[:, :, 0:6:2])
torch.abs(foo_input_data, out=foo_output_data[:, :, 0:6:2])
self.assertEqual(output_data, foo_output_data.cpu())
torch.abs(openreg_input_data, out=openreg_output_data[:, :, 0:6:2])
self.assertEqual(output_data, openreg_output_data.cpu())
# output operand will resize flag is True in TensorIterator.
# and convert output to contiguous tensor in TensorIterator.
output_data = torch.randn(2, 2, 6, dtype=torch.float32, device="cpu")
foo_input_data = input_data.to("foo")
foo_output_data = output_data.to("foo")
openreg_input_data = input_data.to("openreg")
openreg_output_data = output_data.to("openreg")
torch.abs(input_data, out=output_data[:, :, 0:6:3])
torch.abs(foo_input_data, out=foo_output_data[:, :, 0:6:3])
self.assertEqual(output_data, foo_output_data.cpu())
torch.abs(openreg_input_data, out=openreg_output_data[:, :, 0:6:3])
self.assertEqual(output_data, openreg_output_data.cpu())
def test_open_device_quantized(self):
input_data = torch.randn(3, 4, 5, dtype=torch.float32, device="cpu").to("foo")
input_data = torch.randn(3, 4, 5, dtype=torch.float32, device="cpu").to(
"openreg"
)
quantized_tensor = torch.quantize_per_tensor(input_data, 0.1, 10, torch.qint8)
self.assertEqual(quantized_tensor.device, torch.device("foo:0"))
self.assertEqual(quantized_tensor.device, torch.device("openreg:0"))
self.assertEqual(quantized_tensor.dtype, torch.qint8)
def test_open_device_random(self):
# check if torch.foo have implemented get_rng_state
with torch.random.fork_rng(device_type="foo"):
# check if torch.openreg have implemented get_rng_state
with torch.random.fork_rng(device_type="openreg"):
pass
def test_open_device_tensor(self):
@ -230,15 +224,15 @@ class TestCppExtensionOpenRgistration(common.TestCase):
# check whether print tensor.type() meets the expectation
dtypes = {
torch.bool: "torch.foo.BoolTensor",
torch.double: "torch.foo.DoubleTensor",
torch.float32: "torch.foo.FloatTensor",
torch.half: "torch.foo.HalfTensor",
torch.int32: "torch.foo.IntTensor",
torch.int64: "torch.foo.LongTensor",
torch.int8: "torch.foo.CharTensor",
torch.short: "torch.foo.ShortTensor",
torch.uint8: "torch.foo.ByteTensor",
torch.bool: "torch.openreg.BoolTensor",
torch.double: "torch.openreg.DoubleTensor",
torch.float32: "torch.openreg.FloatTensor",
torch.half: "torch.openreg.HalfTensor",
torch.int32: "torch.openreg.IntTensor",
torch.int64: "torch.openreg.LongTensor",
torch.int8: "torch.openreg.CharTensor",
torch.short: "torch.openreg.ShortTensor",
torch.uint8: "torch.openreg.ByteTensor",
}
for tt, dt in dtypes.items():
test_tensor = torch.empty(4, 4, dtype=tt, device=device)
@ -246,69 +240,69 @@ class TestCppExtensionOpenRgistration(common.TestCase):
# check whether the attributes and methods of the corresponding custom backend are generated correctly
x = torch.empty(4, 4)
self.assertFalse(x.is_foo)
self.assertFalse(x.is_openreg)
x = x.foo(torch.device("foo"))
x = x.openreg(torch.device("openreg"))
self.assertFalse(self.module.custom_add_called())
self.assertTrue(x.is_foo)
self.assertTrue(x.is_openreg)
# test different device type input
y = torch.empty(4, 4)
self.assertFalse(y.is_foo)
self.assertFalse(y.is_openreg)
y = y.foo(torch.device("foo:0"))
y = y.openreg(torch.device("openreg:0"))
self.assertFalse(self.module.custom_add_called())
self.assertTrue(y.is_foo)
self.assertTrue(y.is_openreg)
# test different device type input
z = torch.empty(4, 4)
self.assertFalse(z.is_foo)
self.assertFalse(z.is_openreg)
z = z.foo(0)
z = z.openreg(0)
self.assertFalse(self.module.custom_add_called())
self.assertTrue(z.is_foo)
self.assertTrue(z.is_openreg)
def test_open_device_packed_sequence(self):
device = self.module.custom_device() # noqa: F841
a = torch.rand(5, 3)
b = torch.tensor([1, 1, 1, 1, 1])
input = torch.nn.utils.rnn.PackedSequence(a, b)
self.assertFalse(input.is_foo)
input_foo = input.foo()
self.assertTrue(input_foo.is_foo)
self.assertFalse(input.is_openreg)
input_openreg = input.openreg()
self.assertTrue(input_openreg.is_openreg)
def test_open_device_storage(self):
# check whether the attributes and methods for storage of the corresponding custom backend are generated correctly
x = torch.empty(4, 4)
z1 = x.storage()
self.assertFalse(z1.is_foo)
self.assertFalse(z1.is_openreg)
z1 = z1.foo()
z1 = z1.openreg()
self.assertFalse(self.module.custom_add_called())
self.assertTrue(z1.is_foo)
self.assertTrue(z1.is_openreg)
with self.assertRaisesRegex(RuntimeError, "Invalid device"):
z1.foo(torch.device("cpu"))
z1.openreg(torch.device("cpu"))
z1 = z1.cpu()
self.assertFalse(self.module.custom_add_called())
self.assertFalse(z1.is_foo)
self.assertFalse(z1.is_openreg)
z1 = z1.foo(device="foo:0", non_blocking=False)
z1 = z1.openreg(device="openreg:0", non_blocking=False)
self.assertFalse(self.module.custom_add_called())
self.assertTrue(z1.is_foo)
self.assertTrue(z1.is_openreg)
with self.assertRaisesRegex(RuntimeError, "Invalid device"):
z1.foo(device="cuda:0", non_blocking=False)
z1.openreg(device="cuda:0", non_blocking=False)
# check UntypedStorage
y = torch.empty(4, 4)
z2 = y.untyped_storage()
self.assertFalse(z2.is_foo)
self.assertFalse(z2.is_openreg)
z2 = z2.foo()
z2 = z2.openreg()
self.assertFalse(self.module.custom_add_called())
self.assertTrue(z2.is_foo)
self.assertTrue(z2.is_openreg)
# check custom StorageImpl create
self.module.custom_storage_registry()
@ -316,7 +310,7 @@ class TestCppExtensionOpenRgistration(common.TestCase):
z3 = y.untyped_storage()
self.assertFalse(self.module.custom_storageImpl_called())
z3 = z3.foo()
z3 = z3.openreg()
self.assertTrue(self.module.custom_storageImpl_called())
self.assertFalse(self.module.custom_storageImpl_called())
@ -331,50 +325,48 @@ class TestCppExtensionOpenRgistration(common.TestCase):
def test_open_device_storage_pin_memory(self):
# Check if the pin_memory is functioning properly on custom device
cpu_tensor = torch.empty(3)
self.assertFalse(cpu_tensor.is_foo)
self.assertFalse(cpu_tensor.is_pinned("foo"))
self.assertFalse(cpu_tensor.is_openreg)
self.assertFalse(cpu_tensor.is_pinned("openreg"))
cpu_tensor_pin = cpu_tensor.pin_memory("foo")
self.assertTrue(cpu_tensor_pin.is_pinned("foo"))
cpu_tensor_pin = cpu_tensor.pin_memory("openreg")
self.assertTrue(cpu_tensor_pin.is_pinned("openreg"))
# Test storage pin_memory and is_pin
cpu_storage = cpu_tensor.storage()
# We implement a dummy pin_memory of no practical significance
# for custom device. Once tensor.pin_memory() has been called,
# then tensor.is_pinned() will always return true no matter
# what tensor it's called on.
self.assertTrue(cpu_storage.is_pinned("foo"))
self.assertFalse(cpu_storage.is_pinned("openreg"))
cpu_storage_pinned = cpu_storage.pin_memory("foo")
self.assertTrue(cpu_storage_pinned.is_pinned("foo"))
cpu_storage_pinned = cpu_storage.pin_memory("openreg")
self.assertTrue(cpu_storage_pinned.is_pinned("openreg"))
# Test untyped storage pin_memory and is_pin
cpu_tensor = torch.randn([3, 2, 1, 4])
cpu_untyped_storage = cpu_tensor.untyped_storage()
self.assertTrue(cpu_untyped_storage.is_pinned("foo"))
self.assertFalse(cpu_untyped_storage.is_pinned("openreg"))
cpu_untyped_storage_pinned = cpu_untyped_storage.pin_memory("foo")
self.assertTrue(cpu_untyped_storage_pinned.is_pinned("foo"))
cpu_untyped_storage_pinned = cpu_untyped_storage.pin_memory("openreg")
self.assertTrue(cpu_untyped_storage_pinned.is_pinned("openreg"))
@unittest.skip(
"Temporarily disable due to the tiny differences between clang++ and g++ in defining static variable in inline function"
)
def test_open_device_serialization(self):
self.module.set_custom_device_index(-1)
storage = torch.UntypedStorage(4, device=torch.device("foo"))
self.assertEqual(torch.serialization.location_tag(storage), "foo")
storage = torch.UntypedStorage(4, device=torch.device("openreg"))
self.assertEqual(torch.serialization.location_tag(storage), "openreg")
self.module.set_custom_device_index(0)
storage = torch.UntypedStorage(4, device=torch.device("foo"))
self.assertEqual(torch.serialization.location_tag(storage), "foo:0")
storage = torch.UntypedStorage(4, device=torch.device("openreg"))
self.assertEqual(torch.serialization.location_tag(storage), "openreg:0")
cpu_storage = torch.empty(4, 4).storage()
foo_storage = torch.serialization.default_restore_location(cpu_storage, "foo:0")
self.assertTrue(foo_storage.is_foo)
openreg_storage = torch.serialization.default_restore_location(
cpu_storage, "openreg:0"
)
self.assertTrue(openreg_storage.is_openreg)
# test tensor MetaData serialization
x = torch.empty(4, 4).long()
y = x.foo()
y = x.openreg()
self.assertFalse(self.module.check_backend_meta(y))
self.module.custom_set_backend_meta(y)
self.assertTrue(self.module.check_backend_meta(y))
@ -384,30 +376,30 @@ class TestCppExtensionOpenRgistration(common.TestCase):
path = os.path.join(tmpdir, "data.pt")
torch.save(y, path)
z1 = torch.load(path)
# loads correctly onto the foo backend device
self.assertTrue(z1.is_foo)
# loads correctly onto the openreg backend device
self.assertTrue(z1.is_openreg)
# loads BackendMeta data correctly
self.assertTrue(self.module.check_backend_meta(z1))
# cross-backend
z2 = torch.load(path, map_location="cpu")
# loads correctly onto the cpu backend device
self.assertFalse(z2.is_foo)
self.assertFalse(z2.is_openreg)
# loads BackendMeta data correctly
self.assertFalse(self.module.check_backend_meta(z2))
def test_open_device_storage_resize(self):
cpu_tensor = torch.randn([8])
foo_tensor = cpu_tensor.foo()
foo_storage = foo_tensor.storage()
self.assertTrue(foo_storage.size() == 8)
openreg_tensor = cpu_tensor.openreg()
openreg_storage = openreg_tensor.storage()
self.assertTrue(openreg_storage.size() == 8)
# Only register tensor resize_ function.
foo_tensor.resize_(8)
self.assertTrue(foo_storage.size() == 8)
openreg_tensor.resize_(8)
self.assertTrue(openreg_storage.size() == 8)
with self.assertRaisesRegex(TypeError, "Overflow"):
foo_tensor.resize_(8**29)
openreg_tensor.resize_(8**29)
def test_open_device_storage_type(self):
# test cpu float storage
@ -416,9 +408,9 @@ class TestCppExtensionOpenRgistration(common.TestCase):
self.assertEqual(cpu_storage.type(), "torch.FloatStorage")
# test custom float storage before defining FloatStorage
foo_tensor = cpu_tensor.foo()
foo_storage = foo_tensor.storage()
self.assertEqual(foo_storage.type(), "torch.storage.TypedStorage")
openreg_tensor = cpu_tensor.openreg()
openreg_storage = openreg_tensor.storage()
self.assertEqual(openreg_storage.type(), "torch.storage.TypedStorage")
class CustomFloatStorage:
@property
@ -431,24 +423,24 @@ class TestCppExtensionOpenRgistration(common.TestCase):
# test custom float storage after defining FloatStorage
try:
torch.foo.FloatStorage = CustomFloatStorage()
self.assertEqual(foo_storage.type(), "torch.foo.FloatStorage")
torch.openreg.FloatStorage = CustomFloatStorage()
self.assertEqual(openreg_storage.type(), "torch.openreg.FloatStorage")
# test custom int storage after defining FloatStorage
foo_tensor2 = torch.randn([8]).int().foo()
foo_storage2 = foo_tensor2.storage()
self.assertEqual(foo_storage2.type(), "torch.storage.TypedStorage")
openreg_tensor2 = torch.randn([8]).int().openreg()
openreg_storage2 = openreg_tensor2.storage()
self.assertEqual(openreg_storage2.type(), "torch.storage.TypedStorage")
finally:
torch.foo.FloatStorage = None
torch.openreg.FloatStorage = None
def test_open_device_faketensor(self):
with torch._subclasses.fake_tensor.FakeTensorMode.push():
a = torch.empty(1, device="foo")
b = torch.empty(1, device="foo:0")
a = torch.empty(1, device="openreg")
b = torch.empty(1, device="openreg:0")
result = a + b # noqa: F841
def test_open_device_named_tensor(self):
torch.empty([2, 3, 4, 5], device="foo", names=["N", "C", "H", "W"])
torch.empty([2, 3, 4, 5], device="openreg", names=["N", "C", "H", "W"])
# Not an open registration test - this file is just very convenient
# for testing torch.compile on custom C++ operators
@ -483,13 +475,13 @@ class TestCppExtensionOpenRgistration(common.TestCase):
def test_open_device_scalar_type_fallback(self):
z_cpu = torch.Tensor([[0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]]).to(torch.int64)
z = torch.triu_indices(3, 3, device="foo")
z = torch.triu_indices(3, 3, device="openreg")
self.assertEqual(z_cpu, z)
def test_open_device_tensor_type_fallback(self):
# create tensors located in custom device
x = torch.Tensor([[1, 2, 3], [2, 3, 4]]).to("foo")
y = torch.Tensor([1, 0, 2]).to("foo")
x = torch.Tensor([[1, 2, 3], [2, 3, 4]]).to("openreg")
y = torch.Tensor([1, 0, 2]).to("openreg")
# create result tensor located in cpu
z_cpu = torch.Tensor([[0, 2, 1], [1, 3, 2]])
# Check that our device is correct.
@ -503,22 +495,22 @@ class TestCppExtensionOpenRgistration(common.TestCase):
# call index op, which will fallback to cpu
z_cpu = torch.Tensor([3, 1])
y = torch.Tensor([1, 0]).long().to("foo")
y = torch.Tensor([1, 0]).long().to("openreg")
z = x[y, y]
self.assertEqual(z_cpu, z)
def test_open_device_tensorlist_type_fallback(self):
# create tensors located in custom device
v_foo = torch.Tensor([1, 2, 3]).to("foo")
v_openreg = torch.Tensor([1, 2, 3]).to("openreg")
# create result tensor located in cpu
z_cpu = torch.Tensor([2, 4, 6])
# create tensorlist for foreach_add op
x = (v_foo, v_foo)
y = (v_foo, v_foo)
x = (v_openreg, v_openreg)
y = (v_openreg, v_openreg)
# Check that our device is correct.
device = self.module.custom_device()
self.assertTrue(v_foo.device == device)
self.assertFalse(v_foo.is_cpu)
self.assertTrue(v_openreg.device == device)
self.assertFalse(v_openreg.is_cpu)
# call _foreach_add op, which will fallback to cpu
z = torch._foreach_add(x, y)
@ -528,6 +520,7 @@ class TestCppExtensionOpenRgistration(common.TestCase):
# call _fused_adamw_ with undefined tensor.
self.module.fallback_with_undefined_tensor()
@skipIfTorchDynamo()
@unittest.skipIf(
np.__version__ < "1.25",
"versions < 1.25 serialize dtypes differently from how it's serialized in data_legacy_numpy",
@ -536,7 +529,6 @@ class TestCppExtensionOpenRgistration(common.TestCase):
"""
This tests the legacy _rebuild_device_tensor_from_numpy serialization path
"""
torch.utils.rename_privateuse1_backend("foo")
device = self.module.custom_device()
# Legacy data saved with _rebuild_device_tensor_from_numpy on f80ed0b8 via
@ -561,24 +553,24 @@ class TestCppExtensionOpenRgistration(common.TestCase):
b"\x03\x86q\rcnumpy\ndtype\nq\x0eX\x02\x00\x00\x00f4q\x0f\x89\x88\x87q\x10Rq\x11(K\x03X\x01"
b"\x00\x00\x00<q\x12NNNJ\xff\xff\xff\xffJ\xff\xff\xff\xffK\x00tq\x13b\x89h\x06X\x1c\x00\x00"
b"\x00\x00\x00\xc2\x80?\x00\x00\x00@\x00\x00@@\x00\x00\xc2\x80@\x00\x00\xc2\xa0@\x00\x00\xc3"
b"\x80@q\x14h\x08\x86q\x15Rq\x16tq\x17bctorch\nfloat32\nq\x18X\x05\x00\x00\x00foo:0q\x19\x89"
b"tq\x1aRq\x1bs.PK\x07\x08\xe3\xe4\x86\xecO\x01\x00\x00O\x01\x00\x00PK\x03\x04\x00\x00\x08"
b"\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11\x002\x00"
b"archive/byteorderFB.\x00ZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZlittlePK\x07\x08"
b"\x80@q\x14h\x08\x86q\x15Rq\x16tq\x17bctorch\nfloat32\nq\x18X\t\x00\x00\x00openreg:0q\x19\x89"
b"tq\x1aRq\x1bs.PK\x07\x08\xdfE\xd6\xcaS\x01\x00\x00S\x01\x00\x00PK\x03\x04\x00\x00\x08"
b"\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11\x00.\x00"
b"archive/byteorderFB*\x00ZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZlittlePK\x07\x08"
b"\x85=\xe3\x19\x06\x00\x00\x00\x06\x00\x00\x00PK\x03\x04\x00\x00\x08\x08\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0f\x00=\x00archive/versionFB9\x00"
b"ZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ3\nPK\x07\x08\xd1\x9egU\x02\x00\x00"
b"\x00\x02\x00\x00\x00PK\x03\x04\x00\x00\x08\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x1e\x002\x00archive/.data/serialization_idFB.\x00ZZZZZZZZZZZZZ"
b"ZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ0636457737946401051300000027264370494161PK\x07\x08\x91\xbf"
b"\xa7\x0c(\x00\x00\x00(\x00\x00\x00PK\x01\x02\x00\x00\x00\x00\x08\x08\x00\x00\x00\x00\x00\x00"
b"\xe3\xe4\x86\xecO\x01\x00\x00O\x01\x00\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"ZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ0636457737946401051300000025273995036293PK\x07\x08\xee(\xcd"
b"\x8d(\x00\x00\x00(\x00\x00\x00PK\x01\x02\x00\x00\x00\x00\x08\x08\x00\x00\x00\x00\x00\x00"
b"\xdfE\xd6\xcaS\x01\x00\x00S\x01\x00\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00archive/data.pklPK\x01\x02\x00\x00\x00\x00\x08\x08\x00\x00\x00\x00"
b"\x00\x00\x85=\xe3\x19\x06\x00\x00\x00\x06\x00\x00\x00\x11\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x9f\x01\x00\x00archive/byteorderPK\x01\x02\x00\x00\x00\x00\x08\x08\x00"
b"\x00\x00\x00\x00\x00\xa3\x01\x00\x00archive/byteorderPK\x01\x02\x00\x00\x00\x00\x08\x08\x00"
b"\x00\x00\x00\x00\x00\xd1\x9egU\x02\x00\x00\x00\x02\x00\x00\x00\x0f\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x16\x02\x00\x00archive/versionPK\x01\x02\x00\x00\x00\x00\x08"
b"\x08\x00\x00\x00\x00\x00\x00\x91\xbf\xa7\x0c(\x00\x00\x00(\x00\x00\x00\x1e\x00\x00\x00\x00"
b"\x08\x00\x00\x00\x00\x00\x00\xee(\xcd\x8d(\x00\x00\x00(\x00\x00\x00\x1e\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x92\x02\x00\x00archive/.data/serialization_idPK\x06"
b"\x06,\x00\x00\x00\x00\x00\x00\x00\x1e\x03-\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00"
b"\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x06\x01\x00\x00\x00\x00\x00\x008\x03\x00"
@ -612,15 +604,15 @@ class TestCppExtensionOpenRgistration(common.TestCase):
self.assertTrue(sd_loaded_cpu["x"].is_cpu)
def test_open_device_cpu_serialization(self):
torch.utils.rename_privateuse1_backend("foo")
torch.utils.rename_privateuse1_backend("openreg")
device = self.module.custom_device()
default_protocol = torch.serialization.DEFAULT_PROTOCOL
with patch.object(torch._C, "_has_storage", return_value=False):
x = torch.randn(2, 3)
x_foo = x.to(device)
sd = {"x": x_foo}
rebuild_func = x_foo._reduce_ex_internal(default_protocol)[0]
x_openreg = x.to(device)
sd = {"x": x_openreg}
rebuild_func = x_openreg._reduce_ex_internal(default_protocol)[0]
self.assertTrue(
rebuild_func is torch._utils._rebuild_device_tensor_from_cpu_tensor
)
@ -644,7 +636,7 @@ class TestCppExtensionOpenRgistration(common.TestCase):
torch.save(sd, f)
def test_open_device_dlpack(self):
t = torch.randn(2, 3).to("foo")
t = torch.randn(2, 3).to("openreg")
capsule = torch.utils.dlpack.to_dlpack(t)
t1 = torch.from_dlpack(capsule)
self.assertTrue(t1.device == t.device)

View File

@ -56,11 +56,6 @@ from torch.testing._internal.common_cuda import (
tf32_enabled,
)
if not IS_FBCODE:
from test_cpp_extensions_open_device_registration import (
generate_faked_module
)
if TEST_FAIRSEQ:
import fairseq.models.transformer as fairseq_transformer
@ -4016,6 +4011,8 @@ class TestAttnBias(NNTestCase):
class TestSDPAPrivateUse1Only(NNTestCase):
@classmethod
def setUpClass(cls):
import pytorch_openreg # noqa: F401
torch.testing._internal.common_utils.remove_cpp_extensions_build_root()
cls.module = torch.utils.cpp_extension.load(
name="custom_device_extension",
@ -4026,10 +4023,6 @@ class TestSDPAPrivateUse1Only(NNTestCase):
extra_cflags=["-g"],
verbose=True,
)
# register torch.foo module and foo device to torch
torch.utils.rename_privateuse1_backend("foo")
torch.utils.generate_methods_for_privateuse1_backend(for_storage=True)
torch._register_device_module("foo", generate_faked_module())
@skipIfTorchDynamo()
def test_fused_sdp_choice_privateuseone(self):
@ -4037,9 +4030,9 @@ class TestSDPAPrivateUse1Only(NNTestCase):
make_tensor = partial(torch.rand, device="cpu", dtype=torch.float16)
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
q_privateuse1 = q_cpu.to("foo")
k_privateuse1 = k_cpu.to("foo")
v_privateuse1 = v_cpu.to("foo")
q_privateuse1 = q_cpu.to("openreg")
k_privateuse1 = k_cpu.to("openreg")
v_privateuse1 = v_cpu.to("openreg")
assert torch._fused_sdp_choice(q_privateuse1, k_privateuse1, v_privateuse1) == SDPBackend.OVERRIDEABLE.value
def test_scaled_dot_product_fused_attention_overrideable(self):
@ -4047,9 +4040,9 @@ class TestSDPAPrivateUse1Only(NNTestCase):
make_tensor = partial(torch.rand, device="cpu", dtype=torch.float16)
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
q_privateuse1 = q_cpu.to("foo")
k_privateuse1 = k_cpu.to("foo")
v_privateuse1 = v_cpu.to("foo")
q_privateuse1 = q_cpu.to("openreg")
k_privateuse1 = k_cpu.to("openreg")
v_privateuse1 = v_cpu.to("openreg")
torch.nn.functional.scaled_dot_product_attention(
q_privateuse1, k_privateuse1, v_privateuse1, attn_mask=None, dropout_p=0.0)
@ -4059,16 +4052,16 @@ class TestSDPAPrivateUse1Only(NNTestCase):
shape = (batch_size, num_heads, seq_len, head_dim)
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
attn_mask = make_tensor((batch_size, num_heads, seq_len, seq_len))
q_privateuse1 = q_cpu.to("foo")
k_privateuse1 = k_cpu.to("foo")
v_privateuse1 = v_cpu.to("foo")
attn_mask_privateuse1 = attn_mask.to("foo")
q_privateuse1 = q_cpu.to("openreg")
k_privateuse1 = k_cpu.to("openreg")
v_privateuse1 = v_cpu.to("openreg")
attn_mask_privateuse1 = attn_mask.to("openreg")
output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = \
torch.ops.aten._scaled_dot_product_fused_attention_overrideable(
q_privateuse1, k_privateuse1, v_privateuse1, attn_bias=attn_mask_privateuse1)
rand_upward = torch.rand(shape, device="cpu", dtype=torch.float16, requires_grad=False)
rand_upward_privateuse1 = rand_upward.to("foo")
rand_upward_privateuse1 = rand_upward.to("openreg")
grad_input_mask = [True, True, True, True]
grad_q, grad_k, grad_v, grad_attn_mask = torch.ops.aten._scaled_dot_product_fused_attention_overrideable_backward(
rand_upward_privateuse1, q_privateuse1, k_privateuse1, v_privateuse1, attn_mask_privateuse1,

View File

@ -3,6 +3,7 @@
#include <c10/core/Stream.h>
#include <c10/macros/Export.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/python_headers.h>
struct THPStream {
@ -21,6 +22,6 @@ inline bool THPStream_Check(PyObject* obj) {
return THPStreamClass && PyObject_IsInstance(obj, (PyObject*)THPStreamClass);
}
PyObject* THPStream_Wrap(const c10::Stream& stream);
TORCH_PYTHON_API PyObject* THPStream_Wrap(const c10::Stream& stream);
#endif // THP_STREAM_INC