Files
pytorch/torch/nativert/kernels/TritonKernel.h
dolpm 4f72d932fe re-land triton runtime implementation" (#162217)
Summary: original pr - https://github.com/pytorch/pytorch/pull/161798

Test Plan:
ci

Rollback Plan:

Differential Revision: D81724234

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162217
Approved by: https://github.com/SherlockNoMad
2025-09-06 00:52:29 +00:00

32 lines
816 B
C++

#pragma once
#include <c10/core/Device.h>
#include <torch/nativert/executor/ExecutionFrame.h>
#include <torch/nativert/executor/OpKernel.h>
#include <torch/nativert/executor/triton/TritonKernelManager.h>
#include <torch/nativert/graph/Graph.h>
namespace torch::nativert {
class TritonKernel : public OpKernel {
public:
TritonKernel() = delete;
TritonKernel(
const Node* node,
caffe2::serialize::PyTorchStreamReader* reader);
~TritonKernel() override;
void computeInternal(ExecutionFrame& executionFrame) const override;
private:
std::unique_ptr<TritonKernelManager> loader_;
// unnamed node attributes will be passed as arguments to the kernel
std::vector<void*> attr_ptrs_;
std::vector<int64_t> output_indices_;
LaunchParams launch_params_;
};
} // namespace torch::nativert