Files
pytorch/torch/nativert/graph/GraphUtils.cpp
Sherlock Huang c1722db0f7 [NativeRT] Make VariadicOpConverter and FuseListUnpackConverter for cpu nodes only (#159519)
Summary:
VariadicOpConverter and FuseListUnpackConverter would introduce ops that only have CPU kernels.

Currently, the graph passes are ran if static_dispatch is enabled.

As we plan to enable static_dispatch by default, this diff add the additional check for the graph pass to only work on the node that has all the inputs/outputs on CPU.

Test Plan:
CI

Rollback Plan:

Differential Revision: D79295640

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159519
Approved by: https://github.com/dolpm, https://github.com/henryoier
2025-07-31 18:17:21 +00:00

81 lines
2.4 KiB
C++

#include <torch/nativert/graph/GraphUtils.h>
#include <c10/core/Device.h>
#include <torch/nativert/graph/Graph.h>
namespace torch::nativert {
bool areAllIOTensorsAttributesOnCpu(const Node& node) {
const auto& tensorValuesMeta = node.owningGraph()->tensorValuesMeta();
// Check inputs
for (auto& input : node.inputs()) {
if (input.value->type() == Type::Kind::Tensor) {
if (auto it = tensorValuesMeta.find(std::string{input.value->name()});
it != tensorValuesMeta.end()) {
const auto& device = it->second.device();
if (!device.is_cpu()) {
return false;
}
}
} else if (input.value->type() == Type::Kind::TensorList) {
for (const auto& el : input.value->getListElements()) {
if (auto it = tensorValuesMeta.find(std::string{el->name()});
it != tensorValuesMeta.end()) {
const auto& device = it->second.device();
if (!device.is_cpu()) {
return false;
}
}
}
} else {
// other input types doesn't affect if the node is on CPU or not
}
}
// Check outputs
for (auto& output : node.outputs()) {
if (!output) {
// When a node's output is a Constant, its Value* is nullptr
// TODO: this is breaking the invariant of all nodes outputs are non-null
// in the graph. We should fix this.
continue;
}
if (output->type() == Type::Kind::Tensor) {
if (auto it = tensorValuesMeta.find(std::string{output->name()});
it != tensorValuesMeta.end()) {
const auto& device = it->second.device();
if (!device.is_cpu()) {
return false;
}
}
} else if (output->type() == Type::Kind::TensorList) {
for (const auto& el : output->getListElements()) {
if (auto it = tensorValuesMeta.find(std::string{el->name()});
it != tensorValuesMeta.end()) {
const auto& device = it->second.device();
if (!device.is_cpu()) {
return false;
}
}
}
} else {
// other output types doesn't affect if the node is on CPU or not
}
}
// Check attributes
for (auto& attribute : node.attributes()) {
if (std::holds_alternative<c10::Device>(attribute.value)) {
auto device = std::get<c10::Device>(attribute.value);
if (!device.is_cpu()) {
return false;
}
}
}
return true;
}
} // namespace torch::nativert