mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] fixed complex number decompositions and added isnan
This commit is contained in:
@ -473,10 +473,10 @@ def addcdiv(self: Tensor, tensor1: Tensor, tensor2: Tensor, value: float = 1):
|
||||
# Remove special case when https://github.com/pytorch/pytorch/pull/72949 is landed.
|
||||
@register_decomposition(aten.addcmul)
|
||||
def addcmul(self: Tensor, tensor1: Tensor, tensor2: Tensor, value: float = 1):
|
||||
if self.is_floating_point():
|
||||
return self + value * tensor1 * tensor2
|
||||
else:
|
||||
if not self.is_floating_point() and not self.is_complex():
|
||||
return self + int(value) * tensor1 * tensor2
|
||||
else:
|
||||
return self + value * tensor1 * tensor2
|
||||
|
||||
|
||||
@register_decomposition(aten.embedding)
|
||||
@ -548,7 +548,7 @@ def split(self: Tensor, split_size: int, dim: int = 0) -> List[Tensor]:
|
||||
|
||||
@register_decomposition(aten.addmm)
|
||||
def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = 1):
|
||||
if not self.is_floating_point():
|
||||
if not self.is_floating_point() and not self.is_complex():
|
||||
beta = int(beta)
|
||||
alpha = int(alpha)
|
||||
out = alpha * torch.mm(mat1, mat2)
|
||||
@ -704,6 +704,11 @@ def native_layer_norm(input: Tensor, normalized_shape: List[int], weight: Option
|
||||
return (out, mean, rstd)
|
||||
|
||||
|
||||
@register_decomposition(aten.isnan)
|
||||
def isnan(self: Tensor) -> Tensor:
|
||||
return torch.where(self != self, self.new_ones((), dtype=torch.bool), self.new_zeros((), dtype=torch.bool))
|
||||
|
||||
|
||||
@register_decomposition(aten.clamp_min)
|
||||
def clamp_min(self: Tensor, min: float):
|
||||
return torch.clamp(self, min=min)
|
||||
@ -751,31 +756,32 @@ def logical_not(self: Tensor) -> Tensor:
|
||||
# self * aten.log(other)))
|
||||
|
||||
|
||||
@register_decomposition(aten.var.correction)
|
||||
def var_decomposition(x: Tensor, dims: Optional[List[int]], correction: int = 0, keepdim: bool = False):
|
||||
if dims is None:
|
||||
dims = []
|
||||
if len(dims) == 0:
|
||||
n = x.numel()
|
||||
else:
|
||||
n = 1
|
||||
for dim in dims:
|
||||
n *= x.shape[dim]
|
||||
# These are both currently incorrect for complex numbers
|
||||
# @register_decomposition(aten.var.correction)
|
||||
# def var_decomposition(x: Tensor, dims: Optional[List[int]], correction: int = 0, keepdim: bool = False):
|
||||
# if dims is None:
|
||||
# dims = []
|
||||
# if len(dims) == 0:
|
||||
# n = x.numel()
|
||||
# else:
|
||||
# n = 1
|
||||
# for dim in dims:
|
||||
# n *= x.shape[dim]
|
||||
|
||||
mean = torch.mean(x, dims, True)
|
||||
sub = x - mean
|
||||
sq = sub * sub
|
||||
sum = torch.sum(sq, dims, keepdim)
|
||||
# mean = torch.mean(x, dims, True)
|
||||
# sub = x - mean
|
||||
# sq = sub * sub
|
||||
# sum = torch.sum(sq, dims, keepdim)
|
||||
|
||||
if correction:
|
||||
n = n - correction
|
||||
# if correction:
|
||||
# n = n - correction
|
||||
|
||||
return sum / n
|
||||
# return sum / n
|
||||
|
||||
|
||||
@register_decomposition(aten.std.correction)
|
||||
def std_decomposition(x: Tensor, dims: List[int], correction: int = 0, keepdim: bool = False):
|
||||
return torch.sqrt(torch.var(x, dims, correction=correction, keepdim=keepdim))
|
||||
# @register_decomposition(aten.std.correction)
|
||||
# def std_decomposition(x: Tensor, dims: List[int], correction: int = 0, keepdim: bool = False):
|
||||
# return torch.sqrt(torch.var(x, dims, correction=correction, keepdim=keepdim))
|
||||
|
||||
|
||||
# Questionable decompositions
|
||||
|
@ -3,6 +3,7 @@ import csv
|
||||
import torch
|
||||
import sys
|
||||
import os
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class CapturedOutput(object):
|
||||
@ -189,6 +190,10 @@ def remove_suffix(input_string, suffix):
|
||||
if True:
|
||||
with open('run_ops.txt', 'r') as f:
|
||||
opinfo_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()]
|
||||
with open('count_ops.txt', 'r') as f:
|
||||
opinfo_counts = [i.strip() for i in f.readlines()]
|
||||
opinfo_counts = defaultdict(int, {k: v for k, v in zip(opinfo_ops, opinfo_counts)})
|
||||
count_fn = lambda x: opinfo_counts[x['full_name']]
|
||||
with open('run_decompositions.txt', 'r') as f:
|
||||
decomposed_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()]
|
||||
gen_data([full_name_check(opinfo_ops), full_name_check(decomposed_ops)], 'decompositions.txt')
|
||||
gen_data([full_name_check(opinfo_ops), full_name_check(decomposed_ops), count_fn], 'decompositions.txt')
|
||||
|
@ -9,6 +9,7 @@ import torch
|
||||
from torch import Tensor
|
||||
import functools
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_device_type import ops
|
||||
@ -1272,15 +1273,13 @@ def ref_vjp_no_create(f, *primals):
|
||||
|
||||
|
||||
run_decompositions = set()
|
||||
run_ops = set()
|
||||
run_ops = defaultdict(int)
|
||||
|
||||
|
||||
class TestDecompositionOpInfo(TestCase):
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "__torch_dispatch__ is buggy")
|
||||
@ops(
|
||||
functorch_lagging_op_db + additional_op_db,
|
||||
allowed_dtypes=[torch.float32, torch.float64, torch.float16, torch.bfloat16] + [*integral_types()]
|
||||
)
|
||||
# entries in here need don't work and need to be fixed.
|
||||
# Each one of these is a bug (or needs to be investigated)
|
||||
@ -1406,7 +1405,7 @@ class TestDecompositionOpInfo(TestCase):
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
global run_ops
|
||||
run_ops.add(func)
|
||||
run_ops[func] += 1
|
||||
|
||||
def unwrap_tensor(e):
|
||||
if isinstance(e, DecompositionTensor):
|
||||
@ -1541,13 +1540,12 @@ class TestDecompositionOpInfo(TestCase):
|
||||
@unittest.skipIf(IS_FBCODE, "__torch_dispatch__ is buggy")
|
||||
def test_placeholder(self):
|
||||
global run_ops, run_decompositions
|
||||
with open('op_analysis/run_ops.txt', 'w') as f:
|
||||
def get_names(inpt):
|
||||
return sorted([x.__name__ for x in inpt])
|
||||
for op in get_names(run_ops):
|
||||
f.write(f'{op}\n')
|
||||
with open('op_analysis/run_ops.txt', 'w') as f, open('op_analysis/count_ops.txt', 'w') as g:
|
||||
for op, count in sorted(run_ops.items(), key=lambda x: x[0].__name__):
|
||||
f.write(f'{op.__name__}\n')
|
||||
g.write(f'{count}\n')
|
||||
with open('op_analysis/run_decompositions.txt', 'w') as f:
|
||||
for op in get_names(run_decompositions):
|
||||
for op in sorted([i.__name__ for i in run_decompositions]):
|
||||
f.write(f'{op}\n')
|
||||
|
||||
def test_decompositions_torchscriptable(self, device):
|
||||
|
Reference in New Issue
Block a user