mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: VariadicOpConverter and FuseListUnpackConverter would introduce ops that only have CPU kernels. Currently, the graph passes are ran if static_dispatch is enabled. As we plan to enable static_dispatch by default, this diff add the additional check for the graph pass to only work on the node that has all the inputs/outputs on CPU. Test Plan: CI Rollback Plan: Differential Revision: D79295640 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159519 Approved by: https://github.com/dolpm, https://github.com/henryoier
81 lines
2.4 KiB
C++
81 lines
2.4 KiB
C++
#include <torch/nativert/graph/GraphUtils.h>
|
|
|
|
#include <c10/core/Device.h>
|
|
|
|
#include <torch/nativert/graph/Graph.h>
|
|
|
|
namespace torch::nativert {
|
|
|
|
bool areAllIOTensorsAttributesOnCpu(const Node& node) {
|
|
const auto& tensorValuesMeta = node.owningGraph()->tensorValuesMeta();
|
|
|
|
// Check inputs
|
|
for (auto& input : node.inputs()) {
|
|
if (input.value->type() == Type::Kind::Tensor) {
|
|
if (auto it = tensorValuesMeta.find(std::string{input.value->name()});
|
|
it != tensorValuesMeta.end()) {
|
|
const auto& device = it->second.device();
|
|
if (!device.is_cpu()) {
|
|
return false;
|
|
}
|
|
}
|
|
} else if (input.value->type() == Type::Kind::TensorList) {
|
|
for (const auto& el : input.value->getListElements()) {
|
|
if (auto it = tensorValuesMeta.find(std::string{el->name()});
|
|
it != tensorValuesMeta.end()) {
|
|
const auto& device = it->second.device();
|
|
if (!device.is_cpu()) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
// other input types doesn't affect if the node is on CPU or not
|
|
}
|
|
}
|
|
|
|
// Check outputs
|
|
for (auto& output : node.outputs()) {
|
|
if (!output) {
|
|
// When a node's output is a Constant, its Value* is nullptr
|
|
// TODO: this is breaking the invariant of all nodes outputs are non-null
|
|
// in the graph. We should fix this.
|
|
continue;
|
|
}
|
|
if (output->type() == Type::Kind::Tensor) {
|
|
if (auto it = tensorValuesMeta.find(std::string{output->name()});
|
|
it != tensorValuesMeta.end()) {
|
|
const auto& device = it->second.device();
|
|
if (!device.is_cpu()) {
|
|
return false;
|
|
}
|
|
}
|
|
} else if (output->type() == Type::Kind::TensorList) {
|
|
for (const auto& el : output->getListElements()) {
|
|
if (auto it = tensorValuesMeta.find(std::string{el->name()});
|
|
it != tensorValuesMeta.end()) {
|
|
const auto& device = it->second.device();
|
|
if (!device.is_cpu()) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
// other output types doesn't affect if the node is on CPU or not
|
|
}
|
|
}
|
|
|
|
// Check attributes
|
|
for (auto& attribute : node.attributes()) {
|
|
if (std::holds_alternative<c10::Device>(attribute.value)) {
|
|
auto device = std::get<c10::Device>(attribute.value);
|
|
if (!device.is_cpu()) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
} // namespace torch::nativert
|