mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Making batching rule for F.embedding DTensor-aware (#162117)
`vmap(F.embedding)(DTensor, DTensor)` was failing because F.embedding's batching rule generates a new tensor via at::arange, at::arange generates a regular tensor, and DTensor rightfully errors on mixed DTensor-regular Tensor operations. This PR fixes the problem by activating DTensor implicit replication on just the at::arange and the subsequent add operation. In order to accomplish this I move the DTensor implicit replication flag to C++ (most batching rules are in C++). Test Plan: - new test Pull Request resolved: https://github.com/pytorch/pytorch/pull/162117 Approved by: https://github.com/bdhirsh
This commit is contained in:
@ -1088,6 +1088,7 @@ aten_cpu_source_non_codegen_list = [
|
||||
"aten/src/ATen/DeviceAccelerator.cpp",
|
||||
"aten/src/ATen/Context.cpp",
|
||||
"aten/src/ATen/DLConvertor.cpp",
|
||||
"aten/src/ATen/DTensorState.cpp",
|
||||
"aten/src/ATen/EmptyTensor.cpp",
|
||||
"aten/src/ATen/ExpandUtils.cpp",
|
||||
"aten/src/ATen/CachedTensorUtils.cpp",
|
||||
|
Reference in New Issue
Block a user