Making batching rule for F.embedding DTensor-aware (#162117)

`vmap(F.embedding)(DTensor, DTensor)` was failing because F.embedding's
batching rule generates a new tensor via at::arange, at::arange
generates a regular tensor, and DTensor rightfully errors on mixed
DTensor-regular Tensor operations.

This PR fixes the problem by activating DTensor implicit replication on
just the at::arange and the subsequent add operation.

In order to accomplish this I move the DTensor implicit replication flag
to C++ (most batching rules are in C++).

Test Plan:
- new test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162117
Approved by: https://github.com/bdhirsh
This commit is contained in:
rzou
2025-09-04 14:45:59 -07:00
committed by PyTorch MergeBot
parent a00cdc1e41
commit 70d36e047d
10 changed files with 112 additions and 7 deletions

View File

@ -0,0 +1,17 @@
#include <ATen/DTensorState.h>
namespace at {
namespace {
thread_local bool kDTensorAllowImplicitReplication = false;
}
bool get_dtensor_allow_implicit_replication() {
return kDTensorAllowImplicitReplication;
}
void set_dtensor_allow_implicit_replication(bool enabled) {
kDTensorAllowImplicitReplication = enabled;
}
} // namespace at

View File

@ -0,0 +1,34 @@
#pragma once
#include <c10/macros/Macros.h>
namespace at {
TORCH_API bool get_dtensor_allow_implicit_replication();
TORCH_API void set_dtensor_allow_implicit_replication(bool enabled);
struct DTensorAllowImplicitReplication {
DTensorAllowImplicitReplication()
: prev_dtensor_allow_implicit_replication_(
get_dtensor_allow_implicit_replication()) {
set_dtensor_allow_implicit_replication(true);
}
DTensorAllowImplicitReplication(const DTensorAllowImplicitReplication&) =
delete;
DTensorAllowImplicitReplication& operator=(
const DTensorAllowImplicitReplication&) = delete;
DTensorAllowImplicitReplication(DTensorAllowImplicitReplication&&) = delete;
DTensorAllowImplicitReplication& operator=(
DTensorAllowImplicitReplication&&) = delete;
~DTensorAllowImplicitReplication() {
set_dtensor_allow_implicit_replication(
prev_dtensor_allow_implicit_replication_);
}
private:
bool prev_dtensor_allow_implicit_replication_;
};
} // namespace at

View File

@ -8,6 +8,7 @@
#include <ATen/record_function.h>
#include <ATen/SavedTensorHooks.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/DTensorState.h>
namespace at {
@ -19,6 +20,7 @@ ThreadLocalState::ThreadLocalState()
torch_dispatch_mode_state_(c10::impl::TorchDispatchModeTLS::get_state()), python_dispatcher_state_(c10::impl::PythonDispatcherTLS::get_state()),
python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()),
saved_tensors_default_hooks_state_(at::SavedTensorDefaultHooks::get_tls_state()), functionalization_reapply_views_state_(at::functionalization::impl::getFunctionalizationReapplyViewsTLS()),
dtensor_allow_implicit_replication_(at::get_dtensor_allow_implicit_replication()),
saved_objects_(at::impl::ThreadLocalPythonObjects::get_state()) {
#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER)
for(size_t i=0; i<autocast_dtypes_.size(); i++) {
@ -52,6 +54,8 @@ void ThreadLocalState::setThreadLocalState(
c10::impl::PythonDispatcherTLS::set_state(state.python_dispatcher_state_);
at::set_dtensor_allow_implicit_replication(state.dtensor_allow_implicit_replication_);
c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_);
c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_);

View File

@ -75,6 +75,8 @@ class TORCH_API ThreadLocalState {
bool functionalization_reapply_views_state_;
bool dtensor_allow_implicit_replication_;
// TLS for arbitrary python objects that is registered via hooks
at::impl::ThreadLocalPythonObjects saved_objects_;

View File

@ -7,6 +7,7 @@
#include <ATen/functorch/BatchRulesHelper.h>
#include <ATen/functorch/PlumbingHelper.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/DTensorState.h>
#include <utility>
@ -44,8 +45,13 @@ static std::tuple<Tensor, std::optional<int64_t>> embedding_batch_rule(
const auto weight_ = reshape_dim_into(*weight_bdim, 0, weight);
auto indices_ = moveBatchDimToFront(indices, indices_bdim);
const auto range = getStepTensor(indices, batch_size, num_embeddings);
indices_ = indices_ + range;
{
// getStepTensor returns a regular Tensor. If indices_ is a DTensor
// we want to allow this mixed DTensor-Tensor operation.
at::DTensorAllowImplicitReplication guard;
const auto range = getStepTensor(indices, batch_size, num_embeddings);
indices_ = indices_ + range;
}
auto result = at::embedding_symint(weight_, indices_, std::move(padding_idx), scale_grad_by_freq, sparse);
return std::make_tuple(std::move(result), 0);
}

View File

@ -1088,6 +1088,7 @@ aten_cpu_source_non_codegen_list = [
"aten/src/ATen/DeviceAccelerator.cpp",
"aten/src/ATen/Context.cpp",
"aten/src/ATen/DLConvertor.cpp",
"aten/src/ATen/DTensorState.cpp",
"aten/src/ATen/EmptyTensor.cpp",
"aten/src/ATen/ExpandUtils.cpp",
"aten/src/ATen/CachedTensorUtils.cpp",

View File

@ -848,6 +848,30 @@ class DTensorMeshTest(DTensorTestBase):
self.assertEqual(local_shard.shape, (4, 3))
self.assertEqual(local_shard, torch.ones(4, 3) + torch.ones(3))
@with_comms
def test_vmap_embedding(self):
mesh = self.build_device_mesh()
batch_size, seq_len = 2, 6
output_dim = 32
indices = torch.zeros(*(batch_size, seq_len), dtype=torch.int64)
indices[0, 1] = 1
indices[1, 3] = 1
indices[1, 5] = 1
indices = DTensor.from_local(indices, mesh, [Shard(0)])
emb = torch.randn(
*(batch_size, 8, output_dim),
dtype=torch.float32,
)
emb = DTensor.from_local(emb, mesh, [Shard(0)])
result = torch.vmap(F.embedding)(indices, emb)
expected = [F.embedding(indices[i], emb[i]) for i in range(batch_size)]
expected = torch.stack(expected)
local_result = result.to_local()
local_expected = expected.to_local()
self.assertEqual(local_result, local_expected)
@with_comms
def test_auto_implicit_replication(self):
mesh = self.build_device_mesh()

View File

@ -1852,6 +1852,9 @@ class _SetExcludeDispatchKeyGuard:
def __enter__(self): ...
def __exit__(self, *exc_info: object) -> None: ...
def _get_dtensor_allow_implicit_replication() -> _bool: ...
def _set_dtensor_allow_implicit_replication(value: _bool) -> None: ...
# Defined in torch/csrc/utils/schema_info.h
class _SchemaInfo:

View File

@ -2,6 +2,7 @@
#include <torch/csrc/utils/python_dispatch.h>
#include <ATen/ATen.h>
#include <ATen/DTensorState.h>
#include <ATen/FuncTorchTLS.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/TensorSubclassLikeUtils.h>
@ -1045,6 +1046,13 @@ void initDispatchBindings(PyObject* module) {
m.def("_only_lift_cpu_tensors", &torch::utils::only_lift_cpu_tensors);
m.def("_set_only_lift_cpu_tensors", &torch::utils::set_only_lift_cpu_tensors);
m.def(
"_get_dtensor_allow_implicit_replication",
&at::get_dtensor_allow_implicit_replication);
m.def(
"_set_dtensor_allow_implicit_replication",
&at::set_dtensor_allow_implicit_replication);
using c10::impl::TorchDispatchModeKey;
py::enum_<TorchDispatchModeKey>(m, "_TorchDispatchModeKey")
.value("FUNCTIONAL", TorchDispatchModeKey::FUNCTIONAL)

View File

@ -121,11 +121,17 @@ class OpDispatcher:
aten._amp_foreach_non_finite_check_and_unscale_.default: found_inf_reduce_handler,
}
# This flag is used internally to control whether we treat the torch.Tensor(non-DTensor)
# as implicitly replicated or we throw error to user.
# NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave
# it as False by default.
self._allow_implicit_replication = False
# This flag is used internally to control whether we treat the torch.Tensor(non-DTensor)
# as implicitly replicated or we throw error to user.
# NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave
# it as False by default.
@property
def _allow_implicit_replication(self) -> bool:
return torch._C._get_dtensor_allow_implicit_replication()
@_allow_implicit_replication.setter
def _allow_implicit_replication(self, value: bool) -> None:
return torch._C._set_dtensor_allow_implicit_replication(value)
def dispatch(
self,