Files
pytorch/torch/nativert/kernels/C10Kernel.cpp
Shangdi Yu e4c9f6d9a2 [nativert] Move c10_kernel (#156208)
Summary:
Torch Native Runtime RFC: https://github.com/pytorch/rfcs/pull/72

As part of the effort to open source TorchNativeRuntime (or what we call Sigmoid), we are moving the Pytree implementation to torch/:

fbcode/sigmoid/kernels -> fbcode/caffe2/torch/nativert/kernels

Test Plan:
```
buck run fbcode//mode/dev-nosan  //caffe2/test/cpp/nativert:c10_kernel_test
```

Differential Revision: D76825830

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156208
Approved by: https://github.com/zhxchen17
2025-06-19 17:36:23 +00:00

266 lines
8.3 KiB
C++

#include <torch/nativert/kernels/C10Kernel.h>
#include <fmt/ostream.h>
#include <c10/util/Enumerate.h>
#ifdef __SIGRID_USE_GPU__
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#endif
namespace torch::nativert {
C10Kernel::C10Kernel(
const Node* node,
c10::Device device,
OpKernelKind kind,
AliasingSpec&& aliasingSpec)
: OpKernel(node, device, kind),
op_(getOperatorForTarget(node->target(), node)),
schema_(op_.schema(), std::move(aliasingSpec), kind_),
arguments_(prefillStackWithStaticArgs(node, op_.schema())) {}
void C10Kernel::computeInternal(ExecutionFrame& executionFrame) const {
// Make a copy of the stack
std::vector<c10::IValue> stack = arguments_.getStackWithStaticArgs();
fillDynamicInputs(executionFrame, arguments_, stack);
// Call the op with the prepared stack.
try {
op_.callBoxed(stack);
} catch (const std::exception& ex) {
auto stackTrace = node_->getMetadata("stack_trace");
throw std::runtime_error(fmt::format(
"Exception while executing node: {}\n"
"with args:\n{}\n"
"{}\n"
"Original Python stacktrace:\n{}",
fmt::streamed(*node_),
readableArgs(op_.schema(), stack),
ex.what(),
stackTrace ? *stackTrace : "<no stack trace>"));
}
// Write out results
// TODO: we store intermediates in a single table (symint and tensor alike).
// This can theoretically lead to name collisions, although based on how
// these are named I don't think it will ever happen in practice. We need to
// enforce it though.
const auto& outputValues = node_->outputs();
TORCH_CHECK_EQ(outputValues.size(), stack.size())
<< "Output size mismatch for " << node_->toString();
for (auto&& [i, actualOutput] : c10::enumerate(stack)) {
executionFrame.setIValue(outputValues[i]->id(), std::move(actualOutput));
}
}
namespace {
std::unordered_map<std::string, c10::IValue> getSymInputs(
const ExecutionFrame& executionFrame,
const Node& node) {
std::unordered_map<std::string, c10::IValue> inputs;
for (const auto& input : node.inputs()) {
const auto& val = executionFrame.getIValue(input.value->id());
if (val.isInt() || val.isDouble() || val.isBool()) {
inputs[input.name] = val;
} else {
throw std::runtime_error("unsupported type for symbolic input");
}
}
for (const auto& attribute : node.attributes()) {
if (std::holds_alternative<int64_t>(attribute.value)) {
inputs[attribute.name] = std::get<int64_t>(attribute.value);
} else if (std::holds_alternative<double>(attribute.value)) {
inputs[attribute.name] = std::get<double>(attribute.value);
} else if (std::holds_alternative<bool>(attribute.value)) {
inputs[attribute.name] = std::get<bool>(attribute.value);
} else {
throw std::runtime_error("unsupported type for symbolic input");
}
}
return inputs;
}
template <typename T>
void computeScalarBinaryOp(
ExecutionFrame& executionFrame,
const Node& node,
std::enable_if_t<true, T> a,
std::enable_if_t<true, T> b) {
std::string_view target = node.target();
T out;
if (target == "_operator.add") {
out = a + b;
} else if (target == "_operator.sub") {
out = a - b;
} else if (target == "_operator.mul") {
out = a * b;
} else if (target == "_operator.pow") {
out = std::pow(a, b);
} else {
throw std::runtime_error(
fmt::format("unsupported operator for symbolic values: {}", target));
}
executionFrame.setIValue(node.outputs()[0]->id(), out);
VLOG(2) << fmt::format(
"Completed executing node: {} with a={}, b={}, out={}",
fmt::streamed(node),
a,
b,
out);
}
} // namespace
void ScalarBinaryOpKernel::computeInternal(
ExecutionFrame& executionFrame) const {
auto inputs = getSymInputs(executionFrame, *node_);
const auto& a = inputs.at("a");
const auto& b = inputs.at("b");
auto coerceToDouble = [](const c10::IValue& x) -> double {
if (x.isInt()) {
return static_cast<double>(x.toInt());
} else if (x.isDouble()) {
return x.toDouble();
} else {
throw std::runtime_error("unsupported type for symbolic input");
}
};
if (a.isInt() && b.isInt()) {
computeScalarBinaryOp<int64_t>(
executionFrame, *node_, a.toInt(), b.toInt());
} else {
computeScalarBinaryOp<double>(
executionFrame, *node_, coerceToDouble(a), coerceToDouble(b));
}
}
void SymIntOpKernel::computeInternal(ExecutionFrame& executionFrame) const {
auto inputs = getSymInputs(executionFrame, *node_);
int64_t a = inputs.at("a").toInt();
std::string_view target = node_->target();
if (target == "torch.sym_float") {
double out = static_cast<double>(a);
executionFrame.setIValue(node_->outputs()[0]->id(), out);
VLOG(2) << fmt::format(
"Completed executing node: {} with a={}, out={}",
fmt::streamed(*node_),
a,
out);
return;
}
int64_t b = inputs.at("b").toInt();
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t out;
if (target == "_operator.floordiv") {
out = a / b;
} else if (target == "_operator.mod") {
out = a % b;
} else if (target == "torch.sym_max") {
out = std::max(a, b);
} else if (target == "torch.sym_min") {
out = std::min(a, b);
} else {
throw std::runtime_error(
fmt::format("unsupported operator for SymInt: {}", node_->target()));
}
executionFrame.setIValue(node_->outputs()[0]->id(), out);
VLOG(2) << fmt::format(
"Completed executing node: {} with a={}, b={}, out={}",
fmt::streamed(*node_),
a,
b,
out);
}
void SymBoolOpKernel::computeInternal(ExecutionFrame& executionFrame) const {
auto inputs = getSymInputs(executionFrame, *node_);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool out;
const std::string_view target = node_->target();
if (target == "torch.sym_not") {
bool a = inputs.at("a").toBool();
out = !a;
} else if (target == "_operator.ge") {
int64_t a = inputs.at("a").toInt();
int64_t b = inputs.at("b").toInt();
out = a >= b;
} else if (target == "_operator.le") {
int64_t a = inputs.at("a").toInt();
int64_t b = inputs.at("b").toInt();
out = a <= b;
} else if (target == "_operator.eq") {
int64_t a = inputs.at("a").toInt();
int64_t b = inputs.at("b").toInt();
out = a == b;
} else if (target == "_operator.gt") {
int64_t a = inputs.at("a").toInt();
int64_t b = inputs.at("b").toInt();
out = a > b;
} else if (target == "_operator.lt") {
int64_t a = inputs.at("a").toInt();
int64_t b = inputs.at("b").toInt();
out = a < b;
} else if (target == "_operator.and_") {
bool a = inputs.at("a").toBool();
bool b = inputs.at("b").toBool();
out = a && b;
} else {
throw std::runtime_error(
fmt::format("unsupported operator for SymBool: {}", node_->target()));
}
executionFrame.setIValue(node_->outputs()[0]->id(), out);
}
void SymFloatOpKernel::computeInternal(ExecutionFrame& executionFrame) const {
auto inputs = getSymInputs(executionFrame, *node_);
const std::string_view target = node_->target();
if (target == "math.trunc") {
double x = inputs.at("x").toDouble();
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
int64_t out = trunc(x);
executionFrame.setIValue(node_->outputs()[0]->id(), out);
} else if (target == "torch._sym_sqrt") {
double a = inputs.at("a").toDouble();
double out = std::sqrt(a);
executionFrame.setIValue(node_->outputs()[0]->id(), out);
} else if (target == "_operator.neg") {
auto a = inputs.at("a");
c10::IValue out;
if (a.isInt()) {
out = -a.toInt();
} else if (a.isDouble()) {
out = -a.toDouble();
} else {
throw std::runtime_error("unsupported type for symbolic input");
}
executionFrame.setIValue(node_->outputs()[0]->id(), out);
} else if (target == "_operator.truediv") {
auto ia = inputs.at("a");
double a = ia.isInt() ? static_cast<double>(ia.toInt()) : ia.toDouble();
auto ib = inputs.at("b");
double b = ib.isInt() ? static_cast<double>(ib.toInt()) : ib.toDouble();
double out = a / b;
executionFrame.setIValue(node_->outputs()[0]->id(), out);
} else {
throw std::runtime_error(
fmt::format("unsupported operator for SymFloat: {}", node_->target()));
}
}
} // namespace torch::nativert