mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: check will be crit, causing program to exit, which is quite dangerous Test Plan: CI Rollback Plan: Differential Revision: D78050595 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158269 Approved by: https://github.com/SherlockNoMad, https://github.com/henryoier
54 lines
1.6 KiB
C++
54 lines
1.6 KiB
C++
#include <torch/nativert/kernels/CallTorchBindKernel.h>
|
|
|
|
#include <c10/util/Enumerate.h>
|
|
|
|
#include <c10/util/Logging.h>
|
|
|
|
namespace torch::nativert {
|
|
|
|
CallTorchBindKernel::CallTorchBindKernel(const Node* node) : OpKernel(node) {
|
|
const Value* customObjValue = node_->inputs()[0].value;
|
|
TORCH_CHECK(customObjValue->type() == Type::Kind::CustomObj);
|
|
|
|
customClassName_ = customObjValue->type().classFqn();
|
|
customClassType_ = torch::jit::getCustomClass(customClassName_);
|
|
|
|
// sample schema
|
|
// torch.ops.higher_order.call_torchbind(arg1_1, 'add_tensor', arg0_1);
|
|
|
|
TORCH_CHECK(
|
|
node->attributes().size() == 1,
|
|
"Expects higher_order.call_torchbind to only have a single attribute, methodName");
|
|
const auto& attr = node->attributes()[0];
|
|
|
|
TORCH_CHECK(
|
|
std::holds_alternative<std::string>(attr.value),
|
|
"method should be a string");
|
|
methodName_ = std::get<std::string>(attr.value);
|
|
method_ = customClassType_->findMethod(methodName_);
|
|
|
|
TORCH_CHECK(method_ != nullptr, "method not found: ", methodName_);
|
|
}
|
|
|
|
void CallTorchBindKernel::computeInternal(
|
|
ExecutionFrame& executionFrame) const {
|
|
// prepare inputs
|
|
std::vector<c10::IValue> stack;
|
|
for (const auto& input : node_->inputs()) {
|
|
const auto& id = input.value->id();
|
|
stack.emplace_back(executionFrame.getIValue(id));
|
|
}
|
|
|
|
// call the method
|
|
method_->run(stack);
|
|
|
|
// set outputs
|
|
const auto& outputs = node_->outputs();
|
|
TORCH_CHECK(outputs.size() == stack.size());
|
|
for (auto&& [i, outputValue] : c10::enumerate(stack)) {
|
|
executionFrame.setIValue(outputs[i]->id(), std::move(outputValue));
|
|
}
|
|
}
|
|
|
|
} // namespace torch::nativert
|