Making batching rule for F.embedding DTensor-aware (#162117)

`vmap(F.embedding)(DTensor, DTensor)` was failing because F.embedding's
batching rule generates a new tensor via at::arange, at::arange
generates a regular tensor, and DTensor rightfully errors on mixed
DTensor-regular Tensor operations.

This PR fixes the problem by activating DTensor implicit replication on
just the at::arange and the subsequent add operation.

In order to accomplish this I move the DTensor implicit replication flag
to C++ (most batching rules are in C++).

Test Plan:
- new test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162117
Approved by: https://github.com/bdhirsh
This commit is contained in:
rzou
2025-09-04 14:45:59 -07:00
committed by PyTorch MergeBot
parent a00cdc1e41
commit 70d36e047d
10 changed files with 112 additions and 7 deletions

View File

@ -1088,6 +1088,7 @@ aten_cpu_source_non_codegen_list = [
"aten/src/ATen/DeviceAccelerator.cpp",
"aten/src/ATen/Context.cpp",
"aten/src/ATen/DLConvertor.cpp",
"aten/src/ATen/DTensorState.cpp",
"aten/src/ATen/EmptyTensor.cpp",
"aten/src/ATen/ExpandUtils.cpp",
"aten/src/ATen/CachedTensorUtils.cpp",