mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[nativert] Move auto_functionalize_kernel (#156454)
Summary: Torch Native Runtime RFC: https://github.com/pytorch/rfcs/pull/72 As part of the effort to open source TorchNativeRuntime (or what we call Sigmoid), we are moving the Pytree implementation to torch/: fbcode/sigmoid/kernels -> fbcode/caffe2/torch/nativert/kernels Copied from original auto_functionalize Diff Summary D53776805: This is a non-functional kernel implementation for auto_functionalize In AutoFunctionalizeKernel, I directly call the underlying target without making a clone of mutating inputs. This would mutates the input tensors inplace, which is unsafe in general. However, Sigmoid is not doing any graph optimization, or node reordering at the moment, so it's ok do take this short cut. In the proper functional implementation, it will make a clone of the mutating input tensor return these new instance of tensors as AutoFunctionalizeKernel output. If the original exported program has some "bufferMutation" or "userInputMutation" fields, it will also need to honor such mutations in Sigmoid. Test Plan: See internal for test plan Differential Revision: D76926383 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156454 Approved by: https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
eb331b59fe
commit
e5ea24fb27
@ -608,6 +608,7 @@ libtorch_nativert_sources = [
|
||||
"torch/nativert/common/FileUtil.cpp",
|
||||
"torch/nativert/detail/ITree.cpp",
|
||||
"torch/nativert/kernels/C10Kernel.cpp",
|
||||
"torch/nativert/kernels/AutoFunctionalizeKernel.cpp",
|
||||
]
|
||||
|
||||
torch_mobile_tracer_sources = [
|
||||
|
64
torch/nativert/kernels/AutoFunctionalizeKernel.cpp
Normal file
64
torch/nativert/kernels/AutoFunctionalizeKernel.cpp
Normal file
@ -0,0 +1,64 @@
|
||||
#include <torch/nativert/kernels/AutoFunctionalizeKernel.h>
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include <c10/util/Enumerate.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
UnsafeAutoFunctionalizeKernel::UnsafeAutoFunctionalizeKernel(const Node* node)
|
||||
: OpKernel(node),
|
||||
op_(getOperatorForTarget(
|
||||
std::get<std::string>(node->attributes()[0].value))),
|
||||
schema_(op_.schema()),
|
||||
arguments_(prefillStackWithStaticArgs(node, schema_)) {
|
||||
for (const auto& [idx, schemaArg] : c10::enumerate(schema_.arguments())) {
|
||||
if (schemaArg.alias_info() != nullptr &&
|
||||
schemaArg.alias_info()->isWrite()) {
|
||||
mutatingInputArgs_.push_back(node->getInput(schemaArg.name()).value);
|
||||
}
|
||||
}
|
||||
|
||||
numOutputs_ = schema_.returns().size();
|
||||
}
|
||||
|
||||
void UnsafeAutoFunctionalizeKernel::computeInternal(
|
||||
ExecutionFrame& executionFrame) const {
|
||||
// Make a copy of the stack
|
||||
std::vector<c10::IValue> stack = arguments_.getStackWithStaticArgs();
|
||||
|
||||
fillDynamicInputs(executionFrame, arguments_, stack);
|
||||
|
||||
// Call the op with the prepared stack.
|
||||
try {
|
||||
op_.callBoxed(stack);
|
||||
} catch (const std::exception& ex) {
|
||||
// TODO: this eats the original exception type. ATen returns different
|
||||
// exception types that correspond to different Python errors (e.g.
|
||||
// IndexError, ValueError). If retaining this information is important
|
||||
// to us, we'll have to change this up a little.
|
||||
auto stackTrace = node_->getMetadata("stack_trace");
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Original Python stacktrace:\n{}\n{}",
|
||||
stackTrace ? *stackTrace : "<no stack trace>",
|
||||
ex.what()));
|
||||
}
|
||||
|
||||
const auto& outputValues = node_->outputs();
|
||||
|
||||
for (int i = 0; i < numOutputs_; ++i) {
|
||||
executionFrame.setIValue(outputValues[i]->id(), std::move(stack.at(i)));
|
||||
}
|
||||
|
||||
// Copy over mutating inputs to outputs
|
||||
int mutatingArgStartIndex = (numOutputs_ == 0) ? 1 : numOutputs_;
|
||||
for (size_t i = mutatingArgStartIndex; i < outputValues.size(); ++i) {
|
||||
executionFrame.setIValue(
|
||||
outputValues[i]->id(),
|
||||
executionFrame.getIValue(
|
||||
mutatingInputArgs_.at(i - mutatingArgStartIndex)->id(),
|
||||
true /* allowNone */));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
29
torch/nativert/kernels/AutoFunctionalizeKernel.h
Normal file
29
torch/nativert/kernels/AutoFunctionalizeKernel.h
Normal file
@ -0,0 +1,29 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <ATen/core/function_schema.h>
|
||||
#include <c10/core/Device.h>
|
||||
|
||||
#include <torch/nativert/executor/ExecutionFrame.h> // @manual
|
||||
#include <torch/nativert/executor/OpKernel.h> // @manual
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
class UnsafeAutoFunctionalizeKernel : public OpKernel {
|
||||
public:
|
||||
UnsafeAutoFunctionalizeKernel() = delete; // deleted default constructor
|
||||
UnsafeAutoFunctionalizeKernel(const Node* node);
|
||||
|
||||
void computeInternal(ExecutionFrame& executionFrame) const override final;
|
||||
|
||||
private:
|
||||
c10::OperatorHandle op_;
|
||||
c10::FunctionSchema schema_;
|
||||
|
||||
Arguments arguments_;
|
||||
|
||||
std::vector<Value*> mutatingInputArgs_;
|
||||
int numOutputs_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
Reference in New Issue
Block a user