[functorch] Added decomposition testing + display infra (pytorch/functorch#281)

* Added decomposition testing + display infra

* add a couple more decompositions

* changed some stuff

* made some changes

* Added decomposition testing + display infra

* add a couple more decompositions

* fix some decompositions

* changed some stuff

* updated generation

* fix test failures

* removed extraneous files

* fixed test failures

* fixed tests

* updated

* fixed tests again
This commit is contained in:
Horace He
2021-11-29 18:19:36 -05:00
committed by Jon Janzen
parent 0fa9f7af83
commit fe9aac72d1
10 changed files with 698 additions and 2291 deletions

View File

@ -18,7 +18,7 @@ from . import _C
# functorch transforms
from ._src.vmap import vmap
from ._src.eager_transforms import grad, grad_and_value, vjp, jacrev
from ._src.python_key import make_fx, pythonkey_decompose, register_decomposition
from ._src.python_key import make_fx
# utilities. Maybe these should go in their own namespace in the future?
from ._src.make_functional import (

View File

@ -13,10 +13,12 @@ import torch.utils._pytree as pytree
from torch.fx import Tracer, GraphModule
import torch.fx as fx
import torch.fx._pytree as fx_pytree
from torch import Tensor
from .nnc_compile import nnc_compile
from enum import Enum
import warnings
from contextlib import contextmanager
aten = torch.ops.aten
decomposition_table = {}
@ -27,33 +29,81 @@ def register_decomposition(aten_op):
return f
return decomposition_decorator
class Reduction(Enum):
NONE = 0
MEAN = 1
SUM = 2
@register_decomposition(aten.tanh_backward)
def tanh_backward_decomposition(out_grad, y):
return aten.sub(out_grad, out_grad * y * y)
def tanh_backward_decomposition(out_grad: Tensor, y: Tensor):
return out_grad * (1 - y * y)
@register_decomposition(aten.sigmoid_backward)
def sigmoid_backward_decomposition(out_grad, y):
def sigmoid_backward_decomposition(out_grad: Tensor, y: Tensor):
return out_grad * (y * (1 - y))
@register_decomposition(aten._s_where)
def _s_where_decomposition(a, b, c):
return aten.where(a, b, c)
# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
@register_decomposition(aten.detach)
def detach_decomposition(x):
def detach_decomposition(x: Tensor):
return x
@register_decomposition(aten.softplus_backward)
def softplus_backward_decomposition(out_grad, x, beta, threshold, out):
# The out argument seems to always be ignored?
def softplus_backward_decomposition(out_grad: Tensor, x: Tensor, beta: float, threshold: float, out):
z = (x * beta).exp()
return aten.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0))
@register_decomposition(aten.elu_backward)
def elu_backward_decomposition(grad_output: Tensor, alpha: float, scale: float, input_scale: float, is_result: bool, self_or_result: Tensor):
negcoef = alpha * scale
poscoef = scale
negiptcoef = input_scale
if is_result:
return aten.where(self_or_result <= 0, grad_output * negiptcoef * (self_or_result + negcoef), self_or_result * poscoef)
else:
return aten.where(self_or_result <= 0, grad_output * negiptcoef * negcoef * aten.exp(self_or_result * negiptcoef), grad_output * poscoef)
@register_decomposition(aten.hardsigmoid_backward)
def hardsigmoid_backward_decomposition(grad_output: Tensor, self: Tensor):
return aten.where((self > -3.0) & (self < 3.0), grad_output * (1.0/6.0), aten.new_zeros(grad_output, ()))
@register_decomposition(aten.hardtanh_backward)
def hardtanh_backward_decomposition(grad_output: Tensor, self: Tensor, min_val: float, max_val: float):
return aten.where((self <= min_val) | (self >= max_val), aten.new_zeros(grad_output, ()), grad_output)
@register_decomposition(aten.hardshrink_backward)
def hardshrink_backward(grad_out: Tensor, self: Tensor, lambd: float):
return aten.where((self >= -lambd) & (self <= lambd), aten.new_zeros(grad_out, ()), grad_out)
# @register_decomposition(aten.threshold_backward)
# def threshold_backward_decomposition(grad_output: Tensor, self: Tensor, threshold: float):
# return aten.where(self <= threshold, aten.new_zeros(grad_output, ()), grad_output)
@register_decomposition(aten.leaky_relu_backward)
def leaky_relu_backward(grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool):
return aten.where(self > 0, grad_output, grad_output * negative_slope)
@register_decomposition(aten.mse_loss_backward)
def mse_loss_backward_decomposition(grad_output: Tensor, input: Tensor, target: Tensor, reduction: int):
norm = 2./input.numel() if reduction == Reduction.MEAN.value else 2.
return norm * (input - target) * grad_output
@register_decomposition(aten.huber_loss_backward)
def huber_loss_backward_decomposition(grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float):
norm = 1./self.numel() if reduction == Reduction.MEAN.value else 1.
x = self - target
return aten.where(x < -delta, -norm * grad_output * delta, aten.where(x > delta, norm * grad_output * delta, norm * x * grad_output))
# @register_decomposition(aten._fused_dropout)
# def _fused_dropout_decomposition(input, p, generator=None):
# mask = aten.to(aten.rand_like(input) < p, dtype=torch.uint8)
# res = mask.type_as(input) * input * (1./p)
# return [res, mask]
@register_decomposition(aten._s_where)
def _s_where_canonicalization(a, b, c):
return aten.where(a, b, c)
USE_DECOMPOSE = False
@contextmanager

