diff --git a/build_variables.bzl b/build_variables.bzl index 1dda77b63750..a226249db708 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -599,6 +599,7 @@ libtorch_nativert_sources = [ "torch/nativert/graph/GraphSignature.cpp", "torch/nativert/graph/Serialization.cpp", "torch/nativert/graph/TensorMeta.cpp", + "torch/nativert/graph/GraphUtils.cpp", "torch/nativert/executor/DelegateExecutor.cpp", "torch/nativert/executor/Placement.cpp", "torch/nativert/executor/ExecutionPlanner.cpp", diff --git a/test/cpp/nativert/CMakeLists.txt b/test/cpp/nativert/CMakeLists.txt index c05416ce0eef..0675357861f9 100644 --- a/test/cpp/nativert/CMakeLists.txt +++ b/test/cpp/nativert/CMakeLists.txt @@ -10,6 +10,7 @@ set(NATIVERT_TEST_SRCS ${TORCH_ROOT}/torch/nativert/graph/Graph.cpp ${TORCH_ROOT}/torch/nativert/graph/GraphPasses.cpp ${TORCH_ROOT}/torch/nativert/graph/GraphSignature.cpp + ${TORCH_ROOT}/torch/nativert/graph/GraphUtils.cpp ${TORCH_ROOT}/torch/nativert/graph/Serialization.cpp ${TORCH_ROOT}/torch/nativert/executor/OpKernel.cpp ${TORCH_ROOT}/torch/nativert/executor/PlacementUtils.cpp diff --git a/torch/nativert/graph/GraphUtils.cpp b/torch/nativert/graph/GraphUtils.cpp new file mode 100644 index 000000000000..ebe2d68cc0e7 --- /dev/null +++ b/torch/nativert/graph/GraphUtils.cpp @@ -0,0 +1,80 @@ +#include + +#include + +#include + +namespace torch::nativert { + +bool areAllIOTensorsAttributesOnCpu(const Node& node) { + const auto& tensorValuesMeta = node.owningGraph()->tensorValuesMeta(); + + // Check inputs + for (auto& input : node.inputs()) { + if (input.value->type() == Type::Kind::Tensor) { + if (auto it = tensorValuesMeta.find(std::string{input.value->name()}); + it != tensorValuesMeta.end()) { + const auto& device = it->second.device(); + if (!device.is_cpu()) { + return false; + } + } + } else if (input.value->type() == Type::Kind::TensorList) { + for (const auto& el : input.value->getListElements()) { + if (auto it = tensorValuesMeta.find(std::string{el->name()}); + it != tensorValuesMeta.end()) { + const auto& device = it->second.device(); + if (!device.is_cpu()) { + return false; + } + } + } + } else { + // other input types doesn't affect if the node is on CPU or not + } + } + + // Check outputs + for (auto& output : node.outputs()) { + if (!output) { + // When a node's output is a Constant, its Value* is nullptr + // TODO: this is breaking the invariant of all nodes outputs are non-null + // in the graph. We should fix this. + continue; + } + if (output->type() == Type::Kind::Tensor) { + if (auto it = tensorValuesMeta.find(std::string{output->name()}); + it != tensorValuesMeta.end()) { + const auto& device = it->second.device(); + if (!device.is_cpu()) { + return false; + } + } + } else if (output->type() == Type::Kind::TensorList) { + for (const auto& el : output->getListElements()) { + if (auto it = tensorValuesMeta.find(std::string{el->name()}); + it != tensorValuesMeta.end()) { + const auto& device = it->second.device(); + if (!device.is_cpu()) { + return false; + } + } + } + } else { + // other output types doesn't affect if the node is on CPU or not + } + } + + // Check attributes + for (auto& attribute : node.attributes()) { + if (std::holds_alternative(attribute.value)) { + auto device = std::get(attribute.value); + if (!device.is_cpu()) { + return false; + } + } + } + return true; +} + +} // namespace torch::nativert diff --git a/torch/nativert/graph/GraphUtils.h b/torch/nativert/graph/GraphUtils.h new file mode 100644 index 000000000000..593317ebb29b --- /dev/null +++ b/torch/nativert/graph/GraphUtils.h @@ -0,0 +1,22 @@ +#pragma once + +namespace torch::nativert { + +class Node; + +/** + * Utility functions for working with Graph nodes and values. + */ + +/** + * Check if all input/output tensors are on CPU and all device-type attributes + * have the value of 'cpu'. This is a util function to check if a Node can use + * static dispatch CPU kernels. + * + * @param node The node to check + * @return true if all I/O tensors and device attributes are on CPU, false + * otherwise + */ +bool areAllIOTensorsAttributesOnCpu(const Node& node); + +} // namespace torch::nativert