Files
pytorch/torch/nativert/kernels/AutoFunctionalizeKernel.cpp
dolpm 725c327284 [nativert] add memory overlap debug assertion (#157290)
Summary: better safe than sorry. will throw if memory overlap detected when using planned tensors and debug mode is enabled -- this will make our planning unit tests more robust.

Test Plan:
ci

Rollback Plan:

Differential Revision: D77327841

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157290
Approved by: https://github.com/SherlockNoMad, https://github.com/zhxchen17
2025-07-14 19:12:41 +00:00

64 lines
2.2 KiB
C++

#include <torch/nativert/kernels/AutoFunctionalizeKernel.h>
#include <fmt/format.h>
#include <c10/util/Enumerate.h>
namespace torch::nativert {
UnsafeAutoFunctionalizeKernel::UnsafeAutoFunctionalizeKernel(const Node* node)
: OpKernel(node),
op_(getOperatorForTarget(
std::get<std::string>(node->attributes()[0].value))),
schema_(op_.schema()),
arguments_(prefillStackWithStaticArgs(node, schema_)),
numOutputs_(static_cast<int>(schema_.returns().size())) {
for (const auto& [idx, schemaArg] : c10::enumerate(schema_.arguments())) {
if (schemaArg.alias_info() != nullptr &&
schemaArg.alias_info()->isWrite()) {
mutatingInputArgs_.push_back(node->getInput(schemaArg.name()).value);
}
}
}
void UnsafeAutoFunctionalizeKernel::computeInternal(
ExecutionFrame& executionFrame) const {
// Make a copy of the stack
std::vector<c10::IValue> stack = arguments_.getStackWithStaticArgs();
fillDynamicInputs(executionFrame, arguments_, stack);
// Call the op with the prepared stack.
try {
op_.callBoxed(stack);
} catch (const std::exception& ex) {
// TODO: this eats the original exception type. ATen returns different
// exception types that correspond to different Python errors (e.g.
// IndexError, ValueError). If retaining this information is important
// to us, we'll have to change this up a little.
auto stackTrace = node_->getMetadata("stack_trace");
throw std::runtime_error(fmt::format(
"Original Python stacktrace:\n{}\n{}",
stackTrace ? *stackTrace : "<no stack trace>",
ex.what()));
}
const auto& outputValues = node_->outputs();
for (int i = 0; i < numOutputs_; ++i) {
executionFrame.setIValue(outputValues[i]->id(), std::move(stack.at(i)));
}
// Copy over mutating inputs to outputs
int mutatingArgStartIndex = (numOutputs_ == 0) ? 1 : numOutputs_;
for (size_t i = mutatingArgStartIndex; i < outputValues.size(); ++i) {
executionFrame.setIValue(
outputValues[i]->id(),
executionFrame.getIValue(
mutatingInputArgs_.at(i - mutatingArgStartIndex)->id(),
true /* allowNone */));
}
}
} // namespace torch::nativert