[functorch] misc cleanup

This commit is contained in:
Richard Zou
2021-04-27 06:49:33 -07:00
committed by Jon Janzen
parent ce453d449e
commit b7096ab83a
7 changed files with 118 additions and 46 deletions

View File

@ -16,6 +16,10 @@ int64_t rankWithoutBatchDim(const Tensor& tensor, optional<int64_t> maybe_batch_
optional<int64_t> valIfNonempty(optional<int64_t> maybe_empty, int64_t new_val);
int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim);
#define VMAP_SUPPORT(op, batch_rule) \
m.impl(op, PrimBatchRule7< \
decltype(&batch_rule), &batch_rule, to_operator_t<decltype(batch_rule)> \
>::apply);
}}

View File

@ -2,6 +2,76 @@
namespace at { namespace functorch {
// Note [Adding vmap support for an operator]
// Hey there! So you have an operator and you want to get it to work with vmap.
// For example, let's say you just invented the `sum.int` operator and want to make
// it so that the following works.
// >>> tensor = torch.randn(B, 3)
// >>> vmap(torch.sum, (0, None))(tensor, 0)` works
// There are three main ways to do so.
//
// Note [Writing batch rule for out-of-place operators]
// If your operator is out-of-place, you can write a batch rule for it.
// The batch rule defines how to perform the operator on inputs where each
// Tensor input may have an additional dimension that is being vmapped over.
// We refer to this dimension as the *batch dimension* or bdim for short.
//
// For example, let's consider writing a batch rule for
// `Tensor sum(const Tensor& self, int64_t dim)`. The signature of the
// batch rule has an additional optional<int64_t> argument after each
// Tensor argument and return. So, in this case, the batch rule has signature
// tuple<Tensor,optional<int64_t>> sum_batch_rule(
// const Tensor& self, optional<int64_t> self_bdim, int64_t dim);
//
// The vmap call above invokes the batch rule with `self = tensor`,
// `self_bdim = 0`, and `dim = 0`. Note that there are **no BatchedTensors**
// involved in this case; there exists some plumbing that automatically unwraps
// BatchedTensors before calling the batch rule.
//
// To write the logic of the batch rule: think about the semantics of the
// `sum` operation if `self` had an additional dimension (indicated by self_bdim):
// - If `self_bdim` is null, then we just do `result = self.sum(dim)` as usual
// - If `self_bdim` is not-null, then we need to modify `dim`. `dim` is equal
// to whatever the user passed in (0 in this case), but we should actually
// perform the reduction over dimension 1 and do `result = self.sum(1)`
// because dim 0 is being vmapped over.
// Finally, we return the result as well as a new bdim
// - If `self_bdim` is null, then there's no batch dim in the result.
// - If `self_bdim` is not-null, then we return where the bdim is.
// Since we invoked `result = self.sum(1)`, the bdim is still at dim 0.
//
// Now that we have written `sum_batch_rule`, we have to register it inside a
// TORCH_LIBRARY_IMPL block:
// TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
// ...
// VMAP_SUPPORT("sum.int", sum_batch_rule);
// ...
// }
//
// Note [Reusing batch rules to add vmap support for a complicated operator]
// Can't figure out how to write a batch rule for a big operation? If the
// operation can be expressed as a composition of other operations that do have
// batch rules, then that is another way to add vmap support. For example,
// consider the following schema
// func: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1)
// and assume we already have batching rules for basic arithmetic operators.
//
// To add vmap support, define a decomposition using the same signature:
// Tensor addcmul_decomp(const Tensor& self, const Tensor& tensor1,
// const Tensor& tensor2, const Scalar& value) {
// auto product = torch.mul(tensor1, tensor2);
// return torch.add(self, product, value);
// }
// And register it inside a TORCH_LIBRARY_IMPL block:
// TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
// ...
// m.impl("addcmul", addcmul_decomp);
// ...
// }
//
// Note [Writing batch rule for in-place operators]
// TODO: This is kinda complicated. Saving this for a future date.
std::tuple<Tensor, optional<int64_t>> flatten_batch_rule(
const Tensor& self,
optional<int64_t> self_bdim,
@ -22,16 +92,9 @@ std::tuple<Tensor,optional<int64_t>> unsqueeze_batch_rule(
return { self.unsqueeze(dim), valIfNonempty(self_bdim, 0) };
}
TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) {
#define VMAP_SUPPORT(op, batch_rule) \
m.impl(op, PrimBatchRule7< \
decltype(&batch_rule), &batch_rule, to_operator_t<decltype(batch_rule)> \
>::apply);
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT("flatten.using_ints", flatten_batch_rule);
VMAP_SUPPORT("unsqueeze", unsqueeze_batch_rule);
#undef VMAP_SUPPORT
}
}}

