mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[functorch] merged
This commit is contained in:
@ -4,6 +4,8 @@ from . import _C
|
||||
from ._src.vmap import vmap
|
||||
from ._src.eager_transforms import grad, grad_and_value, vjp, jacrev
|
||||
from ._src.make_functional import make_functional, make_functional_with_buffers
|
||||
from ._src.python_key import wrap_key, WrapModule
|
||||
|
||||
|
||||
# Monkeypatching lol
|
||||
_old_cross_entropy = torch.nn.functional.cross_entropy
|
||||
|
@ -26,6 +26,5 @@ std::tuple<Tensor,optional<int64_t>> basic_unary_batch_rule(
|
||||
const Tensor& tensor, optional<int64_t> batch_dim, ExtraArgs... extra_args) {
|
||||
return {Func(tensor, std::forward<ExtraArgs>(extra_args)...), batch_dim};
|
||||
}
|
||||
|
||||
}}
|
||||
|
||||
|
@ -212,7 +212,7 @@ constexpr DispatchKeySet all_dynlayer_keyset = DispatchKeySet({
|
||||
|
||||
static void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
auto num_args = op.schema().arguments().size();
|
||||
foreachTensorInplace(*stack, stack->size() - num_args, stack->size(),
|
||||
foreachTensorInplace(*stack, stack->size() - num_args, stack->size(),
|
||||
[](const Tensor& tensor) {
|
||||
auto* wrapper = maybeGetTensorWrapper(tensor);
|
||||
TORCH_INTERNAL_ASSERT(wrapper == nullptr);
|
||||
@ -366,7 +366,7 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack*
|
||||
for (int64_t arg_idx = 0; arg_idx < args_size; arg_idx++) {
|
||||
auto& ivalue = (*stack)[args_front + arg_idx];
|
||||
if (!ivalue.isTensor()) {
|
||||
continue;
|
||||
continue;
|
||||
}
|
||||
auto maybe_tensor_wrapper = maybeGetTensorWrapper(ivalue.toTensor());
|
||||
if (!maybe_tensor_wrapper) {
|
||||
|
Reference in New Issue
Block a user