mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[NativeRT] Make VariadicOpConverter and FuseListUnpackConverter for cpu nodes only (#159519)
Summary: VariadicOpConverter and FuseListUnpackConverter would introduce ops that only have CPU kernels. Currently, the graph passes are ran if static_dispatch is enabled. As we plan to enable static_dispatch by default, this diff add the additional check for the graph pass to only work on the node that has all the inputs/outputs on CPU. Test Plan: CI Rollback Plan: Differential Revision: D79295640 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159519 Approved by: https://github.com/dolpm, https://github.com/henryoier
This commit is contained in:
committed by
PyTorch MergeBot
parent
8a233d6000
commit
c1722db0f7
@ -599,6 +599,7 @@ libtorch_nativert_sources = [
|
||||
"torch/nativert/graph/GraphSignature.cpp",
|
||||
"torch/nativert/graph/Serialization.cpp",
|
||||
"torch/nativert/graph/TensorMeta.cpp",
|
||||
"torch/nativert/graph/GraphUtils.cpp",
|
||||
"torch/nativert/executor/DelegateExecutor.cpp",
|
||||
"torch/nativert/executor/Placement.cpp",
|
||||
"torch/nativert/executor/ExecutionPlanner.cpp",
|
||||
|
@ -10,6 +10,7 @@ set(NATIVERT_TEST_SRCS
|
||||
${TORCH_ROOT}/torch/nativert/graph/Graph.cpp
|
||||
${TORCH_ROOT}/torch/nativert/graph/GraphPasses.cpp
|
||||
${TORCH_ROOT}/torch/nativert/graph/GraphSignature.cpp
|
||||
${TORCH_ROOT}/torch/nativert/graph/GraphUtils.cpp
|
||||
${TORCH_ROOT}/torch/nativert/graph/Serialization.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/OpKernel.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/PlacementUtils.cpp
|
||||
|
80
torch/nativert/graph/GraphUtils.cpp
Normal file
80
torch/nativert/graph/GraphUtils.cpp
Normal file
@ -0,0 +1,80 @@
|
||||
#include <torch/nativert/graph/GraphUtils.h>
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
|
||||
#include <torch/nativert/graph/Graph.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
bool areAllIOTensorsAttributesOnCpu(const Node& node) {
|
||||
const auto& tensorValuesMeta = node.owningGraph()->tensorValuesMeta();
|
||||
|
||||
// Check inputs
|
||||
for (auto& input : node.inputs()) {
|
||||
if (input.value->type() == Type::Kind::Tensor) {
|
||||
if (auto it = tensorValuesMeta.find(std::string{input.value->name()});
|
||||
it != tensorValuesMeta.end()) {
|
||||
const auto& device = it->second.device();
|
||||
if (!device.is_cpu()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else if (input.value->type() == Type::Kind::TensorList) {
|
||||
for (const auto& el : input.value->getListElements()) {
|
||||
if (auto it = tensorValuesMeta.find(std::string{el->name()});
|
||||
it != tensorValuesMeta.end()) {
|
||||
const auto& device = it->second.device();
|
||||
if (!device.is_cpu()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// other input types doesn't affect if the node is on CPU or not
|
||||
}
|
||||
}
|
||||
|
||||
// Check outputs
|
||||
for (auto& output : node.outputs()) {
|
||||
if (!output) {
|
||||
// When a node's output is a Constant, its Value* is nullptr
|
||||
// TODO: this is breaking the invariant of all nodes outputs are non-null
|
||||
// in the graph. We should fix this.
|
||||
continue;
|
||||
}
|
||||
if (output->type() == Type::Kind::Tensor) {
|
||||
if (auto it = tensorValuesMeta.find(std::string{output->name()});
|
||||
it != tensorValuesMeta.end()) {
|
||||
const auto& device = it->second.device();
|
||||
if (!device.is_cpu()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else if (output->type() == Type::Kind::TensorList) {
|
||||
for (const auto& el : output->getListElements()) {
|
||||
if (auto it = tensorValuesMeta.find(std::string{el->name()});
|
||||
it != tensorValuesMeta.end()) {
|
||||
const auto& device = it->second.device();
|
||||
if (!device.is_cpu()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// other output types doesn't affect if the node is on CPU or not
|
||||
}
|
||||
}
|
||||
|
||||
// Check attributes
|
||||
for (auto& attribute : node.attributes()) {
|
||||
if (std::holds_alternative<c10::Device>(attribute.value)) {
|
||||
auto device = std::get<c10::Device>(attribute.value);
|
||||
if (!device.is_cpu()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
22
torch/nativert/graph/GraphUtils.h
Normal file
22
torch/nativert/graph/GraphUtils.h
Normal file
@ -0,0 +1,22 @@
|
||||
#pragma once
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
class Node;
|
||||
|
||||
/**
|
||||
* Utility functions for working with Graph nodes and values.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Check if all input/output tensors are on CPU and all device-type attributes
|
||||
* have the value of 'cpu'. This is a util function to check if a Node can use
|
||||
* static dispatch CPU kernels.
|
||||
*
|
||||
* @param node The node to check
|
||||
* @return true if all I/O tensors and device attributes are on CPU, false
|
||||
* otherwise
|
||||
*/
|
||||
bool areAllIOTensorsAttributesOnCpu(const Node& node);
|
||||
|
||||
} // namespace torch::nativert
|
Reference in New Issue
Block a user