View File

@ -258,17 +258,17 @@ Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit)
auto size_physical = self_physical.getPhysicalShape(size);
auto self_physical_dim = self_physical.tensor().dim();
TORCH_CHECK(self_physical_dim <= size_physical.size(),
TORCH_CHECK((uint64_t)self_physical_dim <= size_physical.size(),
"expand: the number of sizes provided (", /*logical*/size.size(), ") ",
"must be greater or equal to the number of dimensions in the tensor (",
/*logical dim*/self.dim(), ")");
if (self_physical_dim == size_physical.size()) {
if ((uint64_t)self_physical_dim == size_physical.size()) {
auto result = self_physical.tensor().expand(size_physical, implicit);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
TORCH_INTERNAL_ASSERT(self_physical_dim < size_physical.size());
TORCH_INTERNAL_ASSERT((uint64_t)self_physical_dim < size_physical.size());
// Here, we know we are expanding a (logical) tensor to a larger number
// of dimensions. We have to be careful because we can't call expand directly
// due to the presence of batch dimensions.
@ -363,7 +363,7 @@ std::vector<Tensor> tensor_split_indices_batching_rule(const Tensor& self, IntAr
// Checks if the batch dims in `bdims` appear at the front of the tensor.
static bool areBdimsAtFrontInOrder(BatchDimsRef bdims) {
for (int64_t idx = 0; idx < bdims.size(); idx++) {
for (uint64_t idx = 0; idx < bdims.size(); idx++) {
if (bdims[idx].dim() != idx) {
return false;
}
@ -1380,7 +1380,7 @@ Tensor& BatchedTensor_requires_grad_(Tensor& self, bool requires_grad) {
}
TORCH_LIBRARY_IMPL(_, BatchedOutOfTree, m) {
TORCH_LIBRARY_IMPL(_, FT_BATCHED_KEY, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&batchedTensorForLoopFallback>());
}
@ -1559,13 +1559,7 @@ Tensor matmul_decomposed(
dim_tensor1, "D and ", dim_tensor2, "D");
}
TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) {
#define VMAP_SUPPORT(op, batch_rule) \
m.impl(op, PrimBatchRule7< \
decltype(&batch_rule), &batch_rule, to_operator_t<decltype(batch_rule)> \
>::apply);
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
// VMAP_OUTPLACE_OP("abs", abs_batch_rule);
// m.impl("abs", PrimBatchRule7<decltype(&abs_batch_rule), &abs_batch_rule, to_operator_t<decltype(abs_batch_rule)>>::apply);
@ -1688,10 +1682,10 @@ TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) {
m.impl("clone", clone_batching_rule);
// m.impl("ones_like", ones_like_batching_rule);
using TensorTensorScalarType = Tensor (*)(const Tensor&, const Tensor&, Scalar);
// using TensorTensorScalarType = Tensor (*)(const Tensor&, const Tensor&, Scalar);
using TensorTensorType = Tensor (*)(const Tensor&, const Tensor&);
using TensorScalarType = Tensor (*)(const Tensor&, Scalar);
// using TensorScalarType = Tensor (*)(const Tensor&, Scalar);
//
// #define BINARY_POINTWISE(op) \
// m.impl(#op".Tensor", binary_pointwise_batching_rule<TensorTensorType, at::op>); \
// m.impl(#op".Scalar", unwrap_and_call<TensorScalarType, at::op, Scalar>);
@ -1783,7 +1777,6 @@ TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) {
// // COMPARISON_POINTWISE(ne);
// //
// #undef COMPARISON_POINTWISE
#undef VMAP_SUPPORT
}
}

View File

@ -4,6 +4,18 @@
namespace at {
namespace functorch {
constexpr auto kBatchedKey = c10::DispatchKey::BatchedOutOfTree;
#define FT_BATCHED_KEY FuncTorchBatched
#define FT_VMAP_MODE_KEY FuncTorchVmapMode
#define FT_GRAD_WRAPPER_KEY FuncTorchGradWrapper
#define FT_DYNAMIC_LAYER_FRONT_MODE_KEY FuncTorchDynamicLayerFrontMode
#define FT_DYNAMIC_LAYER_BACK_MODE_KEY FuncTorchDynamicLayerBackMode
#define FT_PYTHON_KEY FuncTorchPython
constexpr auto kBatchedKey = c10::DispatchKey::FT_BATCHED_KEY;
constexpr auto kVmapModeKey = c10::DispatchKey::FT_VMAP_MODE_KEY;
constexpr auto kGradWrapperKey = c10::DispatchKey::FT_GRAD_WRAPPER_KEY;
constexpr auto kDynamicLayerFrontModeKey = c10::DispatchKey::FT_DYNAMIC_LAYER_FRONT_MODE_KEY;
constexpr auto kDynamicLayerBackModeKey = c10::DispatchKey::FT_DYNAMIC_LAYER_BACK_MODE_KEY;
constexpr auto kPythonKey = c10::DispatchKey::FT_PYTHON_KEY;
}} // namespace at::functorch

View File

@ -84,8 +84,8 @@ static DynamicLayer popDynamicLayer() {
if (dynamicLayerStack.size() == 0) {
// std::cout << "DynamicLayer off" << std::endl;
c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerFront, false);
c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerBack, false);
c10::impl::tls_set_dispatch_key_included(kDynamicLayerFrontModeKey, false);
c10::impl::tls_set_dispatch_key_included(kDynamicLayerBackModeKey, false);
}
return result;
@ -100,8 +100,8 @@ static int64_t pushDynamicLayer(DispatchKey key) {
if (layerId == 2) {
// std::cout << "DynamicLayer on" << std::endl;
c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerFront, true);
c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerBack, true);
c10::impl::tls_set_dispatch_key_included(kDynamicLayerFrontModeKey, true);
c10::impl::tls_set_dispatch_key_included(kDynamicLayerBackModeKey, true);
}
return layerId;
@ -202,11 +202,11 @@ static void foreachTensorInplace(std::vector<IValue>& args, int64_t begin, int64
}
constexpr DispatchKeySet all_dynlayer_keyset = DispatchKeySet({
DispatchKey::DynamicLayerFront,
DispatchKey::DynamicLayerBack,
DispatchKey::TensorWrapper,
kDynamicLayerFrontModeKey,
kDynamicLayerBackModeKey,
kGradWrapperKey,
// DispatchKey::Batched,
DispatchKey::BatchedOutOfTree,
kBatchedKey,
DispatchKey::InplaceOrView
}) | autograd_dispatch_keyset;
@ -245,14 +245,14 @@ void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack*
auto layer = dynamicLayerStack.back();
DispatchKeySet exclude = all_dynlayer_keyset;
exclude = exclude.remove(DispatchKey::DynamicLayerBack);
exclude = exclude.remove(kDynamicLayerBackModeKey);
if (layer.key() == DispatchKey::Autograd) {
exclude = exclude - autograd_dispatch_keyset;
exclude = exclude.remove(DispatchKey::InplaceOrView);
// } else if (layer.key() == DispatchKey::Batched) {
// exclude = exclude.remove(DispatchKey::Batched);
} else if (layer.key() == DispatchKey::BatchedOutOfTree) {
exclude = exclude.remove(DispatchKey::BatchedOutOfTree);
} else if (layer.key() == kBatchedKey) {
exclude = exclude.remove(kBatchedKey);
} else {
TORCH_INTERNAL_ASSERT(false);
}
@ -348,8 +348,8 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack*
SaveLocalDispatchKeySet save_guard;
auto keyset = c10::impl::PODLocalDispatchKeySet();
c10::impl::_force_tls_local_dispatch_key_set(keyset);
c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerFront, true);
c10::impl::tls_set_dispatch_key_included(DispatchKey::DynamicLayerBack, true);
c10::impl::tls_set_dispatch_key_included(kDynamicLayerFrontModeKey, true);
c10::impl::tls_set_dispatch_key_included(kDynamicLayerBackModeKey, true);
// Re-dispatch
op.callBoxed(stack);
@ -380,11 +380,11 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack*
}
}
TORCH_LIBRARY_IMPL(_, DynamicLayerFront, m) {
TORCH_LIBRARY_IMPL(_, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallback>());
}
TORCH_LIBRARY_IMPL(_, DynamicLayerBack, m) {
TORCH_LIBRARY_IMPL(_, FT_DYNAMIC_LAYER_BACK_MODE_KEY, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerBackFallback>());
}