View File

@ -1,6 +1,6 @@
from .._src.operator_authoring import pointwise_operator
from .._src.memory_efficient_op_authoring import memory_efficient_operator_authoring, torchscript_nvfuser_compile
from .._src.python_key import nnc_jit, make_nnc, register_decomposition
from .._src.python_key import nnc_jit, make_nnc, register_decomposition, pythonkey_decompose, decomposition_table
from .._src.nnc_compile import nnc_compile, get_ops
from .._src.aot_autograd import (
compiled_function,

View File

@ -76,94 +76,100 @@ def get_ops_for_key(key):
cleaned_ops.append(i[6:].strip())
return set(cleaned_ops)
batched_registrations = get_ops_for_key('FuncTorchBatched')
all_ops = get_ops_for_key(None)
def gen_data(special_op_lists, analysis_name):
all_ops = get_ops_for_key(None)
composite_ops = get_ops_for_key('CompositeImplicitAutograd')
noncomposite_ops = all_ops - composite_ops
composite_ops = get_ops_for_key('CompositeImplicitAutograd')
ops = yaml.load(open('/home/chilli/fb/pytorch/aten/src/ATen/native/native_functions.yaml', 'r').read(), Loader=yaml.CLoader)
annotated_ops = {a.strip(): b.strip() for a,b in list(csv.reader(open('annotated_ops.txt')))}
from collections import defaultdict
vmap_ops = batched_registrations
noncomposite_ops = all_ops - composite_ops
ops = yaml.load(open('/home/chilli/fb/pytorch/aten/src/ATen/native/native_functions.yaml', 'r').read(), Loader=yaml.CLoader)
annotated_ops = {a.strip(): b.strip() for a,b in list(csv.reader(open('annotated_ops.txt')))}
from collections import defaultdict
uniq_ops = []
uniq_names = set()
overload_types = defaultdict(list)
cnt = 0
for op in ops:
func_str = op['func']
name = func_str[:func_str.index('(')]
if '.' in name:
uniq_name = name[:name.index('.')]
overload_types[name[name.index('.') + 1:]].append(name)
else:
uniq_name = name
op['name'] = uniq_name
full_name = func_str[:func_str.index('(')]
op['full_name'] = full_name
ret_type = func_str[func_str.index('->') + 3:]
op['ret_type'] = ret_type
cnt += 1
if uniq_name in uniq_names:
continue
uniq_names.add(uniq_name)
uniq_ops.append(op)
def annotate_ops(ops, is_unique):
categorization = defaultdict(int)
uniq_ops = []
uniq_names = set()
overload_types = defaultdict(list)
cnt = 0
for op in ops:
old_tcnt = sum(categorization.values())
if op['name'][-1] == '_':
categorization['inplace'] += 1
op['meta'] = 'inplace'
continue
if not is_unique and 'a!' in op['func'].lower():
categorization['out'] += 1
op['meta'] = 'out'
continue
if 'conv' in op['name']:
categorization['conv'] += 1
op['meta'] = 'conv'
continue
if 'pool' in op['name']:
categorization['pool'] += 1
op['meta'] = 'pool'
continue
if 'backward' in op['name']:
categorization['backward'] += 1
op['meta'] = 'backward'
continue
if op['name'][0] == '_' and op['name'][1] != '_':
categorization['private'] += 1
op['meta'] = 'private'
continue
if 'batch_norm' in op['name']:
categorization['batch_norm'] += 1
op['meta'] = 'batch_norm'
continue
if 'Tensor' not in op['func'] or'Tensor' not in op['ret_type']:
categorization['non_tensor'] += 1
op['meta'] = 'non_tensor'
continue
if 'cudnn' in op['name'] or 'mkldnn' in op['name'] or 'miopen' in op['name'] or 'native' in op['name'] or 'thnn' in op['name'] or 'slow' in op['name']:
categorization['backend'] += 1
op['meta'] = 'backend'
continue
if op['name'] in annotated_ops:
categorization['core'] += 1
op['meta'] = 'core ' + annotated_ops[op['name']]
func_str = op['func']
name = func_str[:func_str.index('(')]
if '.' in name:
uniq_name = name[:name.index('.')]
overload_types[name[name.index('.') + 1:]].append(name)
else:
categorization['core'] += 1
op['meta'] = 'core unknown'
return categorization
uniq_name = name
op['name'] = uniq_name
full_name = func_str[:func_str.index('(')]
op['full_name'] = full_name
ret_type = func_str[func_str.index('->') + 3:]
op['ret_type'] = ret_type
cnt += 1
if uniq_name in uniq_names:
continue
uniq_names.add(uniq_name)
uniq_ops.append(op)
# categorization = annotate_ops(uniq_ops, True)
categorization = annotate_ops(ops, False)
def annotate_ops(ops, is_unique):
categorization = defaultdict(int)
for op in ops:
old_tcnt = sum(categorization.values())
if op['name'][-1] == '_':
categorization['inplace'] += 1
op['meta'] = 'inplace'
continue
if not is_unique and 'a!' in op['func'].lower():
categorization['out'] += 1
op['meta'] = 'out'
continue
if 'conv' in op['name']:
categorization['conv'] += 1
op['meta'] = 'conv'
continue
if 'pool' in op['name']:
categorization['pool'] += 1
op['meta'] = 'pool'
continue
if 'backward' in op['name']:
categorization['backward'] += 1
op['meta'] = 'backward'
continue
if op['name'][0] == '_' and op['name'][1] != '_':
categorization['private'] += 1
op['meta'] = 'private'
continue
if 'batch_norm' in op['name']:
categorization['batch_norm'] += 1
op['meta'] = 'batch_norm'
continue
if 'Tensor' not in op['func'] or'Tensor' not in op['ret_type']:
categorization['non_tensor'] += 1
op['meta'] = 'non_tensor'
continue
if 'cudnn' in op['name'] or 'mkldnn' in op['name'] or 'miopen' in op['name'] or 'native' in op['name'] or 'thnn' in op['name'] or 'slow' in op['name']:
categorization['backend'] += 1
op['meta'] = 'backend'
continue
if op['name'] in annotated_ops:
categorization['core'] += 1
op['meta'] = 'core ' + annotated_ops[op['name']]
else:
categorization['core'] += 1
op['meta'] = 'core unknown'
return categorization
for op in ops:
info = [op['full_name'], op['meta'], not (op['full_name'] in noncomposite_ops), op['full_name'] in vmap_ops]
print(','.join([str(i) for i in info]))
# categorization = annotate_ops(uniq_ops, True)
categorization = annotate_ops(ops, False)
with open(f"{analysis_name}", 'w') as f:
for op in ops:
info = [op['full_name'], op['meta'], not (op['full_name'] in noncomposite_ops)] + [op['name'] in op_list for op_list in special_op_lists]
f.write(','.join([str(i) for i in info]) + '\n')
# Generates batching rule data
# gen_data([get_ops_for_key('FuncTorchBatched')], 'vmap')
if True:
with open('run_ops.txt', 'r') as f:
opinfo_ops = [i.strip() for i in f.readlines()]
with open('run_decompositions.txt', 'r') as f:
decomposed_ops = [i.strip() for i in f.readlines()]
gen_data([opinfo_ops, decomposed_ops], 'decompositions')

View File

@ -0,0 +1,11 @@
_s_where
elu_backward
hardshrink_backward
hardsigmoid_backward
hardtanh_backward
huber_loss_backward
leaky_relu_backward
mse_loss_backward
sigmoid_backward
softplus_backward
tanh_backward

View File

@ -0,0 +1,364 @@
__and__
__or__
_adaptive_avg_pool2d
_adaptive_avg_pool2d_backward
_adaptive_avg_pool3d
_adaptive_avg_pool3d_backward
_cdist_backward
_cdist_forward
_conj
_conj_physical
_det_lu_based_helper
_euclidean_dist
_fft_c2c
_fft_c2r
_fft_r2c
_histogramdd_bin_edges
_histogramdd_from_bin_cts
_histogramdd_from_bin_tensors
_local_scalar_dense
_log_softmax
_log_softmax_backward_data
_lu_with_info
_reshape_alias
_slow_conv2d_backward
_slow_conv2d_forward
_softmax
_softmax_backward_data
_svd_helper
_to_copy
_unique2
_unsafe_view
abs
acos
acosh
adaptive_max_pool2d
adaptive_max_pool2d_backward
adaptive_max_pool3d
adaptive_max_pool3d_backward
add
add_
addbmm
addcdiv
addcmul
addcmul_
addmm
addmv
addr
alias
all
amax
amin
aminmax
angle
any
argmax
argmin
asin
asinh
atan
atan2
atanh
avg_pool2d
avg_pool2d_backward
avg_pool3d
avg_pool3d_backward
baddbmm
bernoulli_
bincount
bitwise_and
bitwise_and_
bitwise_left_shift
bitwise_not
bitwise_or
bitwise_or_
bitwise_right_shift
bitwise_xor
bmm
bucketize
cat
ceil
ceil_
celu
cholesky
cholesky_inverse
cholesky_solve
clamp
clamp_
clamp_min
clamp_min_
clone
complex
constant_pad_nd
copy_
copysign
cos
cosh
count_nonzero
cummax
cummin
cumprod
cumsum
deg2rad
detach
diag
diagonal
diagonal_scatter
digamma
dist
div
div_
dot
eig
embedding
empty_like
eq
erf
erf_
erfc
erfinv
exp
exp2
exp_
expand
expm1
fill_
flip
floor
floor_divide
fmax
fmin
fmod
frac
frexp
full_like
gather
ge
gelu
geqrf
grid_sampler_2d
grid_sampler_2d_backward
grid_sampler_3d
grid_sampler_3d_backward
gt
hardshrink
hardsigmoid
hardswish
hardswish_backward
hardtanh
histogram
huber_loss
hypot
i0
igamma
igammac
im2col
index
index_add_
index_copy_
index_fill_
index_put_
index_select
inverse
isin
isnan
isneginf
isposinf
kthvalue
le
leaky_relu
lerp
lerp_
lgamma
linalg_cholesky_ex
linalg_cross
linalg_eig
linalg_eigh
linalg_eigvalsh
linalg_householder_product
linalg_inv_ex
linalg_lstsq
linalg_matrix_exp
linalg_pinv
linalg_qr
linalg_slogdet
linalg_solve
linalg_vector_norm
log
log10
log1p
log2
log_sigmoid_backward
log_sigmoid_forward
logaddexp
logaddexp2
logcumsumexp
logdet
logical_and
logical_not
logical_or
logical_xor
logit
logsumexp
lt
lu_solve
lu_unpack
masked_fill_
masked_scatter_
masked_select
max
max_pool2d_with_indices
max_pool2d_with_indices_backward
maximum
mean
median
min
minimum
mish
mkldnn_convolution
mkldnn_convolution_backward
mm
mode
mse_loss
mul
mul_
mv
nan_to_num
nanmedian
nansum
native_batch_norm
native_batch_norm_backward
native_group_norm
native_layer_norm
native_layer_norm_backward
ne
neg
new_empty
nextafter
nll_loss2d_backward
nll_loss2d_forward
nll_loss_backward
nll_loss_forward
nonzero
norm
ones_like
ormqr
permute
polar
polygamma
pow
pow_
prod
put_
rad2deg
randint_like
randn_like
reciprocal
reciprocal_
reflection_pad1d
reflection_pad1d_backward
reflection_pad2d
reflection_pad2d_backward
reflection_pad3d
reflection_pad3d_backward
relu
remainder
renorm
repeat
repeat_interleave
replication_pad1d
replication_pad1d_backward
replication_pad2d
replication_pad2d_backward
replication_pad3d
replication_pad3d_backward
resize_as_
roll
rot90
round
rsqrt
rsub
scatter
scatter_
scatter_add
scatter_add_
searchsorted
select
select_backward
select_scatter
sgn
sigmoid
sign
signbit
sin
sinc
sinh
slice
slice_backward
slice_scatter
slow_conv_dilated2d
slow_conv_dilated2d_backward
slow_conv_transpose2d
slow_conv_transpose2d_backward
softplus
solve
sort
special_entr
special_erfcx
special_i0e
special_i1
special_i1e
special_ndtri
special_xlog1py
special_zeta
split
split_with_sizes
sqrt
sqrt_
squeeze
squeeze_
stack
std
std_mean
sub
sub_
sum
symeig
t
take
tan
tanh
threshold
threshold_backward
to_sparse
topk
trace
transpose
triangular_solve
tril
tril_
triu
triu_
trunc
unbind
unfold
unique_consecutive
unique_dim
unsqueeze
unsqueeze_
upsample_bicubic2d
upsample_bilinear2d
upsample_linear1d
upsample_nearest1d
upsample_nearest2d
upsample_nearest3d
upsample_trilinear3d
var
var_mean
vdot
view
view_as_real
where
xlogy
zero_
zeros_like

File diff suppressed because it is too large Load Diff

View File

View File

@ -16,6 +16,7 @@ import unittest
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
skipCUDAIfNoMagma
from torch.testing._internal.common_device_type import ops, onlyCPU
from torch.testing._internal.common_dtype import floating_types_and, integral_types
from functorch_lagging_op_db import functorch_lagging_op_db
from functorch_additional_op_db import additional_op_db
from common_utils import (
@ -32,6 +33,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
from functorch import grad, vjp, vmap
import torch.autograd.forward_ad as fwAD
from functorch._src.eager_transforms import _as_tuple, jvp
from functorch.compile import decomposition_table
# Version of autograd.grad that handles outputs that don't depend on inputs
def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True):
@ -753,8 +755,170 @@ class TestOperators(TestCase):
self.assertEqual(result_vjps, expected_vjps)
class InplaceError(Exception):
def __repr__(self):
return "Decomposition Tensor with no elem was created (probably due to an in-place op)"
class DecompositionTensor(torch.Tensor):
run_decompositions = set()
run_ops = set()
elem: torch.Tensor
__slots__ = ['elem']
@staticmethod
def __new__(cls, elem):
# The wrapping tensor (PythonTensor) is just a meta tensor, so it
# doesn't hold any memory (meta tensor is generally the preferred type
# of tensor you want to make a subclass from)...
r = torch.Tensor._make_wrapper_subclass(
cls, elem.size(),
strides=elem.stride(), storage_offset=elem.storage_offset(),
dtype=elem.dtype, layout=elem.layout,
device=elem.device, requires_grad=elem.requires_grad
)
# ...the real tensor is held as an element on the tensor.
r.elem = elem
return r
def __repr__(self):
return f"DecompositionTensor(elem={self.elem})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if func in decomposition_table and func != torch.ops.aten.detach:
decomposition = decomposition_table[func]
DecompositionTensor.run_decompositions.add(func)
return decomposition(*args, **kwargs)
DecompositionTensor.run_ops.add(func)
def unwrap_tensor(e):
if isinstance(e, DecompositionTensor):
if not hasattr(e, 'elem'):
raise InplaceError()
return e.elem
return e
real_out = func(*tree_map(unwrap_tensor, args), **tree_map(unwrap_tensor, kwargs))
def wrap_tensor(e):
if e is None:
return DecompositionTensor(torch.empty(()))
return DecompositionTensor(e) if type(e) == torch.Tensor else e
wrapped_out = tree_map(wrap_tensor, real_out)
return wrapped_out
class TestDecompositionOpInfo(TestCase):
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=[torch.float32, torch.float64, torch.float16, torch.bfloat16] + [*integral_types()] )
# entries in here need don't work and need to be fixed.
# Each one of these is a bug (or needs to be investigated)
@skipOps('TestDecompositionOpInfo', 'test_decomposition', {
skip('view_as_complex'),
xfail('linalg.cholesky'),
xfail('linalg.inv'),
xfail('linalg.matrix_power'),
xfail('to_sparse'),
skip('tensor_split'),
skip('nn.functional.ctc_loss'),
skip('mvlgamma'),
# Some weird matmul stuff with int64 matmuls
skip('__rmatmul__'),
skip('linalg.multi_dot'),
skip('matmul'),
# Can't be compared
skip('empty_like'),
skip('new_empty'),
# inplace op
skip('resize_'),
# ???
skip('nanmean'),
})
def test_decomposition(self, device, dtype, op):
if dtype not in op.supported_dtypes(dtype):
self.skipTest("Dtype not in op's supported dtypes")
return
if is_inplace(op, op.get_op()):
self.skipTest("op is inplace")
return
print(op.op)
_requires_grad = op.supports_autograd and dtype.is_floating_point
# print(dtype)
# print(_requires_grad)
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
# Acquires variants to test
def wrap_tensor(x):
if type(x) == torch.Tensor:
return DecompositionTensor(x)
return x
def unwrap_tensor(x):
if type(x) == DecompositionTensor:
if not hasattr(x, 'elem'):
raise InplaceError()
return x.elem
return x
def _assertEqual(a, b):
if dtype == torch.half:
self.assertEqual(a, b, rtol=0.001, atol=1e-04)
else:
self.assertEqual(a, b)
try:
func = op.get_op()
for sample_input in samples:
if _requires_grad:
fn, primals = normalize_op_input_output(func, sample_input)
result = fn(*primals)
out, vjp_fn = ref_vjp(fn, *primals)
cotangents = tree_map(lambda x: torch.randn_like(x), out)
expected_grads = vjp_fn(cotangents)
decomp_out, decomp_vjp_fn = ref_vjp(fn, *tree_map(wrap_tensor, primals))
_assertEqual(out, tree_map(unwrap_tensor, decomp_out))
decomp_grads = decomp_vjp_fn(cotangents)
_assertEqual(expected_grads, tree_map(unwrap_tensor, decomp_grads))
else:
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs
orig_out = func(*args, **kwargs)
args = tree_map(wrap_tensor, args)
kwargs = tree_map(wrap_tensor, kwargs)
decomp_out = func(*args, **kwargs)
self.assertEqual(orig_out, tree_map(unwrap_tensor, decomp_out))
except InplaceError:
self.skipTest("op is inplace")
return
except RuntimeError as e:
if "not implemented for" in str(e):
self.skipTest(str(e))
return
if "Mismatch in shape: grad_output" in str(e):
self.skipTest("Some weird issue with autograd engine and tensor subclasses")
raise e
import gc; gc.collect()
def test_placeholder(self):
with open('op_analysis/run_ops.txt', 'w') as f:
def get_names(l):
return sorted([x.__name__ for x in l])
for op in get_names(DecompositionTensor.run_ops):
f.write(f'{op}\n')
with open('op_analysis/run_decompositions.txt', 'w') as f:
for op in get_names(DecompositionTensor.run_decompositions):
f.write(f'{op}\n')
only_for = ("cpu", "cuda")
instantiate_device_type_tests(TestOperators, globals(), only_for=only_for)
instantiate_device_type_tests(TestDecompositionOpInfo, globals(), only_for=only_for)
if __name__ == '__main__':
run_tests()

