Files
pytorch/torch/nativert/kernels/PrimKernelRegistry.h
zhxchen17 c06164a9c5 [nativert][ez] Remove unused dist collectives ops. (#159220)
Removing dependency to c10d/ in ExecutionFrame.h. We don't need c10d::Work in the frame.

Differential Revision: [D79041618](https://our.internmc.facebook.com/intern/diff/D79041618/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159220
Approved by: https://github.com/SherlockNoMad, https://github.com/dolpm
2025-07-28 16:03:14 +00:00

40 lines
1.4 KiB
C++

#pragma once
#include <ATen/ATen.h>
#include <torch/nativert/executor/OpKernel.h>
#include <torch/nativert/graph/Graph.h>
#include <torch/nativert/kernels/C10Kernel.h>
namespace torch::nativert {
#define KernelInput(id) input(id, executionFrame)
#define KernelOutput(id) output(id, executionFrame)
TORCH_DECLARE_REGISTRY(PrimKernelRegistry, OpKernel, const Node*);
#define REGISTER_PRIM_KERNEL(name, id, ...) \
class OpKernel_##id : public OpKernel { \
public: \
OpKernel_##id(const Node* node) \
: OpKernel(node, OpKernelKind::kPrimKernel) {} \
void computeInternal( \
ExecutionFrame& executionFrame) const override final { \
__VA_ARGS__; \
} \
}; \
C10_REGISTER_TYPED_CLASS(PrimKernelRegistry, name, OpKernel_##id)
inline bool checkResizedDataPtr(at::Tensor& t) {
auto const prev_data_ptr = t.data_ptr();
t.resize_({0});
return prev_data_ptr == t.data_ptr();
}
inline void fastResizeToZero(at::Tensor& t) {
t.unsafeGetTensorImpl()->set_sizes_contiguous({0});
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(checkResizedDataPtr(t));
}
} // namespace torch::nativert