mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[functorch] updated some decompositions and cleaned some stuff
This commit is contained in:
1
functorch/.gitignore
vendored
1
functorch/.gitignore
vendored
@ -12,6 +12,7 @@ docs/build
|
||||
docs/src
|
||||
docs/source/generated
|
||||
.DS_Store
|
||||
op_analysis/*.txt
|
||||
|
||||
# Editor temporaries
|
||||
*.swn
|
||||
|
@ -239,16 +239,25 @@ def log_sigmoid_backward(grad_output: Tensor, self: Tensor, buffer: Tensor) -> T
|
||||
# return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output
|
||||
|
||||
|
||||
def apply_loss_reduction(loss: Tensor, reduction: int):
|
||||
if reduction == Reduction.MEAN.value:
|
||||
return torch.mean(loss)
|
||||
elif reduction == Reduction.SUM.value:
|
||||
return torch.sum(loss)
|
||||
else:
|
||||
return loss
|
||||
|
||||
|
||||
@register_decomposition(aten.l1_loss)
|
||||
def l1_loss(self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value) -> Tensor:
|
||||
if reduction != Reduction.NONE.value:
|
||||
loss = (self - target).abs()
|
||||
if reduction == Reduction.MEAN.value:
|
||||
return torch.mean(loss)
|
||||
else:
|
||||
return torch.sum(loss)
|
||||
else:
|
||||
return (self - target).abs()
|
||||
loss = (self - target).abs()
|
||||
return apply_loss_reduction(loss, reduction)
|
||||
|
||||
|
||||
@register_decomposition(aten.mse_loss)
|
||||
def mse_loss(self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value) -> Tensor:
|
||||
loss = (self - target) ** 2
|
||||
return apply_loss_reduction(loss, reduction)
|
||||
|
||||
|
||||
@register_decomposition(aten.mse_loss_backward)
|
||||
@ -257,6 +266,14 @@ def mse_loss_backward(grad_output: Tensor, input: Tensor, target: Tensor, reduct
|
||||
return norm * (input - target) * grad_output
|
||||
|
||||
|
||||
@register_decomposition(aten.huber_loss)
|
||||
def huber_loss(self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value, delta: float = 1.0) -> Tensor:
|
||||
assert delta > 0, "huber_loss does not support non-positive values for delta."
|
||||
z = (self - target).abs()
|
||||
loss = torch.where(z < delta, 0.5 * z * z, delta * (z - 0.5 * delta))
|
||||
return apply_loss_reduction(loss, reduction)
|
||||
|
||||
|
||||
@register_decomposition(aten.huber_loss_backward)
|
||||
def huber_loss_backward(grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float):
|
||||
norm = 1. / self.numel() if reduction == Reduction.MEAN.value else 1.
|
||||
|
@ -84,7 +84,7 @@ def gen_data(special_op_lists, analysis_name):
|
||||
|
||||
ops = yaml.load(open('../../pytorch/aten/src/ATen/native/native_functions.yaml', 'r').read(), Loader=yaml.CLoader)
|
||||
|
||||
annotated_ops = {a.strip(): b.strip() for a, b in list(csv.reader(open('annotated_ops.txt')))}
|
||||
annotated_ops = {a.strip(): b.strip() for a, b in list(csv.reader(open('annotated_ops')))}
|
||||
from collections import defaultdict
|
||||
|
||||
uniq_ops = []
|
||||
@ -160,6 +160,7 @@ def gen_data(special_op_lists, analysis_name):
|
||||
|
||||
annotate_ops(ops, is_unique=False)
|
||||
with open(f"{analysis_name}", 'w') as f:
|
||||
# import pdb; pdb.set_trace()
|
||||
for op in ops:
|
||||
info = [
|
||||
op['full_name'], op['meta'], not (op['full_name'] in noncomposite_ops)
|
||||
@ -176,12 +177,18 @@ def full_name_check(lst):
|
||||
|
||||
|
||||
# Generates batching rule data
|
||||
gen_data([full_name_check(get_ops_for_key('FuncTorchBatched'))], 'vmap')
|
||||
gen_data([full_name_check(get_ops_for_key('FuncTorchBatched'))], 'vmap.txt')
|
||||
|
||||
|
||||
def remove_suffix(input_string, suffix):
|
||||
if suffix and input_string.endswith(suffix):
|
||||
return input_string[:-len(suffix)]
|
||||
return input_string
|
||||
|
||||
|
||||
if True:
|
||||
with open('run_ops.txt', 'r') as f:
|
||||
opinfo_ops = [i.strip() for i in f.readlines()]
|
||||
opinfo_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()]
|
||||
with open('run_decompositions.txt', 'r') as f:
|
||||
decomposed_ops = [i.strip() for i in f.readlines()]
|
||||
gen_data([name_check(opinfo_ops), name_check(decomposed_ops)], 'decompositions')
|
||||
decomposed_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()]
|
||||
gen_data([full_name_check(opinfo_ops), full_name_check(decomposed_ops)], 'decompositions.txt')
|
||||
|
@ -1292,10 +1292,6 @@ class TestDecompositionOpInfo(TestCase):
|
||||
xfail('linalg.tensorinv'),
|
||||
xfail('to_sparse'),
|
||||
skip('tensor_split'),
|
||||
skip('mvlgamma', 'mvlgamma_p_1'),
|
||||
skip('mvlgamma', 'mvlgamma_p_3'),
|
||||
skip('mvlgamma', 'mvlgamma_p_5'),
|
||||
skip('eig'),
|
||||
skip('nn.functional.dropout'),
|
||||
skip('_masked.softmin'),
|
||||
skip('_masked.log_softmax'),
|
||||
@ -1303,16 +1299,10 @@ class TestDecompositionOpInfo(TestCase):
|
||||
skip('_masked.softmax'),
|
||||
skip('_masked.normalize'),
|
||||
xfail('linalg.lu_factor', ''),
|
||||
# Some weird matmul stuff with int64 matmuls
|
||||
# inplace op
|
||||
skip('resize_'),
|
||||
# Weird conj errors
|
||||
xfail('fft.hfft2', dtypes=(torch.float32, torch.float64)),
|
||||
xfail('fft.hfft', dtypes=(torch.float32, torch.float64)),
|
||||
xfail('fft.hfftn', dtypes=(torch.float32, torch.float64)),
|
||||
skip('nn.functional.binary_cross_entropy', ''),
|
||||
skip('nn.functional.binary_cross_entropy_with_logits', '',),
|
||||
skip('nn.functional.huber_loss'),
|
||||
})
|
||||
def test_decomposition(self, device, dtype, op):
|
||||
# dtype is too confusing of a name for how we're using it
|
||||
|
Reference in New Issue
Block a user