mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] Added decomposition testing + display infra (pytorch/functorch#281)
* Added decomposition testing + display infra * add a couple more decompositions * changed some stuff * made some changes * Added decomposition testing + display infra * add a couple more decompositions * fix some decompositions * changed some stuff * updated generation * fix test failures * removed extraneous files * fixed test failures * fixed tests * updated * fixed tests again
This commit is contained in:
@ -18,7 +18,7 @@ from . import _C
|
||||
# functorch transforms
|
||||
from ._src.vmap import vmap
|
||||
from ._src.eager_transforms import grad, grad_and_value, vjp, jacrev
|
||||
from ._src.python_key import make_fx, pythonkey_decompose, register_decomposition
|
||||
from ._src.python_key import make_fx
|
||||
|
||||
# utilities. Maybe these should go in their own namespace in the future?
|
||||
from ._src.make_functional import (
|
||||
|
@ -13,10 +13,12 @@ import torch.utils._pytree as pytree
|
||||
from torch.fx import Tracer, GraphModule
|
||||
import torch.fx as fx
|
||||
import torch.fx._pytree as fx_pytree
|
||||
from torch import Tensor
|
||||
from .nnc_compile import nnc_compile
|
||||
from enum import Enum
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
decomposition_table = {}
|
||||
@ -27,33 +29,81 @@ def register_decomposition(aten_op):
|
||||
return f
|
||||
return decomposition_decorator
|
||||
|
||||
class Reduction(Enum):
|
||||
NONE = 0
|
||||
MEAN = 1
|
||||
SUM = 2
|
||||
|
||||
@register_decomposition(aten.tanh_backward)
|
||||
def tanh_backward_decomposition(out_grad, y):
|
||||
return aten.sub(out_grad, out_grad * y * y)
|
||||
def tanh_backward_decomposition(out_grad: Tensor, y: Tensor):
|
||||
return out_grad * (1 - y * y)
|
||||
|
||||
@register_decomposition(aten.sigmoid_backward)
|
||||
def sigmoid_backward_decomposition(out_grad, y):
|
||||
def sigmoid_backward_decomposition(out_grad: Tensor, y: Tensor):
|
||||
return out_grad * (y * (1 - y))
|
||||
|
||||
@register_decomposition(aten._s_where)
|
||||
def _s_where_decomposition(a, b, c):
|
||||
return aten.where(a, b, c)
|
||||
|
||||
# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
|
||||
@register_decomposition(aten.detach)
|
||||
def detach_decomposition(x):
|
||||
def detach_decomposition(x: Tensor):
|
||||
return x
|
||||
|
||||
@register_decomposition(aten.softplus_backward)
|
||||
def softplus_backward_decomposition(out_grad, x, beta, threshold, out):
|
||||
# The out argument seems to always be ignored?
|
||||
def softplus_backward_decomposition(out_grad: Tensor, x: Tensor, beta: float, threshold: float, out):
|
||||
z = (x * beta).exp()
|
||||
return aten.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0))
|
||||
|
||||
@register_decomposition(aten.elu_backward)
|
||||
def elu_backward_decomposition(grad_output: Tensor, alpha: float, scale: float, input_scale: float, is_result: bool, self_or_result: Tensor):
|
||||
negcoef = alpha * scale
|
||||
poscoef = scale
|
||||
negiptcoef = input_scale
|
||||
if is_result:
|
||||
return aten.where(self_or_result <= 0, grad_output * negiptcoef * (self_or_result + negcoef), self_or_result * poscoef)
|
||||
else:
|
||||
return aten.where(self_or_result <= 0, grad_output * negiptcoef * negcoef * aten.exp(self_or_result * negiptcoef), grad_output * poscoef)
|
||||
|
||||
@register_decomposition(aten.hardsigmoid_backward)
|
||||
def hardsigmoid_backward_decomposition(grad_output: Tensor, self: Tensor):
|
||||
return aten.where((self > -3.0) & (self < 3.0), grad_output * (1.0/6.0), aten.new_zeros(grad_output, ()))
|
||||
|
||||
@register_decomposition(aten.hardtanh_backward)
|
||||
def hardtanh_backward_decomposition(grad_output: Tensor, self: Tensor, min_val: float, max_val: float):
|
||||
return aten.where((self <= min_val) | (self >= max_val), aten.new_zeros(grad_output, ()), grad_output)
|
||||
|
||||
@register_decomposition(aten.hardshrink_backward)
|
||||
def hardshrink_backward(grad_out: Tensor, self: Tensor, lambd: float):
|
||||
return aten.where((self >= -lambd) & (self <= lambd), aten.new_zeros(grad_out, ()), grad_out)
|
||||
|
||||
# @register_decomposition(aten.threshold_backward)
|
||||
# def threshold_backward_decomposition(grad_output: Tensor, self: Tensor, threshold: float):
|
||||
# return aten.where(self <= threshold, aten.new_zeros(grad_output, ()), grad_output)
|
||||
|
||||
@register_decomposition(aten.leaky_relu_backward)
|
||||
def leaky_relu_backward(grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool):
|
||||
return aten.where(self > 0, grad_output, grad_output * negative_slope)
|
||||
|
||||
@register_decomposition(aten.mse_loss_backward)
|
||||
def mse_loss_backward_decomposition(grad_output: Tensor, input: Tensor, target: Tensor, reduction: int):
|
||||
norm = 2./input.numel() if reduction == Reduction.MEAN.value else 2.
|
||||
return norm * (input - target) * grad_output
|
||||
|
||||
@register_decomposition(aten.huber_loss_backward)
|
||||
def huber_loss_backward_decomposition(grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float):
|
||||
norm = 1./self.numel() if reduction == Reduction.MEAN.value else 1.
|
||||
x = self - target
|
||||
return aten.where(x < -delta, -norm * grad_output * delta, aten.where(x > delta, norm * grad_output * delta, norm * x * grad_output))
|
||||
|
||||
# @register_decomposition(aten._fused_dropout)
|
||||
# def _fused_dropout_decomposition(input, p, generator=None):
|
||||
# mask = aten.to(aten.rand_like(input) < p, dtype=torch.uint8)
|
||||
# res = mask.type_as(input) * input * (1./p)
|
||||
# return [res, mask]
|
||||
|
||||
@register_decomposition(aten._s_where)
|
||||
def _s_where_canonicalization(a, b, c):
|
||||
return aten.where(a, b, c)
|
||||
|
||||
USE_DECOMPOSE = False
|
||||
|
||||
@contextmanager
|
||||
|
@ -1,6 +1,6 @@
|
||||
from .._src.operator_authoring import pointwise_operator
|
||||
from .._src.memory_efficient_op_authoring import memory_efficient_operator_authoring, torchscript_nvfuser_compile
|
||||
from .._src.python_key import nnc_jit, make_nnc, register_decomposition
|
||||
from .._src.python_key import nnc_jit, make_nnc, register_decomposition, pythonkey_decompose, decomposition_table
|
||||
from .._src.nnc_compile import nnc_compile, get_ops
|
||||
from .._src.aot_autograd import (
|
||||
compiled_function,
|
||||
|
@ -76,94 +76,100 @@ def get_ops_for_key(key):
|
||||
cleaned_ops.append(i[6:].strip())
|
||||
return set(cleaned_ops)
|
||||
|
||||
batched_registrations = get_ops_for_key('FuncTorchBatched')
|
||||
all_ops = get_ops_for_key(None)
|
||||
def gen_data(special_op_lists, analysis_name):
|
||||
all_ops = get_ops_for_key(None)
|
||||
composite_ops = get_ops_for_key('CompositeImplicitAutograd')
|
||||
noncomposite_ops = all_ops - composite_ops
|
||||
|
||||
composite_ops = get_ops_for_key('CompositeImplicitAutograd')
|
||||
ops = yaml.load(open('/home/chilli/fb/pytorch/aten/src/ATen/native/native_functions.yaml', 'r').read(), Loader=yaml.CLoader)
|
||||
|
||||
annotated_ops = {a.strip(): b.strip() for a,b in list(csv.reader(open('annotated_ops.txt')))}
|
||||
from collections import defaultdict
|
||||
|
||||
vmap_ops = batched_registrations
|
||||
noncomposite_ops = all_ops - composite_ops
|
||||
|
||||
ops = yaml.load(open('/home/chilli/fb/pytorch/aten/src/ATen/native/native_functions.yaml', 'r').read(), Loader=yaml.CLoader)
|
||||
|
||||
annotated_ops = {a.strip(): b.strip() for a,b in list(csv.reader(open('annotated_ops.txt')))}
|
||||
from collections import defaultdict
|
||||
|
||||
uniq_ops = []
|
||||
uniq_names = set()
|
||||
overload_types = defaultdict(list)
|
||||
cnt = 0
|
||||
for op in ops:
|
||||
func_str = op['func']
|
||||
name = func_str[:func_str.index('(')]
|
||||
if '.' in name:
|
||||
uniq_name = name[:name.index('.')]
|
||||
overload_types[name[name.index('.') + 1:]].append(name)
|
||||
else:
|
||||
uniq_name = name
|
||||
op['name'] = uniq_name
|
||||
full_name = func_str[:func_str.index('(')]
|
||||
op['full_name'] = full_name
|
||||
ret_type = func_str[func_str.index('->') + 3:]
|
||||
op['ret_type'] = ret_type
|
||||
cnt += 1
|
||||
if uniq_name in uniq_names:
|
||||
continue
|
||||
uniq_names.add(uniq_name)
|
||||
uniq_ops.append(op)
|
||||
|
||||
def annotate_ops(ops, is_unique):
|
||||
categorization = defaultdict(int)
|
||||
uniq_ops = []
|
||||
uniq_names = set()
|
||||
overload_types = defaultdict(list)
|
||||
cnt = 0
|
||||
for op in ops:
|
||||
old_tcnt = sum(categorization.values())
|
||||
if op['name'][-1] == '_':
|
||||
categorization['inplace'] += 1
|
||||
op['meta'] = 'inplace'
|
||||
continue
|
||||
if not is_unique and 'a!' in op['func'].lower():
|
||||
categorization['out'] += 1
|
||||
op['meta'] = 'out'
|
||||
continue
|
||||
if 'conv' in op['name']:
|
||||
categorization['conv'] += 1
|
||||
op['meta'] = 'conv'
|
||||
continue
|
||||
if 'pool' in op['name']:
|
||||
categorization['pool'] += 1
|
||||
op['meta'] = 'pool'
|
||||
continue
|
||||
if 'backward' in op['name']:
|
||||
categorization['backward'] += 1
|
||||
op['meta'] = 'backward'
|
||||
continue
|
||||
if op['name'][0] == '_' and op['name'][1] != '_':
|
||||
categorization['private'] += 1
|
||||
op['meta'] = 'private'
|
||||
continue
|
||||
if 'batch_norm' in op['name']:
|
||||
categorization['batch_norm'] += 1
|
||||
op['meta'] = 'batch_norm'
|
||||
continue
|
||||
if 'Tensor' not in op['func'] or'Tensor' not in op['ret_type']:
|
||||
categorization['non_tensor'] += 1
|
||||
op['meta'] = 'non_tensor'
|
||||
continue
|
||||
if 'cudnn' in op['name'] or 'mkldnn' in op['name'] or 'miopen' in op['name'] or 'native' in op['name'] or 'thnn' in op['name'] or 'slow' in op['name']:
|
||||
categorization['backend'] += 1
|
||||
op['meta'] = 'backend'
|
||||
continue
|
||||
if op['name'] in annotated_ops:
|
||||
categorization['core'] += 1
|
||||
op['meta'] = 'core ' + annotated_ops[op['name']]
|
||||
func_str = op['func']
|
||||
name = func_str[:func_str.index('(')]
|
||||
if '.' in name:
|
||||
uniq_name = name[:name.index('.')]
|
||||
overload_types[name[name.index('.') + 1:]].append(name)
|
||||
else:
|
||||
categorization['core'] += 1
|
||||
op['meta'] = 'core unknown'
|
||||
return categorization
|
||||
uniq_name = name
|
||||
op['name'] = uniq_name
|
||||
full_name = func_str[:func_str.index('(')]
|
||||
op['full_name'] = full_name
|
||||
ret_type = func_str[func_str.index('->') + 3:]
|
||||
op['ret_type'] = ret_type
|
||||
cnt += 1
|
||||
if uniq_name in uniq_names:
|
||||
continue
|
||||
uniq_names.add(uniq_name)
|
||||
uniq_ops.append(op)
|
||||
|
||||
# categorization = annotate_ops(uniq_ops, True)
|
||||
categorization = annotate_ops(ops, False)
|
||||
def annotate_ops(ops, is_unique):
|
||||
categorization = defaultdict(int)
|
||||
for op in ops:
|
||||
old_tcnt = sum(categorization.values())
|
||||
if op['name'][-1] == '_':
|
||||
categorization['inplace'] += 1
|
||||
op['meta'] = 'inplace'
|
||||
continue
|
||||
if not is_unique and 'a!' in op['func'].lower():
|
||||
categorization['out'] += 1
|
||||
op['meta'] = 'out'
|
||||
continue
|
||||
if 'conv' in op['name']:
|
||||
categorization['conv'] += 1
|
||||
op['meta'] = 'conv'
|
||||
continue
|
||||
if 'pool' in op['name']:
|
||||
categorization['pool'] += 1
|
||||
op['meta'] = 'pool'
|
||||
continue
|
||||
if 'backward' in op['name']:
|
||||
categorization['backward'] += 1
|
||||
op['meta'] = 'backward'
|
||||
continue
|
||||
if op['name'][0] == '_' and op['name'][1] != '_':
|
||||
categorization['private'] += 1
|
||||
op['meta'] = 'private'
|
||||
continue
|
||||
if 'batch_norm' in op['name']:
|
||||
categorization['batch_norm'] += 1
|
||||
op['meta'] = 'batch_norm'
|
||||
continue
|
||||
if 'Tensor' not in op['func'] or'Tensor' not in op['ret_type']:
|
||||
categorization['non_tensor'] += 1
|
||||
op['meta'] = 'non_tensor'
|
||||
continue
|
||||
if 'cudnn' in op['name'] or 'mkldnn' in op['name'] or 'miopen' in op['name'] or 'native' in op['name'] or 'thnn' in op['name'] or 'slow' in op['name']:
|
||||
categorization['backend'] += 1
|
||||
op['meta'] = 'backend'
|
||||
continue
|
||||
if op['name'] in annotated_ops:
|
||||
categorization['core'] += 1
|
||||
op['meta'] = 'core ' + annotated_ops[op['name']]
|
||||
else:
|
||||
categorization['core'] += 1
|
||||
op['meta'] = 'core unknown'
|
||||
return categorization
|
||||
|
||||
for op in ops:
|
||||
info = [op['full_name'], op['meta'], not (op['full_name'] in noncomposite_ops), op['full_name'] in vmap_ops]
|
||||
print(','.join([str(i) for i in info]))
|
||||
# categorization = annotate_ops(uniq_ops, True)
|
||||
categorization = annotate_ops(ops, False)
|
||||
|
||||
with open(f"{analysis_name}", 'w') as f:
|
||||
for op in ops:
|
||||
info = [op['full_name'], op['meta'], not (op['full_name'] in noncomposite_ops)] + [op['name'] in op_list for op_list in special_op_lists]
|
||||
f.write(','.join([str(i) for i in info]) + '\n')
|
||||
|
||||
# Generates batching rule data
|
||||
# gen_data([get_ops_for_key('FuncTorchBatched')], 'vmap')
|
||||
if True:
|
||||
with open('run_ops.txt', 'r') as f:
|
||||
opinfo_ops = [i.strip() for i in f.readlines()]
|
||||
with open('run_decompositions.txt', 'r') as f:
|
||||
decomposed_ops = [i.strip() for i in f.readlines()]
|
||||
gen_data([opinfo_ops, decomposed_ops], 'decompositions')
|
11
functorch/op_analysis/run_decompositions.txt
Normal file
11
functorch/op_analysis/run_decompositions.txt
Normal file
@ -0,0 +1,11 @@
|
||||
_s_where
|
||||
elu_backward
|
||||
hardshrink_backward
|
||||
hardsigmoid_backward
|
||||
hardtanh_backward
|
||||
huber_loss_backward
|
||||
leaky_relu_backward
|
||||
mse_loss_backward
|
||||
sigmoid_backward
|
||||
softplus_backward
|
||||
tanh_backward
|
364
functorch/op_analysis/run_ops.txt
Normal file
364
functorch/op_analysis/run_ops.txt
Normal file
@ -0,0 +1,364 @@
|
||||
__and__
|
||||
__or__
|
||||
_adaptive_avg_pool2d
|
||||
_adaptive_avg_pool2d_backward
|
||||
_adaptive_avg_pool3d
|
||||
_adaptive_avg_pool3d_backward
|
||||
_cdist_backward
|
||||
_cdist_forward
|
||||
_conj
|
||||
_conj_physical
|
||||
_det_lu_based_helper
|
||||
_euclidean_dist
|
||||
_fft_c2c
|
||||
_fft_c2r
|
||||
_fft_r2c
|
||||
_histogramdd_bin_edges
|
||||
_histogramdd_from_bin_cts
|
||||
_histogramdd_from_bin_tensors
|
||||
_local_scalar_dense
|
||||
_log_softmax
|
||||
_log_softmax_backward_data
|
||||
_lu_with_info
|
||||
_reshape_alias
|
||||
_slow_conv2d_backward
|
||||
_slow_conv2d_forward
|
||||
_softmax
|
||||
_softmax_backward_data
|
||||
_svd_helper
|
||||
_to_copy
|
||||
_unique2
|
||||
_unsafe_view
|
||||
abs
|
||||
acos
|
||||
acosh
|
||||
adaptive_max_pool2d
|
||||
adaptive_max_pool2d_backward
|
||||
adaptive_max_pool3d
|
||||
adaptive_max_pool3d_backward
|
||||
add
|
||||
add_
|
||||
addbmm
|
||||
addcdiv
|
||||
addcmul
|
||||
addcmul_
|
||||
addmm
|
||||
addmv
|
||||
addr
|
||||
alias
|
||||
all
|
||||
amax
|
||||
amin
|
||||
aminmax
|
||||
angle
|
||||
any
|
||||
argmax
|
||||
argmin
|
||||
asin
|
||||
asinh
|
||||
atan
|
||||
atan2
|
||||
atanh
|
||||
avg_pool2d
|
||||
avg_pool2d_backward
|
||||
avg_pool3d
|
||||
avg_pool3d_backward
|
||||
baddbmm
|
||||
bernoulli_
|
||||
bincount
|
||||
bitwise_and
|
||||
bitwise_and_
|
||||
bitwise_left_shift
|
||||
bitwise_not
|
||||
bitwise_or
|
||||
bitwise_or_
|
||||
bitwise_right_shift
|
||||
bitwise_xor
|
||||
bmm
|
||||
bucketize
|
||||
cat
|
||||
ceil
|
||||
ceil_
|
||||
celu
|
||||
cholesky
|
||||
cholesky_inverse
|
||||
cholesky_solve
|
||||
clamp
|
||||
clamp_
|
||||
clamp_min
|
||||
clamp_min_
|
||||
clone
|
||||
complex
|
||||
constant_pad_nd
|
||||
copy_
|
||||
copysign
|
||||
cos
|
||||
cosh
|
||||
count_nonzero
|
||||
cummax
|
||||
cummin
|
||||
cumprod
|
||||
cumsum
|
||||
deg2rad
|
||||
detach
|
||||
diag
|
||||
diagonal
|
||||
diagonal_scatter
|
||||
digamma
|
||||
dist
|
||||
div
|
||||
div_
|
||||
dot
|
||||
eig
|
||||
embedding
|
||||
empty_like
|
||||
eq
|
||||
erf
|
||||
erf_
|
||||
erfc
|
||||
erfinv
|
||||
exp
|
||||
exp2
|
||||
exp_
|
||||
expand
|
||||
expm1
|
||||
fill_
|
||||
flip
|
||||
floor
|
||||
floor_divide
|
||||
fmax
|
||||
fmin
|
||||
fmod
|
||||
frac
|
||||
frexp
|
||||
full_like
|
||||
gather
|
||||
ge
|
||||
gelu
|
||||
geqrf
|
||||
grid_sampler_2d
|
||||
grid_sampler_2d_backward
|
||||
grid_sampler_3d
|
||||
grid_sampler_3d_backward
|
||||
gt
|
||||
hardshrink
|
||||
hardsigmoid
|
||||
hardswish
|
||||
hardswish_backward
|
||||
hardtanh
|
||||
histogram
|
||||
huber_loss
|
||||
hypot
|
||||
i0
|
||||
igamma
|
||||
igammac
|
||||
im2col
|
||||
index
|
||||
index_add_
|
||||
index_copy_
|
||||
index_fill_
|
||||
index_put_
|
||||
index_select
|
||||
inverse
|
||||
isin
|
||||
isnan
|
||||
isneginf
|
||||
isposinf
|
||||
kthvalue
|
||||
le
|
||||
leaky_relu
|
||||
lerp
|
||||
lerp_
|
||||
lgamma
|
||||
linalg_cholesky_ex
|
||||
linalg_cross
|
||||
linalg_eig
|
||||
linalg_eigh
|
||||
linalg_eigvalsh
|
||||
linalg_householder_product
|
||||
linalg_inv_ex
|
||||
linalg_lstsq
|
||||
linalg_matrix_exp
|
||||
linalg_pinv
|
||||
linalg_qr
|
||||
linalg_slogdet
|
||||
linalg_solve
|
||||
linalg_vector_norm
|
||||
log
|
||||
log10
|
||||
log1p
|
||||
log2
|
||||
log_sigmoid_backward
|
||||
log_sigmoid_forward
|
||||
logaddexp
|
||||
logaddexp2
|
||||
logcumsumexp
|
||||
logdet
|
||||
logical_and
|
||||
logical_not
|
||||
logical_or
|
||||
logical_xor
|
||||
logit
|
||||
logsumexp
|
||||
lt
|
||||
lu_solve
|
||||
lu_unpack
|
||||
masked_fill_
|
||||
masked_scatter_
|
||||
masked_select
|
||||
max
|
||||
max_pool2d_with_indices
|
||||
max_pool2d_with_indices_backward
|
||||
maximum
|
||||
mean
|
||||
median
|
||||
min
|
||||
minimum
|
||||
mish
|
||||
mkldnn_convolution
|
||||
mkldnn_convolution_backward
|
||||
mm
|
||||
mode
|
||||
mse_loss
|
||||
mul
|
||||
mul_
|
||||
mv
|
||||
nan_to_num
|
||||
nanmedian
|
||||
nansum
|
||||
native_batch_norm
|
||||
native_batch_norm_backward
|
||||
native_group_norm
|
||||
native_layer_norm
|
||||
native_layer_norm_backward
|
||||
ne
|
||||
neg
|
||||
new_empty
|
||||
nextafter
|
||||
nll_loss2d_backward
|
||||
nll_loss2d_forward
|
||||
nll_loss_backward
|
||||
nll_loss_forward
|
||||
nonzero
|
||||
norm
|
||||
ones_like
|
||||
ormqr
|
||||
permute
|
||||
polar
|
||||
polygamma
|
||||
pow
|
||||
pow_
|
||||
prod
|
||||
put_
|
||||
rad2deg
|
||||
randint_like
|
||||
randn_like
|
||||
reciprocal
|
||||
reciprocal_
|
||||
reflection_pad1d
|
||||
reflection_pad1d_backward
|
||||
reflection_pad2d
|
||||
reflection_pad2d_backward
|
||||
reflection_pad3d
|
||||
reflection_pad3d_backward
|
||||
relu
|
||||
remainder
|
||||
renorm
|
||||
repeat
|
||||
repeat_interleave
|
||||
replication_pad1d
|
||||
replication_pad1d_backward
|
||||
replication_pad2d
|
||||
replication_pad2d_backward
|
||||
replication_pad3d
|
||||
replication_pad3d_backward
|
||||
resize_as_
|
||||
roll
|
||||
rot90
|
||||
round
|
||||
rsqrt
|
||||
rsub
|
||||
scatter
|
||||
scatter_
|
||||
scatter_add
|
||||
scatter_add_
|
||||
searchsorted
|
||||
select
|
||||
select_backward
|
||||
select_scatter
|
||||
sgn
|
||||
sigmoid
|
||||
sign
|
||||
signbit
|
||||
sin
|
||||
sinc
|
||||
sinh
|
||||
slice
|
||||
slice_backward
|
||||
slice_scatter
|
||||
slow_conv_dilated2d
|
||||
slow_conv_dilated2d_backward
|
||||
slow_conv_transpose2d
|
||||
slow_conv_transpose2d_backward
|
||||
softplus
|
||||
solve
|
||||
sort
|
||||
special_entr
|
||||
special_erfcx
|
||||
special_i0e
|
||||
special_i1
|
||||
special_i1e
|
||||
special_ndtri
|
||||
special_xlog1py
|
||||
special_zeta
|
||||
split
|
||||
split_with_sizes
|
||||
sqrt
|
||||
sqrt_
|
||||
squeeze
|
||||
squeeze_
|
||||
stack
|
||||
std
|
||||
std_mean
|
||||
sub
|
||||
sub_
|
||||
sum
|
||||
symeig
|
||||
t
|
||||
take
|
||||
tan
|
||||
tanh
|
||||
threshold
|
||||
threshold_backward
|
||||
to_sparse
|
||||
topk
|
||||
trace
|
||||
transpose
|
||||
triangular_solve
|
||||
tril
|
||||
tril_
|
||||
triu
|
||||
triu_
|
||||
trunc
|
||||
unbind
|
||||
unfold
|
||||
unique_consecutive
|
||||
unique_dim
|
||||
unsqueeze
|
||||
unsqueeze_
|
||||
upsample_bicubic2d
|
||||
upsample_bilinear2d
|
||||
upsample_linear1d
|
||||
upsample_nearest1d
|
||||
upsample_nearest2d
|
||||
upsample_nearest3d
|
||||
upsample_trilinear3d
|
||||
var
|
||||
var_mean
|
||||
vdot
|
||||
view
|
||||
view_as_real
|
||||
where
|
||||
xlogy
|
||||
zero_
|
||||
zeros_like
|
File diff suppressed because it is too large
Load Diff
0
functorch/op_analysis/vmap
Normal file
0
functorch/op_analysis/vmap
Normal file
@ -16,6 +16,7 @@ import unittest
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
|
||||
skipCUDAIfNoMagma
|
||||
from torch.testing._internal.common_device_type import ops, onlyCPU
|
||||
from torch.testing._internal.common_dtype import floating_types_and, integral_types
|
||||
from functorch_lagging_op_db import functorch_lagging_op_db
|
||||
from functorch_additional_op_db import additional_op_db
|
||||
from common_utils import (
|
||||
@ -32,6 +33,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
|
||||
from functorch import grad, vjp, vmap
|
||||
import torch.autograd.forward_ad as fwAD
|
||||
from functorch._src.eager_transforms import _as_tuple, jvp
|
||||
from functorch.compile import decomposition_table
|
||||
|
||||
# Version of autograd.grad that handles outputs that don't depend on inputs
|
||||
def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True):
|
||||
@ -753,8 +755,170 @@ class TestOperators(TestCase):
|
||||
|
||||
self.assertEqual(result_vjps, expected_vjps)
|
||||
|
||||
class InplaceError(Exception):
|
||||
def __repr__(self):
|
||||
return "Decomposition Tensor with no elem was created (probably due to an in-place op)"
|
||||
|
||||
class DecompositionTensor(torch.Tensor):
|
||||
run_decompositions = set()
|
||||
run_ops = set()
|
||||
|
||||
elem: torch.Tensor
|
||||
|
||||
__slots__ = ['elem']
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, elem):
|
||||
# The wrapping tensor (PythonTensor) is just a meta tensor, so it
|
||||
# doesn't hold any memory (meta tensor is generally the preferred type
|
||||
# of tensor you want to make a subclass from)...
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls, elem.size(),
|
||||
strides=elem.stride(), storage_offset=elem.storage_offset(),
|
||||
dtype=elem.dtype, layout=elem.layout,
|
||||
device=elem.device, requires_grad=elem.requires_grad
|
||||
)
|
||||
|
||||
# ...the real tensor is held as an element on the tensor.
|
||||
r.elem = elem
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
return f"DecompositionTensor(elem={self.elem})"
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
if func in decomposition_table and func != torch.ops.aten.detach:
|
||||
decomposition = decomposition_table[func]
|
||||
DecompositionTensor.run_decompositions.add(func)
|
||||
return decomposition(*args, **kwargs)
|
||||
DecompositionTensor.run_ops.add(func)
|
||||
def unwrap_tensor(e):
|
||||
if isinstance(e, DecompositionTensor):
|
||||
if not hasattr(e, 'elem'):
|
||||
raise InplaceError()
|
||||
return e.elem
|
||||
return e
|
||||
|
||||
real_out = func(*tree_map(unwrap_tensor, args), **tree_map(unwrap_tensor, kwargs))
|
||||
|
||||
def wrap_tensor(e):
|
||||
if e is None:
|
||||
return DecompositionTensor(torch.empty(()))
|
||||
return DecompositionTensor(e) if type(e) == torch.Tensor else e
|
||||
wrapped_out = tree_map(wrap_tensor, real_out)
|
||||
return wrapped_out
|
||||
|
||||
|
||||
class TestDecompositionOpInfo(TestCase):
|
||||
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=[torch.float32, torch.float64, torch.float16, torch.bfloat16] + [*integral_types()] )
|
||||
# entries in here need don't work and need to be fixed.
|
||||
# Each one of these is a bug (or needs to be investigated)
|
||||
@skipOps('TestDecompositionOpInfo', 'test_decomposition', {
|
||||
skip('view_as_complex'),
|
||||
xfail('linalg.cholesky'),
|
||||
xfail('linalg.inv'),
|
||||
xfail('linalg.matrix_power'),
|
||||
xfail('to_sparse'),
|
||||
skip('tensor_split'),
|
||||
skip('nn.functional.ctc_loss'),
|
||||
skip('mvlgamma'),
|
||||
# Some weird matmul stuff with int64 matmuls
|
||||
skip('__rmatmul__'),
|
||||
skip('linalg.multi_dot'),
|
||||
skip('matmul'),
|
||||
# Can't be compared
|
||||
skip('empty_like'),
|
||||
skip('new_empty'),
|
||||
# inplace op
|
||||
skip('resize_'),
|
||||
# ???
|
||||
skip('nanmean'),
|
||||
})
|
||||
def test_decomposition(self, device, dtype, op):
|
||||
if dtype not in op.supported_dtypes(dtype):
|
||||
self.skipTest("Dtype not in op's supported dtypes")
|
||||
return
|
||||
if is_inplace(op, op.get_op()):
|
||||
self.skipTest("op is inplace")
|
||||
return
|
||||
print(op.op)
|
||||
_requires_grad = op.supports_autograd and dtype.is_floating_point
|
||||
# print(dtype)
|
||||
|
||||
# print(_requires_grad)
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
|
||||
# Acquires variants to test
|
||||
def wrap_tensor(x):
|
||||
if type(x) == torch.Tensor:
|
||||
return DecompositionTensor(x)
|
||||
return x
|
||||
|
||||
def unwrap_tensor(x):
|
||||
if type(x) == DecompositionTensor:
|
||||
if not hasattr(x, 'elem'):
|
||||
raise InplaceError()
|
||||
return x.elem
|
||||
return x
|
||||
|
||||
def _assertEqual(a, b):
|
||||
if dtype == torch.half:
|
||||
self.assertEqual(a, b, rtol=0.001, atol=1e-04)
|
||||
else:
|
||||
self.assertEqual(a, b)
|
||||
|
||||
try:
|
||||
func = op.get_op()
|
||||
for sample_input in samples:
|
||||
if _requires_grad:
|
||||
fn, primals = normalize_op_input_output(func, sample_input)
|
||||
result = fn(*primals)
|
||||
out, vjp_fn = ref_vjp(fn, *primals)
|
||||
cotangents = tree_map(lambda x: torch.randn_like(x), out)
|
||||
expected_grads = vjp_fn(cotangents)
|
||||
|
||||
decomp_out, decomp_vjp_fn = ref_vjp(fn, *tree_map(wrap_tensor, primals))
|
||||
_assertEqual(out, tree_map(unwrap_tensor, decomp_out))
|
||||
|
||||
decomp_grads = decomp_vjp_fn(cotangents)
|
||||
_assertEqual(expected_grads, tree_map(unwrap_tensor, decomp_grads))
|
||||
|
||||
else:
|
||||
args = [sample_input.input] + list(sample_input.args)
|
||||
kwargs = sample_input.kwargs
|
||||
orig_out = func(*args, **kwargs)
|
||||
|
||||
args = tree_map(wrap_tensor, args)
|
||||
kwargs = tree_map(wrap_tensor, kwargs)
|
||||
decomp_out = func(*args, **kwargs)
|
||||
self.assertEqual(orig_out, tree_map(unwrap_tensor, decomp_out))
|
||||
|
||||
|
||||
except InplaceError:
|
||||
self.skipTest("op is inplace")
|
||||
return
|
||||
except RuntimeError as e:
|
||||
if "not implemented for" in str(e):
|
||||
self.skipTest(str(e))
|
||||
return
|
||||
if "Mismatch in shape: grad_output" in str(e):
|
||||
self.skipTest("Some weird issue with autograd engine and tensor subclasses")
|
||||
raise e
|
||||
import gc; gc.collect()
|
||||
|
||||
def test_placeholder(self):
|
||||
with open('op_analysis/run_ops.txt', 'w') as f:
|
||||
def get_names(l):
|
||||
return sorted([x.__name__ for x in l])
|
||||
for op in get_names(DecompositionTensor.run_ops):
|
||||
f.write(f'{op}\n')
|
||||
with open('op_analysis/run_decompositions.txt', 'w') as f:
|
||||
for op in get_names(DecompositionTensor.run_decompositions):
|
||||
f.write(f'{op}\n')
|
||||
|
||||
only_for = ("cpu", "cuda")
|
||||
instantiate_device_type_tests(TestOperators, globals(), only_for=only_for)
|
||||
instantiate_device_type_tests(TestDecompositionOpInfo, globals(), only_for=only_for)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
@ -23,11 +23,11 @@ from functools import partial, wraps
|
||||
import functorch
|
||||
from functorch import (
|
||||
grad, vjp, vmap, jacrev, grad_and_value,
|
||||
make_fx, pythonkey_decompose
|
||||
make_fx
|
||||
)
|
||||
from functorch.compile import (
|
||||
nnc_jit, compiled_function, compiled_module,
|
||||
partition_with_recompute_fwd_in_bwd
|
||||
partition_with_recompute_fwd_in_bwd, pythonkey_decompose, decomposition_table
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_device_type import ops, onlyCPU
|
||||
@ -204,9 +204,7 @@ class TestPythonKey(TestCase):
|
||||
self.assertEqual(grads, grads2)
|
||||
|
||||
|
||||
class TestPythonKeyOperatorsOpInfo(TestCase):
|
||||
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
|
||||
@skipOps('TestPythonKeyOperatorsOpInfo', 'test_make_fx_exhaustive', {
|
||||
make_fx_failures = {
|
||||
xfail('to_sparse'),
|
||||
xfail('allclose'),
|
||||
xfail('rsub', 'rsub_scalar'),
|
||||
@ -216,13 +214,16 @@ class TestPythonKeyOperatorsOpInfo(TestCase):
|
||||
xfail('nn.functional.dropout'),
|
||||
xfail('linalg.eigvals'),
|
||||
xfail('nn.functional.ctc_loss'),
|
||||
xfail('empty_like'), # randomness
|
||||
xfail('randn_like'), # randomness
|
||||
xfail('rand_like'), # randomness
|
||||
xfail('randint_like'), # randomness
|
||||
skip('new_empty'), # nondeterministic
|
||||
skip('empty_like'), # nondeterministic
|
||||
})
|
||||
}
|
||||
class TestPythonKeyOperatorsOpInfo(TestCase):
|
||||
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
|
||||
@skipOps('TestPythonKeyOperatorsOpInfo', 'test_make_fx_exhaustive', make_fx_failures
|
||||
)
|
||||
def test_make_fx_exhaustive(self, device, dtype, op):
|
||||
|
||||
def f(args, kwargs):
|
||||
@ -390,7 +391,6 @@ class TestEagerFusionOpInfo(TestCase):
|
||||
orig_grad = get_grads(args)
|
||||
self.assertEqual(orig_grad, compiled_grad)
|
||||
|
||||
|
||||
class TestPartitioning(TestCase):
|
||||
def test_recompute_partitioning(self):
|
||||
def fn(a, b):
|
||||
|
Reference in New Issue
Block a user