Files
pytorch/torch/nativert/kernels/C10Kernel.cpp
lichuyang 38ed608956 Better error handling in torch/nativert/* (#163308)
Replace the **runtime_error** of the vallina C++ exceptions with **TORCH_CEHCK** in **torch/nativert/***

The vallina C++ exception should not exist in the core part of pytorch for its corss-languanges trait. Comparing with the vallina C++ exceptions, TORCH_CHECK have the richer error context and It has the unified error handling mechanism. This commit replace the runtime_error with TORCH_CHECK of the files in
torch/nativert/*   .

Fixes part of #148114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163308
Approved by: https://github.com/dolpm
2025-09-28 14:23:44 +00:00

266 lines
8.2 KiB
C++

#include <torch/nativert/kernels/C10Kernel.h>
#include <fmt/ostream.h>
#include <c10/util/Enumerate.h>
#include <c10/util/Exception.h>
#ifdef __SIGRID_USE_GPU__
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#endif
namespace torch::nativert {
C10Kernel::C10Kernel(
const Node* node,
OpKernelKind kind,
AliasingSpec&& aliasingSpec)
: OpKernel(node, kind),
op_(getOperatorForTarget(node->target(), node)),
schema_(op_.schema(), std::move(aliasingSpec), kind_),
arguments_(prefillStackWithStaticArgs(node, op_.schema())) {}
void C10Kernel::computeInternal(ExecutionFrame& executionFrame) const {
// Make a copy of the stack
std::vector<c10::IValue> stack = arguments_.getStackWithStaticArgs();
fillDynamicInputs(executionFrame, arguments_, stack);
// Call the op with the prepared stack.
try {
op_.callBoxed(stack);
} catch (const std::exception& ex) {
auto stackTrace = node_->getMetadata("stack_trace");
TORCH_CHECK(
false,
"Exception while executing node: ",
*node_,
"\n"
"with args:\n",
readableArgs(op_.schema(), stack),
"\n",
ex.what(),
"\n",
"Original Python stacktrace:\n",
stackTrace ? *stackTrace : "<no stack trace>")
}
// Write out results
// TODO: we store intermediates in a single table (symint and tensor alike).
// This can theoretically lead to name collisions, although based on how
// these are named I don't think it will ever happen in practice. We need to
// enforce it though.
const auto& outputValues = node_->outputs();
TORCH_CHECK(
outputValues.size() == stack.size(),
"Output size mismatch for ",
node_->toString());
for (auto&& [i, actualOutput] : c10::enumerate(stack)) {
executionFrame.setIValue(outputValues[i]->id(), std::move(actualOutput));
}
}
namespace {
std::unordered_map<std::string, c10::IValue> getSymInputs(
const ExecutionFrame& executionFrame,
const Node& node) {
std::unordered_map<std::string, c10::IValue> inputs;
for (const auto& input : node.inputs()) {
const auto& val = executionFrame.getIValue(input.value->id());
if (val.isInt() || val.isDouble() || val.isBool()) {
inputs[input.name] = val;
} else {
TORCH_CHECK(false, "unsupported type for symbolic input");
}
}
for (const auto& attribute : node.attributes()) {
if (std::holds_alternative<int64_t>(attribute.value)) {
inputs[attribute.name] = std::get<int64_t>(attribute.value);
} else if (std::holds_alternative<double>(attribute.value)) {
inputs[attribute.name] = std::get<double>(attribute.value);
} else if (std::holds_alternative<bool>(attribute.value)) {
inputs[attribute.name] = std::get<bool>(attribute.value);
} else {
TORCH_CHECK(false, "unsupported type for symbolic input");
}
}
return inputs;
}
template <typename T>
void computeScalarBinaryOp(
ExecutionFrame& executionFrame,
const Node& node,
std::enable_if_t<true, T> a,
std::enable_if_t<true, T> b) {
std::string_view target = node.target();
T out;
if (target == "_operator.add") {
out = a + b;
} else if (target == "_operator.sub") {
out = a - b;
} else if (target == "_operator.mul") {
out = a * b;
} else if (target == "_operator.pow") {
out = std::pow(a, b);
} else {
TORCH_CHECK(false, "unsupported operator for scalar binary op: ", target);
}
executionFrame.setIValue(node.outputs()[0]->id(), out);
VLOG(2) << fmt::format(
"Completed executing node: {} with a={}, b={}, out={}",
fmt::streamed(node),
a,
b,
out);
}
} // namespace
void ScalarBinaryOpKernel::computeInternal(
ExecutionFrame& executionFrame) const {
auto inputs = getSymInputs(executionFrame, *node_);
const auto& a = inputs.at("a");
const auto& b = inputs.at("b");
auto coerceToDouble = [](const c10::IValue& x) -> double {
if (x.isInt()) {
return static_cast<double>(x.toInt());
} else if (x.isDouble()) {
return x.toDouble();
} else {
TORCH_CHECK(false, "unsupported type for symbolic input");
}
};
if (a.isInt() && b.isInt()) {
computeScalarBinaryOp<int64_t>(
executionFrame, *node_, a.toInt(), b.toInt());
} else {
computeScalarBinaryOp<double>(
executionFrame, *node_, coerceToDouble(a), coerceToDouble(b));
}
}
void SymIntOpKernel::computeInternal(ExecutionFrame& executionFrame) const {
auto inputs = getSymInputs(executionFrame, *node_);
int64_t a = inputs.at("a").toInt();
std::string_view target = node_->target();
if (target == "torch.sym_float") {
double out = static_cast<double>(a);
executionFrame.setIValue(node_->outputs()[0]->id(), out);
VLOG(2) << fmt::format(
"Completed executing node: {} with a={}, out={}",
fmt::streamed(*node_),
a,
out);
return;
}
int64_t b = inputs.at("b").toInt();
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t out;
if (target == "_operator.floordiv") {
out = a / b;
} else if (target == "_operator.mod") {
out = a % b;
} else if (target == "torch.sym_max") {
out = std::max(a, b);
} else if (target == "torch.sym_min") {
out = std::min(a, b);
} else {
TORCH_CHECK(false, "unsupported operator for SymInt: ", node_->target())
}
executionFrame.setIValue(node_->outputs()[0]->id(), out);
VLOG(2) << fmt::format(
"Completed executing node: {} with a={}, b={}, out={}",
fmt::streamed(*node_),
a,
b,
out);
}
void SymBoolOpKernel::computeInternal(ExecutionFrame& executionFrame) const {
auto inputs = getSymInputs(executionFrame, *node_);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool out;
const std::string_view target = node_->target();
if (target == "torch.sym_not") {
bool a = inputs.at("a").toBool();
out = !a;
} else if (target == "_operator.ge") {
int64_t a = inputs.at("a").toInt();
int64_t b = inputs.at("b").toInt();
out = a >= b;
} else if (target == "_operator.le") {
int64_t a = inputs.at("a").toInt();
int64_t b = inputs.at("b").toInt();
out = a <= b;
} else if (target == "_operator.eq") {
int64_t a = inputs.at("a").toInt();
int64_t b = inputs.at("b").toInt();
out = a == b;
} else if (target == "_operator.gt") {
int64_t a = inputs.at("a").toInt();
int64_t b = inputs.at("b").toInt();
out = a > b;
} else if (target == "_operator.lt") {
int64_t a = inputs.at("a").toInt();
int64_t b = inputs.at("b").toInt();
out = a < b;
} else if (target == "_operator.and_") {
bool a = inputs.at("a").toBool();
bool b = inputs.at("b").toBool();
out = a && b;
} else {
TORCH_CHECK(false, "unsupported operator for SymBool: ", node_->target())
}
executionFrame.setIValue(node_->outputs()[0]->id(), out);
}
void SymFloatOpKernel::computeInternal(ExecutionFrame& executionFrame) const {
auto inputs = getSymInputs(executionFrame, *node_);
const std::string_view target = node_->target();
if (target == "math.trunc") {
double x = inputs.at("x").toDouble();
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
int64_t out = trunc(x);
executionFrame.setIValue(node_->outputs()[0]->id(), out);
} else if (target == "torch._sym_sqrt") {
double a = inputs.at("a").toDouble();
double out = std::sqrt(a);
executionFrame.setIValue(node_->outputs()[0]->id(), out);
} else if (target == "_operator.neg") {
auto a = inputs.at("a");
c10::IValue out;
if (a.isInt()) {
out = -a.toInt();
} else if (a.isDouble()) {
out = -a.toDouble();
} else {
TORCH_CHECK(false, "unsupported type for symbolic input");
}
executionFrame.setIValue(node_->outputs()[0]->id(), out);
} else if (target == "_operator.truediv") {
auto ia = inputs.at("a");
double a = ia.isInt() ? static_cast<double>(ia.toInt()) : ia.toDouble();
auto ib = inputs.at("b");
double b = ib.isInt() ? static_cast<double>(ib.toInt()) : ib.toDouble();
double out = a / b;
executionFrame.setIValue(node_->outputs()[0]->id(), out);
} else {
TORCH_CHECK(false, "unsupported operator for SymFloat: ", node_->target());
}
}
} // namespace torch::nativert