mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62336 This PR was generated by removing `const` for all types of nodes in NNC IR, and fixing compilation errors that were the result of this change. This is the first step in making all NNC mutations in-place. Test Plan: Imported from OSS Reviewed By: iramazanli Differential Revision: D30049829 Pulled By: navahgar fbshipit-source-id: ed14e2d2ca0559ffc0b92ac371f405579c85dd63
276 lines
7.4 KiB
C++
276 lines
7.4 KiB
C++
#include <torch/csrc/jit/tensorexpr/external_functions.h>
|
|
|
|
#include <ATen/Functions.h>
|
|
#include <ATen/NativeFunctions.h>
|
|
#include <ATen/core/dispatch/Dispatcher.h>
|
|
#include <ATen/native/xnnpack/OpContext.h>
|
|
#include <c10/util/irange.h>
|
|
#include <torch/csrc/jit/tensorexpr/external_functions_registry.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace tensorexpr {
|
|
|
|
std::vector<at::Tensor> constructTensors(
|
|
int64_t bufs_num,
|
|
void** buf_data,
|
|
int64_t* buf_ranks,
|
|
int64_t* buf_dims,
|
|
int8_t* buf_dtypes) {
|
|
std::vector<void*> buf_data_vec;
|
|
std::vector<std::vector<int64_t>> buf_dims_vec;
|
|
std::vector<c10::ScalarType> buf_dtypes_vec;
|
|
int64_t buf_dims_idx = 0;
|
|
for (auto i : c10::irange(bufs_num)) {
|
|
buf_data_vec.push_back(buf_data[i]);
|
|
buf_dims_vec.emplace_back();
|
|
// NOLINTNEXTLINE(clang-diagnostic-unused-variable,clang-analyzer-deadcode.DeadStores)
|
|
for (auto dim : c10::irange(buf_ranks[i])) {
|
|
buf_dims_vec[i].push_back(buf_dims[buf_dims_idx++]);
|
|
}
|
|
buf_dtypes_vec.push_back(static_cast<c10::ScalarType>(buf_dtypes[i]));
|
|
}
|
|
|
|
std::vector<at::Tensor> tensors;
|
|
for (auto i : c10::irange(buf_data_vec.size())) {
|
|
auto options = at::TensorOptions()
|
|
.dtype(buf_dtypes_vec[i])
|
|
.layout(at::kStrided)
|
|
.device(at::kCPU) // TODO: support GPUs too
|
|
.requires_grad(false);
|
|
tensors.emplace_back(
|
|
at::from_blob(buf_data_vec[i], buf_dims_vec[i], options));
|
|
}
|
|
return tensors;
|
|
}
|
|
|
|
#ifdef C10_MOBILE
|
|
extern "C" {
|
|
#endif
|
|
|
|
void nnc_aten_conv2d(
|
|
int64_t bufs_num,
|
|
void** buf_data,
|
|
int64_t* buf_ranks,
|
|
int64_t* buf_dims,
|
|
int8_t* buf_dtypes,
|
|
int64_t args_num,
|
|
int64_t* extra_args) {
|
|
std::vector<at::Tensor> tensors =
|
|
constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes);
|
|
|
|
at::Tensor& r = tensors[0];
|
|
const at::Tensor& x = tensors[1];
|
|
const at::Tensor& w = tensors[2];
|
|
if (args_num > 0) {
|
|
// Check that if the extra arguments are provided, then the bias tensor is
|
|
// also present
|
|
TORCH_INTERNAL_ASSERT(args_num == 7 && bufs_num == 4);
|
|
const at::Tensor& b = tensors[3];
|
|
|
|
int64_t strideH = extra_args[0];
|
|
int64_t strideW = extra_args[1];
|
|
int64_t paddingH = extra_args[2];
|
|
int64_t paddingW = extra_args[3];
|
|
int64_t dilationH = extra_args[4];
|
|
int64_t dilationW = extra_args[5];
|
|
int64_t groups = extra_args[6];
|
|
|
|
try {
|
|
r = at::conv2d(
|
|
x,
|
|
w,
|
|
b,
|
|
{strideH, strideW},
|
|
{paddingH, paddingW},
|
|
{dilationH, dilationW},
|
|
groups);
|
|
} catch (...) {
|
|
}
|
|
} else {
|
|
try {
|
|
r = at::conv2d(x, w);
|
|
} catch (...) {
|
|
}
|
|
}
|
|
|
|
// TODO: can i haz an out version of the conv2d?
|
|
memcpy(buf_data[0], r.data_ptr(), r.element_size() * r.numel());
|
|
}
|
|
|
|
void nnc_aten_adaptive_avg_pool2d(
|
|
int64_t bufs_num,
|
|
void** buf_data,
|
|
int64_t* buf_ranks,
|
|
int64_t* buf_dims,
|
|
int8_t* buf_dtypes,
|
|
int64_t args_num,
|
|
int64_t* extra_args) {
|
|
std::vector<at::Tensor> tensors =
|
|
constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes);
|
|
|
|
at::Tensor& r = tensors[0];
|
|
const at::Tensor& x = tensors[1];
|
|
int64_t H = extra_args[0];
|
|
int64_t W = H;
|
|
if (args_num > 1) {
|
|
W = extra_args[1];
|
|
}
|
|
try {
|
|
at::adaptive_avg_pool2d_out(r, x, {H, W});
|
|
} catch (...) {
|
|
}
|
|
}
|
|
|
|
void nnc_aten_mean(
|
|
int64_t bufs_num,
|
|
void** buf_data,
|
|
int64_t* buf_ranks,
|
|
int64_t* buf_dims,
|
|
int8_t* buf_dtypes,
|
|
int64_t args_num,
|
|
int64_t* extra_args) {
|
|
std::vector<at::Tensor> tensors =
|
|
constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes);
|
|
|
|
at::Tensor& r = tensors[0];
|
|
const at::Tensor& x = tensors[1];
|
|
std::vector<int64_t> mean_dims(args_num);
|
|
if (args_num > 0) {
|
|
memcpy(mean_dims.data(), extra_args, sizeof(int64_t) * args_num);
|
|
}
|
|
try {
|
|
at::mean_out(r, x, mean_dims);
|
|
} catch (...) {
|
|
}
|
|
}
|
|
|
|
void nnc_aten_addmm(
|
|
int64_t bufs_num,
|
|
void** buf_data,
|
|
int64_t* buf_ranks,
|
|
int64_t* buf_dims,
|
|
int8_t* buf_dtypes,
|
|
int64_t args_num,
|
|
int64_t* extra_args) {
|
|
std::vector<at::Tensor> tensors =
|
|
constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes);
|
|
|
|
at::Tensor& r = tensors[0];
|
|
const at::Tensor& x = tensors[1];
|
|
const at::Tensor& y = tensors[2];
|
|
const at::Tensor& z = tensors[3];
|
|
// TODO: handle other alpha and beta dtypes, e.g. alpha=0.6, beta=0.2
|
|
int64_t alpha = extra_args[0], beta = extra_args[1];
|
|
|
|
try {
|
|
at::addmm_out(r, x, y, z, alpha, beta);
|
|
} catch (...) {
|
|
}
|
|
}
|
|
|
|
// Only provides first output, the second output is just a copy of one of the
|
|
// inputs
|
|
void nnc_aten_triangular_solve(
|
|
int64_t bufs_num,
|
|
void** buf_data,
|
|
int64_t* buf_ranks,
|
|
int64_t* buf_dims,
|
|
int8_t* buf_dtypes,
|
|
int64_t args_num,
|
|
int64_t* extra_args) {
|
|
std::vector<at::Tensor> tensors =
|
|
constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes);
|
|
at::Tensor& r = tensors[0];
|
|
at::Tensor r2 = tensors[2].clone();
|
|
const at::Tensor& input = tensors[1];
|
|
const at::Tensor& A = tensors[2];
|
|
try {
|
|
at::triangular_solve_out(
|
|
r, r2, input, A, extra_args[0], extra_args[2], extra_args[3]);
|
|
} catch (...) {
|
|
}
|
|
}
|
|
|
|
#ifdef USE_XNNPACK
|
|
|
|
void nnc_prepacked_linear_clamp_run(
|
|
int64_t bufs_num,
|
|
void** buf_data,
|
|
int64_t* buf_ranks,
|
|
int64_t* buf_dims,
|
|
int8_t* buf_dtypes,
|
|
int64_t args_num,
|
|
int64_t* extra_args) {
|
|
using namespace at::native::xnnpack;
|
|
|
|
std::vector<at::Tensor> tensors =
|
|
constructTensors(bufs_num - 1, buf_data, buf_ranks, buf_dims, buf_dtypes);
|
|
|
|
const at::Tensor& x = tensors[1];
|
|
auto context = reinterpret_cast<LinearOpContext*>(buf_data[2]);
|
|
at::Tensor output = context->run(x);
|
|
memcpy(
|
|
buf_data[0], output.data_ptr(), output.element_size() * output.numel());
|
|
}
|
|
|
|
void nnc_prepacked_conv2d_clamp_run(
|
|
int64_t bufs_num,
|
|
void** buf_data,
|
|
int64_t* buf_ranks,
|
|
int64_t* buf_dims,
|
|
int8_t* buf_dtypes,
|
|
int64_t args_num,
|
|
int64_t* extra_args) {
|
|
using namespace at::native::xnnpack;
|
|
|
|
std::vector<at::Tensor> tensors =
|
|
constructTensors(bufs_num - 1, buf_data, buf_ranks, buf_dims, buf_dtypes);
|
|
|
|
const at::Tensor& x = tensors[1];
|
|
auto context = reinterpret_cast<Conv2dOpContext*>(buf_data[2]);
|
|
at::Tensor output = context->run(x);
|
|
memcpy(
|
|
buf_data[0], output.data_ptr(), output.element_size() * output.numel());
|
|
}
|
|
|
|
#endif // USE_XNNPACK
|
|
|
|
#ifndef C10_MOBILE
|
|
|
|
const static RegisterNNCExternalFunction nnc_conv2d(
|
|
"nnc_aten_conv2d",
|
|
nnc_aten_conv2d);
|
|
const static RegisterNNCExternalFunction nnc_adaptive_avg_pool2d(
|
|
"nnc_aten_adaptive_avg_pool2d",
|
|
nnc_aten_adaptive_avg_pool2d);
|
|
const static RegisterNNCExternalFunction nnc_mean(
|
|
"nnc_aten_mean",
|
|
nnc_aten_mean);
|
|
const static RegisterNNCExternalFunction nnc_addmm(
|
|
"nnc_aten_addmm",
|
|
nnc_aten_addmm);
|
|
|
|
const static RegisterNNCExternalFunction nnc_triangular_solve(
|
|
"nnc_aten_triangular_solve",
|
|
nnc_aten_triangular_solve);
|
|
|
|
#ifdef USE_XNNPACK
|
|
const static RegisterNNCExternalFunction reg_nnc_prepacked_linear_clamp_run(
|
|
"nnc_prepacked_linear_clamp_run",
|
|
nnc_prepacked_linear_clamp_run);
|
|
const static RegisterNNCExternalFunction reg_nnc_prepacked_conv2d_clamp_run(
|
|
"nnc_prepacked_conv2d_clamp_run",
|
|
nnc_prepacked_conv2d_clamp_run);
|
|
#endif // USE_XNNPACK
|
|
|
|
#endif // C10_MOBILE
|
|
|
|
#ifdef C10_MOBILE
|
|
} // extern "C"
|
|
#endif
|
|
|
|
} // namespace tensorexpr
|
|
} // namespace jit
|
|
} // namespace torch
|