Files
pytorch/torch/nativert/kernels/AutoFunctionalizeKernel.cpp
Shangdi Yu e5ea24fb27 [nativert] Move auto_functionalize_kernel (#156454)
Summary:
Torch Native Runtime RFC: https://github.com/pytorch/rfcs/pull/72

As part of the effort to open source TorchNativeRuntime (or what we call Sigmoid), we are moving the Pytree implementation to torch/:

fbcode/sigmoid/kernels -> fbcode/caffe2/torch/nativert/kernels

Copied from original auto_functionalize Diff Summary D53776805:

This is a non-functional kernel implementation for auto_functionalize

In AutoFunctionalizeKernel, I directly call the underlying target without making a clone of mutating inputs.

This would mutates the input tensors inplace, which is unsafe in general.

However, Sigmoid is not doing any graph optimization, or node reordering at the moment, so it's ok do take this short cut.

In the proper functional implementation, it will

make a clone of the mutating input tensor

return these new instance of tensors as AutoFunctionalizeKernel output.

If the original exported program has some "bufferMutation" or "userInputMutation" fields, it will also need to honor such mutations in Sigmoid.

Test Plan: See internal for test plan

Differential Revision: D76926383

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156454
Approved by: https://github.com/zhxchen17
2025-06-20 19:53:16 +00:00

65 lines
2.2 KiB
C++

#include <torch/nativert/kernels/AutoFunctionalizeKernel.h>
#include <fmt/format.h>
#include <c10/util/Enumerate.h>
namespace torch::nativert {
UnsafeAutoFunctionalizeKernel::UnsafeAutoFunctionalizeKernel(const Node* node)
: OpKernel(node),
op_(getOperatorForTarget(
std::get<std::string>(node->attributes()[0].value))),
schema_(op_.schema()),
arguments_(prefillStackWithStaticArgs(node, schema_)) {
for (const auto& [idx, schemaArg] : c10::enumerate(schema_.arguments())) {
if (schemaArg.alias_info() != nullptr &&
schemaArg.alias_info()->isWrite()) {
mutatingInputArgs_.push_back(node->getInput(schemaArg.name()).value);
}
}
numOutputs_ = schema_.returns().size();
}
void UnsafeAutoFunctionalizeKernel::computeInternal(
ExecutionFrame& executionFrame) const {
// Make a copy of the stack
std::vector<c10::IValue> stack = arguments_.getStackWithStaticArgs();
fillDynamicInputs(executionFrame, arguments_, stack);
// Call the op with the prepared stack.
try {
op_.callBoxed(stack);
} catch (const std::exception& ex) {
// TODO: this eats the original exception type. ATen returns different
// exception types that correspond to different Python errors (e.g.
// IndexError, ValueError). If retaining this information is important
// to us, we'll have to change this up a little.
auto stackTrace = node_->getMetadata("stack_trace");
throw std::runtime_error(fmt::format(
"Original Python stacktrace:\n{}\n{}",
stackTrace ? *stackTrace : "<no stack trace>",
ex.what()));
}
const auto& outputValues = node_->outputs();
for (int i = 0; i < numOutputs_; ++i) {
executionFrame.setIValue(outputValues[i]->id(), std::move(stack.at(i)));
}
// Copy over mutating inputs to outputs
int mutatingArgStartIndex = (numOutputs_ == 0) ? 1 : numOutputs_;
for (size_t i = mutatingArgStartIndex; i < outputValues.size(); ++i) {
executionFrame.setIValue(
outputValues[i]->id(),
executionFrame.getIValue(
mutatingInputArgs_.at(i - mutatingArgStartIndex)->id(),
true /* allowNone */));
}
}
} // namespace torch::nativert