[functorch] updated some decompositions and cleaned some stuff

This commit is contained in:
Horace He
2022-04-27 09:53:26 +00:00
committed by Jon Janzen
parent 151b1f4ae1
commit 9eb10d033c
5 changed files with 38 additions and 23 deletions

View File

@ -12,6 +12,7 @@ docs/build
docs/src
docs/source/generated
.DS_Store
op_analysis/*.txt
# Editor temporaries
*.swn

View File

@ -239,16 +239,25 @@ def log_sigmoid_backward(grad_output: Tensor, self: Tensor, buffer: Tensor) -> T
# return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output
def apply_loss_reduction(loss: Tensor, reduction: int):
if reduction == Reduction.MEAN.value:
return torch.mean(loss)
elif reduction == Reduction.SUM.value:
return torch.sum(loss)
else:
return loss
@register_decomposition(aten.l1_loss)
def l1_loss(self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value) -> Tensor:
if reduction != Reduction.NONE.value:
loss = (self - target).abs()
if reduction == Reduction.MEAN.value:
return torch.mean(loss)
else:
return torch.sum(loss)
else:
return (self - target).abs()
loss = (self - target).abs()
return apply_loss_reduction(loss, reduction)
@register_decomposition(aten.mse_loss)
def mse_loss(self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value) -> Tensor:
loss = (self - target) ** 2
return apply_loss_reduction(loss, reduction)
@register_decomposition(aten.mse_loss_backward)
@ -257,6 +266,14 @@ def mse_loss_backward(grad_output: Tensor, input: Tensor, target: Tensor, reduct
return norm * (input - target) * grad_output
@register_decomposition(aten.huber_loss)
def huber_loss(self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value, delta: float = 1.0) -> Tensor:
assert delta > 0, "huber_loss does not support non-positive values for delta."
z = (self - target).abs()
loss = torch.where(z < delta, 0.5 * z * z, delta * (z - 0.5 * delta))
return apply_loss_reduction(loss, reduction)
@register_decomposition(aten.huber_loss_backward)
def huber_loss_backward(grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float):
norm = 1. / self.numel() if reduction == Reduction.MEAN.value else 1.

View File

@ -84,7 +84,7 @@ def gen_data(special_op_lists, analysis_name):
ops = yaml.load(open('../../pytorch/aten/src/ATen/native/native_functions.yaml', 'r').read(), Loader=yaml.CLoader)
annotated_ops = {a.strip(): b.strip() for a, b in list(csv.reader(open('annotated_ops.txt')))}
annotated_ops = {a.strip(): b.strip() for a, b in list(csv.reader(open('annotated_ops')))}
from collections import defaultdict
uniq_ops = []
@ -160,6 +160,7 @@ def gen_data(special_op_lists, analysis_name):
annotate_ops(ops, is_unique=False)
with open(f"{analysis_name}", 'w') as f:
# import pdb; pdb.set_trace()
for op in ops:
info = [
op['full_name'], op['meta'], not (op['full_name'] in noncomposite_ops)
@ -176,12 +177,18 @@ def full_name_check(lst):
# Generates batching rule data
gen_data([full_name_check(get_ops_for_key('FuncTorchBatched'))], 'vmap')
gen_data([full_name_check(get_ops_for_key('FuncTorchBatched'))], 'vmap.txt')
def remove_suffix(input_string, suffix):
if suffix and input_string.endswith(suffix):
return input_string[:-len(suffix)]
return input_string
if True:
with open('run_ops.txt', 'r') as f:
opinfo_ops = [i.strip() for i in f.readlines()]
opinfo_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()]
with open('run_decompositions.txt', 'r') as f:
decomposed_ops = [i.strip() for i in f.readlines()]
gen_data([name_check(opinfo_ops), name_check(decomposed_ops)], 'decompositions')
decomposed_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()]
gen_data([full_name_check(opinfo_ops), full_name_check(decomposed_ops)], 'decompositions.txt')

View File

@ -1292,10 +1292,6 @@ class TestDecompositionOpInfo(TestCase):
xfail('linalg.tensorinv'),
xfail('to_sparse'),
skip('tensor_split'),
skip('mvlgamma', 'mvlgamma_p_1'),
skip('mvlgamma', 'mvlgamma_p_3'),
skip('mvlgamma', 'mvlgamma_p_5'),
skip('eig'),
skip('nn.functional.dropout'),
skip('_masked.softmin'),
skip('_masked.log_softmax'),
@ -1303,16 +1299,10 @@ class TestDecompositionOpInfo(TestCase):
skip('_masked.softmax'),
skip('_masked.normalize'),
xfail('linalg.lu_factor', ''),
# Some weird matmul stuff with int64 matmuls
# inplace op
skip('resize_'),
# Weird conj errors
xfail('fft.hfft2', dtypes=(torch.float32, torch.float64)),
xfail('fft.hfft', dtypes=(torch.float32, torch.float64)),
xfail('fft.hfftn', dtypes=(torch.float32, torch.float64)),
skip('nn.functional.binary_cross_entropy', ''),
skip('nn.functional.binary_cross_entropy_with_logits', '',),
skip('nn.functional.huber_loss'),
})
def test_decomposition(self, device, dtype, op):
# dtype is too confusing of a name for how we're using it