diff --git a/torch/_inductor/kernel/bmm.py b/torch/_inductor/kernel/bmm.py index b22e7a1f6149..06c4a63497d7 100644 --- a/torch/_inductor/kernel/bmm.py +++ b/torch/_inductor/kernel/bmm.py @@ -119,7 +119,7 @@ bmm_template = TritonTemplate( cache_codegen_enabled_for_template=True, ) -aten_bmm = ExternKernelChoice(torch.bmm, "at::bmm_out") +aten_bmm = ExternKernelChoice(torch.bmm, "at::bmm_out", op_overload=aten.bmm.out) aten_bmm_dtype = ExternKernelChoice( torch.bmm, "at::_bmm_out_dtype_cuda",