mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
---- - fake tensor - named tensor - custom autograd function Pull Request resolved: https://github.com/pytorch/pytorch/pull/156588 Approved by: https://github.com/albanD
153 lines
6.6 KiB
C++
153 lines
6.6 KiB
C++
#include <c10/core/Allocator.h>
|
|
#include <c10/core/ScalarType.h>
|
|
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
|
#include <c10/core/impl/alloc_cpu.h>
|
|
#include <c10/macros/Macros.h>
|
|
#include <c10/util/ArrayRef.h>
|
|
|
|
#include <torch/csrc/Device.h>
|
|
#include <torch/csrc/jit/serialization/pickler.h>
|
|
#include <torch/extension.h>
|
|
|
|
#include <ATen/EmptyTensor.h>
|
|
#include <ATen/detail/PrivateUse1HooksInterface.h>
|
|
#include <ATen/native/CPUFallback.h>
|
|
#include <ATen/native/DispatchStub.h>
|
|
#include <ATen/native/Resize.h>
|
|
#include <ATen/native/UnaryOps.h>
|
|
#include <ATen/native/cpu/Loops.h>
|
|
#include <ATen/native/transformers/attention.h>
|
|
#include <ATen/native/transformers/sdp_utils_cpp.h>
|
|
#include <ATen/ops/view.h>
|
|
|
|
#include <unordered_map>
|
|
|
|
static uint64_t add_counter = 0;
|
|
static uint64_t last_saved_value = 0;
|
|
|
|
static uint64_t storageImpl_counter = 0;
|
|
static uint64_t last_storageImpl_saved_value = 0;
|
|
|
|
// A dummy storageImpl for our custom device, that secretly uses the CPU
|
|
c10::intrusive_ptr<c10::StorageImpl> make_custom_storage_impl(c10::StorageImpl::use_byte_size_t,
|
|
c10::SymInt size_bytes,
|
|
c10::DataPtr data_ptr,
|
|
c10::Allocator* allocator,
|
|
bool resizable) {
|
|
c10::intrusive_ptr<c10::StorageImpl> custom_storage_impl;
|
|
if (data_ptr == nullptr){
|
|
custom_storage_impl = c10::make_intrusive<c10::StorageImpl>(
|
|
c10::StorageImpl::use_byte_size_t(), size_bytes, allocator, resizable);
|
|
} else {
|
|
custom_storage_impl = c10::make_intrusive<c10::StorageImpl>(
|
|
c10::StorageImpl::use_byte_size_t(), size_bytes, std::move(data_ptr), allocator, resizable);
|
|
}
|
|
storageImpl_counter += 1;
|
|
return custom_storage_impl;
|
|
}
|
|
|
|
// Register our dummy storageImpl create method.
|
|
void custom_storage_registry() {
|
|
c10::SetStorageImplCreate(c10::DeviceType::PrivateUse1, &make_custom_storage_impl);
|
|
}
|
|
|
|
bool custom_storageImpl_called() {
|
|
if (storageImpl_counter > last_storageImpl_saved_value) {
|
|
last_storageImpl_saved_value = storageImpl_counter;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
// basic dummy add function
|
|
at::Tensor custom_add_Tensor(const at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha) {
|
|
add_counter += 1;
|
|
// Since this custom device is just for testing, not bothering to implement kernels.
|
|
return at::empty(self.sizes(), self.options());
|
|
}
|
|
|
|
at::Tensor custom__copy_from_and_resize(const at::Tensor& self, const at::Tensor& dst) {
|
|
return dst.copy_(self, false);
|
|
}
|
|
|
|
// Some set operations for the basic use case
|
|
at::Tensor& custom_set_source_Storage(at::Tensor& result, c10::Storage src) {
|
|
int64_t new_size = static_cast<int64_t>(src.nbytes() / result.dtype().itemsize());
|
|
c10::IntArrayRef stride = {};
|
|
result.unsafeGetTensorImpl()->set_storage_offset(0);
|
|
at::OptionalIntArrayRef stride_opt = stride.data() != nullptr ? at::OptionalIntArrayRef(stride) : std::nullopt;
|
|
at::native::resize_impl_cpu_(result.unsafeGetTensorImpl(),
|
|
new_size, stride_opt,
|
|
/*resize_storage=*/!result.is_meta());
|
|
return result;
|
|
}
|
|
|
|
// This macro does the heavy lifting.
|
|
// With TORCH_LIBRARY_IMPL, you can register custom kernels for your backend.
|
|
// For open registration, we're registering all of our kernels to the PrivateUse1 dispatch key.
|
|
// Later in this file, we map a custom device to the PrivateUse1 device type,
|
|
// which allows user code that puts a tensor on your custom_device to eventually get plumbed
|
|
// into the kernels registered here.
|
|
//
|
|
// This macro registers your kernels to the PyTorch Dispatcher.
|
|
// More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/.
|
|
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
|
|
m.impl("add.Tensor", &custom_add_Tensor);
|
|
m.impl("_copy_from_and_resize", &custom__copy_from_and_resize);
|
|
m.impl("set_.source_Storage", &custom_set_source_Storage);
|
|
}
|
|
|
|
void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
|
at::native::cpu_fallback(op, stack);
|
|
}
|
|
|
|
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
|
|
m.impl("_foreach_add.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
|
|
m.impl("_fused_adamw_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
|
|
m.impl("triu_indices", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
|
|
}
|
|
|
|
// This basic implementation doesn't bother dealing with different device indices
|
|
// (e.g. custom_device:0 vs. custom_device:1).
|
|
// We could do that by letting the user pass in a device index in our exposed device function.
|
|
// Note that if you do that, you'll also need to register a device guard to core.
|
|
// See `c10/core/impl/DeviceGuardImplInterface.h:C10_REGISTER_GUARD_IMPL`.
|
|
c10::Device get_custom_device() {
|
|
return c10::Device(c10::DeviceType::PrivateUse1, 0);
|
|
}
|
|
|
|
bool custom_add_called() {
|
|
bool called = false;
|
|
if (add_counter > last_saved_value) {
|
|
called = true;
|
|
last_saved_value = add_counter;
|
|
}
|
|
return called;
|
|
}
|
|
|
|
void fallback_with_undefined_tensor() {
|
|
at::Tensor first = at::empty({2, 3}).to(at::DeviceType::PrivateUse1);
|
|
at::Tensor second = at::Tensor();
|
|
at::Tensor step = at::empty({}).fill_(2).to(at::DeviceType::PrivateUse1);
|
|
at::Tensor grad_scale = at::empty({}).fill_(0.00001).to(at::DeviceType::PrivateUse1);
|
|
at::Tensor found_inf = at::empty({}).fill_(1).to(at::DeviceType::PrivateUse1);
|
|
at::TensorList tensors = {first, first};
|
|
at::TensorList undefined_tensors = {first, second};
|
|
at::TensorList steps = {step, step};
|
|
return at::_fused_adamw_(tensors, tensors, tensors, tensors, undefined_tensors,
|
|
steps, 0.001, 0.9, 0.999, 1e-2, 1e-8, false, false,
|
|
grad_scale, found_inf);
|
|
}
|
|
|
|
// Here, we're exposing a custom device object that corresponds to our custom backend.
|
|
// We do this using pybind: exposing an "extension_name.custom_device()" function in python,
|
|
// that's implemented in C++.
|
|
// The implementation in this file maps directly to the `PrivateUse1` device type.
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
m.def("custom_device", &get_custom_device, "get custom device object");
|
|
m.def("custom_add_called", &custom_add_called, "check if our custom add function was called");
|
|
m.def("custom_storage_registry", &custom_storage_registry, "set custom storageImpl creat method");
|
|
m.def("custom_storageImpl_called", &custom_storageImpl_called, "check if our custom abs function was called");
|
|
m.def("fallback_with_undefined_tensor", &fallback_with_undefined_tensor, "fallback_with_undefined_tensor for privateuse1");
|
|
}
|