mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[functorch] Added decomposition testing + display infra (pytorch/functorch#281)
* Added decomposition testing + display infra * add a couple more decompositions * changed some stuff * made some changes * Added decomposition testing + display infra * add a couple more decompositions * fix some decompositions * changed some stuff * updated generation * fix test failures * removed extraneous files * fixed test failures * fixed tests * updated * fixed tests again
This commit is contained in:
@ -18,7 +18,7 @@ from . import _C
|
|||||||
# functorch transforms
|
# functorch transforms
|
||||||
from ._src.vmap import vmap
|
from ._src.vmap import vmap
|
||||||
from ._src.eager_transforms import grad, grad_and_value, vjp, jacrev
|
from ._src.eager_transforms import grad, grad_and_value, vjp, jacrev
|
||||||
from ._src.python_key import make_fx, pythonkey_decompose, register_decomposition
|
from ._src.python_key import make_fx
|
||||||
|
|
||||||
# utilities. Maybe these should go in their own namespace in the future?
|
# utilities. Maybe these should go in their own namespace in the future?
|
||||||
from ._src.make_functional import (
|
from ._src.make_functional import (
|
||||||
|
@ -13,10 +13,12 @@ import torch.utils._pytree as pytree
|
|||||||
from torch.fx import Tracer, GraphModule
|
from torch.fx import Tracer, GraphModule
|
||||||
import torch.fx as fx
|
import torch.fx as fx
|
||||||
import torch.fx._pytree as fx_pytree
|
import torch.fx._pytree as fx_pytree
|
||||||
|
from torch import Tensor
|
||||||
from .nnc_compile import nnc_compile
|
from .nnc_compile import nnc_compile
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
|
||||||
decomposition_table = {}
|
decomposition_table = {}
|
||||||
@ -27,33 +29,81 @@ def register_decomposition(aten_op):
|
|||||||
return f
|
return f
|
||||||
return decomposition_decorator
|
return decomposition_decorator
|
||||||
|
|
||||||
|
class Reduction(Enum):
|
||||||
|
NONE = 0
|
||||||
|
MEAN = 1
|
||||||
|
SUM = 2
|
||||||
|
|
||||||
@register_decomposition(aten.tanh_backward)
|
@register_decomposition(aten.tanh_backward)
|
||||||
def tanh_backward_decomposition(out_grad, y):
|
def tanh_backward_decomposition(out_grad: Tensor, y: Tensor):
|
||||||
return aten.sub(out_grad, out_grad * y * y)
|
return out_grad * (1 - y * y)
|
||||||
|
|
||||||
@register_decomposition(aten.sigmoid_backward)
|
@register_decomposition(aten.sigmoid_backward)
|
||||||
def sigmoid_backward_decomposition(out_grad, y):
|
def sigmoid_backward_decomposition(out_grad: Tensor, y: Tensor):
|
||||||
return out_grad * (y * (1 - y))
|
return out_grad * (y * (1 - y))
|
||||||
|
|
||||||
@register_decomposition(aten._s_where)
|
# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
|
||||||
def _s_where_decomposition(a, b, c):
|
|
||||||
return aten.where(a, b, c)
|
|
||||||
|
|
||||||
@register_decomposition(aten.detach)
|
@register_decomposition(aten.detach)
|
||||||
def detach_decomposition(x):
|
def detach_decomposition(x: Tensor):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@register_decomposition(aten.softplus_backward)
|
@register_decomposition(aten.softplus_backward)
|
||||||
def softplus_backward_decomposition(out_grad, x, beta, threshold, out):
|
# The out argument seems to always be ignored?
|
||||||
|
def softplus_backward_decomposition(out_grad: Tensor, x: Tensor, beta: float, threshold: float, out):
|
||||||
z = (x * beta).exp()
|
z = (x * beta).exp()
|
||||||
return aten.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0))
|
return aten.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0))
|
||||||
|
|
||||||
|
@register_decomposition(aten.elu_backward)
|
||||||
|
def elu_backward_decomposition(grad_output: Tensor, alpha: float, scale: float, input_scale: float, is_result: bool, self_or_result: Tensor):
|
||||||
|
negcoef = alpha * scale
|
||||||
|
poscoef = scale
|
||||||
|
negiptcoef = input_scale
|
||||||
|
if is_result:
|
||||||
|
return aten.where(self_or_result <= 0, grad_output * negiptcoef * (self_or_result + negcoef), self_or_result * poscoef)
|
||||||
|
else:
|
||||||
|
return aten.where(self_or_result <= 0, grad_output * negiptcoef * negcoef * aten.exp(self_or_result * negiptcoef), grad_output * poscoef)
|
||||||
|
|
||||||
|
@register_decomposition(aten.hardsigmoid_backward)
|
||||||
|
def hardsigmoid_backward_decomposition(grad_output: Tensor, self: Tensor):
|
||||||
|
return aten.where((self > -3.0) & (self < 3.0), grad_output * (1.0/6.0), aten.new_zeros(grad_output, ()))
|
||||||
|
|
||||||
|
@register_decomposition(aten.hardtanh_backward)
|
||||||
|
def hardtanh_backward_decomposition(grad_output: Tensor, self: Tensor, min_val: float, max_val: float):
|
||||||
|
return aten.where((self <= min_val) | (self >= max_val), aten.new_zeros(grad_output, ()), grad_output)
|
||||||
|
|
||||||
|
@register_decomposition(aten.hardshrink_backward)
|
||||||
|
def hardshrink_backward(grad_out: Tensor, self: Tensor, lambd: float):
|
||||||
|
return aten.where((self >= -lambd) & (self <= lambd), aten.new_zeros(grad_out, ()), grad_out)
|
||||||
|
|
||||||
|
# @register_decomposition(aten.threshold_backward)
|
||||||
|
# def threshold_backward_decomposition(grad_output: Tensor, self: Tensor, threshold: float):
|
||||||
|
# return aten.where(self <= threshold, aten.new_zeros(grad_output, ()), grad_output)
|
||||||
|
|
||||||
|
@register_decomposition(aten.leaky_relu_backward)
|
||||||
|
def leaky_relu_backward(grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool):
|
||||||
|
return aten.where(self > 0, grad_output, grad_output * negative_slope)
|
||||||
|
|
||||||
|
@register_decomposition(aten.mse_loss_backward)
|
||||||
|
def mse_loss_backward_decomposition(grad_output: Tensor, input: Tensor, target: Tensor, reduction: int):
|
||||||
|
norm = 2./input.numel() if reduction == Reduction.MEAN.value else 2.
|
||||||
|
return norm * (input - target) * grad_output
|
||||||
|
|
||||||
|
@register_decomposition(aten.huber_loss_backward)
|
||||||
|
def huber_loss_backward_decomposition(grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float):
|
||||||
|
norm = 1./self.numel() if reduction == Reduction.MEAN.value else 1.
|
||||||
|
x = self - target
|
||||||
|
return aten.where(x < -delta, -norm * grad_output * delta, aten.where(x > delta, norm * grad_output * delta, norm * x * grad_output))
|
||||||
|
|
||||||
# @register_decomposition(aten._fused_dropout)
|
# @register_decomposition(aten._fused_dropout)
|
||||||
# def _fused_dropout_decomposition(input, p, generator=None):
|
# def _fused_dropout_decomposition(input, p, generator=None):
|
||||||
# mask = aten.to(aten.rand_like(input) < p, dtype=torch.uint8)
|
# mask = aten.to(aten.rand_like(input) < p, dtype=torch.uint8)
|
||||||
# res = mask.type_as(input) * input * (1./p)
|
# res = mask.type_as(input) * input * (1./p)
|
||||||
# return [res, mask]
|
# return [res, mask]
|
||||||
|
|
||||||
|
@register_decomposition(aten._s_where)
|
||||||
|
def _s_where_canonicalization(a, b, c):
|
||||||
|
return aten.where(a, b, c)
|
||||||
|
|
||||||
USE_DECOMPOSE = False
|
USE_DECOMPOSE = False
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from .._src.operator_authoring import pointwise_operator
|
from .._src.operator_authoring import pointwise_operator
|
||||||
from .._src.memory_efficient_op_authoring import memory_efficient_operator_authoring, torchscript_nvfuser_compile
|
from .._src.memory_efficient_op_authoring import memory_efficient_operator_authoring, torchscript_nvfuser_compile
|
||||||
from .._src.python_key import nnc_jit, make_nnc, register_decomposition
|
from .._src.python_key import nnc_jit, make_nnc, register_decomposition, pythonkey_decompose, decomposition_table
|
||||||
from .._src.nnc_compile import nnc_compile, get_ops
|
from .._src.nnc_compile import nnc_compile, get_ops
|
||||||
from .._src.aot_autograd import (
|
from .._src.aot_autograd import (
|
||||||
compiled_function,
|
compiled_function,
|
||||||
|
@ -76,94 +76,100 @@ def get_ops_for_key(key):
|
|||||||
cleaned_ops.append(i[6:].strip())
|
cleaned_ops.append(i[6:].strip())
|
||||||
return set(cleaned_ops)
|
return set(cleaned_ops)
|
||||||
|
|
||||||
batched_registrations = get_ops_for_key('FuncTorchBatched')
|
def gen_data(special_op_lists, analysis_name):
|
||||||
all_ops = get_ops_for_key(None)
|
all_ops = get_ops_for_key(None)
|
||||||
|
composite_ops = get_ops_for_key('CompositeImplicitAutograd')
|
||||||
|
noncomposite_ops = all_ops - composite_ops
|
||||||
|
|
||||||
composite_ops = get_ops_for_key('CompositeImplicitAutograd')
|
ops = yaml.load(open('/home/chilli/fb/pytorch/aten/src/ATen/native/native_functions.yaml', 'r').read(), Loader=yaml.CLoader)
|
||||||
|
|
||||||
|
annotated_ops = {a.strip(): b.strip() for a,b in list(csv.reader(open('annotated_ops.txt')))}
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
vmap_ops = batched_registrations
|
uniq_ops = []
|
||||||
noncomposite_ops = all_ops - composite_ops
|
uniq_names = set()
|
||||||
|
overload_types = defaultdict(list)
|
||||||
ops = yaml.load(open('/home/chilli/fb/pytorch/aten/src/ATen/native/native_functions.yaml', 'r').read(), Loader=yaml.CLoader)
|
cnt = 0
|
||||||
|
|
||||||
annotated_ops = {a.strip(): b.strip() for a,b in list(csv.reader(open('annotated_ops.txt')))}
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
uniq_ops = []
|
|
||||||
uniq_names = set()
|
|
||||||
overload_types = defaultdict(list)
|
|
||||||
cnt = 0
|
|
||||||
for op in ops:
|
|
||||||
func_str = op['func']
|
|
||||||
name = func_str[:func_str.index('(')]
|
|
||||||
if '.' in name:
|
|
||||||
uniq_name = name[:name.index('.')]
|
|
||||||
overload_types[name[name.index('.') + 1:]].append(name)
|
|
||||||
else:
|
|
||||||
uniq_name = name
|
|
||||||
op['name'] = uniq_name
|
|
||||||
full_name = func_str[:func_str.index('(')]
|
|
||||||
op['full_name'] = full_name
|
|
||||||
ret_type = func_str[func_str.index('->') + 3:]
|
|
||||||
op['ret_type'] = ret_type
|
|
||||||
cnt += 1
|
|
||||||
if uniq_name in uniq_names:
|
|
||||||
continue
|
|
||||||
uniq_names.add(uniq_name)
|
|
||||||
uniq_ops.append(op)
|
|
||||||
|
|
||||||
def annotate_ops(ops, is_unique):
|
|
||||||
categorization = defaultdict(int)
|
|
||||||
for op in ops:
|
for op in ops:
|
||||||
old_tcnt = sum(categorization.values())
|
func_str = op['func']
|
||||||
if op['name'][-1] == '_':
|
name = func_str[:func_str.index('(')]
|
||||||
categorization['inplace'] += 1
|
if '.' in name:
|
||||||
op['meta'] = 'inplace'
|
uniq_name = name[:name.index('.')]
|
||||||
continue
|
overload_types[name[name.index('.') + 1:]].append(name)
|
||||||
if not is_unique and 'a!' in op['func'].lower():
|
|
||||||
categorization['out'] += 1
|
|
||||||
op['meta'] = 'out'
|
|
||||||
continue
|
|
||||||
if 'conv' in op['name']:
|
|
||||||
categorization['conv'] += 1
|
|
||||||
op['meta'] = 'conv'
|
|
||||||
continue
|
|
||||||
if 'pool' in op['name']:
|
|
||||||
categorization['pool'] += 1
|
|
||||||
op['meta'] = 'pool'
|
|
||||||
continue
|
|
||||||
if 'backward' in op['name']:
|
|
||||||
categorization['backward'] += 1
|
|
||||||
op['meta'] = 'backward'
|
|
||||||
continue
|
|
||||||
if op['name'][0] == '_' and op['name'][1] != '_':
|
|
||||||
categorization['private'] += 1
|
|
||||||
op['meta'] = 'private'
|
|
||||||
continue
|
|
||||||
if 'batch_norm' in op['name']:
|
|
||||||
categorization['batch_norm'] += 1
|
|
||||||
op['meta'] = 'batch_norm'
|
|
||||||
continue
|
|
||||||
if 'Tensor' not in op['func'] or'Tensor' not in op['ret_type']:
|
|
||||||
categorization['non_tensor'] += 1
|
|
||||||
op['meta'] = 'non_tensor'
|
|
||||||
continue
|
|
||||||
if 'cudnn' in op['name'] or 'mkldnn' in op['name'] or 'miopen' in op['name'] or 'native' in op['name'] or 'thnn' in op['name'] or 'slow' in op['name']:
|
|
||||||
categorization['backend'] += 1
|
|
||||||
op['meta'] = 'backend'
|
|
||||||
continue
|
|
||||||
if op['name'] in annotated_ops:
|
|
||||||
categorization['core'] += 1
|
|
||||||
op['meta'] = 'core ' + annotated_ops[op['name']]
|
|
||||||
else:
|
else:
|
||||||
categorization['core'] += 1
|
uniq_name = name
|
||||||
op['meta'] = 'core unknown'
|
op['name'] = uniq_name
|
||||||
return categorization
|
full_name = func_str[:func_str.index('(')]
|
||||||
|
op['full_name'] = full_name
|
||||||
|
ret_type = func_str[func_str.index('->') + 3:]
|
||||||
|
op['ret_type'] = ret_type
|
||||||
|
cnt += 1
|
||||||
|
if uniq_name in uniq_names:
|
||||||
|
continue
|
||||||
|
uniq_names.add(uniq_name)
|
||||||
|
uniq_ops.append(op)
|
||||||
|
|
||||||
# categorization = annotate_ops(uniq_ops, True)
|
def annotate_ops(ops, is_unique):
|
||||||
categorization = annotate_ops(ops, False)
|
categorization = defaultdict(int)
|
||||||
|
for op in ops:
|
||||||
|
old_tcnt = sum(categorization.values())
|
||||||
|
if op['name'][-1] == '_':
|
||||||
|
categorization['inplace'] += 1
|
||||||
|
op['meta'] = 'inplace'
|
||||||
|
continue
|
||||||
|
if not is_unique and 'a!' in op['func'].lower():
|
||||||
|
categorization['out'] += 1
|
||||||
|
op['meta'] = 'out'
|
||||||
|
continue
|
||||||
|
if 'conv' in op['name']:
|
||||||
|
categorization['conv'] += 1
|
||||||
|
op['meta'] = 'conv'
|
||||||
|
continue
|
||||||
|
if 'pool' in op['name']:
|
||||||
|
categorization['pool'] += 1
|
||||||
|
op['meta'] = 'pool'
|
||||||
|
continue
|
||||||
|
if 'backward' in op['name']:
|
||||||
|
categorization['backward'] += 1
|
||||||
|
op['meta'] = 'backward'
|
||||||
|
continue
|
||||||
|
if op['name'][0] == '_' and op['name'][1] != '_':
|
||||||
|
categorization['private'] += 1
|
||||||
|
op['meta'] = 'private'
|
||||||
|
continue
|
||||||
|
if 'batch_norm' in op['name']:
|
||||||
|
categorization['batch_norm'] += 1
|
||||||
|
op['meta'] = 'batch_norm'
|
||||||
|
continue
|
||||||
|
if 'Tensor' not in op['func'] or'Tensor' not in op['ret_type']:
|
||||||
|
categorization['non_tensor'] += 1
|
||||||
|
op['meta'] = 'non_tensor'
|
||||||
|
continue
|
||||||
|
if 'cudnn' in op['name'] or 'mkldnn' in op['name'] or 'miopen' in op['name'] or 'native' in op['name'] or 'thnn' in op['name'] or 'slow' in op['name']:
|
||||||
|
categorization['backend'] += 1
|
||||||
|
op['meta'] = 'backend'
|
||||||
|
continue
|
||||||
|
if op['name'] in annotated_ops:
|
||||||
|
categorization['core'] += 1
|
||||||
|
op['meta'] = 'core ' + annotated_ops[op['name']]
|
||||||
|
else:
|
||||||
|
categorization['core'] += 1
|
||||||
|
op['meta'] = 'core unknown'
|
||||||
|
return categorization
|
||||||
|
|
||||||
for op in ops:
|
# categorization = annotate_ops(uniq_ops, True)
|
||||||
info = [op['full_name'], op['meta'], not (op['full_name'] in noncomposite_ops), op['full_name'] in vmap_ops]
|
categorization = annotate_ops(ops, False)
|
||||||
print(','.join([str(i) for i in info]))
|
|
||||||
|
with open(f"{analysis_name}", 'w') as f:
|
||||||
|
for op in ops:
|
||||||
|
info = [op['full_name'], op['meta'], not (op['full_name'] in noncomposite_ops)] + [op['name'] in op_list for op_list in special_op_lists]
|
||||||
|
f.write(','.join([str(i) for i in info]) + '\n')
|
||||||
|
|
||||||
|
# Generates batching rule data
|
||||||
|
# gen_data([get_ops_for_key('FuncTorchBatched')], 'vmap')
|
||||||
|
if True:
|
||||||
|
with open('run_ops.txt', 'r') as f:
|
||||||
|
opinfo_ops = [i.strip() for i in f.readlines()]
|
||||||
|
with open('run_decompositions.txt', 'r') as f:
|
||||||
|
decomposed_ops = [i.strip() for i in f.readlines()]
|
||||||
|
gen_data([opinfo_ops, decomposed_ops], 'decompositions')
|
11
functorch/op_analysis/run_decompositions.txt
Normal file
11
functorch/op_analysis/run_decompositions.txt
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
_s_where
|
||||||
|
elu_backward
|
||||||
|
hardshrink_backward
|
||||||
|
hardsigmoid_backward
|
||||||
|
hardtanh_backward
|
||||||
|
huber_loss_backward
|
||||||
|
leaky_relu_backward
|
||||||
|
mse_loss_backward
|
||||||
|
sigmoid_backward
|
||||||
|
softplus_backward
|
||||||
|
tanh_backward
|
364
functorch/op_analysis/run_ops.txt
Normal file
364
functorch/op_analysis/run_ops.txt
Normal file
@ -0,0 +1,364 @@
|
|||||||
|
__and__
|
||||||
|
__or__
|
||||||
|
_adaptive_avg_pool2d
|
||||||
|
_adaptive_avg_pool2d_backward
|
||||||
|
_adaptive_avg_pool3d
|
||||||
|
_adaptive_avg_pool3d_backward
|
||||||
|
_cdist_backward
|
||||||
|
_cdist_forward
|
||||||
|
_conj
|
||||||
|
_conj_physical
|
||||||
|
_det_lu_based_helper
|
||||||
|
_euclidean_dist
|
||||||
|
_fft_c2c
|
||||||
|
_fft_c2r
|
||||||
|
_fft_r2c
|
||||||
|
_histogramdd_bin_edges
|
||||||
|
_histogramdd_from_bin_cts
|
||||||
|
_histogramdd_from_bin_tensors
|
||||||
|
_local_scalar_dense
|
||||||
|
_log_softmax
|
||||||
|
_log_softmax_backward_data
|
||||||
|
_lu_with_info
|
||||||
|
_reshape_alias
|
||||||
|
_slow_conv2d_backward
|
||||||
|
_slow_conv2d_forward
|
||||||
|
_softmax
|
||||||
|
_softmax_backward_data
|
||||||
|
_svd_helper
|
||||||
|
_to_copy
|
||||||
|
_unique2
|
||||||
|
_unsafe_view
|
||||||
|
abs
|
||||||
|
acos
|
||||||
|
acosh
|
||||||
|
adaptive_max_pool2d
|
||||||
|
adaptive_max_pool2d_backward
|
||||||
|
adaptive_max_pool3d
|
||||||
|
adaptive_max_pool3d_backward
|
||||||
|
add
|
||||||
|
add_
|
||||||
|
addbmm
|
||||||
|
addcdiv
|
||||||
|
addcmul
|
||||||
|
addcmul_
|
||||||
|
addmm
|
||||||
|
addmv
|
||||||
|
addr
|
||||||
|
alias
|
||||||
|
all
|
||||||
|
amax
|
||||||
|
amin
|
||||||
|
aminmax
|
||||||
|
angle
|
||||||
|
any
|
||||||
|
argmax
|
||||||
|
argmin
|
||||||
|
asin
|
||||||
|
asinh
|
||||||
|
atan
|
||||||
|
atan2
|
||||||
|
atanh
|
||||||
|
avg_pool2d
|
||||||
|
avg_pool2d_backward
|
||||||
|
avg_pool3d
|
||||||
|
avg_pool3d_backward
|
||||||
|
baddbmm
|
||||||
|
bernoulli_
|
||||||
|
bincount
|
||||||
|
bitwise_and
|
||||||
|
bitwise_and_
|
||||||
|
bitwise_left_shift
|
||||||
|
bitwise_not
|
||||||
|
bitwise_or
|
||||||
|
bitwise_or_
|
||||||
|
bitwise_right_shift
|
||||||
|
bitwise_xor
|
||||||
|
bmm
|
||||||
|
bucketize
|
||||||
|
cat
|
||||||
|
ceil
|
||||||
|
ceil_
|
||||||
|
celu
|
||||||
|
cholesky
|
||||||
|
cholesky_inverse
|
||||||
|
cholesky_solve
|
||||||
|
clamp
|
||||||
|
clamp_
|
||||||
|
clamp_min
|
||||||
|
clamp_min_
|
||||||
|
clone
|
||||||
|
complex
|
||||||
|
constant_pad_nd
|
||||||
|
copy_
|
||||||
|
copysign
|
||||||
|
cos
|
||||||
|
cosh
|
||||||
|
count_nonzero
|
||||||
|
cummax
|
||||||
|
cummin
|
||||||
|
cumprod
|
||||||
|
cumsum
|
||||||
|
deg2rad
|
||||||
|
detach
|
||||||
|
diag
|
||||||
|
diagonal
|
||||||
|
diagonal_scatter
|
||||||
|
digamma
|
||||||
|
dist
|
||||||
|
div
|
||||||
|
div_
|
||||||
|
dot
|
||||||
|
eig
|
||||||
|
embedding
|
||||||
|
empty_like
|
||||||
|
eq
|
||||||
|
erf
|
||||||
|
erf_
|
||||||
|
erfc
|
||||||
|
erfinv
|
||||||
|
exp
|
||||||
|
exp2
|
||||||
|
exp_
|
||||||
|
expand
|
||||||
|
expm1
|
||||||
|
fill_
|
||||||
|
flip
|
||||||
|
floor
|
||||||
|
floor_divide
|
||||||
|
fmax
|
||||||
|
fmin
|
||||||
|
fmod
|
||||||
|
frac
|
||||||
|
frexp
|
||||||
|
full_like
|
||||||
|
gather
|
||||||
|
ge
|
||||||
|
gelu
|
||||||
|
geqrf
|
||||||
|
grid_sampler_2d
|
||||||
|
grid_sampler_2d_backward
|
||||||
|
grid_sampler_3d
|
||||||
|
grid_sampler_3d_backward
|
||||||
|
gt
|
||||||
|
hardshrink
|
||||||
|
hardsigmoid
|
||||||
|
hardswish
|
||||||
|
hardswish_backward
|
||||||
|
hardtanh
|
||||||
|
histogram
|
||||||
|
huber_loss
|
||||||
|
hypot
|
||||||
|
i0
|
||||||
|
igamma
|
||||||
|
igammac
|
||||||
|
im2col
|
||||||
|
index
|
||||||
|
index_add_
|
||||||
|
index_copy_
|
||||||
|
index_fill_
|
||||||
|
index_put_
|
||||||
|
index_select
|
||||||
|
inverse
|
||||||
|
isin
|
||||||
|
isnan
|
||||||
|
isneginf
|
||||||
|
isposinf
|
||||||
|
kthvalue
|
||||||
|
le
|
||||||
|
leaky_relu
|
||||||
|
lerp
|
||||||
|
lerp_
|
||||||
|
lgamma
|
||||||
|
linalg_cholesky_ex
|
||||||
|
linalg_cross
|
||||||
|
linalg_eig
|
||||||
|
linalg_eigh
|
||||||
|
linalg_eigvalsh
|
||||||
|
linalg_householder_product
|
||||||
|
linalg_inv_ex
|
||||||
|
linalg_lstsq
|
||||||
|
linalg_matrix_exp
|
||||||
|
linalg_pinv
|
||||||
|
linalg_qr
|
||||||
|
linalg_slogdet
|
||||||
|
linalg_solve
|
||||||
|
linalg_vector_norm
|
||||||
|
log
|
||||||
|
log10
|
||||||
|
log1p
|
||||||
|
log2
|
||||||
|
log_sigmoid_backward
|
||||||
|
log_sigmoid_forward
|
||||||
|
logaddexp
|
||||||
|
logaddexp2
|
||||||
|
logcumsumexp
|
||||||
|
logdet
|
||||||
|
logical_and
|
||||||
|
logical_not
|
||||||
|
logical_or
|
||||||
|
logical_xor
|
||||||
|
logit
|
||||||
|
logsumexp
|
||||||
|
lt
|
||||||
|
lu_solve
|
||||||
|
lu_unpack
|
||||||
|
masked_fill_
|
||||||
|
masked_scatter_
|
||||||
|
masked_select
|
||||||
|
max
|
||||||
|
max_pool2d_with_indices
|
||||||
|
max_pool2d_with_indices_backward
|
||||||
|
maximum
|
||||||
|
mean
|
||||||
|
median
|
||||||
|
min
|
||||||
|
minimum
|
||||||
|
mish
|
||||||
|
mkldnn_convolution
|
||||||
|
mkldnn_convolution_backward
|
||||||
|
mm
|
||||||
|
mode
|
||||||
|
mse_loss
|
||||||
|
mul
|
||||||
|
mul_
|
||||||
|
mv
|
||||||
|
nan_to_num
|
||||||
|
nanmedian
|
||||||
|
nansum
|
||||||
|
native_batch_norm
|
||||||
|
native_batch_norm_backward
|
||||||
|
native_group_norm
|
||||||
|
native_layer_norm
|
||||||
|
native_layer_norm_backward
|
||||||
|
ne
|
||||||
|
neg
|
||||||
|
new_empty
|
||||||
|
nextafter
|
||||||
|
nll_loss2d_backward
|
||||||
|
nll_loss2d_forward
|
||||||
|
nll_loss_backward
|
||||||
|
nll_loss_forward
|
||||||
|
nonzero
|
||||||
|
norm
|
||||||
|
ones_like
|
||||||
|
ormqr
|
||||||
|
permute
|
||||||
|
polar
|
||||||
|
polygamma
|
||||||
|
pow
|
||||||
|
pow_
|
||||||
|
prod
|
||||||
|
put_
|
||||||
|
rad2deg
|
||||||
|
randint_like
|
||||||
|
randn_like
|
||||||
|
reciprocal
|
||||||
|
reciprocal_
|
||||||
|
reflection_pad1d
|
||||||
|
reflection_pad1d_backward
|
||||||
|
reflection_pad2d
|
||||||
|
reflection_pad2d_backward
|
||||||
|
reflection_pad3d
|
||||||
|
reflection_pad3d_backward
|
||||||
|
relu
|
||||||
|
remainder
|
||||||
|
renorm
|
||||||
|
repeat
|
||||||
|
repeat_interleave
|
||||||
|
replication_pad1d
|
||||||
|
replication_pad1d_backward
|
||||||
|
replication_pad2d
|
||||||
|
replication_pad2d_backward
|
||||||
|
replication_pad3d
|
||||||
|
replication_pad3d_backward
|
||||||
|
resize_as_
|
||||||
|
roll
|
||||||
|
rot90
|
||||||
|
round
|
||||||
|
rsqrt
|
||||||
|
rsub
|
||||||
|
scatter
|
||||||
|
scatter_
|
||||||
|
scatter_add
|
||||||
|
scatter_add_
|
||||||
|
searchsorted
|
||||||
|
select
|
||||||
|
select_backward
|
||||||
|
select_scatter
|
||||||
|
sgn
|
||||||
|
sigmoid
|
||||||
|
sign
|
||||||
|
signbit
|
||||||
|
sin
|
||||||
|
sinc
|
||||||
|
sinh
|
||||||
|
slice
|
||||||
|
slice_backward
|
||||||
|
slice_scatter
|
||||||
|
slow_conv_dilated2d
|
||||||
|
slow_conv_dilated2d_backward
|
||||||
|
slow_conv_transpose2d
|
||||||
|
slow_conv_transpose2d_backward
|
||||||
|
softplus
|
||||||
|
solve
|
||||||
|
sort
|
||||||
|
special_entr
|
||||||
|
special_erfcx
|
||||||
|
special_i0e
|
||||||
|
special_i1
|
||||||
|
special_i1e
|
||||||
|
special_ndtri
|
||||||
|
special_xlog1py
|
||||||
|
special_zeta
|
||||||
|
split
|
||||||
|
split_with_sizes
|
||||||
|
sqrt
|
||||||
|
sqrt_
|
||||||
|
squeeze
|
||||||
|
squeeze_
|
||||||
|
stack
|
||||||
|
std
|
||||||
|
std_mean
|
||||||
|
sub
|
||||||
|
sub_
|
||||||
|
sum
|
||||||
|
symeig
|
||||||
|
t
|
||||||
|
take
|
||||||
|
tan
|
||||||
|
tanh
|
||||||
|
threshold
|
||||||
|
threshold_backward
|
||||||
|
to_sparse
|
||||||
|
topk
|
||||||
|
trace
|
||||||
|
transpose
|
||||||
|
triangular_solve
|
||||||
|
tril
|
||||||
|
tril_
|
||||||
|
triu
|
||||||
|
triu_
|
||||||
|
trunc
|
||||||
|
unbind
|
||||||
|
unfold
|
||||||
|
unique_consecutive
|
||||||
|
unique_dim
|
||||||
|
unsqueeze
|
||||||
|
unsqueeze_
|
||||||
|
upsample_bicubic2d
|
||||||
|
upsample_bilinear2d
|
||||||
|
upsample_linear1d
|
||||||
|
upsample_nearest1d
|
||||||
|
upsample_nearest2d
|
||||||
|
upsample_nearest3d
|
||||||
|
upsample_trilinear3d
|
||||||
|
var
|
||||||
|
var_mean
|
||||||
|
vdot
|
||||||
|
view
|
||||||
|
view_as_real
|
||||||
|
where
|
||||||
|
xlogy
|
||||||
|
zero_
|
||||||
|
zeros_like
|
File diff suppressed because it is too large
Load Diff
0
functorch/op_analysis/vmap
Normal file
0
functorch/op_analysis/vmap
Normal file
@ -16,6 +16,7 @@ import unittest
|
|||||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
|
||||||
skipCUDAIfNoMagma
|
skipCUDAIfNoMagma
|
||||||
from torch.testing._internal.common_device_type import ops, onlyCPU
|
from torch.testing._internal.common_device_type import ops, onlyCPU
|
||||||
|
from torch.testing._internal.common_dtype import floating_types_and, integral_types
|
||||||
from functorch_lagging_op_db import functorch_lagging_op_db
|
from functorch_lagging_op_db import functorch_lagging_op_db
|
||||||
from functorch_additional_op_db import additional_op_db
|
from functorch_additional_op_db import additional_op_db
|
||||||
from common_utils import (
|
from common_utils import (
|
||||||
@ -32,6 +33,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
|
|||||||
from functorch import grad, vjp, vmap
|
from functorch import grad, vjp, vmap
|
||||||
import torch.autograd.forward_ad as fwAD
|
import torch.autograd.forward_ad as fwAD
|
||||||
from functorch._src.eager_transforms import _as_tuple, jvp
|
from functorch._src.eager_transforms import _as_tuple, jvp
|
||||||
|
from functorch.compile import decomposition_table
|
||||||
|
|
||||||
# Version of autograd.grad that handles outputs that don't depend on inputs
|
# Version of autograd.grad that handles outputs that don't depend on inputs
|
||||||
def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True):
|
def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True):
|
||||||
@ -753,8 +755,170 @@ class TestOperators(TestCase):
|
|||||||
|
|
||||||
self.assertEqual(result_vjps, expected_vjps)
|
self.assertEqual(result_vjps, expected_vjps)
|
||||||
|
|
||||||
|
class InplaceError(Exception):
|
||||||
|
def __repr__(self):
|
||||||
|
return "Decomposition Tensor with no elem was created (probably due to an in-place op)"
|
||||||
|
|
||||||
|
class DecompositionTensor(torch.Tensor):
|
||||||
|
run_decompositions = set()
|
||||||
|
run_ops = set()
|
||||||
|
|
||||||
|
elem: torch.Tensor
|
||||||
|
|
||||||
|
__slots__ = ['elem']
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __new__(cls, elem):
|
||||||
|
# The wrapping tensor (PythonTensor) is just a meta tensor, so it
|
||||||
|
# doesn't hold any memory (meta tensor is generally the preferred type
|
||||||
|
# of tensor you want to make a subclass from)...
|
||||||
|
r = torch.Tensor._make_wrapper_subclass(
|
||||||
|
cls, elem.size(),
|
||||||
|
strides=elem.stride(), storage_offset=elem.storage_offset(),
|
||||||
|
dtype=elem.dtype, layout=elem.layout,
|
||||||
|
device=elem.device, requires_grad=elem.requires_grad
|
||||||
|
)
|
||||||
|
|
||||||
|
# ...the real tensor is held as an element on the tensor.
|
||||||
|
r.elem = elem
|
||||||
|
return r
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"DecompositionTensor(elem={self.elem})"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||||
|
if func in decomposition_table and func != torch.ops.aten.detach:
|
||||||
|
decomposition = decomposition_table[func]
|
||||||
|
DecompositionTensor.run_decompositions.add(func)
|
||||||
|
return decomposition(*args, **kwargs)
|
||||||
|
DecompositionTensor.run_ops.add(func)
|
||||||
|
def unwrap_tensor(e):
|
||||||
|
if isinstance(e, DecompositionTensor):
|
||||||
|
if not hasattr(e, 'elem'):
|
||||||
|
raise InplaceError()
|
||||||
|
return e.elem
|
||||||
|
return e
|
||||||
|
|
||||||
|
real_out = func(*tree_map(unwrap_tensor, args), **tree_map(unwrap_tensor, kwargs))
|
||||||
|
|
||||||
|
def wrap_tensor(e):
|
||||||
|
if e is None:
|
||||||
|
return DecompositionTensor(torch.empty(()))
|
||||||
|
return DecompositionTensor(e) if type(e) == torch.Tensor else e
|
||||||
|
wrapped_out = tree_map(wrap_tensor, real_out)
|
||||||
|
return wrapped_out
|
||||||
|
|
||||||
|
|
||||||
|
class TestDecompositionOpInfo(TestCase):
|
||||||
|
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=[torch.float32, torch.float64, torch.float16, torch.bfloat16] + [*integral_types()] )
|
||||||
|
# entries in here need don't work and need to be fixed.
|
||||||
|
# Each one of these is a bug (or needs to be investigated)
|
||||||
|
@skipOps('TestDecompositionOpInfo', 'test_decomposition', {
|
||||||
|
skip('view_as_complex'),
|
||||||
|
xfail('linalg.cholesky'),
|
||||||
|
xfail('linalg.inv'),
|
||||||
|
xfail('linalg.matrix_power'),
|
||||||
|
xfail('to_sparse'),
|
||||||
|
skip('tensor_split'),
|
||||||
|
skip('nn.functional.ctc_loss'),
|
||||||
|
skip('mvlgamma'),
|
||||||
|
# Some weird matmul stuff with int64 matmuls
|
||||||
|
skip('__rmatmul__'),
|
||||||
|
skip('linalg.multi_dot'),
|
||||||
|
skip('matmul'),
|
||||||
|
# Can't be compared
|
||||||
|
skip('empty_like'),
|
||||||
|
skip('new_empty'),
|
||||||
|
# inplace op
|
||||||
|
skip('resize_'),
|
||||||
|
# ???
|
||||||
|
skip('nanmean'),
|
||||||
|
})
|
||||||
|
def test_decomposition(self, device, dtype, op):
|
||||||
|
if dtype not in op.supported_dtypes(dtype):
|
||||||
|
self.skipTest("Dtype not in op's supported dtypes")
|
||||||
|
return
|
||||||
|
if is_inplace(op, op.get_op()):
|
||||||
|
self.skipTest("op is inplace")
|
||||||
|
return
|
||||||
|
print(op.op)
|
||||||
|
_requires_grad = op.supports_autograd and dtype.is_floating_point
|
||||||
|
# print(dtype)
|
||||||
|
|
||||||
|
# print(_requires_grad)
|
||||||
|
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
|
||||||
|
# Acquires variants to test
|
||||||
|
def wrap_tensor(x):
|
||||||
|
if type(x) == torch.Tensor:
|
||||||
|
return DecompositionTensor(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def unwrap_tensor(x):
|
||||||
|
if type(x) == DecompositionTensor:
|
||||||
|
if not hasattr(x, 'elem'):
|
||||||
|
raise InplaceError()
|
||||||
|
return x.elem
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _assertEqual(a, b):
|
||||||
|
if dtype == torch.half:
|
||||||
|
self.assertEqual(a, b, rtol=0.001, atol=1e-04)
|
||||||
|
else:
|
||||||
|
self.assertEqual(a, b)
|
||||||
|
|
||||||
|
try:
|
||||||
|
func = op.get_op()
|
||||||
|
for sample_input in samples:
|
||||||
|
if _requires_grad:
|
||||||
|
fn, primals = normalize_op_input_output(func, sample_input)
|
||||||
|
result = fn(*primals)
|
||||||
|
out, vjp_fn = ref_vjp(fn, *primals)
|
||||||
|
cotangents = tree_map(lambda x: torch.randn_like(x), out)
|
||||||
|
expected_grads = vjp_fn(cotangents)
|
||||||
|
|
||||||
|
decomp_out, decomp_vjp_fn = ref_vjp(fn, *tree_map(wrap_tensor, primals))
|
||||||
|
_assertEqual(out, tree_map(unwrap_tensor, decomp_out))
|
||||||
|
|
||||||
|
decomp_grads = decomp_vjp_fn(cotangents)
|
||||||
|
_assertEqual(expected_grads, tree_map(unwrap_tensor, decomp_grads))
|
||||||
|
|
||||||
|
else:
|
||||||
|
args = [sample_input.input] + list(sample_input.args)
|
||||||
|
kwargs = sample_input.kwargs
|
||||||
|
orig_out = func(*args, **kwargs)
|
||||||
|
|
||||||
|
args = tree_map(wrap_tensor, args)
|
||||||
|
kwargs = tree_map(wrap_tensor, kwargs)
|
||||||
|
decomp_out = func(*args, **kwargs)
|
||||||
|
self.assertEqual(orig_out, tree_map(unwrap_tensor, decomp_out))
|
||||||
|
|
||||||
|
|
||||||
|
except InplaceError:
|
||||||
|
self.skipTest("op is inplace")
|
||||||
|
return
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "not implemented for" in str(e):
|
||||||
|
self.skipTest(str(e))
|
||||||
|
return
|
||||||
|
if "Mismatch in shape: grad_output" in str(e):
|
||||||
|
self.skipTest("Some weird issue with autograd engine and tensor subclasses")
|
||||||
|
raise e
|
||||||
|
import gc; gc.collect()
|
||||||
|
|
||||||
|
def test_placeholder(self):
|
||||||
|
with open('op_analysis/run_ops.txt', 'w') as f:
|
||||||
|
def get_names(l):
|
||||||
|
return sorted([x.__name__ for x in l])
|
||||||
|
for op in get_names(DecompositionTensor.run_ops):
|
||||||
|
f.write(f'{op}\n')
|
||||||
|
with open('op_analysis/run_decompositions.txt', 'w') as f:
|
||||||
|
for op in get_names(DecompositionTensor.run_decompositions):
|
||||||
|
f.write(f'{op}\n')
|
||||||
|
|
||||||
only_for = ("cpu", "cuda")
|
only_for = ("cpu", "cuda")
|
||||||
instantiate_device_type_tests(TestOperators, globals(), only_for=only_for)
|
instantiate_device_type_tests(TestOperators, globals(), only_for=only_for)
|
||||||
|
instantiate_device_type_tests(TestDecompositionOpInfo, globals(), only_for=only_for)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_tests()
|
run_tests()
|
||||||
|
@ -23,11 +23,11 @@ from functools import partial, wraps
|
|||||||
import functorch
|
import functorch
|
||||||
from functorch import (
|
from functorch import (
|
||||||
grad, vjp, vmap, jacrev, grad_and_value,
|
grad, vjp, vmap, jacrev, grad_and_value,
|
||||||
make_fx, pythonkey_decompose
|
make_fx
|
||||||
)
|
)
|
||||||
from functorch.compile import (
|
from functorch.compile import (
|
||||||
nnc_jit, compiled_function, compiled_module,
|
nnc_jit, compiled_function, compiled_module,
|
||||||
partition_with_recompute_fwd_in_bwd
|
partition_with_recompute_fwd_in_bwd, pythonkey_decompose, decomposition_table
|
||||||
)
|
)
|
||||||
|
|
||||||
from torch.testing._internal.common_device_type import ops, onlyCPU
|
from torch.testing._internal.common_device_type import ops, onlyCPU
|
||||||
@ -204,9 +204,7 @@ class TestPythonKey(TestCase):
|
|||||||
self.assertEqual(grads, grads2)
|
self.assertEqual(grads, grads2)
|
||||||
|
|
||||||
|
|
||||||
class TestPythonKeyOperatorsOpInfo(TestCase):
|
make_fx_failures = {
|
||||||
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
|
|
||||||
@skipOps('TestPythonKeyOperatorsOpInfo', 'test_make_fx_exhaustive', {
|
|
||||||
xfail('to_sparse'),
|
xfail('to_sparse'),
|
||||||
xfail('allclose'),
|
xfail('allclose'),
|
||||||
xfail('rsub', 'rsub_scalar'),
|
xfail('rsub', 'rsub_scalar'),
|
||||||
@ -216,13 +214,16 @@ class TestPythonKeyOperatorsOpInfo(TestCase):
|
|||||||
xfail('nn.functional.dropout'),
|
xfail('nn.functional.dropout'),
|
||||||
xfail('linalg.eigvals'),
|
xfail('linalg.eigvals'),
|
||||||
xfail('nn.functional.ctc_loss'),
|
xfail('nn.functional.ctc_loss'),
|
||||||
xfail('empty_like'), # randomness
|
|
||||||
xfail('randn_like'), # randomness
|
xfail('randn_like'), # randomness
|
||||||
xfail('rand_like'), # randomness
|
xfail('rand_like'), # randomness
|
||||||
xfail('randint_like'), # randomness
|
xfail('randint_like'), # randomness
|
||||||
skip('new_empty'), # nondeterministic
|
skip('new_empty'), # nondeterministic
|
||||||
skip('empty_like'), # nondeterministic
|
skip('empty_like'), # nondeterministic
|
||||||
})
|
}
|
||||||
|
class TestPythonKeyOperatorsOpInfo(TestCase):
|
||||||
|
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
|
||||||
|
@skipOps('TestPythonKeyOperatorsOpInfo', 'test_make_fx_exhaustive', make_fx_failures
|
||||||
|
)
|
||||||
def test_make_fx_exhaustive(self, device, dtype, op):
|
def test_make_fx_exhaustive(self, device, dtype, op):
|
||||||
|
|
||||||
def f(args, kwargs):
|
def f(args, kwargs):
|
||||||
@ -390,7 +391,6 @@ class TestEagerFusionOpInfo(TestCase):
|
|||||||
orig_grad = get_grads(args)
|
orig_grad = get_grads(args)
|
||||||
self.assertEqual(orig_grad, compiled_grad)
|
self.assertEqual(orig_grad, compiled_grad)
|
||||||
|
|
||||||
|
|
||||||
class TestPartitioning(TestCase):
|
class TestPartitioning(TestCase):
|
||||||
def test_recompute_partitioning(self):
|
def test_recompute_partitioning(self):
|
||||||
def fn(a, b):
|
def fn(a, b):
|
||||||
|
Reference in New Issue
Block a user