mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] Fixed pythonkey tests
This commit is contained in:
@ -48,8 +48,9 @@ class PythonTensor(torch.Tensor):
|
||||
return e.elem if isinstance(e, PythonTensor) else e
|
||||
aten_func = getattr(torch.ops.aten, func.__name__)
|
||||
proxy_args = pytree.tree_map(unwrap_proxy, args)
|
||||
proxy_out = aten_func(*proxy_args)
|
||||
real_out = aten_func(*pytree.tree_map(unwrap_tensor, args))
|
||||
proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)
|
||||
proxy_out = aten_func(*proxy_args, **proxy_kwargs)
|
||||
real_out = aten_func(*pytree.tree_map(unwrap_tensor, args), **pytree.tree_map(unwrap_tensor, kwargs))
|
||||
|
||||
def wrap_with_proxy(e, idx):
|
||||
return PythonTensor(e, proxy_out[idx]) if type(e) == torch.Tensor else e
|
||||
|
20
functorch/functorch/csrc/BatchRulesStopDecomposition.cpp
Normal file
20
functorch/functorch/csrc/BatchRulesStopDecomposition.cpp
Normal file
@ -0,0 +1,20 @@
|
||||
|
||||
// Copyright (c) Facebook, Inc. and its affiliates.
|
||||
// All rights reserved.
|
||||
//
|
||||
// This source code is licensed under the BSD-style license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
#include <functorch/csrc/BatchRulesHelper.h>
|
||||
#include <ATen/Operators.h>
|
||||
#include <functorch/csrc/PlumbingHelper.h>
|
||||
#include <functorch/csrc/BatchedFallback.h>
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
|
||||
namespace at { namespace functorch {
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
// m.impl("index_add_", torch::CppFunction::makeFromBoxedFunction<&batchedTensorForLoopFallback>());
|
||||
}
|
||||
|
||||
}}
|
@ -148,17 +148,12 @@ class TestPythonKeyOperatorsOpInfo(TestCase):
|
||||
# entries in here need don't work and need to be fixed.
|
||||
# Each one of these is a bug
|
||||
python_fail = {
|
||||
'var',
|
||||
'std',
|
||||
'sort',
|
||||
'prod',
|
||||
'to_sparse',
|
||||
'rsub.rsub_scalar',
|
||||
'linalg.matrix_power',
|
||||
'linalg.inv',
|
||||
'linalg.cholesky',
|
||||
'linalg.eigvals',
|
||||
'tensor_split',
|
||||
'nn.functional.pad.circular',
|
||||
}
|
||||
if opinfo_in_dict(op, python_fail):
|
||||
|
Reference in New Issue
Block a user