View File

@ -64,7 +64,7 @@ c10::intrusive_ptr<TensorWrapper> makeTensorWrapperPtr(const Tensor& tensor, int
key_set = key_set.add(DispatchKey::CPU);
key_set = key_set.add(DispatchKey::AutogradCPU);
}
key_set = key_set.add(DispatchKey::TensorWrapper);
key_set = key_set.add(kGradWrapperKey);
if (should_be_alive) {
return c10::make_intrusive<TensorWrapper>(key_set, tensor, level, getLifeHandleForLevel(level));
} else {
@ -87,10 +87,10 @@ Tensor makeTensorWrapper(const Tensor& tensor, int64_t level) {
key_set = key_set.add(DispatchKey::CPU);
key_set = key_set.add(DispatchKey::AutogradCPU);
}
key_set = key_set.add(DispatchKey::TensorWrapper);
key_set = key_set.add(kGradWrapperKey);
auto life_handle = getLifeHandleForLevel(level);
auto result = at::detail::make_tensor<TensorWrapper>(key_set, tensor, level, std::move(life_handle));
TORCH_INTERNAL_ASSERT(result.key_set().has(DispatchKey::TensorWrapper));
TORCH_INTERNAL_ASSERT(result.key_set().has(kGradWrapperKey));
return result;
}
@ -161,7 +161,7 @@ const char* TensorWrapper::tensorimpl_type_name() const {
TensorWrapper* maybeGetTensorWrapper(const Tensor& tensor) {
if (!tensor.key_set().has(DispatchKey::TensorWrapper)) {
if (!tensor.key_set().has(kGradWrapperKey)) {
return nullptr;
}
return (TensorWrapper*)(tensor.unsafeGetTensorImpl());
@ -224,7 +224,7 @@ void dead_tensor_wrapper_fallback(const c10::OperatorHandle& op, torch::jit::Sta
// TensorWrapper backend fallback: Unwrap and fallthrough.
TORCH_LIBRARY_IMPL(_, TensorWrapper, m) {
TORCH_LIBRARY_IMPL(_, FT_GRAD_WRAPPER_KEY, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&dead_tensor_wrapper_fallback>());
}

View File

@ -48,7 +48,7 @@ TORCH_LIBRARY(functorch, m) {
m.def("new_empty_hack", new_empty_hack_impl);
}
TORCH_LIBRARY_IMPL(aten, DynamicLayerFront, m) {
TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) {
m.impl("new_zeros", new_zeros_hack);
m.impl("new_empty", new_empty_hack);
}