#include #include #include #include #include #include #include "utils.h" /* * How to write a meta implementation for a custom operator (meta kernel): * * Meta implementations are used for shape and dtype inference, tracing, and export. * They do NOT perform any real computation or allocate device memory. * Instead, they return empty tensors with the correct shapes, dtypes, and device types. * * Steps to write a meta implementation: * 1. The function signature should match the operator's schema, but only use the arguments * necessary to infer output shapes and dtypes. * 2. Use input tensor shapes, dtypes, and any relevant arguments to compute the output shapes. * 3. Return empty tensors (e.g., at::empty_symint, at::empty_like) with the correct shape and dtype. * 4. Do NOT perform any real computation or data movement. * 5. Register the meta implementation with the "Meta" dispatch key using TORCH_LIBRARY_IMPL or similar. * * Example: * std::tuple my_op_meta( * at::Tensor &input, int64_t some_param) { * // Infer output shape based on input and parameters * auto out_shape = ...; * at::Tensor out = at::empty_symint(out_shape, input.options()); * // Return empty tensor(s) with correct shape/dtype * return {out, ...}; * } * * See below for real examples. */ namespace vllm_ascend { namespace meta { std::tuple rotary_embedding_meta( at::Tensor &positions, at::Tensor &query, at::Tensor &key, int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox) { auto num_tokens = positions.sym_numel(); auto query_hidden_size = query.sym_numel() / num_tokens; auto key_hidden_size = key.sym_numel() / num_tokens; auto num_heads = query_hidden_size / head_size; auto num_kv_heads = key_hidden_size / head_size; at::Tensor query_dst = at::empty_symint({num_tokens, num_heads, head_size}, query.options()); at::Tensor key_dst = at::empty_symint({num_tokens, num_kv_heads, head_size}, key.options()); return {query_dst, key_dst}; } std::tuple get_masked_input_and_mask_meta( at::Tensor &input, const int64_t org_vocab_start_index, const int64_t org_vocab_end_index, const int64_t num_org_vocab_padding, const int64_t added_vocab_start_index, const int64_t added_vocab_end_index) { at::Tensor masked_input = at::empty_like(input); at::Tensor mask = at::empty_like(input, input.options().dtype(at::kBool)); return {masked_input, mask}; } } // namespace meta } // namespace vllm_ascend namespace { // Register the meta implementations of the custom kernels for symbolic tracing, this will also // the custom kernel been captured into aclgraph TORCH_LIBRARY_IMPL_EXPAND(_C, Meta, ops) { // Rotary embedding meta implementation ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta); // Masked input and mask meta implementation ops.impl("get_masked_input_and_mask", &vllm_ascend::meta::get_masked_input_and_mask_meta); } }