mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Re-land of https://github.com/pytorch/pytorch/pull/125242 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127034 Approved by: https://github.com/malfet
365 lines
12 KiB
C++
365 lines
12 KiB
C++
#include <torch/csrc/api/include/torch/types.h>
|
|
|
|
#include <cstdint>
|
|
#include <iostream>
|
|
#include <string>
|
|
|
|
namespace at {
|
|
|
|
Tensor custom_add_impl(Tensor t1, Tensor t2) {
|
|
return t1 + t2;
|
|
}
|
|
|
|
Tensor fn_with_all_inputs_impl(
|
|
const Tensor& tensor,
|
|
const c10::List<Tensor>& tensors,
|
|
const c10::List<std::optional<Tensor>>& optional_tensors,
|
|
const bool b8,
|
|
const c10::List<bool>& b8s,
|
|
const int64_t i64,
|
|
const c10::List<int64_t>& i64s,
|
|
const int64_t& symint,
|
|
const IntArrayRef symints,
|
|
const double f64,
|
|
const c10::List<double>& f64s,
|
|
const at::Scalar& scalar,
|
|
at::ArrayRef<at::Scalar> scalars,
|
|
const std::string& string,
|
|
const std::vector<std::string>& strings,
|
|
const c10::ScalarType& dtype,
|
|
const MemoryFormat& memory_format,
|
|
const Layout& layout,
|
|
const Device& device,
|
|
// optional
|
|
const std::optional<Tensor>& o_tensor,
|
|
const std::optional<c10::List<Tensor>>& o_tensors,
|
|
const std::optional<bool>& o_b8,
|
|
const std::optional<c10::List<bool>>& o_b8s,
|
|
const std::optional<int64_t>& o_i64,
|
|
const std::optional<c10::List<int64_t>>& o_i64s,
|
|
const std::optional<int64_t>& o_symint,
|
|
const std::optional<IntArrayRef>& o_symints,
|
|
const std::optional<double>& o_f64,
|
|
const std::optional<c10::List<double>>& o_f64s,
|
|
const std::optional<at::Scalar>& o_scalar,
|
|
const std::optional<at::ArrayRef<at::Scalar>>& o_scalars,
|
|
const std::optional<std::string>& o_string,
|
|
const std::optional<std::vector<std::string>>& o_strings,
|
|
const std::optional<c10::ScalarType>& o_dtype,
|
|
const std::optional<MemoryFormat>& o_memory_format,
|
|
const std::optional<Layout>& o_layout,
|
|
const std::optional<Device>& o_device) {
|
|
std::cout << "tensor shape: " << tensor.sizes() << std::endl;
|
|
|
|
std::cout << "tensors shape: ";
|
|
for (auto t : tensors) {
|
|
std::cout << t.get().toTensor().sizes() << ", ";
|
|
}
|
|
std::cout << std::endl;
|
|
|
|
std::cout << "optional tensors shape: ";
|
|
for (auto t : optional_tensors) {
|
|
if (t.get().toOptional<Tensor>().has_value()) {
|
|
std::cout << t.get().toTensor().sizes() << ", ";
|
|
} else {
|
|
std::cout << "None, ";
|
|
}
|
|
}
|
|
std::cout << std::endl;
|
|
|
|
std::cout << "b8 " << c10::IValue(b8) << std::endl;
|
|
std::cout << "b8s " << c10::IValue(b8s) << std::endl;
|
|
std::cout << "i64 " << c10::IValue(i64) << std::endl;
|
|
std::cout << "i64s " << c10::IValue(i64s) << std::endl;
|
|
std::cout << "symint " << c10::IValue(symint) << std::endl;
|
|
std::cout << "symints " << c10::IValue(symints) << std::endl;
|
|
std::cout << "f64 " << c10::IValue(f64) << std::endl;
|
|
std::cout << "f64s " << c10::IValue(f64s) << std::endl;
|
|
std::cout << "scalar " << c10::IValue(scalar) << std::endl;
|
|
std::cout << "scalars " << c10::IValue(scalars) << std::endl;
|
|
std::cout << "string " << c10::IValue(string) << std::endl;
|
|
std::cout << "strings " << c10::IValue(strings) << std::endl;
|
|
std::cout << "dtype " << c10::IValue(dtype) << std::endl;
|
|
std::cout << "memory_format " << c10::IValue(memory_format) << std::endl;
|
|
std::cout << "layout " << c10::IValue(layout) << std::endl;
|
|
std::cout << "device " << c10::IValue(device) << std::endl;
|
|
|
|
std::cout << "o_tensor "
|
|
<< (o_tensor.has_value() ? c10::IValue(o_tensor.value().sizes())
|
|
: "None")
|
|
<< std::endl;
|
|
|
|
std::cout << "o_tensors shape: ";
|
|
if (o_tensors.has_value()) {
|
|
for (auto t : o_tensors.value()) {
|
|
std::cout << t.get().toTensor().sizes() << ", ";
|
|
}
|
|
} else {
|
|
std::cout << "None";
|
|
}
|
|
std::cout << std::endl;
|
|
|
|
std::cout << "o_b8 "
|
|
<< (o_b8.has_value() ? c10::IValue(o_b8.value()) : "None")
|
|
<< std::endl;
|
|
std::cout << "o_b8s "
|
|
<< (o_b8s.has_value() ? c10::IValue(o_b8s.value()) : "None")
|
|
<< std::endl;
|
|
std::cout << "o_i64 "
|
|
<< (o_i64.has_value() ? c10::IValue(o_i64.value()) : "None")
|
|
<< std::endl;
|
|
std::cout << "o_i64s "
|
|
<< (o_i64s.has_value() ? c10::IValue(o_i64s.value()) : "None")
|
|
<< std::endl;
|
|
std::cout << "o_symint "
|
|
<< (o_symint.has_value() ? c10::IValue(o_symint.value()) : "None")
|
|
<< std::endl;
|
|
std::cout << "o_symints "
|
|
<< (o_symints.has_value() ? c10::IValue(o_symints.value()) : "None")
|
|
<< std::endl;
|
|
std::cout << "o_f64 "
|
|
<< (o_f64.has_value() ? c10::IValue(o_f64.value()) : "None")
|
|
<< std::endl;
|
|
std::cout << "o_f64s "
|
|
<< (o_f64s.has_value() ? c10::IValue(o_f64s.value()) : "None")
|
|
<< std::endl;
|
|
std::cout << "o_scalar "
|
|
<< (o_scalar.has_value() ? c10::IValue(o_scalar.value()) : "None")
|
|
<< std::endl;
|
|
std::cout << "o_scalars "
|
|
<< (o_scalars.has_value() ? c10::IValue(o_scalars.value()) : "None")
|
|
<< std::endl;
|
|
std::cout << "o_string "
|
|
<< (o_string.has_value() ? c10::IValue(o_string.value()) : "None")
|
|
<< std::endl;
|
|
std::cout << "o_strings "
|
|
<< (o_strings.has_value() ? c10::IValue(o_strings.value()) : "None")
|
|
<< std::endl;
|
|
std::cout << "o_dtype "
|
|
<< (o_dtype.has_value() ? c10::IValue(o_dtype.value()) : "None")
|
|
<< std::endl;
|
|
std::cout << "o_memory_format "
|
|
<< (o_memory_format.has_value()
|
|
? c10::IValue(o_memory_format.value())
|
|
: "None")
|
|
<< std::endl;
|
|
std::cout << "o_layout "
|
|
<< (o_layout.has_value() ? c10::IValue(o_layout.value()) : "None")
|
|
<< std::endl;
|
|
std::cout << "o_device "
|
|
<< (o_device.has_value() ? c10::IValue(o_device.value()) : "None")
|
|
<< std::endl;
|
|
|
|
int64_t int_hash = 0;
|
|
int_hash ^= i64;
|
|
for (auto i : i64s) {
|
|
int_hash ^= i;
|
|
}
|
|
if (o_i64.has_value()) {
|
|
int_hash ^= o_i64.value();
|
|
}
|
|
if (o_i64s.has_value()) {
|
|
for (auto i : o_i64s.value()) {
|
|
int_hash ^= i;
|
|
}
|
|
}
|
|
|
|
int_hash ^= symint;
|
|
for (auto i : symints) {
|
|
int_hash ^= i;
|
|
}
|
|
if (o_symint.has_value()) {
|
|
int_hash ^= o_symint.value();
|
|
}
|
|
if (o_symints.has_value()) {
|
|
for (auto i : o_symints.value()) {
|
|
int_hash ^= i;
|
|
}
|
|
}
|
|
|
|
return tensor + int_hash;
|
|
}
|
|
|
|
Tensor fn_with_default_input_impl(const Tensor& tensor, const int64_t i64) {
|
|
return tensor + i64;
|
|
}
|
|
|
|
std::tuple<Tensor, Tensor> fn_with_tuple_output_impl(
|
|
const Tensor& tensor,
|
|
const int64_t i64) {
|
|
return {tensor + i64, tensor - i64};
|
|
}
|
|
|
|
std::vector<Tensor> fn_with_list_output_impl(
|
|
TensorList tensors,
|
|
const int64_t i64) {
|
|
std::vector<Tensor> outputs;
|
|
for (auto& t : tensors) {
|
|
outputs.emplace_back(t + i64);
|
|
}
|
|
return outputs;
|
|
}
|
|
|
|
std::tuple<Tensor, std::vector<Tensor>> fn_with_mix_outputs_impl(
|
|
const Tensor& tensor,
|
|
TensorList tensors) {
|
|
std::vector<Tensor> outputs;
|
|
for (auto& t : tensors) {
|
|
outputs.emplace_back(t + 2);
|
|
}
|
|
return {tensor + 1, outputs};
|
|
}
|
|
|
|
std::tuple<Tensor, Tensor> fn_with_input_mutation_impl(
|
|
Tensor& t0,
|
|
const Tensor& t1,
|
|
Tensor& t2) {
|
|
t0.add_(1);
|
|
t2.sub_(1);
|
|
return {t1 + 1, t1 + 2};
|
|
}
|
|
|
|
// NOLINTBEGIN(clang-diagnostic-unused-parameter)
|
|
Tensor fn_with_all_inputs_meta(
|
|
const Tensor& tensor,
|
|
const c10::List<Tensor>& tensors,
|
|
const c10::List<std::optional<Tensor>>& optional_tensors,
|
|
const bool b8,
|
|
const c10::List<bool>& b8s,
|
|
const int64_t i64,
|
|
const c10::List<int64_t>& i64s,
|
|
const c10::SymInt& symint,
|
|
c10::SymIntArrayRef symints,
|
|
const double f64,
|
|
const c10::List<double>& f64s,
|
|
const at::Scalar& scalar,
|
|
at::ArrayRef<at::Scalar> scalars,
|
|
const std::string& string,
|
|
const std::vector<std::string>& strings,
|
|
const c10::ScalarType& dtype,
|
|
const MemoryFormat& memory_format,
|
|
const Layout& layout,
|
|
const Device& device,
|
|
// optional
|
|
const std::optional<Tensor>& o_tensor,
|
|
const std::optional<c10::List<Tensor>>& o_tensors,
|
|
const std::optional<bool>& o_b8,
|
|
const std::optional<c10::List<bool>>& o_b8s,
|
|
const std::optional<int64_t>& o_i64,
|
|
const std::optional<c10::List<int64_t>>& o_i64s,
|
|
const std::optional<c10::SymInt>& o_symint,
|
|
at::OptionalSymIntArrayRef o_symints,
|
|
const std::optional<double>& o_f64,
|
|
const std::optional<c10::List<double>>& o_f64s,
|
|
const std::optional<at::Scalar>& o_scalar,
|
|
const std::optional<at::ArrayRef<at::Scalar>>& o_scalars,
|
|
const std::optional<std::string>& o_string,
|
|
const std::optional<std::vector<std::string>>& o_strings,
|
|
const std::optional<c10::ScalarType>& o_dtype,
|
|
const std::optional<MemoryFormat>& o_memory_format,
|
|
const std::optional<Layout>& o_layout,
|
|
const std::optional<Device>& o_device) {
|
|
return tensor;
|
|
}
|
|
|
|
Tensor fn_with_default_input_meta(const Tensor& tensor, const int64_t i64) {
|
|
return tensor.clone();
|
|
}
|
|
|
|
std::tuple<Tensor, Tensor> fn_with_tuple_output_meta(
|
|
const Tensor& tensor,
|
|
const int64_t i64) {
|
|
return {tensor.clone(), tensor.clone()};
|
|
}
|
|
|
|
std::vector<Tensor> fn_with_list_output_meta(
|
|
TensorList tensors,
|
|
const int64_t i64) {
|
|
std::vector<Tensor> outputs;
|
|
for (auto& t : tensors) {
|
|
outputs.push_back(t.clone());
|
|
}
|
|
return outputs;
|
|
}
|
|
|
|
std::tuple<Tensor, std::vector<Tensor>> fn_with_mix_outputs_meta(
|
|
const Tensor& tensor,
|
|
TensorList tensors) {
|
|
std::vector<Tensor> outputs;
|
|
for (auto& t : tensors) {
|
|
outputs.push_back(t.clone());
|
|
}
|
|
return {tensor.clone(), outputs};
|
|
}
|
|
|
|
std::tuple<Tensor, Tensor> fn_with_input_mutation_meta(
|
|
Tensor& t0,
|
|
const Tensor& t1,
|
|
Tensor& t2) {
|
|
return {t1.clone(), t1.clone()};
|
|
}
|
|
|
|
} // namespace at
|
|
|
|
TORCH_LIBRARY(aoti_custom_ops, m) {
|
|
m.def("custom_add(Tensor t1, Tensor t2) -> Tensor");
|
|
m.def(
|
|
"fn_with_all_inputs(Tensor tensor, "
|
|
"Tensor[] tensors, "
|
|
"Tensor?[] optional_tensors, "
|
|
"bool b8, bool[] b8s, "
|
|
"int i64, int[] i64s, "
|
|
"SymInt symint, SymInt[] symints, "
|
|
"float f64, float[] f64s, "
|
|
"Scalar scalar, Scalar[] scalars, "
|
|
"str string, str[] strings, "
|
|
"ScalarType dtype, "
|
|
"MemoryFormat memory_format, "
|
|
"Layout layout, "
|
|
"Device device, "
|
|
"*, "
|
|
"Tensor? o_tensor, Tensor[]? o_tensors, "
|
|
"bool? o_b8, bool[]? o_b8s, "
|
|
"int? o_i64, int[]? o_i64s, "
|
|
"SymInt? o_symint, SymInt[]? o_symints, "
|
|
"float? o_f64, float[]? o_f64s, "
|
|
"Scalar? o_scalar, Scalar[]? o_scalars, "
|
|
"str? o_string, str[]? o_strings, "
|
|
"ScalarType? o_dtype, "
|
|
"MemoryFormat? o_memory_format, "
|
|
"Layout? o_layout, "
|
|
"Device? o_device) -> Tensor");
|
|
|
|
m.def("fn_with_default_input(Tensor t, int i=3) -> Tensor");
|
|
|
|
m.def("fn_with_tuple_output(Tensor t, int i) -> (Tensor, Tensor)");
|
|
|
|
m.def("fn_with_list_output(Tensor[] tensors, int i) -> Tensor[]");
|
|
|
|
m.def(
|
|
"fn_with_mix_outputs(Tensor t, Tensor[] tensors) -> (Tensor, Tensor[])");
|
|
|
|
m.def(
|
|
"fn_with_input_mutation(Tensor(a!) t0, Tensor t1, Tensor(b!) t2) -> (Tensor, Tensor)");
|
|
|
|
}
|
|
|
|
TORCH_LIBRARY_IMPL(aoti_custom_ops, CompositeExplicitAutograd, m) {
|
|
m.impl("custom_add", at::custom_add_impl);
|
|
m.impl("fn_with_all_inputs", at::fn_with_all_inputs_impl);
|
|
m.impl("fn_with_default_input", at::fn_with_default_input_impl);
|
|
m.impl("fn_with_tuple_output", at::fn_with_tuple_output_impl);
|
|
m.impl("fn_with_list_output", at::fn_with_list_output_impl);
|
|
m.impl("fn_with_mix_outputs", at::fn_with_mix_outputs_impl);
|
|
m.impl("fn_with_input_mutation", at::fn_with_input_mutation_impl);
|
|
}
|
|
|
|
TORCH_LIBRARY_IMPL(aoti_custom_ops, Meta, m) {
|
|
m.impl("fn_with_all_inputs", at::fn_with_all_inputs_meta);
|
|
m.impl("fn_with_default_input", at::fn_with_default_input_meta);
|
|
m.impl("fn_with_tuple_output", at::fn_with_tuple_output_meta);
|
|
m.impl("fn_with_list_output", at::fn_with_list_output_meta);
|
|
m.impl("fn_with_mix_outputs", at::fn_with_mix_outputs_meta);
|
|
m.impl("fn_with_input_mutation", at::fn_with_input_mutation_meta);
|
|
}
|