#pragma once #include #include #include #include #include namespace torch::nativert { class TritonKernel : public OpKernel { public: TritonKernel() = delete; TritonKernel( const Node* node, caffe2::serialize::PyTorchStreamReader* reader); ~TritonKernel() override; void computeInternal(ExecutionFrame& executionFrame) const override; private: std::unique_ptr loader_; // unnamed node attributes will be passed as arguments to the kernel std::vector attr_ptrs_; // Storage for float attributes that were serialized as doubles std::vector float_attrs_; std::vector output_indices_; LaunchParams launch_params_; }; } // namespace torch::nativert