mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: att Test Plan: ci Rollback Plan: Differential Revision: D81731425 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162353 Approved by: https://github.com/yiming0416
44 lines
1.2 KiB
C++
44 lines
1.2 KiB
C++
#include <torch/nativert/kernels/ETCallDelegateKernel.h>
|
|
|
|
#include <torch/nativert/executor/ETDelegateExecutor.h>
|
|
|
|
namespace torch::nativert {
|
|
|
|
ETCallDelegateKernel::ETCallDelegateKernel(
|
|
const Node* node,
|
|
ETDelegateExecutor& delegateExecutor)
|
|
: OpKernel(node), delegateExecutor_(delegateExecutor) {
|
|
for (const auto& input : node_->inputs()) {
|
|
TORCH_CHECK(input.value->type() == Type::Kind::Tensor);
|
|
}
|
|
|
|
for (const auto* output : node_->outputs()) {
|
|
TORCH_CHECK(output->type() == Type::Kind::Tensor);
|
|
}
|
|
}
|
|
|
|
void ETCallDelegateKernel::computeInternal(
|
|
ExecutionFrame& executionFrame) const {
|
|
std::vector<at::Tensor> inputs;
|
|
inputs.reserve(numInputs());
|
|
|
|
for (const auto& input : node_->inputs()) {
|
|
inputs.emplace_back(executionFrame.getTensor(input.value->id()));
|
|
}
|
|
|
|
auto outputs = delegateExecutor_.run(inputs);
|
|
const auto& node_outputs = node_->outputs();
|
|
TORCH_CHECK(outputs.size() == node_outputs.size());
|
|
|
|
size_t i = 0;
|
|
for (auto begin = std::make_move_iterator(outputs.begin()),
|
|
end = std::make_move_iterator(outputs.end());
|
|
begin != end;
|
|
++begin) {
|
|
executionFrame.setIValue(node_outputs[i]->id(), *begin);
|
|
i++;
|
|
}
|
|
}
|
|
|
|
} // namespace torch::nativert
|