mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] misc cleanup
This commit is contained in:
@ -16,6 +16,10 @@ int64_t rankWithoutBatchDim(const Tensor& tensor, optional<int64_t> maybe_batch_
|
||||
optional<int64_t> valIfNonempty(optional<int64_t> maybe_empty, int64_t new_val);
|
||||
int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim);
|
||||
|
||||
#define VMAP_SUPPORT(op, batch_rule) \
|
||||
m.impl(op, PrimBatchRule7< \
|
||||
decltype(&batch_rule), &batch_rule, to_operator_t<decltype(batch_rule)> \
|
||||
>::apply);
|
||||
|
||||
}}
|
||||
|
||||
|
@ -2,6 +2,76 @@
|
||||
|
||||
namespace at { namespace functorch {
|
||||
|
||||
// Note [Adding vmap support for an operator]
|
||||
// Hey there! So you have an operator and you want to get it to work with vmap.
|
||||
// For example, let's say you just invented the `sum.int` operator and want to make
|
||||
// it so that the following works.
|
||||
// >>> tensor = torch.randn(B, 3)
|
||||
// >>> vmap(torch.sum, (0, None))(tensor, 0)` works
|
||||
// There are three main ways to do so.
|
||||
//
|
||||
// Note [Writing batch rule for out-of-place operators]
|
||||
// If your operator is out-of-place, you can write a batch rule for it.
|
||||
// The batch rule defines how to perform the operator on inputs where each
|
||||
// Tensor input may have an additional dimension that is being vmapped over.
|
||||
// We refer to this dimension as the *batch dimension* or bdim for short.
|
||||
//
|
||||
// For example, let's consider writing a batch rule for
|
||||
// `Tensor sum(const Tensor& self, int64_t dim)`. The signature of the
|
||||
// batch rule has an additional optional<int64_t> argument after each
|
||||
// Tensor argument and return. So, in this case, the batch rule has signature
|
||||
// tuple<Tensor,optional<int64_t>> sum_batch_rule(
|
||||
// const Tensor& self, optional<int64_t> self_bdim, int64_t dim);
|
||||
//
|
||||
// The vmap call above invokes the batch rule with `self = tensor`,
|
||||
// `self_bdim = 0`, and `dim = 0`. Note that there are **no BatchedTensors**
|
||||
// involved in this case; there exists some plumbing that automatically unwraps
|
||||
// BatchedTensors before calling the batch rule.
|
||||
//
|
||||
// To write the logic of the batch rule: think about the semantics of the
|
||||
// `sum` operation if `self` had an additional dimension (indicated by self_bdim):
|
||||
// - If `self_bdim` is null, then we just do `result = self.sum(dim)` as usual
|
||||
// - If `self_bdim` is not-null, then we need to modify `dim`. `dim` is equal
|
||||
// to whatever the user passed in (0 in this case), but we should actually
|
||||
// perform the reduction over dimension 1 and do `result = self.sum(1)`
|
||||
// because dim 0 is being vmapped over.
|
||||
// Finally, we return the result as well as a new bdim
|
||||
// - If `self_bdim` is null, then there's no batch dim in the result.
|
||||
// - If `self_bdim` is not-null, then we return where the bdim is.
|
||||
// Since we invoked `result = self.sum(1)`, the bdim is still at dim 0.
|
||||
//
|
||||
// Now that we have written `sum_batch_rule`, we have to register it inside a
|
||||
// TORCH_LIBRARY_IMPL block:
|
||||
// TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
// ...
|
||||
// VMAP_SUPPORT("sum.int", sum_batch_rule);
|
||||
// ...
|
||||
// }
|
||||
//
|
||||
// Note [Reusing batch rules to add vmap support for a complicated operator]
|
||||
// Can't figure out how to write a batch rule for a big operation? If the
|
||||
// operation can be expressed as a composition of other operations that do have
|
||||
// batch rules, then that is another way to add vmap support. For example,
|
||||
// consider the following schema
|
||||
// func: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1)
|
||||
// and assume we already have batching rules for basic arithmetic operators.
|
||||
//
|
||||
// To add vmap support, define a decomposition using the same signature:
|
||||
// Tensor addcmul_decomp(const Tensor& self, const Tensor& tensor1,
|
||||
// const Tensor& tensor2, const Scalar& value) {
|
||||
// auto product = torch.mul(tensor1, tensor2);
|
||||
// return torch.add(self, product, value);
|
||||
// }
|
||||
// And register it inside a TORCH_LIBRARY_IMPL block:
|
||||
// TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
// ...
|
||||
// m.impl("addcmul", addcmul_decomp);
|
||||
// ...
|
||||
// }
|
||||
//
|
||||
// Note [Writing batch rule for in-place operators]
|
||||
// TODO: This is kinda complicated. Saving this for a future date.
|
||||
|
||||
std::tuple<Tensor, optional<int64_t>> flatten_batch_rule(
|
||||
const Tensor& self,
|
||||
optional<int64_t> self_bdim,
|
||||
@ -22,16 +92,9 @@ std::tuple<Tensor,optional<int64_t>> unsqueeze_batch_rule(
|
||||
return { self.unsqueeze(dim), valIfNonempty(self_bdim, 0) };
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) {
|
||||
#define VMAP_SUPPORT(op, batch_rule) \
|
||||
m.impl(op, PrimBatchRule7< \
|
||||
decltype(&batch_rule), &batch_rule, to_operator_t<decltype(batch_rule)> \
|
||||
>::apply);
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
VMAP_SUPPORT("flatten.using_ints", flatten_batch_rule);
|
||||
VMAP_SUPPORT("unsqueeze", unsqueeze_batch_rule);
|
||||
|
||||
#undef VMAP_SUPPORT
|
||||
}
|
||||
|
||||
}}
|
||||
|
@ -258,17 +258,17 @@ Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit)
|
||||
auto size_physical = self_physical.getPhysicalShape(size);
|
||||
auto self_physical_dim = self_physical.tensor().dim();
|
||||
|
||||
TORCH_CHECK(self_physical_dim <= size_physical.size(),
|
||||
TORCH_CHECK((uint64_t)self_physical_dim <= size_physical.size(),
|
||||
"expand: the number of sizes provided (", /*logical*/size.size(), ") ",
|
||||
"must be greater or equal to the number of dimensions in the tensor (",
|
||||
/*logical dim*/self.dim(), ")");
|
||||
|
||||
if (self_physical_dim == size_physical.size()) {
|
||||
if ((uint64_t)self_physical_dim == size_physical.size()) {
|
||||
auto result = self_physical.tensor().expand(size_physical, implicit);
|
||||
return self_physical.getPhysicalToLogicalMap().apply(result);
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(self_physical_dim < size_physical.size());
|
||||
TORCH_INTERNAL_ASSERT((uint64_t)self_physical_dim < size_physical.size());
|
||||
// Here, we know we are expanding a (logical) tensor to a larger number
|
||||
// of dimensions. We have to be careful because we can't call expand directly
|
||||
// due to the presence of batch dimensions.
|
||||
@ -363,7 +363,7 @@ std::vector<Tensor> tensor_split_indices_batching_rule(const Tensor& self, IntAr
|
||||
|
||||
// Checks if the batch dims in `bdims` appear at the front of the tensor.
|
||||
static bool areBdimsAtFrontInOrder(BatchDimsRef bdims) {
|
||||
for (int64_t idx = 0; idx < bdims.size(); idx++) {
|
||||
for (uint64_t idx = 0; idx < bdims.size(); idx++) {
|
||||
if (bdims[idx].dim() != idx) {
|
||||
return false;
|
||||
}
|
||||
@ -1380,7 +1380,7 @@ Tensor& BatchedTensor_requires_grad_(Tensor& self, bool requires_grad) {
|
||||
}
|
||||
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, BatchedOutOfTree, m) {
|
||||
TORCH_LIBRARY_IMPL(_, FT_BATCHED_KEY, m) {
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&batchedTensorForLoopFallback>());
|
||||
}
|
||||
|
||||
@ -1559,13 +1559,7 @@ Tensor matmul_decomposed(
|
||||
dim_tensor1, "D and ", dim_tensor2, "D");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) {
|
||||
|
||||
#define VMAP_SUPPORT(op, batch_rule) \
|
||||
m.impl(op, PrimBatchRule7< \
|
||||
decltype(&batch_rule), &batch_rule, to_operator_t<decltype(batch_rule)> \
|
||||
>::apply);
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
// VMAP_OUTPLACE_OP("abs", abs_batch_rule);
|
||||
// m.impl("abs", PrimBatchRule7<decltype(&abs_batch_rule), &abs_batch_rule, to_operator_t<decltype(abs_batch_rule)>>::apply);
|
||||
|
||||
@ -1688,10 +1682,10 @@ TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) {
|
||||
m.impl("clone", clone_batching_rule);
|
||||
// m.impl("ones_like", ones_like_batching_rule);
|
||||
|
||||
using TensorTensorScalarType = Tensor (*)(const Tensor&, const Tensor&, Scalar);
|
||||
// using TensorTensorScalarType = Tensor (*)(const Tensor&, const Tensor&, Scalar);
|
||||
using TensorTensorType = Tensor (*)(const Tensor&, const Tensor&);
|
||||
using TensorScalarType = Tensor (*)(const Tensor&, Scalar);
|
||||
|
||||
// using TensorScalarType = Tensor (*)(const Tensor&, Scalar);
|
||||
//
|
||||
// #define BINARY_POINTWISE(op) \
|
||||
// m.impl(#op".Tensor", binary_pointwise_batching_rule<TensorTensorType, at::op>); \
|
||||
// m.impl(#op".Scalar", unwrap_and_call<TensorScalarType, at::op, Scalar>);
|
||||
@ -1783,7 +1777,6 @@ TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) {
|
||||
// // COMPARISON_POINTWISE(ne);
|
||||
// //
|
||||
// #undef COMPARISON_POINTWISE
|
||||
#undef VMAP_SUPPORT
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -4,6 +4,18 @@
|
||||
namespace at {
|
||||
namespace functorch {
|
||||
|
||||
constexpr auto kBatchedKey = c10::DispatchKey::BatchedOutOfTree;
|
||||
#define FT_BATCHED_KEY FuncTorchBatched
|
||||
#define FT_VMAP_MODE_KEY FuncTorchVmapMode
|
||||
#define FT_GRAD_WRAPPER_KEY FuncTorchGradWrapper
|
||||
#define FT_DYNAMIC_LAYER_FRONT_MODE_KEY FuncTorchDynamicLayerFrontMode
|
||||
#define FT_DYNAMIC_LAYER_BACK_MODE_KEY FuncTorchDynamicLayerBackMode
|
||||
#define FT_PYTHON_KEY FuncTorchPython
|
||||
|
||||
constexpr auto kBatchedKey = c10::DispatchKey::FT_BATCHED_KEY;
|
||||
constexpr auto kVmapModeKey = c10::DispatchKey::FT_VMAP_MODE_KEY;
|
||||
constexpr auto kGradWrapperKey = c10::DispatchKey::FT_GRAD_WRAPPER_KEY;
|
||||
constexpr auto kDynamicLayerFrontModeKey = c10::DispatchKey::FT_DYNAMIC_LAYER_FRONT_MODE_KEY;
|
||||
constexpr auto kDynamicLayerBackModeKey = c10::DispatchKey::FT_DYNAMIC_LAYER_BACK_MODE_KEY;
|
||||
constexpr auto kPythonKey = c10::DispatchKey::FT_PYTHON_KEY;
|
||||
|
||||
}} // namespace at::functorch
|
||||
|
@ -84,8 +84,8 @@ static DynamicLayer popDynamicLayer() {
|
||||
|
||||
if (dynamicLayerStack.size() == 0) {
|
||||
// std::cout << "DynamicLayer off" << std::endl;
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerFront, false);
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerBack, false);
|
||||
c10::impl::tls_set_dispatch_key_included(kDynamicLayerFrontModeKey, false);
|
||||
c10::impl::tls_set_dispatch_key_included(kDynamicLayerBackModeKey, false);
|
||||
}
|
||||
|
||||
return result;
|
||||
@ -100,8 +100,8 @@ static int64_t pushDynamicLayer(DispatchKey key) {
|
||||
|
||||
if (layerId == 2) {
|
||||
// std::cout << "DynamicLayer on" << std::endl;
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerFront, true);
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerBack, true);
|
||||
c10::impl::tls_set_dispatch_key_included(kDynamicLayerFrontModeKey, true);
|
||||
c10::impl::tls_set_dispatch_key_included(kDynamicLayerBackModeKey, true);
|
||||
}
|
||||
|
||||
return layerId;
|
||||
@ -202,11 +202,11 @@ static void foreachTensorInplace(std::vector<IValue>& args, int64_t begin, int64
|
||||
}
|
||||
|
||||
constexpr DispatchKeySet all_dynlayer_keyset = DispatchKeySet({
|
||||
DispatchKey::DynamicLayerFront,
|
||||
DispatchKey::DynamicLayerBack,
|
||||
DispatchKey::TensorWrapper,
|
||||
kDynamicLayerFrontModeKey,
|
||||
kDynamicLayerBackModeKey,
|
||||
kGradWrapperKey,
|
||||
// DispatchKey::Batched,
|
||||
DispatchKey::BatchedOutOfTree,
|
||||
kBatchedKey,
|
||||
DispatchKey::InplaceOrView
|
||||
}) | autograd_dispatch_keyset;
|
||||
|
||||
@ -245,14 +245,14 @@ void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack*
|
||||
auto layer = dynamicLayerStack.back();
|
||||
|
||||
DispatchKeySet exclude = all_dynlayer_keyset;
|
||||
exclude = exclude.remove(DispatchKey::DynamicLayerBack);
|
||||
exclude = exclude.remove(kDynamicLayerBackModeKey);
|
||||
if (layer.key() == DispatchKey::Autograd) {
|
||||
exclude = exclude - autograd_dispatch_keyset;
|
||||
exclude = exclude.remove(DispatchKey::InplaceOrView);
|
||||
// } else if (layer.key() == DispatchKey::Batched) {
|
||||
// exclude = exclude.remove(DispatchKey::Batched);
|
||||
} else if (layer.key() == DispatchKey::BatchedOutOfTree) {
|
||||
exclude = exclude.remove(DispatchKey::BatchedOutOfTree);
|
||||
} else if (layer.key() == kBatchedKey) {
|
||||
exclude = exclude.remove(kBatchedKey);
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(false);
|
||||
}
|
||||
@ -348,8 +348,8 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack*
|
||||
SaveLocalDispatchKeySet save_guard;
|
||||
auto keyset = c10::impl::PODLocalDispatchKeySet();
|
||||
c10::impl::_force_tls_local_dispatch_key_set(keyset);
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerFront, true);
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerBack, true);
|
||||
c10::impl::tls_set_dispatch_key_included(kDynamicLayerFrontModeKey, true);
|
||||
c10::impl::tls_set_dispatch_key_included(kDynamicLayerBackModeKey, true);
|
||||
|
||||
// Re-dispatch
|
||||
op.callBoxed(stack);
|
||||
@ -380,11 +380,11 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack*
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, DynamicLayerFront, m) {
|
||||
TORCH_LIBRARY_IMPL(_, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) {
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallback>());
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, DynamicLayerBack, m) {
|
||||
TORCH_LIBRARY_IMPL(_, FT_DYNAMIC_LAYER_BACK_MODE_KEY, m) {
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerBackFallback>());
|
||||
}
|
||||
|
||||
|
@ -64,7 +64,7 @@ c10::intrusive_ptr<TensorWrapper> makeTensorWrapperPtr(const Tensor& tensor, int
|
||||
key_set = key_set.add(DispatchKey::CPU);
|
||||
key_set = key_set.add(DispatchKey::AutogradCPU);
|
||||
}
|
||||
key_set = key_set.add(DispatchKey::TensorWrapper);
|
||||
key_set = key_set.add(kGradWrapperKey);
|
||||
if (should_be_alive) {
|
||||
return c10::make_intrusive<TensorWrapper>(key_set, tensor, level, getLifeHandleForLevel(level));
|
||||
} else {
|
||||
@ -87,10 +87,10 @@ Tensor makeTensorWrapper(const Tensor& tensor, int64_t level) {
|
||||
key_set = key_set.add(DispatchKey::CPU);
|
||||
key_set = key_set.add(DispatchKey::AutogradCPU);
|
||||
}
|
||||
key_set = key_set.add(DispatchKey::TensorWrapper);
|
||||
key_set = key_set.add(kGradWrapperKey);
|
||||
auto life_handle = getLifeHandleForLevel(level);
|
||||
auto result = at::detail::make_tensor<TensorWrapper>(key_set, tensor, level, std::move(life_handle));
|
||||
TORCH_INTERNAL_ASSERT(result.key_set().has(DispatchKey::TensorWrapper));
|
||||
TORCH_INTERNAL_ASSERT(result.key_set().has(kGradWrapperKey));
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -161,7 +161,7 @@ const char* TensorWrapper::tensorimpl_type_name() const {
|
||||
|
||||
|
||||
TensorWrapper* maybeGetTensorWrapper(const Tensor& tensor) {
|
||||
if (!tensor.key_set().has(DispatchKey::TensorWrapper)) {
|
||||
if (!tensor.key_set().has(kGradWrapperKey)) {
|
||||
return nullptr;
|
||||
}
|
||||
return (TensorWrapper*)(tensor.unsafeGetTensorImpl());
|
||||
@ -224,7 +224,7 @@ void dead_tensor_wrapper_fallback(const c10::OperatorHandle& op, torch::jit::Sta
|
||||
|
||||
// TensorWrapper backend fallback: Unwrap and fallthrough.
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, TensorWrapper, m) {
|
||||
TORCH_LIBRARY_IMPL(_, FT_GRAD_WRAPPER_KEY, m) {
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&dead_tensor_wrapper_fallback>());
|
||||
}
|
||||
|
||||
|
@ -48,7 +48,7 @@ TORCH_LIBRARY(functorch, m) {
|
||||
m.def("new_empty_hack", new_empty_hack_impl);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, DynamicLayerFront, m) {
|
||||
TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) {
|
||||
m.impl("new_zeros", new_zeros_hack);
|
||||
m.impl("new_empty", new_empty_hack);
|
||||
}
|
||||
|
Reference in New Issue
Block a user