View File

@ -23,11 +23,11 @@ from functools import partial, wraps
import functorch
from functorch import (
grad, vjp, vmap, jacrev, grad_and_value,
make_fx, pythonkey_decompose
make_fx
)
from functorch.compile import (
nnc_jit, compiled_function, compiled_module,
partition_with_recompute_fwd_in_bwd
partition_with_recompute_fwd_in_bwd, pythonkey_decompose, decomposition_table
)
from torch.testing._internal.common_device_type import ops, onlyCPU
@ -204,9 +204,7 @@ class TestPythonKey(TestCase):
self.assertEqual(grads, grads2)
class TestPythonKeyOperatorsOpInfo(TestCase):
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
@skipOps('TestPythonKeyOperatorsOpInfo', 'test_make_fx_exhaustive', {
make_fx_failures = {
xfail('to_sparse'),
xfail('allclose'),
xfail('rsub', 'rsub_scalar'),
@ -216,13 +214,16 @@ class TestPythonKeyOperatorsOpInfo(TestCase):
xfail('nn.functional.dropout'),
xfail('linalg.eigvals'),
xfail('nn.functional.ctc_loss'),
xfail('empty_like'), # randomness
xfail('randn_like'), # randomness
xfail('rand_like'), # randomness
xfail('randint_like'), # randomness
skip('new_empty'), # nondeterministic
skip('empty_like'), # nondeterministic
})
}
class TestPythonKeyOperatorsOpInfo(TestCase):
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
@skipOps('TestPythonKeyOperatorsOpInfo', 'test_make_fx_exhaustive', make_fx_failures
)
def test_make_fx_exhaustive(self, device, dtype, op):
def f(args, kwargs):
@ -390,7 +391,6 @@ class TestEagerFusionOpInfo(TestCase):
orig_grad = get_grads(args)
self.assertEqual(orig_grad, compiled_grad)
class TestPartitioning(TestCase):
def test_recompute_partitioning(self):
def fn(a, b):