mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165152 Approved by: https://github.com/mikaylagawarecki ghstack dependencies: #164991
236 lines
8.1 KiB
C++
236 lines
8.1 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/stable/stableivalue_conversions.h>
|
|
#include <array>
|
|
#include <cstdint>
|
|
#include <optional>
|
|
#include <string>
|
|
|
|
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h>
|
|
#include <torch/headeronly/core/ScalarType.h>
|
|
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
|
|
|
|
namespace torch::stable {
|
|
|
|
// We expect this to be the stable version of the empty_like op that takes in
|
|
// no kwargs (device, dtype, layout, memory_format). We will add kwargs
|
|
// support in the future.
|
|
inline torch::stable::Tensor empty_like(const torch::stable::Tensor& self) {
|
|
const auto num_args = 6;
|
|
std::array<StableIValue, num_args> stack{
|
|
from(self),
|
|
from(std::nullopt),
|
|
from(std::nullopt),
|
|
from(std::nullopt),
|
|
from(std::nullopt),
|
|
from(std::nullopt)};
|
|
TORCH_ERROR_CODE_CHECK(
|
|
aoti_torch_call_dispatcher("aten::empty_like", "", stack.data()));
|
|
return to<torch::stable::Tensor>(stack[0]);
|
|
}
|
|
|
|
// We expect this to be the stable version of the fill_.Scalar op
|
|
// with identical semantics to the existing fill_.Scalar op.
|
|
// A subtle nuance is that `value` is typed as a double, but it is
|
|
// actually a Scalar. This is because Scalar.h is currently not
|
|
// header-only.
|
|
inline torch::stable::Tensor fill_(
|
|
const torch::stable::Tensor& self,
|
|
double value) {
|
|
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_fill__Scalar(self.get(), value));
|
|
return self;
|
|
}
|
|
|
|
// We expect this to be the stable version of the narrow.default op.
|
|
// narrow takes in a SymInt for start and length, but these are typed as
|
|
// int64_t as SymInt is not yet header-only.
|
|
inline torch::stable::Tensor narrow(
|
|
torch::stable::Tensor& self,
|
|
int64_t dim,
|
|
int64_t start,
|
|
int64_t length) {
|
|
AtenTensorHandle ret0 = nullptr;
|
|
|
|
TORCH_ERROR_CODE_CHECK(
|
|
aoti_torch_aten_narrow(self.get(), dim, start, length, &ret0));
|
|
return torch::stable::Tensor(ret0);
|
|
}
|
|
|
|
// We expect this to be a stable version of the new_empty op that takes in
|
|
// only dtype information.
|
|
inline torch::stable::Tensor new_empty(
|
|
const torch::stable::Tensor& self,
|
|
torch::headeronly::IntHeaderOnlyArrayRef size,
|
|
std::optional<c10::ScalarType> dtype = std::nullopt) {
|
|
int32_t device_type;
|
|
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));
|
|
|
|
int32_t device_index;
|
|
TORCH_ERROR_CODE_CHECK(
|
|
aoti_torch_get_device_index(self.get(), &device_index));
|
|
|
|
int32_t target_dtype;
|
|
if (dtype.has_value()) {
|
|
target_dtype = to<int32_t>(from(dtype.value()));
|
|
} else {
|
|
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &target_dtype));
|
|
}
|
|
|
|
int32_t layout;
|
|
TORCH_ERROR_CODE_CHECK(aoti_torch_get_layout(self.get(), &layout));
|
|
|
|
AtenTensorHandle ret0;
|
|
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_new_empty(
|
|
self.get(),
|
|
size.data(),
|
|
static_cast<int64_t>(size.size()),
|
|
&target_dtype,
|
|
&layout,
|
|
&device_type,
|
|
device_index,
|
|
nullptr, // pin_memory (nullptr for default)
|
|
&ret0));
|
|
|
|
return torch::stable::Tensor(ret0);
|
|
}
|
|
|
|
// We expect this to be a stable version of the new_zeros op that takes in
|
|
// only dtype information.
|
|
inline torch::stable::Tensor new_zeros(
|
|
const torch::stable::Tensor& self,
|
|
torch::headeronly::IntHeaderOnlyArrayRef size,
|
|
std::optional<c10::ScalarType> dtype = std::nullopt) {
|
|
int32_t device_type;
|
|
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));
|
|
|
|
int32_t device_index;
|
|
TORCH_ERROR_CODE_CHECK(
|
|
aoti_torch_get_device_index(self.get(), &device_index));
|
|
|
|
int32_t target_dtype;
|
|
if (dtype.has_value()) {
|
|
target_dtype = to<int32_t>(from(dtype.value()));
|
|
} else {
|
|
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &target_dtype));
|
|
}
|
|
|
|
int32_t layout;
|
|
TORCH_ERROR_CODE_CHECK(aoti_torch_get_layout(self.get(), &layout));
|
|
|
|
AtenTensorHandle ath;
|
|
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_new_zeros(
|
|
self.get(),
|
|
size.data(),
|
|
static_cast<int64_t>(size.size()),
|
|
&target_dtype,
|
|
&layout,
|
|
&device_type,
|
|
device_index,
|
|
nullptr, // pin_memory (nullptr for default)
|
|
&ath));
|
|
|
|
return torch::stable::Tensor(ath);
|
|
}
|
|
|
|
// We expect this to be the stable version of the pad.default op.
|
|
// pad.default takes in a SymInt[] as the pad argument however pad is typed as
|
|
// torch::headeronly::IntHeaderOnlyArrayRef as SymInt is not yet header-only.
|
|
inline torch::stable::Tensor pad(
|
|
const torch::stable::Tensor& self,
|
|
torch::headeronly::IntHeaderOnlyArrayRef pad,
|
|
const std::string& mode = "constant",
|
|
double value = 0.0) {
|
|
AtenTensorHandle ret0 = nullptr;
|
|
|
|
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_pad(
|
|
self.get(), pad.data(), pad.size(), mode.c_str(), &value, &ret0));
|
|
return torch::stable::Tensor(ret0);
|
|
}
|
|
|
|
// We expect the following two functions to be stable versions of the
|
|
// amax.default op with identical semantics to the existing amax.default op. If
|
|
// `keepdim` is true, the result will have the same number of dimensions as
|
|
// `self`, with the specified dimension having size 1. Otherwise, the result
|
|
// will have one fewer dimension than `self`, with the specified dimension
|
|
// removed.
|
|
|
|
// This function is an overload to compute the maximum value along each slice of
|
|
// `self` along a single dimension `dim`.
|
|
inline torch::stable::Tensor amax(
|
|
const torch::stable::Tensor& self,
|
|
int64_t dim,
|
|
bool keepdim = false) {
|
|
AtenTensorHandle ret = nullptr;
|
|
TORCH_ERROR_CODE_CHECK(
|
|
aoti_torch_aten_amax(self.get(), &dim, 1, keepdim, &ret));
|
|
return torch::stable::Tensor(ret);
|
|
}
|
|
|
|
// This function is an overload to compute the maximum value along each slice of
|
|
// `self` reducing over all the dimensions in the vector `dims`. The
|
|
// amax.default op takes in a SymInt[] as the dims argument, however dims is
|
|
// typed as use IntHeaderOnlyArrayRef here because SymInt is not yet header-only
|
|
inline torch::stable::Tensor amax(
|
|
const torch::stable::Tensor& self,
|
|
torch::headeronly::IntHeaderOnlyArrayRef dims,
|
|
bool keepdim = false) {
|
|
AtenTensorHandle ret = nullptr;
|
|
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_amax(
|
|
self.get(),
|
|
dims.data(),
|
|
static_cast<int64_t>(dims.size()),
|
|
keepdim,
|
|
&ret));
|
|
return torch::stable::Tensor(ret);
|
|
}
|
|
|
|
// We expect this to be the stable version of the transpose op with identical
|
|
// semantics to the existing transpose.int op.
|
|
inline torch::stable::Tensor transpose(
|
|
const torch::stable::Tensor& self,
|
|
int64_t dim0,
|
|
int64_t dim1) {
|
|
const auto num_args = 3;
|
|
std::array<StableIValue, num_args> stack{from(self), from(dim0), from(dim1)};
|
|
TORCH_ERROR_CODE_CHECK(
|
|
aoti_torch_call_dispatcher("aten::transpose", "int", stack.data()));
|
|
return to<torch::stable::Tensor>(stack[0]);
|
|
}
|
|
|
|
// We expect this to be the stable version of the zero_ op with identical
|
|
// semantics to the existing zero_ op (except that it will not be called as
|
|
// a tensor method but only as a function i.e. zero_(t) not t.zero_()).
|
|
inline torch::stable::Tensor zero_(torch::stable::Tensor& self) {
|
|
const auto num_args = 1;
|
|
std::array<StableIValue, num_args> stack{from(self)};
|
|
TORCH_ERROR_CODE_CHECK(
|
|
aoti_torch_call_dispatcher("aten::zero_", "", stack.data()));
|
|
return to<torch::stable::Tensor>(stack[0]);
|
|
}
|
|
|
|
// We expect this to be the stable version of the copy_ op with
|
|
// identical semantics to the existing copy_ op.
|
|
inline torch::stable::Tensor copy_(
|
|
torch::stable::Tensor& self,
|
|
const torch::stable::Tensor& src,
|
|
std::optional<bool> non_blocking = std::nullopt) {
|
|
const auto num_args = 3;
|
|
std::array<StableIValue, num_args> stack{
|
|
from(self), from(src), from(non_blocking.value_or(false))};
|
|
TORCH_ERROR_CODE_CHECK(
|
|
aoti_torch_call_dispatcher("aten::copy_", "", stack.data()));
|
|
return to<torch::stable::Tensor>(stack[0]);
|
|
}
|
|
|
|
// We expect this to be the stable version of the clone op. We will
|
|
// add optional memory_format kwarg support in the future.
|
|
inline torch::stable::Tensor clone(const torch::stable::Tensor& self) {
|
|
const auto num_args = 2;
|
|
std::array<StableIValue, num_args> stack{from(self), from(std::nullopt)};
|
|
TORCH_ERROR_CODE_CHECK(
|
|
aoti_torch_call_dispatcher("aten::clone", "", stack.data()));
|
|
return to<torch::stable::Tensor>(stack[0]);
|
|
}
|
|
|
|
} // namespace torch::stable
|