mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Making batching rule for F.embedding DTensor-aware (#162117)
`vmap(F.embedding)(DTensor, DTensor)` was failing because F.embedding's batching rule generates a new tensor via at::arange, at::arange generates a regular tensor, and DTensor rightfully errors on mixed DTensor-regular Tensor operations. This PR fixes the problem by activating DTensor implicit replication on just the at::arange and the subsequent add operation. In order to accomplish this I move the DTensor implicit replication flag to C++ (most batching rules are in C++). Test Plan: - new test Pull Request resolved: https://github.com/pytorch/pytorch/pull/162117 Approved by: https://github.com/bdhirsh
This commit is contained in:
17
aten/src/ATen/DTensorState.cpp
Normal file
17
aten/src/ATen/DTensorState.cpp
Normal file
@ -0,0 +1,17 @@
|
||||
#include <ATen/DTensorState.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
namespace {
|
||||
thread_local bool kDTensorAllowImplicitReplication = false;
|
||||
}
|
||||
|
||||
bool get_dtensor_allow_implicit_replication() {
|
||||
return kDTensorAllowImplicitReplication;
|
||||
}
|
||||
|
||||
void set_dtensor_allow_implicit_replication(bool enabled) {
|
||||
kDTensorAllowImplicitReplication = enabled;
|
||||
}
|
||||
|
||||
} // namespace at
|
34
aten/src/ATen/DTensorState.h
Normal file
34
aten/src/ATen/DTensorState.h
Normal file
@ -0,0 +1,34 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
TORCH_API bool get_dtensor_allow_implicit_replication();
|
||||
TORCH_API void set_dtensor_allow_implicit_replication(bool enabled);
|
||||
|
||||
struct DTensorAllowImplicitReplication {
|
||||
DTensorAllowImplicitReplication()
|
||||
: prev_dtensor_allow_implicit_replication_(
|
||||
get_dtensor_allow_implicit_replication()) {
|
||||
set_dtensor_allow_implicit_replication(true);
|
||||
}
|
||||
|
||||
DTensorAllowImplicitReplication(const DTensorAllowImplicitReplication&) =
|
||||
delete;
|
||||
DTensorAllowImplicitReplication& operator=(
|
||||
const DTensorAllowImplicitReplication&) = delete;
|
||||
DTensorAllowImplicitReplication(DTensorAllowImplicitReplication&&) = delete;
|
||||
DTensorAllowImplicitReplication& operator=(
|
||||
DTensorAllowImplicitReplication&&) = delete;
|
||||
|
||||
~DTensorAllowImplicitReplication() {
|
||||
set_dtensor_allow_implicit_replication(
|
||||
prev_dtensor_allow_implicit_replication_);
|
||||
}
|
||||
|
||||
private:
|
||||
bool prev_dtensor_allow_implicit_replication_;
|
||||
};
|
||||
|
||||
} // namespace at
|
@ -8,6 +8,7 @@
|
||||
#include <ATen/record_function.h>
|
||||
#include <ATen/SavedTensorHooks.h>
|
||||
#include <ATen/FunctionalTensorWrapper.h>
|
||||
#include <ATen/DTensorState.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
@ -19,6 +20,7 @@ ThreadLocalState::ThreadLocalState()
|
||||
torch_dispatch_mode_state_(c10::impl::TorchDispatchModeTLS::get_state()), python_dispatcher_state_(c10::impl::PythonDispatcherTLS::get_state()),
|
||||
python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()),
|
||||
saved_tensors_default_hooks_state_(at::SavedTensorDefaultHooks::get_tls_state()), functionalization_reapply_views_state_(at::functionalization::impl::getFunctionalizationReapplyViewsTLS()),
|
||||
dtensor_allow_implicit_replication_(at::get_dtensor_allow_implicit_replication()),
|
||||
saved_objects_(at::impl::ThreadLocalPythonObjects::get_state()) {
|
||||
#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER)
|
||||
for(size_t i=0; i<autocast_dtypes_.size(); i++) {
|
||||
@ -52,6 +54,8 @@ void ThreadLocalState::setThreadLocalState(
|
||||
|
||||
c10::impl::PythonDispatcherTLS::set_state(state.python_dispatcher_state_);
|
||||
|
||||
at::set_dtensor_allow_implicit_replication(state.dtensor_allow_implicit_replication_);
|
||||
|
||||
c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_);
|
||||
|
||||
c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_);
|
||||
|
@ -75,6 +75,8 @@ class TORCH_API ThreadLocalState {
|
||||
|
||||
bool functionalization_reapply_views_state_;
|
||||
|
||||
bool dtensor_allow_implicit_replication_;
|
||||
|
||||
// TLS for arbitrary python objects that is registered via hooks
|
||||
at::impl::ThreadLocalPythonObjects saved_objects_;
|
||||
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <ATen/functorch/BatchRulesHelper.h>
|
||||
#include <ATen/functorch/PlumbingHelper.h>
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <ATen/DTensorState.h>
|
||||
|
||||
#include <utility>
|
||||
|
||||
@ -44,8 +45,13 @@ static std::tuple<Tensor, std::optional<int64_t>> embedding_batch_rule(
|
||||
const auto weight_ = reshape_dim_into(*weight_bdim, 0, weight);
|
||||
auto indices_ = moveBatchDimToFront(indices, indices_bdim);
|
||||
|
||||
{
|
||||
// getStepTensor returns a regular Tensor. If indices_ is a DTensor
|
||||
// we want to allow this mixed DTensor-Tensor operation.
|
||||
at::DTensorAllowImplicitReplication guard;
|
||||
const auto range = getStepTensor(indices, batch_size, num_embeddings);
|
||||
indices_ = indices_ + range;
|
||||
}
|
||||
auto result = at::embedding_symint(weight_, indices_, std::move(padding_idx), scale_grad_by_freq, sparse);
|
||||
return std::make_tuple(std::move(result), 0);
|
||||
}
|
||||
|
@ -1088,6 +1088,7 @@ aten_cpu_source_non_codegen_list = [
|
||||
"aten/src/ATen/DeviceAccelerator.cpp",
|
||||
"aten/src/ATen/Context.cpp",
|
||||
"aten/src/ATen/DLConvertor.cpp",
|
||||
"aten/src/ATen/DTensorState.cpp",
|
||||
"aten/src/ATen/EmptyTensor.cpp",
|
||||
"aten/src/ATen/ExpandUtils.cpp",
|
||||
"aten/src/ATen/CachedTensorUtils.cpp",
|
||||
|
@ -848,6 +848,30 @@ class DTensorMeshTest(DTensorTestBase):
|
||||
self.assertEqual(local_shard.shape, (4, 3))
|
||||
self.assertEqual(local_shard, torch.ones(4, 3) + torch.ones(3))
|
||||
|
||||
@with_comms
|
||||
def test_vmap_embedding(self):
|
||||
mesh = self.build_device_mesh()
|
||||
batch_size, seq_len = 2, 6
|
||||
output_dim = 32
|
||||
|
||||
indices = torch.zeros(*(batch_size, seq_len), dtype=torch.int64)
|
||||
indices[0, 1] = 1
|
||||
indices[1, 3] = 1
|
||||
indices[1, 5] = 1
|
||||
indices = DTensor.from_local(indices, mesh, [Shard(0)])
|
||||
|
||||
emb = torch.randn(
|
||||
*(batch_size, 8, output_dim),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
emb = DTensor.from_local(emb, mesh, [Shard(0)])
|
||||
result = torch.vmap(F.embedding)(indices, emb)
|
||||
expected = [F.embedding(indices[i], emb[i]) for i in range(batch_size)]
|
||||
expected = torch.stack(expected)
|
||||
local_result = result.to_local()
|
||||
local_expected = expected.to_local()
|
||||
self.assertEqual(local_result, local_expected)
|
||||
|
||||
@with_comms
|
||||
def test_auto_implicit_replication(self):
|
||||
mesh = self.build_device_mesh()
|
||||
|
@ -1852,6 +1852,9 @@ class _SetExcludeDispatchKeyGuard:
|
||||
def __enter__(self): ...
|
||||
def __exit__(self, *exc_info: object) -> None: ...
|
||||
|
||||
def _get_dtensor_allow_implicit_replication() -> _bool: ...
|
||||
def _set_dtensor_allow_implicit_replication(value: _bool) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/utils/schema_info.h
|
||||
|
||||
class _SchemaInfo:
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include <torch/csrc/utils/python_dispatch.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/DTensorState.h>
|
||||
#include <ATen/FuncTorchTLS.h>
|
||||
#include <ATen/FunctionalTensorWrapper.h>
|
||||
#include <ATen/TensorSubclassLikeUtils.h>
|
||||
@ -1045,6 +1046,13 @@ void initDispatchBindings(PyObject* module) {
|
||||
m.def("_only_lift_cpu_tensors", &torch::utils::only_lift_cpu_tensors);
|
||||
m.def("_set_only_lift_cpu_tensors", &torch::utils::set_only_lift_cpu_tensors);
|
||||
|
||||
m.def(
|
||||
"_get_dtensor_allow_implicit_replication",
|
||||
&at::get_dtensor_allow_implicit_replication);
|
||||
m.def(
|
||||
"_set_dtensor_allow_implicit_replication",
|
||||
&at::set_dtensor_allow_implicit_replication);
|
||||
|
||||
using c10::impl::TorchDispatchModeKey;
|
||||
py::enum_<TorchDispatchModeKey>(m, "_TorchDispatchModeKey")
|
||||
.value("FUNCTIONAL", TorchDispatchModeKey::FUNCTIONAL)
|
||||
|
@ -125,7 +125,13 @@ class OpDispatcher:
|
||||
# as implicitly replicated or we throw error to user.
|
||||
# NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave
|
||||
# it as False by default.
|
||||
self._allow_implicit_replication = False
|
||||
@property
|
||||
def _allow_implicit_replication(self) -> bool:
|
||||
return torch._C._get_dtensor_allow_implicit_replication()
|
||||
|
||||
@_allow_implicit_replication.setter
|
||||
def _allow_implicit_replication(self, value: bool) -> None:
|
||||
return torch._C._set_dtensor_allow_implicit_replication(value)
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
|
Reference in New Issue
Block a user