mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: In general, device_ is not very useful in OpKernel. Remove it to avoid misuse. Also, the meaning of `device_` is also ambiguous in the OpKernel. For StaticDispatch kernels, we always call cpu kernel. For C10Kernel, we rely on input tensor's device and dispatcher to determine which device to run on. For ops involves multiple device, e.g. aten._to_copy(device), the meaning of device is ill-defined. Test Plan: CI Rollback Plan: Reviewed By: henryoier, dolpm, kqfu, zhxchen17 Differential Revision: D78704840 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158944 Approved by: https://github.com/dolpm
164 lines
4.3 KiB
C++
164 lines
4.3 KiB
C++
#include <ATen/record_function.h>
|
|
|
|
#include <ATen/CPUFunctions.h>
|
|
#include <c10/core/ScalarType.h>
|
|
#include <c10/util/irange.h>
|
|
#include <torch/csrc/jit/runtime/static/ops.h>
|
|
|
|
#include <c10/util/Enumerate.h>
|
|
#include <torch/nativert/kernels/PrimKernelRegistry.h>
|
|
|
|
namespace torch::nativert {
|
|
|
|
C10_DEFINE_REGISTRY(PrimKernelRegistry, OpKernel, const Node*)
|
|
|
|
namespace {
|
|
|
|
class OpKernel_prim_listpack : public OpKernel {
|
|
public:
|
|
explicit OpKernel_prim_listpack(const Node* node)
|
|
: OpKernel(node, OpKernelKind::kPrimKernel) {
|
|
auto listType = node->outputs()[0]->type();
|
|
switch (listType.kind()) {
|
|
case Type::Kind::TensorList:
|
|
type_ = c10::TensorType::get();
|
|
break;
|
|
case Type::Kind::SymIntList:
|
|
type_ = c10::IntType::get();
|
|
break;
|
|
case Type::Kind::OptionalTensorList:
|
|
type_ = c10::OptionalType::create(c10::TensorType::get());
|
|
break;
|
|
default:
|
|
TORCH_CHECK(false, "Unsupported list type: ", listType);
|
|
}
|
|
}
|
|
|
|
void computeInternal(ExecutionFrame& executionFrame) const override final {
|
|
RECORD_USER_SCOPE("nativert::OpKernel_prim_listpack");
|
|
c10::List<c10::IValue> list(type_);
|
|
list.reserve(numInputs());
|
|
for (size_t i = 0; i < numInputs(); ++i) {
|
|
if (KernelInput(i).isNone()) {
|
|
list.emplace_back();
|
|
} else {
|
|
list.push_back(KernelInput(i));
|
|
}
|
|
}
|
|
KernelOutput(0) = std::move(list);
|
|
}
|
|
|
|
private:
|
|
c10::TypePtr type_;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
C10_REGISTER_TYPED_CLASS(
|
|
PrimKernelRegistry,
|
|
"prim.ListPack",
|
|
OpKernel_prim_listpack)
|
|
|
|
REGISTER_PRIM_KERNEL("prim.ListUnpack", prim_listunpack, {
|
|
RECORD_USER_SCOPE("nativert::OpKernel_prim_listunpack");
|
|
auto inputListRef = KernelInput(0).toListRef();
|
|
for (const auto& [i, ivalue] : c10::enumerate(inputListRef)) {
|
|
KernelOutput(i) = ivalue;
|
|
}
|
|
})
|
|
|
|
// Noop for input and output
|
|
REGISTER_PRIM_KERNEL("prim.Input", prim_input, {})
|
|
REGISTER_PRIM_KERNEL("prim.Output", prim_output, {})
|
|
|
|
namespace {
|
|
|
|
class OpKernel_variadic_concat : public OpKernel {
|
|
public:
|
|
explicit OpKernel_variadic_concat(const Node* node)
|
|
: OpKernel(node, OpKernelKind::kPrimKernel) {
|
|
dim_ = node_->attributes().size() > 0
|
|
? constantToIValue(node_->getAttribute("dim").value).toInt()
|
|
: 0;
|
|
}
|
|
void computeInternal(ExecutionFrame& executionFrame) const override final {
|
|
{
|
|
const size_t numNodeInps = numInputs();
|
|
auto numCatInps = numNodeInps;
|
|
auto dim = dim_;
|
|
if (KernelInput(numCatInps - 1).isInt()) {
|
|
dim = KernelInput(numCatInps - 1).toInt();
|
|
numCatInps--;
|
|
}
|
|
std::vector<at::Tensor> inputs(numCatInps);
|
|
for (const auto i : c10::irange(numCatInps)) {
|
|
inputs[i] = KernelInput(i).toTensor();
|
|
}
|
|
|
|
if (KernelOutput(0).isNone()) {
|
|
KernelOutput(0) = at::cpu::cat(inputs, dim);
|
|
return;
|
|
}
|
|
auto& out_t = KernelOutput(0).toTensor();
|
|
fastResizeToZero(out_t);
|
|
at::cpu::cat_outf(inputs, dim, out_t);
|
|
}
|
|
}
|
|
|
|
private:
|
|
int dim_;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
C10_REGISTER_TYPED_CLASS(
|
|
PrimKernelRegistry,
|
|
"prim.VarConcat",
|
|
OpKernel_variadic_concat)
|
|
|
|
namespace {
|
|
|
|
class OpKernel_variadic_stack : public OpKernel {
|
|
public:
|
|
explicit OpKernel_variadic_stack(const Node* node)
|
|
: OpKernel(node, OpKernelKind::kPrimKernel) {
|
|
dim_ = node_->attributes().size() > 0
|
|
? constantToIValue(node_->getAttribute("dim").value).toInt()
|
|
: 0;
|
|
}
|
|
void computeInternal(ExecutionFrame& executionFrame) const override final {
|
|
{
|
|
const size_t numNodeInps = numInputs();
|
|
auto numStackInps = numNodeInps;
|
|
auto dim = dim_;
|
|
if (KernelInput(numStackInps - 1).isInt()) {
|
|
dim = KernelInput(numStackInps - 1).toInt();
|
|
numStackInps--;
|
|
}
|
|
std::vector<at::Tensor> inputs(numStackInps);
|
|
for (const auto i : c10::irange(numStackInps)) {
|
|
inputs[i] = KernelInput(i).toTensor();
|
|
}
|
|
auto& out = KernelOutput(0);
|
|
if (out.isNone()) {
|
|
out = at::native::_stack_cpu(inputs, dim);
|
|
return;
|
|
}
|
|
auto& out_t = out.toTensor();
|
|
fastResizeToZero(out_t);
|
|
at::native::_stack_out_cpu(inputs, dim, out_t);
|
|
}
|
|
}
|
|
|
|
private:
|
|
int64_t dim_;
|
|
};
|
|
} // namespace
|
|
|
|
C10_REGISTER_TYPED_CLASS(
|
|
PrimKernelRegistry,
|
|
"prim.VarStack",
|
|
OpKernel_variadic_stack)
|
|
|
|
} // namespace torch::nativert
|