[NativeRT] Make VariadicOpConverter and FuseListUnpackConverter for cpu nodes only (#159519)

Summary:
VariadicOpConverter and FuseListUnpackConverter would introduce ops that only have CPU kernels.

Currently, the graph passes are ran if static_dispatch is enabled.

As we plan to enable static_dispatch by default, this diff add the additional check for the graph pass to only work on the node that has all the inputs/outputs on CPU.

Test Plan:
CI

Rollback Plan:

Differential Revision: D79295640

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159519
Approved by: https://github.com/dolpm, https://github.com/henryoier
This commit is contained in:
Sherlock Huang
2025-07-31 18:17:21 +00:00
committed by PyTorch MergeBot
parent 8a233d6000
commit c1722db0f7
4 changed files with 104 additions and 0 deletions

View File

@ -599,6 +599,7 @@ libtorch_nativert_sources = [
"torch/nativert/graph/GraphSignature.cpp",
"torch/nativert/graph/Serialization.cpp",
"torch/nativert/graph/TensorMeta.cpp",
"torch/nativert/graph/GraphUtils.cpp",
"torch/nativert/executor/DelegateExecutor.cpp",
"torch/nativert/executor/Placement.cpp",
"torch/nativert/executor/ExecutionPlanner.cpp",

View File

@ -10,6 +10,7 @@ set(NATIVERT_TEST_SRCS
${TORCH_ROOT}/torch/nativert/graph/Graph.cpp
${TORCH_ROOT}/torch/nativert/graph/GraphPasses.cpp
${TORCH_ROOT}/torch/nativert/graph/GraphSignature.cpp
${TORCH_ROOT}/torch/nativert/graph/GraphUtils.cpp
${TORCH_ROOT}/torch/nativert/graph/Serialization.cpp
${TORCH_ROOT}/torch/nativert/executor/OpKernel.cpp
${TORCH_ROOT}/torch/nativert/executor/PlacementUtils.cpp

View File

@ -0,0 +1,80 @@
#include <torch/nativert/graph/GraphUtils.h>
#include <c10/core/Device.h>
#include <torch/nativert/graph/Graph.h>
namespace torch::nativert {
bool areAllIOTensorsAttributesOnCpu(const Node& node) {
const auto& tensorValuesMeta = node.owningGraph()->tensorValuesMeta();
// Check inputs
for (auto& input : node.inputs()) {
if (input.value->type() == Type::Kind::Tensor) {
if (auto it = tensorValuesMeta.find(std::string{input.value->name()});
it != tensorValuesMeta.end()) {
const auto& device = it->second.device();
if (!device.is_cpu()) {
return false;
}
}
} else if (input.value->type() == Type::Kind::TensorList) {
for (const auto& el : input.value->getListElements()) {
if (auto it = tensorValuesMeta.find(std::string{el->name()});
it != tensorValuesMeta.end()) {
const auto& device = it->second.device();
if (!device.is_cpu()) {
return false;
}
}
}
} else {
// other input types doesn't affect if the node is on CPU or not
}
}
// Check outputs
for (auto& output : node.outputs()) {
if (!output) {
// When a node's output is a Constant, its Value* is nullptr
// TODO: this is breaking the invariant of all nodes outputs are non-null
// in the graph. We should fix this.
continue;
}
if (output->type() == Type::Kind::Tensor) {
if (auto it = tensorValuesMeta.find(std::string{output->name()});
it != tensorValuesMeta.end()) {
const auto& device = it->second.device();
if (!device.is_cpu()) {
return false;
}
}
} else if (output->type() == Type::Kind::TensorList) {
for (const auto& el : output->getListElements()) {
if (auto it = tensorValuesMeta.find(std::string{el->name()});
it != tensorValuesMeta.end()) {
const auto& device = it->second.device();
if (!device.is_cpu()) {
return false;
}
}
}
} else {
// other output types doesn't affect if the node is on CPU or not
}
}
// Check attributes
for (auto& attribute : node.attributes()) {
if (std::holds_alternative<c10::Device>(attribute.value)) {
auto device = std::get<c10::Device>(attribute.value);
if (!device.is_cpu()) {
return false;
}
}
}
return true;
}
} // namespace torch::nativert

View File

@ -0,0 +1,22 @@
#pragma once
namespace torch::nativert {
class Node;
/**
* Utility functions for working with Graph nodes and values.
*/
/**
* Check if all input/output tensors are on CPU and all device-type attributes
* have the value of 'cpu'. This is a util function to check if a Node can use
* static dispatch CPU kernels.
*
* @param node The node to check
* @return true if all I/O tensors and device attributes are on CPU, false
* otherwise
*/
bool areAllIOTensorsAttributesOnCpu(const Node& node);
} // namespace torch::nativert