mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[NativeRT] Make VariadicOpConverter and FuseListUnpackConverter for cpu nodes only (#159519)
Summary: VariadicOpConverter and FuseListUnpackConverter would introduce ops that only have CPU kernels. Currently, the graph passes are ran if static_dispatch is enabled. As we plan to enable static_dispatch by default, this diff add the additional check for the graph pass to only work on the node that has all the inputs/outputs on CPU. Test Plan: CI Rollback Plan: Differential Revision: D79295640 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159519 Approved by: https://github.com/dolpm, https://github.com/henryoier
This commit is contained in:
committed by
PyTorch MergeBot
parent
8a233d6000
commit
c1722db0f7
@ -599,6 +599,7 @@ libtorch_nativert_sources = [
|
|||||||
"torch/nativert/graph/GraphSignature.cpp",
|
"torch/nativert/graph/GraphSignature.cpp",
|
||||||
"torch/nativert/graph/Serialization.cpp",
|
"torch/nativert/graph/Serialization.cpp",
|
||||||
"torch/nativert/graph/TensorMeta.cpp",
|
"torch/nativert/graph/TensorMeta.cpp",
|
||||||
|
"torch/nativert/graph/GraphUtils.cpp",
|
||||||
"torch/nativert/executor/DelegateExecutor.cpp",
|
"torch/nativert/executor/DelegateExecutor.cpp",
|
||||||
"torch/nativert/executor/Placement.cpp",
|
"torch/nativert/executor/Placement.cpp",
|
||||||
"torch/nativert/executor/ExecutionPlanner.cpp",
|
"torch/nativert/executor/ExecutionPlanner.cpp",
|
||||||
|
@ -10,6 +10,7 @@ set(NATIVERT_TEST_SRCS
|
|||||||
${TORCH_ROOT}/torch/nativert/graph/Graph.cpp
|
${TORCH_ROOT}/torch/nativert/graph/Graph.cpp
|
||||||
${TORCH_ROOT}/torch/nativert/graph/GraphPasses.cpp
|
${TORCH_ROOT}/torch/nativert/graph/GraphPasses.cpp
|
||||||
${TORCH_ROOT}/torch/nativert/graph/GraphSignature.cpp
|
${TORCH_ROOT}/torch/nativert/graph/GraphSignature.cpp
|
||||||
|
${TORCH_ROOT}/torch/nativert/graph/GraphUtils.cpp
|
||||||
${TORCH_ROOT}/torch/nativert/graph/Serialization.cpp
|
${TORCH_ROOT}/torch/nativert/graph/Serialization.cpp
|
||||||
${TORCH_ROOT}/torch/nativert/executor/OpKernel.cpp
|
${TORCH_ROOT}/torch/nativert/executor/OpKernel.cpp
|
||||||
${TORCH_ROOT}/torch/nativert/executor/PlacementUtils.cpp
|
${TORCH_ROOT}/torch/nativert/executor/PlacementUtils.cpp
|
||||||
|
80
torch/nativert/graph/GraphUtils.cpp
Normal file
80
torch/nativert/graph/GraphUtils.cpp
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
#include <torch/nativert/graph/GraphUtils.h>
|
||||||
|
|
||||||
|
#include <c10/core/Device.h>
|
||||||
|
|
||||||
|
#include <torch/nativert/graph/Graph.h>
|
||||||
|
|
||||||
|
namespace torch::nativert {
|
||||||
|
|
||||||
|
bool areAllIOTensorsAttributesOnCpu(const Node& node) {
|
||||||
|
const auto& tensorValuesMeta = node.owningGraph()->tensorValuesMeta();
|
||||||
|
|
||||||
|
// Check inputs
|
||||||
|
for (auto& input : node.inputs()) {
|
||||||
|
if (input.value->type() == Type::Kind::Tensor) {
|
||||||
|
if (auto it = tensorValuesMeta.find(std::string{input.value->name()});
|
||||||
|
it != tensorValuesMeta.end()) {
|
||||||
|
const auto& device = it->second.device();
|
||||||
|
if (!device.is_cpu()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (input.value->type() == Type::Kind::TensorList) {
|
||||||
|
for (const auto& el : input.value->getListElements()) {
|
||||||
|
if (auto it = tensorValuesMeta.find(std::string{el->name()});
|
||||||
|
it != tensorValuesMeta.end()) {
|
||||||
|
const auto& device = it->second.device();
|
||||||
|
if (!device.is_cpu()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// other input types doesn't affect if the node is on CPU or not
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check outputs
|
||||||
|
for (auto& output : node.outputs()) {
|
||||||
|
if (!output) {
|
||||||
|
// When a node's output is a Constant, its Value* is nullptr
|
||||||
|
// TODO: this is breaking the invariant of all nodes outputs are non-null
|
||||||
|
// in the graph. We should fix this.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (output->type() == Type::Kind::Tensor) {
|
||||||
|
if (auto it = tensorValuesMeta.find(std::string{output->name()});
|
||||||
|
it != tensorValuesMeta.end()) {
|
||||||
|
const auto& device = it->second.device();
|
||||||
|
if (!device.is_cpu()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (output->type() == Type::Kind::TensorList) {
|
||||||
|
for (const auto& el : output->getListElements()) {
|
||||||
|
if (auto it = tensorValuesMeta.find(std::string{el->name()});
|
||||||
|
it != tensorValuesMeta.end()) {
|
||||||
|
const auto& device = it->second.device();
|
||||||
|
if (!device.is_cpu()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// other output types doesn't affect if the node is on CPU or not
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check attributes
|
||||||
|
for (auto& attribute : node.attributes()) {
|
||||||
|
if (std::holds_alternative<c10::Device>(attribute.value)) {
|
||||||
|
auto device = std::get<c10::Device>(attribute.value);
|
||||||
|
if (!device.is_cpu()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace torch::nativert
|
22
torch/nativert/graph/GraphUtils.h
Normal file
22
torch/nativert/graph/GraphUtils.h
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace torch::nativert {
|
||||||
|
|
||||||
|
class Node;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Utility functions for working with Graph nodes and values.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if all input/output tensors are on CPU and all device-type attributes
|
||||||
|
* have the value of 'cpu'. This is a util function to check if a Node can use
|
||||||
|
* static dispatch CPU kernels.
|
||||||
|
*
|
||||||
|
* @param node The node to check
|
||||||
|
* @return true if all I/O tensors and device attributes are on CPU, false
|
||||||
|
* otherwise
|
||||||
|
*/
|
||||||
|
bool areAllIOTensorsAttributesOnCpu(const Node& node);
|
||||||
|
|
||||||
|
} // namespace torch::nativert
|
Reference in New Issue
Block a user