mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #148510. Pull Request resolved: https://github.com/pytorch/pytorch/pull/148511 Approved by: https://github.com/albanD
144 lines
4.9 KiB
C++
144 lines
4.9 KiB
C++
#include <torch/extension.h>
|
|
#include <torch/library.h>
|
|
|
|
using namespace at;
|
|
|
|
static int test_int;
|
|
|
|
Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) {
|
|
auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
|
|
Storage(
|
|
Storage::use_byte_size_t(),
|
|
0,
|
|
at::DataPtr(nullptr, Device(DeviceType::MAIA, 0)),
|
|
nullptr,
|
|
false),
|
|
DispatchKey::MAIA,
|
|
dtype);
|
|
// This is a hack to workaround the shape checks in _convolution.
|
|
tensor_impl->set_sizes_contiguous(size);
|
|
return Tensor(std::move(tensor_impl));
|
|
}
|
|
|
|
Tensor empty_override(IntArrayRef size, std::optional<ScalarType> dtype, std::optional<Layout> layout, std::optional<Device> device,
|
|
std::optional<bool> pin_memory, std::optional<c10::MemoryFormat> optional_memory_format) {
|
|
test_int = 0;
|
|
return get_tensor(scalarTypeToTypeMeta(dtype_or_default(dtype)), size);
|
|
}
|
|
|
|
Tensor& add_out_override(const Tensor & a, const Tensor & b , const Scalar& c, Tensor & out) {
|
|
test_int = 1;
|
|
return out;
|
|
}
|
|
|
|
Tensor fake_convolution(
|
|
const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias,
|
|
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
|
|
bool transposed, IntArrayRef output_padding, int64_t groups) {
|
|
test_int = 2;
|
|
// Only the first 2 dimension of output shape is correct.
|
|
return get_tensor(input.dtype(), {input.size(0), weight.size(0), input.size(2), input.size(3)});
|
|
}
|
|
|
|
std::tuple<Tensor,Tensor,Tensor> fake_convolution_backward(
|
|
const Tensor & grad_output, const Tensor & input, const Tensor & weight,
|
|
IntArrayRef stride, IntArrayRef padding,
|
|
IntArrayRef dilation, bool transposed, IntArrayRef output_padding,
|
|
int64_t groups, std::array<bool,3> output_mask) {
|
|
test_int = 3;
|
|
return std::tuple<Tensor, Tensor, Tensor>(
|
|
get_tensor(input.dtype(), input.sizes()),
|
|
get_tensor(weight.dtype(), weight.sizes()),
|
|
get_tensor(input.dtype(), {}));
|
|
}
|
|
|
|
at::Tensor maia_to_dtype_override(
|
|
const at::Tensor & self, at::ScalarType dtype, bool non_blocking,
|
|
bool copy, ::std::optional<at::MemoryFormat> memory_format
|
|
) {
|
|
return get_tensor(scalarTypeToTypeMeta(dtype), self.sizes());
|
|
}
|
|
|
|
at::Tensor maia_matmul_override(const at::Tensor & self, const at::Tensor & other) {
|
|
AT_ASSERT(self.dim() == 2);
|
|
AT_ASSERT(other.dim() == 2);
|
|
AT_ASSERT(self.dtype() == other.dtype());
|
|
AT_ASSERT(self.device() == other.device());
|
|
return get_tensor(self.dtype(), {self.size(0), other.size(1)});
|
|
}
|
|
|
|
TORCH_LIBRARY_IMPL(aten, MAIA, m) {
|
|
m.impl("empty.memory_format", empty_override);
|
|
m.impl("add.out", add_out_override);
|
|
m.impl("convolution_overrideable", fake_convolution);
|
|
m.impl("convolution_backward_overrideable", fake_convolution_backward);
|
|
m.impl("to.dtype", maia_to_dtype_override);
|
|
m.impl("matmul", maia_matmul_override);
|
|
}
|
|
|
|
// TODO: Extend this to exercise multi-device setting. In that case,
|
|
// we need to add a thread local variable to track the current device.
|
|
struct MAIAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
|
static constexpr DeviceType static_type = DeviceType::MAIA;
|
|
MAIAGuardImpl() {}
|
|
MAIAGuardImpl(DeviceType t) {
|
|
AT_ASSERT(t == DeviceType::MAIA);
|
|
}
|
|
DeviceType type() const override {
|
|
return DeviceType::MAIA;
|
|
}
|
|
Device exchangeDevice(Device d) const override {
|
|
AT_ASSERT(d.type() == DeviceType::MAIA);
|
|
AT_ASSERT(d.index() == 0);
|
|
return d;
|
|
}
|
|
Device getDevice() const override {
|
|
return Device(DeviceType::MAIA, 0);
|
|
}
|
|
void setDevice(Device d) const override {
|
|
AT_ASSERT(d.type() == DeviceType::MAIA);
|
|
AT_ASSERT(d.index() == 0);
|
|
}
|
|
void uncheckedSetDevice(Device d) const noexcept override {
|
|
}
|
|
Stream getStream(Device d) const noexcept override {
|
|
return Stream(Stream::DEFAULT, Device(DeviceType::MAIA, 0));
|
|
}
|
|
Stream exchangeStream(Stream s) const noexcept override {
|
|
return Stream(Stream::DEFAULT, Device(DeviceType::MAIA, 0));
|
|
}
|
|
DeviceIndex deviceCount() const noexcept override {
|
|
return 1;
|
|
}
|
|
|
|
// Event-related functions
|
|
void record(void** event,
|
|
const Stream& stream,
|
|
const DeviceIndex device_index,
|
|
const EventFlag flag) const override {
|
|
TORCH_CHECK(false, "MAIA backend doesn't support events.");
|
|
}
|
|
void block(
|
|
void* event,
|
|
const Stream& stream) const override {
|
|
TORCH_CHECK(false, "MAIA backend doesn't support events.");
|
|
}
|
|
bool queryEvent(void* event) const override {
|
|
TORCH_CHECK(false, "MAIA backend doesn't support events.");
|
|
}
|
|
void destroyEvent(
|
|
void* event,
|
|
const DeviceIndex device_index) const noexcept override { }
|
|
};
|
|
|
|
constexpr DeviceType MAIAGuardImpl::static_type;
|
|
C10_REGISTER_GUARD_IMPL(MAIA, MAIAGuardImpl);
|
|
|
|
int get_test_int() {
|
|
return test_int;
|
|
}
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
m.def("get_test_int", &get_test_int);
|
|
}
|