mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[nativert] Move call_torchbind_kernel (#156571)
Summary: Move call_torchbind_kernel target from internal sigmoid to pytorch Test Plan: Test Internally: buck2 test mode/dev-nosan caffe2/test/cpp/nativert:op_kernel_test buck build //sigmoid/core/kernels:kernel_factory and all sandcastle tests Rollback Plan: Differential Revision: D77118592 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156571 Approved by: https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
795a6a0aff
commit
4c59edf0c5
@ -613,6 +613,7 @@ libtorch_nativert_sources = [
|
||||
"torch/nativert/kernels/HigherOrderKernel.cpp",
|
||||
"torch/nativert/executor/memory/GreedyBySize.cpp",
|
||||
"torch/nativert/executor/memory/Bump.cpp",
|
||||
"torch/nativert/kernels/CallTorchBindKernel.cpp",
|
||||
]
|
||||
|
||||
torch_mobile_tracer_sources = [
|
||||
|
51
torch/nativert/kernels/CallTorchBindKernel.cpp
Normal file
51
torch/nativert/kernels/CallTorchBindKernel.cpp
Normal file
@ -0,0 +1,51 @@
|
||||
#include <torch/nativert/kernels/CallTorchBindKernel.h>
|
||||
|
||||
#include <c10/util/Enumerate.h>
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
CallTorchBindKernel::CallTorchBindKernel(const Node* node) : OpKernel(node) {
|
||||
const Value* customObjValue = node_->inputs()[0].value;
|
||||
CHECK(customObjValue->type() == Type::Kind::CustomObj);
|
||||
|
||||
customClassName_ = customObjValue->type().classFqn();
|
||||
customClassType_ = torch::jit::getCustomClass(customClassName_);
|
||||
|
||||
// sample schema
|
||||
// torch.ops.higher_order.call_torchbind(arg1_1, 'add_tensor', arg0_1);
|
||||
|
||||
CHECK(node->attributes().size() == 1)
|
||||
<< "Expects higher_order.call_torchbind to only have a single attribute, methodName";
|
||||
const auto& attr = node->attributes()[0];
|
||||
|
||||
CHECK(std::holds_alternative<std::string>(attr.value))
|
||||
<< "method should be a string";
|
||||
methodName_ = std::get<std::string>(attr.value);
|
||||
method_ = customClassType_->findMethod(methodName_);
|
||||
|
||||
CHECK(method_ != nullptr) << "method not found: " << methodName_;
|
||||
}
|
||||
|
||||
void CallTorchBindKernel::computeInternal(
|
||||
ExecutionFrame& executionFrame) const {
|
||||
// prepare inputs
|
||||
std::vector<c10::IValue> stack;
|
||||
for (const auto& input : node_->inputs()) {
|
||||
const auto& id = input.value->id();
|
||||
stack.emplace_back(executionFrame.getIValue(id));
|
||||
}
|
||||
|
||||
// call the method
|
||||
method_->run(stack);
|
||||
|
||||
// set outputs
|
||||
const auto& outputs = node_->outputs();
|
||||
TORCH_CHECK_EQ(outputs.size(), stack.size());
|
||||
for (auto&& [i, outputValue] : c10::enumerate(stack)) {
|
||||
executionFrame.setIValue(outputs[i]->id(), std::move(outputValue));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
26
torch/nativert/kernels/CallTorchBindKernel.h
Normal file
26
torch/nativert/kernels/CallTorchBindKernel.h
Normal file
@ -0,0 +1,26 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
#include <torch/custom_class.h>
|
||||
|
||||
#include <torch/nativert/executor/ExecutionFrame.h> // @manual
|
||||
#include <torch/nativert/executor/OpKernel.h> // @manual
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
class CallTorchBindKernel : public OpKernel {
|
||||
public:
|
||||
CallTorchBindKernel() = delete; // deleted default constructor
|
||||
CallTorchBindKernel(const Node* node);
|
||||
|
||||
void computeInternal(ExecutionFrame& executionFrame) const override final;
|
||||
|
||||
private:
|
||||
std::string methodName_;
|
||||
torch::jit::Function* method_;
|
||||
|
||||
std::string customClassName_;
|
||||
at::ClassTypePtr customClassType_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
Reference in New Issue
Block a user