From 88c6199db09372b6e2d55a5349ab545527842727 Mon Sep 17 00:00:00 2001 From: Sheng Qin Date: Sat, 28 Jun 2025 06:34:21 +0000 Subject: [PATCH] [nativert] Move KernelFactory to PyTorch core (#156913) Summary: Kernel factory handles the kernel nodes initializations and different type of kernels executions. Test Plan: CI Rollback Plan: Differential Revision: D77346836 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156913 Approved by: https://github.com/zhxchen17 --- build_variables.bzl | 1 + torch/nativert/kernels/KernelFactory.cpp | 270 +++++++++++++++++++++++ torch/nativert/kernels/KernelFactory.h | 89 ++++++++ 3 files changed, 360 insertions(+) create mode 100644 torch/nativert/kernels/KernelFactory.cpp create mode 100644 torch/nativert/kernels/KernelFactory.h diff --git a/build_variables.bzl b/build_variables.bzl index 77fad7cdc5cb..76f21a6c1ac5 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -617,6 +617,7 @@ libtorch_nativert_sources = [ "torch/nativert/executor/memory/Bump.cpp", "torch/nativert/executor/ParallelGraphExecutor.cpp", "torch/nativert/kernels/CallTorchBindKernel.cpp", + "torch/nativert/kernels/KernelFactory.cpp", "torch/nativert/kernels/PrimKernelRegistry.cpp", "torch/nativert/executor/memory/DisjointStorageGroups.cpp", "torch/nativert/executor/memory/AliasAnalyzer.cpp", diff --git a/torch/nativert/kernels/KernelFactory.cpp b/torch/nativert/kernels/KernelFactory.cpp new file mode 100644 index 000000000000..1f72fef810d6 --- /dev/null +++ b/torch/nativert/kernels/KernelFactory.cpp @@ -0,0 +1,270 @@ +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch::nativert { + +namespace { + +c10::Device inferTargetDevice( + const Node& node, + const std::unordered_map& + tensorValuesMeta, + const Placement& placement) { + if (node.target() == "prim.Input" || node.target() == "prim.Output") { + return c10::Device(c10::DeviceType::CPU); + } + + std::vector devices; + for (auto& output : node.outputs()) { + if (output->type() == Type::Kind::Tensor) { + auto it = tensorValuesMeta.find(std::string{output->name()}); + if (it != tensorValuesMeta.end()) { + devices.emplace_back(it->second.device()); + } + } else if (output->type() == Type::Kind::TensorList) { + for (const auto& el : output->getListElements()) { + auto it = tensorValuesMeta.find(std::string{el->name()}); + if (it != tensorValuesMeta.end()) { + devices.emplace_back(it->second.device()); + } + } + } + } + + if (devices.empty()) { + return c10::Device(c10::DeviceType::CPU); + } else { + for (size_t i = 1; i < devices.size(); ++i) { + if (!torch::nativert::isSameDevice(devices[0], devices[i])) { + LOG(WARNING) << "Node " << node + << " has outputs on multiple devices: " << devices[0] + << " and " << devices[i]; + } + } + + return placement.getMappedDevice(devices[0]); + } +} + +} // namespace + +inline constexpr std::string_view kSymIntOps[] = { + "_operator.floordiv", + "_operator.mod", + "torch.sym_int", + "torch.sym_float", + "torch.sym_ite", + "torch.sym_max", + "torch.sym_min", +}; + +inline constexpr std::string_view kSymBoolOps[] = { + "_operator.eq", + "_operator.ne", + "_operator.le", + "_operator.ge", + "_operator.lt", + "_operator.gt", + "_operator.and_", + "torch.sym_not", +}; + +inline constexpr std::string_view kSymFloatOps[] = { + "torch._sym_sqrt", + "math.trunc", + "_operator.neg", + "_operator.truediv", +}; + +inline constexpr std::string_view kScalarBinaryOps[] = { + "_operator.mul", + "_operator.add", + "_operator.sub", + "_operator.pow", +}; + +namespace { + +struct KernelFactoryRegistry { + std::unordered_map handlers; +}; + +c10::Synchronized& getKernelFactoryRegistry() { + static auto* registry = new c10::Synchronized(); + return *registry; +} + +} // namespace + +void KernelFactory::registerHandler( + const std::string& name, + KernelFactoryHandler handler) { + auto& registry = getKernelFactoryRegistry(); + registry.withLock([&](auto&& reg) { + if (reg.handlers.find(name) != reg.handlers.end()) { + TORCH_CHECK(false, "Handler for ", name, " already registered"); + } + reg.handlers.emplace(name, std::move(handler)); + }); +} + +ExecutionKernels KernelFactory::initializeNodeKernels( + const Graph& graph, + std::shared_ptr weights, + const torch::nativert::ExecutorConfig& executorConfig, + const Placement& placement, + std::shared_ptr pytorchStreamReader, + const MakeProxyExecutorFn& makeProxyExecutorFunc) { + std::vector> nodeKernels; + std::vector> delegateExecutors; + std::vector constFoldingExecutions; + + std::unordered_map opsWithoutStaticDispatchCount; + + VLOG(1) << fmt::format( + "PrimKernelRegistry: {}", fmt::join(PrimKernelRegistry()->Keys(), ", ")); + + std::unordered_map handlers; + getKernelFactoryRegistry().withLock( + [&](auto&& reg) { handlers = reg.handlers; }); + + for (const auto& node : graph.nodes()) { + std::string target = std::string(node.target()); + + c10::Device targetDevice = + inferTargetDevice(node, graph.tensorValuesMeta(), placement); + + bool matched = false; + for (const auto& [_, handler] : handlers) { + if (handler.match(node, executorConfig, targetDevice)) { + auto [kernel, delegate] = handler( + node, + weights, + executorConfig, + pytorchStreamReader.get(), + targetDevice); + if (kernel) { + nodeKernels.push_back(std::move(kernel)); + } + if (delegate) { + delegateExecutors.push_back(std::move(delegate)); + } + matched = true; + break; + } + } + if (matched) { + continue; + } + + if (PrimKernelRegistry()->Has(target)) { + nodeKernels.push_back(PrimKernelRegistry()->Create(target, &node)); + } else if (c10::starts_with( + node.target(), "torch.ops.higher_order.call_torchbind")) { + nodeKernels.push_back(std::make_unique(&node)); + } else if ( + c10::starts_with( + node.target(), + "torch.ops.higher_order.auto_functionalized") || + c10::starts_with( // TODO Remove this condition once the old + // pt2 archives are expired. + node.target(), + "torch._higher_order_ops.auto_functionalize.auto_functionalized")) { + nodeKernels.push_back( + std::make_unique(&node)); + } else if ( + std::find( + std::begin(kSymIntOps), std::end(kSymIntOps), node.target()) != + std::end(kSymIntOps)) { + nodeKernels.push_back(std::make_unique(&node)); + } else if ( + std::find( + std::begin(kSymBoolOps), std::end(kSymBoolOps), node.target()) != + std::end(kSymBoolOps)) { + nodeKernels.push_back(std::make_unique(&node)); + } else if ( + std::find( + std::begin(kSymFloatOps), std::end(kSymFloatOps), node.target()) != + std::end(kSymFloatOps)) { + nodeKernels.push_back(std::make_unique(&node)); + } else if ( + std::find( + std::begin(kScalarBinaryOps), + std::end(kScalarBinaryOps), + node.target()) != std::end(kScalarBinaryOps)) { + nodeKernels.push_back(std::make_unique(&node)); + } else if (c10::starts_with(node.target(), "torch.ops.higher_order")) { + std::vector> graphExecutors; + for (const auto& attr : node.attributes()) { + if (std::holds_alternative>(attr.value)) { + const auto& subgraph = std::get>(attr.value); + auto executionKernels = initializeNodeKernels( + *subgraph, weights, executorConfig, placement); + CHECK(executionKernels.delegateExecutors.empty()) + << "HigherOrderKernel does not support delegates"; + CHECK(executionKernels.constFoldingExecutions.size() == 0) + << "HigherOrderKernel does not support const folding"; + if (executorConfig.maxParallelOps > 1) { + graphExecutors.emplace_back( + std::unique_ptr(new ParallelGraphExecutor( + *subgraph, + std::move(executionKernels.nodeKernels), + executorConfig))); + } else { + graphExecutors.emplace_back(std::unique_ptr( + new torch::nativert::SerialGraphExecutor( + *subgraph, + std::move(executionKernels.nodeKernels), + executorConfig))); + } + } + } + if (node.target() == "torch.ops.higher_order.run_const_graph") { + constFoldingExecutions.push_back( + ConstFoldingExecution{std::move(graphExecutors[0])}); + } + nodeKernels.push_back(std::make_unique( + &node, std::move(graphExecutors))); + } else if (c10::starts_with(node.target(), "torch.ops")) { + nodeKernels.push_back(std::make_unique(&node, targetDevice)); + + std::string opName = std::string(node.target()); + if (opsWithoutStaticDispatchCount.find(opName) == + opsWithoutStaticDispatchCount.end()) { + opsWithoutStaticDispatchCount[opName] = 0; + } + opsWithoutStaticDispatchCount[opName] += 1; + } else { + TORCH_CHECK(false, "Unsupported operator: ", target); + } + } + + if (executorConfig.enableStaticCPUKernels) { + std::stringstream ss; + for (const auto& [op, count] : opsWithoutStaticDispatchCount) { + ss << op << ": " << count << ", \n"; + } + LOG(WARNING) << "Following ops are missing static dispatched kernels: \n" + << ss.str(); + } + + return { + std::move(nodeKernels), + std::move(delegateExecutors), + std::move(constFoldingExecutions)}; +} +} // namespace torch::nativert diff --git a/torch/nativert/kernels/KernelFactory.h b/torch/nativert/kernels/KernelFactory.h new file mode 100644 index 000000000000..c01d64c3a017 --- /dev/null +++ b/torch/nativert/kernels/KernelFactory.h @@ -0,0 +1,89 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace torch::nativert { + +struct ConstFoldingExecution { + std::unique_ptr executor; +}; + +struct ExecutionKernels { + std::vector> nodeKernels; + std::vector> delegateExecutors; + std::vector constFoldingExecutions; +}; + +class KernelFactoryHandler { + public: + using OpKernelPtr = std::unique_ptr; + using DelegateExecutorPtr = std::unique_ptr; + using Matcher = c10::function_ref; + using Callback = + c10::function_ref( + const Node&, + std::shared_ptr weights, + const torch::nativert::ExecutorConfig& executorConfig, + caffe2::serialize::PyTorchStreamReader* pytorchStreamReader, + c10::Device targetDevice)>; + + KernelFactoryHandler(Matcher matcher, Callback callback) + : matcher_(matcher), callback_(callback) {} + + KernelFactoryHandler() = delete; + KernelFactoryHandler(const KernelFactoryHandler&) = default; + KernelFactoryHandler& operator=(const KernelFactoryHandler&) = default; + KernelFactoryHandler(KernelFactoryHandler&&) = default; + KernelFactoryHandler& operator=(KernelFactoryHandler&&) = default; + ~KernelFactoryHandler() = default; + + bool match( + const Node& node, + const torch::nativert::ExecutorConfig& config, + c10::Device device) const { + return matcher_(node, config, device); + } + + std::pair operator()( + const Node& node, + std::shared_ptr weights, + const torch::nativert::ExecutorConfig& executorConfig, + caffe2::serialize::PyTorchStreamReader* pytorchStreamReader, + c10::Device targetDevice) const { + return callback_( + node, weights, executorConfig, pytorchStreamReader, targetDevice); + } + + private: + Matcher matcher_; + Callback callback_; +}; + +class KernelFactory { + public: + explicit KernelFactory() {} + + ExecutionKernels initializeNodeKernels( + const Graph& graph, + std::shared_ptr weights, + const torch::nativert::ExecutorConfig& executorConfig, + const Placement& placement, + std::shared_ptr + pytorchStreamReader = nullptr, + const MakeProxyExecutorFn& makeProxyExecutorFunc = nullptr); + + static void registerHandler( + const std::string& name, + KernelFactoryHandler handler); +}; + +} // namespace torch::nativert