[functorch] fixed complex number decompositions and added isnan

This commit is contained in:
Horace He
2022-04-29 03:23:23 +00:00
committed by Jon Janzen
parent aba4a67b04
commit 855f939ce3
3 changed files with 44 additions and 35 deletions

View File

@ -473,10 +473,10 @@ def addcdiv(self: Tensor, tensor1: Tensor, tensor2: Tensor, value: float = 1):
# Remove special case when https://github.com/pytorch/pytorch/pull/72949 is landed.
@register_decomposition(aten.addcmul)
def addcmul(self: Tensor, tensor1: Tensor, tensor2: Tensor, value: float = 1):
if self.is_floating_point():
return self + value * tensor1 * tensor2
else:
if not self.is_floating_point() and not self.is_complex():
return self + int(value) * tensor1 * tensor2
else:
return self + value * tensor1 * tensor2
@register_decomposition(aten.embedding)
@ -548,7 +548,7 @@ def split(self: Tensor, split_size: int, dim: int = 0) -> List[Tensor]:
@register_decomposition(aten.addmm)
def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = 1):
if not self.is_floating_point():
if not self.is_floating_point() and not self.is_complex():
beta = int(beta)
alpha = int(alpha)
out = alpha * torch.mm(mat1, mat2)
@ -704,6 +704,11 @@ def native_layer_norm(input: Tensor, normalized_shape: List[int], weight: Option
return (out, mean, rstd)
@register_decomposition(aten.isnan)
def isnan(self: Tensor) -> Tensor:
return torch.where(self != self, self.new_ones((), dtype=torch.bool), self.new_zeros((), dtype=torch.bool))
@register_decomposition(aten.clamp_min)
def clamp_min(self: Tensor, min: float):
return torch.clamp(self, min=min)
@ -751,31 +756,32 @@ def logical_not(self: Tensor) -> Tensor:
# self * aten.log(other)))
@register_decomposition(aten.var.correction)
def var_decomposition(x: Tensor, dims: Optional[List[int]], correction: int = 0, keepdim: bool = False):
if dims is None:
dims = []
if len(dims) == 0:
n = x.numel()
else:
n = 1
for dim in dims:
n *= x.shape[dim]
# These are both currently incorrect for complex numbers
# @register_decomposition(aten.var.correction)
# def var_decomposition(x: Tensor, dims: Optional[List[int]], correction: int = 0, keepdim: bool = False):
# if dims is None:
# dims = []
# if len(dims) == 0:
# n = x.numel()
# else:
# n = 1
# for dim in dims:
# n *= x.shape[dim]
mean = torch.mean(x, dims, True)
sub = x - mean
sq = sub * sub
sum = torch.sum(sq, dims, keepdim)
# mean = torch.mean(x, dims, True)
# sub = x - mean
# sq = sub * sub
# sum = torch.sum(sq, dims, keepdim)
if correction:
n = n - correction
# if correction:
# n = n - correction
return sum / n
# return sum / n
@register_decomposition(aten.std.correction)
def std_decomposition(x: Tensor, dims: List[int], correction: int = 0, keepdim: bool = False):
return torch.sqrt(torch.var(x, dims, correction=correction, keepdim=keepdim))
# @register_decomposition(aten.std.correction)
# def std_decomposition(x: Tensor, dims: List[int], correction: int = 0, keepdim: bool = False):
# return torch.sqrt(torch.var(x, dims, correction=correction, keepdim=keepdim))
# Questionable decompositions

View File

@ -3,6 +3,7 @@ import csv
import torch
import sys
import os
from collections import defaultdict
class CapturedOutput(object):
@ -189,6 +190,10 @@ def remove_suffix(input_string, suffix):
if True:
with open('run_ops.txt', 'r') as f:
opinfo_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()]
with open('count_ops.txt', 'r') as f:
opinfo_counts = [i.strip() for i in f.readlines()]
opinfo_counts = defaultdict(int, {k: v for k, v in zip(opinfo_ops, opinfo_counts)})
count_fn = lambda x: opinfo_counts[x['full_name']]
with open('run_decompositions.txt', 'r') as f:
decomposed_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()]
gen_data([full_name_check(opinfo_ops), full_name_check(decomposed_ops)], 'decompositions.txt')
gen_data([full_name_check(opinfo_ops), full_name_check(decomposed_ops), count_fn], 'decompositions.txt')

View File

@ -9,6 +9,7 @@ import torch
from torch import Tensor
import functools
import unittest
from collections import defaultdict
from contextlib import contextmanager
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_device_type import ops
@ -1272,15 +1273,13 @@ def ref_vjp_no_create(f, *primals):
run_decompositions = set()
run_ops = set()
run_ops = defaultdict(int)
class TestDecompositionOpInfo(TestCase):
@unittest.skipIf(IS_FBCODE, "__torch_dispatch__ is buggy")
@ops(
functorch_lagging_op_db + additional_op_db,
allowed_dtypes=[torch.float32, torch.float64, torch.float16, torch.bfloat16] + [*integral_types()]
)
# entries in here need don't work and need to be fixed.
# Each one of these is a bug (or needs to be investigated)
@ -1406,7 +1405,7 @@ class TestDecompositionOpInfo(TestCase):
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
global run_ops
run_ops.add(func)
run_ops[func] += 1
def unwrap_tensor(e):
if isinstance(e, DecompositionTensor):
@ -1541,13 +1540,12 @@ class TestDecompositionOpInfo(TestCase):
@unittest.skipIf(IS_FBCODE, "__torch_dispatch__ is buggy")
def test_placeholder(self):
global run_ops, run_decompositions
with open('op_analysis/run_ops.txt', 'w') as f:
def get_names(inpt):
return sorted([x.__name__ for x in inpt])
for op in get_names(run_ops):
f.write(f'{op}\n')
with open('op_analysis/run_ops.txt', 'w') as f, open('op_analysis/count_ops.txt', 'w') as g:
for op, count in sorted(run_ops.items(), key=lambda x: x[0].__name__):
f.write(f'{op.__name__}\n')
g.write(f'{count}\n')
with open('op_analysis/run_decompositions.txt', 'w') as f:
for op in get_names(run_decompositions):
for op in sorted([i.__name__ for i in run_decompositions]):
f.write(f'{op}\n')
def test_decompositions_torchscriptable(self, device):