[functorch] merged

This commit is contained in:
Horace He
2021-04-30 13:23:53 -07:00
committed by Jon Janzen
parent 98df806b95
commit 9d36895a83
3 changed files with 4 additions and 3 deletions

View File

@ -4,6 +4,8 @@ from . import _C
from ._src.vmap import vmap
from ._src.eager_transforms import grad, grad_and_value, vjp, jacrev
from ._src.make_functional import make_functional, make_functional_with_buffers
from ._src.python_key import wrap_key, WrapModule
# Monkeypatching lol
_old_cross_entropy = torch.nn.functional.cross_entropy

View File

@ -26,6 +26,5 @@ std::tuple<Tensor,optional<int64_t>> basic_unary_batch_rule(
const Tensor& tensor, optional<int64_t> batch_dim, ExtraArgs... extra_args) {
return {Func(tensor, std::forward<ExtraArgs>(extra_args)...), batch_dim};
}
}}

View File

@ -212,7 +212,7 @@ constexpr DispatchKeySet all_dynlayer_keyset = DispatchKeySet({
static void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
auto num_args = op.schema().arguments().size();
foreachTensorInplace(*stack, stack->size() - num_args, stack->size(),
foreachTensorInplace(*stack, stack->size() - num_args, stack->size(),
[](const Tensor& tensor) {
auto* wrapper = maybeGetTensorWrapper(tensor);
TORCH_INTERNAL_ASSERT(wrapper == nullptr);
@ -366,7 +366,7 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack*
for (int64_t arg_idx = 0; arg_idx < args_size; arg_idx++) {
auto& ivalue = (*stack)[args_front + arg_idx];
if (!ivalue.isTensor()) {
continue;
continue;
}
auto maybe_tensor_wrapper = maybeGetTensorWrapper(ivalue.toTensor());
if (!maybe_tensor_wrapper) {