mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: original pr - https://github.com/pytorch/pytorch/pull/161798 Test Plan: ci Rollback Plan: Differential Revision: D81724234 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162217 Approved by: https://github.com/SherlockNoMad
32 lines
816 B
C++
32 lines
816 B
C++
#pragma once
|
|
|
|
#include <c10/core/Device.h>
|
|
|
|
#include <torch/nativert/executor/ExecutionFrame.h>
|
|
#include <torch/nativert/executor/OpKernel.h>
|
|
#include <torch/nativert/executor/triton/TritonKernelManager.h>
|
|
#include <torch/nativert/graph/Graph.h>
|
|
|
|
namespace torch::nativert {
|
|
|
|
class TritonKernel : public OpKernel {
|
|
public:
|
|
TritonKernel() = delete;
|
|
TritonKernel(
|
|
const Node* node,
|
|
caffe2::serialize::PyTorchStreamReader* reader);
|
|
~TritonKernel() override;
|
|
|
|
void computeInternal(ExecutionFrame& executionFrame) const override;
|
|
|
|
private:
|
|
std::unique_ptr<TritonKernelManager> loader_;
|
|
|
|
// unnamed node attributes will be passed as arguments to the kernel
|
|
std::vector<void*> attr_ptrs_;
|
|
std::vector<int64_t> output_indices_;
|
|
LaunchParams launch_params_;
|
|
};
|
|
|
|
} // namespace torch::nativert
|