diff --git a/torch/nativert/kernels/KernelRegistry.cpp b/torch/nativert/kernels/KernelRegistry.cpp index f416210cc393..f157257e733b 100644 --- a/torch/nativert/kernels/KernelRegistry.cpp +++ b/torch/nativert/kernels/KernelRegistry.cpp @@ -1424,4 +1424,33 @@ C10_REGISTER_TYPED_CLASS( "torch.ops.aten._to_copy.default", OpKernel_aten__to_copy) +REGISTER_CPU_KERNEL( + "torch.ops.aten.where.ScalarOther", + aten_where_ScalarOther, + { + const auto& condition = KernelInput(0).toTensor(); + const auto& self = KernelInput(1).toTensor(); + const auto& other = KernelInput(2).toScalar(); + + KernelOutput(0) = at::where(condition, self, other); + }) + +REGISTER_CPU_KERNEL( + "torch.ops.aten.repeat_interleave.self_Tensor", + aten_repeat_interleave_self_Tensor, + { + const auto& self = KernelInput(0).toTensor(); + const auto& repeats = KernelInput(1).toTensor(); + std::optional dim = std::nullopt; + if (!KernelInput(2).isNone()) { + dim = KernelInput(2).toInt(); + } + std::optional output_size = std::nullopt; + if (!KernelInput(3).isNone()) { + output_size = KernelInput(3).toInt(); + } + + KernelOutput(0) = at::repeat_interleave(self, repeats, dim, output_size); + }) + } // namespace torch::nativert