Add Static Dispatch Kernels (#163676) (#163870)

Summary:
X-link: https://github.com/facebookresearch/FBGEMM/pull/1951

X-link: https://github.com/pytorch/FBGEMM/pull/4927

Add a few missing static dispatch kernels for remote_ro.

Test Plan: Tested with scripts in D83028841.

Differential Revision: D83258808

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163870
Approved by: https://github.com/henryoier
This commit is contained in:
Kevin Fu
2025-09-26 03:00:07 +00:00
committed by PyTorch MergeBot
parent bbf8aa43ef
commit 67cc0e0ac9

View File

@ -1424,4 +1424,33 @@ C10_REGISTER_TYPED_CLASS(
"torch.ops.aten._to_copy.default",
OpKernel_aten__to_copy)
REGISTER_CPU_KERNEL(
"torch.ops.aten.where.ScalarOther",
aten_where_ScalarOther,
{
const auto& condition = KernelInput(0).toTensor();
const auto& self = KernelInput(1).toTensor();
const auto& other = KernelInput(2).toScalar();
KernelOutput(0) = at::where(condition, self, other);
})
REGISTER_CPU_KERNEL(
"torch.ops.aten.repeat_interleave.self_Tensor",
aten_repeat_interleave_self_Tensor,
{
const auto& self = KernelInput(0).toTensor();
const auto& repeats = KernelInput(1).toTensor();
std::optional<int64_t> dim = std::nullopt;
if (!KernelInput(2).isNone()) {
dim = KernelInput(2).toInt();
}
std::optional<int64_t> output_size = std::nullopt;
if (!KernelInput(3).isNone()) {
output_size = KernelInput(3).toInt();
}
KernelOutput(0) = at::repeat_interleave(self, repeats, dim, output_size);
})
} // namespace torch::nativert