[functorch] Fixed pythonkey tests

This commit is contained in:
Horace He
2021-08-19 01:52:23 -07:00
committed by Jon Janzen
parent b178cf6867
commit eb26b8b97a
3 changed files with 23 additions and 7 deletions

View File

@ -48,8 +48,9 @@ class PythonTensor(torch.Tensor):
return e.elem if isinstance(e, PythonTensor) else e
aten_func = getattr(torch.ops.aten, func.__name__)
proxy_args = pytree.tree_map(unwrap_proxy, args)
proxy_out = aten_func(*proxy_args)
real_out = aten_func(*pytree.tree_map(unwrap_tensor, args))
proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)
proxy_out = aten_func(*proxy_args, **proxy_kwargs)
real_out = aten_func(*pytree.tree_map(unwrap_tensor, args), **pytree.tree_map(unwrap_tensor, kwargs))
def wrap_with_proxy(e, idx):
return PythonTensor(e, proxy_out[idx]) if type(e) == torch.Tensor else e

View File

@ -0,0 +1,20 @@
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <functorch/csrc/BatchRulesHelper.h>
#include <ATen/Operators.h>
#include <functorch/csrc/PlumbingHelper.h>
#include <functorch/csrc/BatchedFallback.h>
#include <ATen/core/dispatch/Dispatcher.h>
namespace at { namespace functorch {
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
// m.impl("index_add_", torch::CppFunction::makeFromBoxedFunction<&batchedTensorForLoopFallback>());
}
}}

View File

@ -148,17 +148,12 @@ class TestPythonKeyOperatorsOpInfo(TestCase):
# entries in here need don't work and need to be fixed.
# Each one of these is a bug
python_fail = {
'var',
'std',
'sort',
'prod',
'to_sparse',
'rsub.rsub_scalar',
'linalg.matrix_power',
'linalg.inv',
'linalg.cholesky',
'linalg.eigvals',
'tensor_split',
'nn.functional.pad.circular',
}
if opinfo_in_dict(op, python_fail):