mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Fixes #147801 ## Changes - Change hardswish gradient compute condition as [torch.nn.functional.hardswish](https://pytorch.org/docs/stable/generated/torch.nn.functional.hardswish.html) - Enable cuda for test `test_hardswish_grad_corner` - Add test case for value=-3 ## Test Result ```bash pytest test/test_nn.py -k test_hardswish pytest test/test_unary_ufuncs.py -k test_hardswish pytest test/inductor/test_torchinductor.py -k test_hardswish ```    Pull Request resolved: https://github.com/pytorch/pytorch/pull/148049 Approved by: https://github.com/soulitzer
1646 lines
62 KiB
C++
1646 lines
62 KiB
C++
#include <torch/csrc/jit/runtime/symbolic_script.h>
|
|
|
|
#include <torch/csrc/jit/frontend/ir_emitter.h>
|
|
#include <torch/csrc/jit/runtime/operator.h>
|
|
|
|
namespace torch::jit {
|
|
namespace {
|
|
std::mutex lock;
|
|
const std::vector<std::string> functions = {
|
|
R"(
|
|
#### HELPER FUNCTIONS ###
|
|
#### PREFIX: AD_ ###
|
|
#### SCHEMA NOT SAVED IN CACHE ###
|
|
|
|
def AD_unsqueeze_multiple(t,
|
|
dims: List[int],
|
|
n_dims: int):
|
|
seen = [False] * n_dims
|
|
for i in range(len(dims)):
|
|
seen[dims[i]] = True
|
|
|
|
for d in range(n_dims):
|
|
if seen[d]:
|
|
t = t.unsqueeze(d)
|
|
return t
|
|
|
|
def AD_sum_backward(grad,
|
|
sizes: List[int],
|
|
dims: Optional[List[int]],
|
|
keepdim: bool):
|
|
if not keepdim and len(sizes) > 0:
|
|
if dims is None:
|
|
return grad.expand(sizes)
|
|
elif len(dims) == 1:
|
|
return grad.unsqueeze(dims[0]).expand(sizes)
|
|
else:
|
|
res = AD_unsqueeze_multiple(grad, dims, len(sizes))
|
|
return res.expand(sizes)
|
|
else:
|
|
return grad.expand(sizes)
|
|
|
|
def AD_logsumexp_backward(grad, self, result,
|
|
dim: List[int],
|
|
keepdim: bool):
|
|
if not keepdim and self.dim() != 0:
|
|
n_dims = len(self.size())
|
|
grad = AD_unsqueeze_multiple(grad, dim, n_dims)
|
|
result = AD_unsqueeze_multiple(result, dim, n_dims)
|
|
return grad * (self - result).exp()
|
|
|
|
def mean_0(self, *, dtype: Optional[int]):
|
|
self_size = self.size()
|
|
self_numel = self.numel()
|
|
self_scalar_type = self.dtype
|
|
def backward(grad_output):
|
|
return grad_output.expand(self_size).to(self_scalar_type) / self_numel, None
|
|
|
|
return torch.mean(self, dtype=dtype), backward
|
|
|
|
def mean_1(self,
|
|
dim: Optional[List[int]],
|
|
keepdim: bool,
|
|
*,
|
|
dtype: Optional[int]):
|
|
self_size = self.size()
|
|
self_scalar_type = self.dtype
|
|
def backward(grad_output):
|
|
grad_self = AD_sum_backward(grad_output, self_size, dim, keepdim).to(self_scalar_type) / AD_safe_size(self_size, dim)
|
|
return grad_self, None, None, None
|
|
|
|
return torch.mean(self, dim, keepdim, dtype=dtype), backward
|
|
|
|
def logsumexp(self,
|
|
dim: List[int],
|
|
keepdim: bool):
|
|
result = torch.logsumexp(self, dim, keepdim)
|
|
self_dim = self.dim()
|
|
def backward(grad_output):
|
|
grad_self = AD_logsumexp_backward(grad_output, self, result, dim, keepdim)
|
|
return grad_self, None, None
|
|
|
|
return result, backward
|
|
|
|
def AD_bool_to_int(b: bool):
|
|
# FIXME: torchscript: int - bool
|
|
if b:
|
|
i = 1
|
|
else:
|
|
i = 0
|
|
return i
|
|
|
|
def AD_var_backward_0(grad, self, correction: number):
|
|
# FIXME: torchscript: div(float, float)
|
|
return grad * (self - self.mean()) * 2.0 / (self.numel() - correction)
|
|
|
|
def AD_safe_size(sizes: List[int],
|
|
dims: Optional[List[int]]):
|
|
if len(sizes) == 0:
|
|
return 1
|
|
|
|
size = 1
|
|
|
|
if dims is None:
|
|
for s in sizes:
|
|
size *= s
|
|
|
|
else:
|
|
for i in range(len(dims)):
|
|
d = dims[i]
|
|
size *= sizes[d]
|
|
|
|
return size
|
|
|
|
def AD_var_backward_1(grad,
|
|
self,
|
|
dim: List[int],
|
|
correction: number,
|
|
keepdim: bool):
|
|
if self.dim() == 0:
|
|
return AD_var_backward_0(grad, self, correction)
|
|
self_size = self.size()
|
|
if not keepdim and self.dim() > 1:
|
|
grad = AD_unsqueeze_multiple(grad, dim, len(self_size))
|
|
|
|
# FIXME: torchscript: div(float, float)
|
|
return grad * (self - self.mean(dim, True)) * 2.0 / (AD_safe_size(self_size, dim) - correction)
|
|
|
|
def AD_var_backward_2(grad,
|
|
self,
|
|
dim: Optional[List[int]],
|
|
correction: Optional[number],
|
|
keepdim: bool):
|
|
if correction is None:
|
|
correction = 1
|
|
if self.dim() == 0 or dim is None:
|
|
return AD_var_backward_0(grad, self, correction)
|
|
|
|
return AD_var_backward_1(grad, self, dim, correction, keepdim)
|
|
|
|
def std_0(self,
|
|
unbiased: bool=True):
|
|
std_out = torch.std(self, unbiased)
|
|
def backward(grad_output):
|
|
correction = AD_bool_to_int(unbiased)
|
|
grad_self = AD_var_backward_0(grad_output / (std_out * 2), self, correction)
|
|
return grad_self, None
|
|
|
|
return std_out, backward
|
|
|
|
def std_1(self,
|
|
dim: Optional[List[int]],
|
|
unbiased: bool,
|
|
keepdim: bool):
|
|
std_out = torch.std(self, dim, unbiased, keepdim)
|
|
def backward(grad_output):
|
|
correction = AD_bool_to_int(unbiased)
|
|
grad_self = AD_var_backward_2(grad_output / (std_out * 2), self, dim, correction, keepdim)
|
|
return grad_self, None, None, None
|
|
|
|
return std_out, backward
|
|
|
|
def std_2(self,
|
|
dim: Optional[List[int]],
|
|
*,
|
|
correction: Optional[number],
|
|
keepdim: bool):
|
|
std_out = torch.std(self, dim, correction=correction, keepdim=keepdim)
|
|
def backward(grad_output):
|
|
grad_self = AD_var_backward_2(grad_output / (std_out * 2), self, dim, correction, keepdim)
|
|
return grad_self, None, None, None
|
|
|
|
return std_out, backward
|
|
|
|
def var_0(self,
|
|
unbiased: bool=True):
|
|
def backward(grad_output):
|
|
correction = AD_bool_to_int(unbiased)
|
|
grad_self = AD_var_backward_0(grad_output, self, correction)
|
|
return grad_self, None
|
|
|
|
return torch.var(self, unbiased), backward
|
|
|
|
def var_1(self,
|
|
dim: Optional[List[int]],
|
|
unbiased: bool,
|
|
keepdim: bool):
|
|
def backward(grad_output):
|
|
correction = AD_bool_to_int(unbiased)
|
|
grad_self = AD_var_backward_2(grad_output, self, dim, correction, keepdim)
|
|
return grad_self, None, None, None
|
|
|
|
return torch.var(self, dim, unbiased, keepdim), backward
|
|
|
|
def var_2(self,
|
|
dim: Optional[List[int]],
|
|
*,
|
|
correction: Optional[number],
|
|
keepdim: bool):
|
|
def backward(grad_output):
|
|
grad_self = AD_var_backward_2(grad_output, self, dim, correction, keepdim)
|
|
return grad_self, None, None, None
|
|
|
|
return torch.var(self, dim, correction=correction, keepdim=keepdim), backward
|
|
|
|
def tanh(self):
|
|
output = torch.tanh(self)
|
|
def backward(grad_output):
|
|
return grad_output * (1 - output * output)
|
|
|
|
return output, backward
|
|
|
|
def AD_index_select_backward(grad,
|
|
dim: int,
|
|
indices,
|
|
sizes: List[int],
|
|
keepdim: bool):
|
|
if not keepdim and len(sizes) > 0:
|
|
grad = grad.unsqueeze(dim)
|
|
indices = indices.unsqueeze(dim)
|
|
|
|
# FIXME: torchscript: torch.zeros(sizes, grad.options())
|
|
return torch.zeros(sizes).to(grad).scatter_(dim, indices, grad)
|
|
|
|
# def topk(self,
|
|
# k: int,
|
|
# dim: int = -1,
|
|
# largest: bool = True,
|
|
# sorted: bool = True):
|
|
# result0, result1 = torch.topk(self, k, dim, largest, sorted)
|
|
# self_size = self.size()
|
|
# def backward(grad_output):
|
|
# grad_self = AD_index_select_backward(grad_output, dim, result1, self_size, True)
|
|
# return grad_self, None, None, None, None
|
|
|
|
# return result0, result1, backward
|
|
|
|
# def kthvalue(self,
|
|
# k: int,
|
|
# dim: int,
|
|
# keepdim: bool):
|
|
# result0, result1 = torch.kthvalue(self, k, dim, keepdim)
|
|
# self_size = self.size()
|
|
# def backward(grad_output):
|
|
# grad_self = AD_index_select_backward(grad_output, dim, result1, self_size, keepdim)
|
|
# return grad_self, None, None, None
|
|
|
|
# return result0, result1, backward
|
|
|
|
def AD_mm_backward_self(grad, mat2):
|
|
return grad.mm(mat2.t())
|
|
|
|
def AD_mm_backward_mat2(grad, self):
|
|
return self.t().mm(grad)
|
|
|
|
def mm(self, mat2):
|
|
def backward(grad_output):
|
|
grad_self = AD_mm_backward_self(grad_output, mat2)
|
|
grad_mat2 = AD_mm_backward_mat2(grad_output, self)
|
|
return grad_self, grad_mat2
|
|
|
|
return torch.mm(self, mat2), backward
|
|
|
|
def AD_permute_backward(grad,
|
|
fwd_dims: List[int]):
|
|
ndims = len(fwd_dims)
|
|
dims = [0] * ndims
|
|
|
|
for i in range(ndims):
|
|
dims[fwd_dims[i]] = i
|
|
|
|
return grad.permute(dims)
|
|
|
|
def permute(self,
|
|
dims: List[int]):
|
|
def backward(grad_output):
|
|
grad_self = AD_permute_backward(grad_output, dims)
|
|
return grad_self, None
|
|
|
|
return torch.permute(self, dims), backward
|
|
|
|
def AD_select_backward(grad,
|
|
input_sizes: List[int],
|
|
dim: int,
|
|
index: int):
|
|
# FIXME: torchscript: torch.zeros(sizes, grad.options())
|
|
grad_input = torch.zeros(input_sizes).to(grad)
|
|
grad_input.select(dim, index).copy_(grad)
|
|
return grad_input
|
|
|
|
# TODO: fix torch.zeros(sizes, grad.options()) before enabling select, topk, kthvalue
|
|
# def select(self,
|
|
# dim: int,
|
|
# index: int):
|
|
# self_size = self.size()
|
|
# def backward(grad_output):
|
|
# grad_self = AD_select_backward(grad_output, self_size, dim, index)
|
|
# return grad_self, None, None
|
|
|
|
# return torch.select(self, dim, index), backward
|
|
|
|
def AD_slice_backward(grad,
|
|
input_sizes: List[int],
|
|
dim: int,
|
|
start: int,
|
|
end: int,
|
|
step: int):
|
|
# FIXME: torchscript: torch.zeros(sizes, grad.options())
|
|
grad_input = torch.zeros(input_sizes).to(grad)
|
|
grad_input.slice(dim, start, end, step).copy_(grad)
|
|
return grad_input
|
|
|
|
# DON'T enable slice unless we can correctly handle view ops in graph executor.
|
|
# It triggers failure of TestJit.test_sample in test_distributions.py.
|
|
# def slice(self,
|
|
# dim: int=0,
|
|
# start: int=0,
|
|
# end: int=9223372036854775807,
|
|
# step: int=1):
|
|
# def backward(grad_output):
|
|
# grad_self = AD_slice_backward(grad_output, self.size(), dim, start, end, step)
|
|
# return grad_self, None, None, None, None
|
|
|
|
# return torch.slice(self, dim, start, end, step), backward
|
|
|
|
def AD_unsqueeze_to_0(self,
|
|
sizes: List[int]):
|
|
ndims = len(sizes)
|
|
for i in range(ndims):
|
|
if sizes[i] == 1:
|
|
self = self.unsqueeze(i)
|
|
|
|
return self
|
|
|
|
def AD_unsqueeze_to_1(self,
|
|
dim: int,
|
|
sizes: List[int]):
|
|
if len(sizes) > 0 and sizes[dim] == 1:
|
|
return self.unsqueeze(dim)
|
|
return self
|
|
|
|
def squeeze_0(self):
|
|
self_size = self.size()
|
|
def backward(grad_output):
|
|
grad_self = AD_unsqueeze_to_0(grad_output, self_size)
|
|
return grad_self
|
|
|
|
return torch.squeeze(self), backward
|
|
|
|
def squeeze_1(self,
|
|
dim: int):
|
|
self_size = self.size()
|
|
def backward(grad_output):
|
|
grad_self = AD_unsqueeze_to_1(grad_output, dim, self_size)
|
|
return grad_self, None
|
|
|
|
return torch.squeeze(self, dim), backward
|
|
|
|
def AD_infer_size(a: List[int],
|
|
b: List[int]):
|
|
dimsA = len(a)
|
|
dimsB = len(b)
|
|
|
|
ndim = dimsA if dimsA > dimsB else dimsB
|
|
expand_sizes = [0] * ndim
|
|
|
|
for i in range(ndim):
|
|
idx = - i + ndim - 1
|
|
sizeA = a[i] if dimsA + i >= 0 else 1
|
|
sizeB = b[i] if dimsB + i >= 0 else 1
|
|
|
|
# Assert sizeA == sizeB or sizeA == 1 or sizeB == 1
|
|
expand_sizes[i] = sizeB if sizeA == 1 else sizeA
|
|
|
|
return expand_sizes
|
|
|
|
def AD_bmm_backward_self(grad, mat2):
|
|
return grad.bmm(mat2.transpose(1, 2))
|
|
|
|
def AD_bmm_backward_mat2(grad, self):
|
|
return self.transpose(1, 2).bmm(grad)
|
|
|
|
def bmm(self, mat2):
|
|
def backward(grad_output):
|
|
grad_self = AD_bmm_backward_self(grad_output, mat2)
|
|
grad_mat2 = AD_bmm_backward_mat2(grad_output, self)
|
|
return grad_self, grad_mat2
|
|
return torch.bmm(self, mat2), backward
|
|
)",
|
|
R"(
|
|
def AD_mat_transpose(mat):
|
|
dim = mat.dim()
|
|
if dim == 1:
|
|
out = mat
|
|
elif dim == 2:
|
|
out = mat.t()
|
|
else:
|
|
dims = rangelist(dim)
|
|
dims[-1] = dim - 2
|
|
dims[-2] = dim - 1
|
|
out = mat.permute(dims)
|
|
return out
|
|
|
|
# In matmul backward case of [b, m, n] * [b, n, p] => [m, p],
|
|
# instead of doing [b, m, p] and then reduce to [m, p]
|
|
# which potentially uses large intermediate of size b*m*p,
|
|
# we do [m, bn] * [bn, p] to avoid having the large
|
|
# intermediate, thus reduces max memory usage.
|
|
def AD_matmul_bw_special_fold(mat1, mat2):
|
|
mat1_transpose = AD_mat_transpose(mat1)
|
|
mat1_fold = mat1_transpose.reshape(-1, mat1_transpose.size()[-1])
|
|
mat2_fold = mat2.reshape(-1, mat2.size()[-1])
|
|
return mat1_fold.t().mm(mat2_fold)
|
|
|
|
def AD_matmul_bw_size(mat1, mat2,
|
|
out_size: List[int]):
|
|
dim1 = mat1.dim()
|
|
dim2 = mat2.dim()
|
|
dim_out = len(out_size)
|
|
if dim1 == 0 or dim2 == 0:
|
|
out = mat1 * mat2
|
|
elif dim_out == 2 and dim1 == dim2 and dim1 >=3:
|
|
out = AD_matmul_bw_special_fold(mat1, mat2)
|
|
elif dim_out == 1 and dim1 - dim2 == 1 and dim1 >= 3:
|
|
mat2_unsqueeze = mat2.unsqueeze(-1)
|
|
out = AD_matmul_bw_special_fold(mat1, mat2_unsqueeze)
|
|
out = out.squeeze(-1)
|
|
elif dim1 + dim2 == dim_out:
|
|
if dim2 == 1:
|
|
target_dim2 = 0
|
|
else:
|
|
target_dim2 = -2
|
|
out = torch.matmul(mat1.unsqueeze(dim1), mat2.unsqueeze(target_dim2))
|
|
elif dim_out == dim1 - dim2:
|
|
out = torch.matmul(mat1, mat2.unsqueeze(dim2)).squeeze(-1)
|
|
elif dim_out == dim2 - dim1:
|
|
out = torch.matmul(mat1.unsqueeze(-2), mat2).squeeze(-2)
|
|
else:
|
|
out = torch.matmul(mat1, mat2)
|
|
return out
|
|
|
|
def matmul(self, other):
|
|
def backward(grad_output):
|
|
self_size = self.size()
|
|
other_size = other.size()
|
|
grad_self = AD_matmul_bw_size(grad_output, AD_mat_transpose(other), self_size)._grad_sum_to_size(self_size)
|
|
grad_other = AD_matmul_bw_size(AD_mat_transpose(self), grad_output, other_size)._grad_sum_to_size(other_size)
|
|
return grad_self, grad_other
|
|
|
|
return torch.matmul(self, other), backward
|
|
|
|
def linear(input : Tensor,
|
|
weight : Tensor,
|
|
bias : Optional[Tensor]):
|
|
result = torch.linear(input, weight, bias)
|
|
|
|
def backward(grad_output):
|
|
if bias is not None:
|
|
grad_bias = grad_output._grad_sum_to_size(bias.size())
|
|
else:
|
|
grad_bias = None
|
|
|
|
weight_size = weight.size()
|
|
grad_input = torch.matmul(grad_output, weight)
|
|
grad_weight = torch.matmul(grad_output.reshape(-1, weight_size[0]).t(), input.reshape(-1, weight_size[1]))
|
|
# Note: calling unchecked_unwrap_optional is only safe, when we
|
|
# directly return grad_bias directly back to bias.
|
|
# Because in the case where `bias is None`, unwrapped
|
|
# grad_bias would just be pruned away.
|
|
return grad_input, grad_weight, grad_bias.unchecked_unwrap_optional
|
|
return result, backward
|
|
)",
|
|
R"(
|
|
def addcmul(self,
|
|
tensor1,
|
|
tensor2,
|
|
*,
|
|
value: number):
|
|
result = torch.addcmul(self, tensor1, tensor2, value=value)
|
|
self_size = torch._size_if_not_equal(self.size(), result.size())
|
|
tensor1_size = torch._size_if_not_equal(tensor1.size(), result.size())
|
|
tensor2_size = torch._size_if_not_equal(tensor2.size(), result.size())
|
|
def backward(grad_output):
|
|
grad = grad_output * value
|
|
grad_tensor1 = (grad * tensor2)._grad_sum_to_size(tensor1_size)
|
|
grad_tensor2 = (grad * tensor1)._grad_sum_to_size(tensor2_size)
|
|
return grad_output._grad_sum_to_size(self_size), grad_tensor1, grad_tensor2, None
|
|
return result, backward
|
|
|
|
def _autocast_to_full_precision(self, cuda_enabled : bool, cpu_enabled : bool):
|
|
self_dtype = self.dtype
|
|
def backward(grad_output):
|
|
return grad_output.to(self_dtype), None, None
|
|
|
|
return torch._autocast_to_full_precision(self, cuda_enabled, cpu_enabled), backward
|
|
|
|
def _autocast_to_reduced_precision(self,
|
|
cuda_enabled : bool,
|
|
cpu_enabled : bool,
|
|
cuda_dtype : int,
|
|
cpu_dtype : int):
|
|
self_dtype = self.dtype
|
|
def backward(grad_output):
|
|
return grad_output.to(self_dtype), None, None, None, None
|
|
|
|
return torch._autocast_to_reduced_precision(self, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype), backward
|
|
|
|
def _dim_arange(like,
|
|
dim: int):
|
|
def backward(grad_output):
|
|
return None, None
|
|
|
|
return torch._dim_arange(like, dim), backward
|
|
|
|
def contiguous(self, *, memory_format: int=0):
|
|
def backward(grad_output):
|
|
return grad_output, None
|
|
|
|
return self.contiguous(memory_format=memory_format), backward
|
|
|
|
def dot(self, tensor):
|
|
def backward(grad_output):
|
|
return grad_output * tensor, grad_output * self
|
|
|
|
return torch.dot(self, tensor), backward
|
|
|
|
def erf(self):
|
|
def backward(grad_output):
|
|
# Precomputed constant C = 2.0 / math.sqrt(math.pi)
|
|
C = 1.1283791670955126
|
|
return C * torch.exp(- self * self) * grad_output
|
|
|
|
return torch.erf(self), backward
|
|
|
|
def expand(self,
|
|
size: List[int],
|
|
*,
|
|
implicit: bool=False):
|
|
result = torch.expand(self, size, implicit=implicit)
|
|
self_size = torch._size_if_not_equal(self.size(), result.size())
|
|
|
|
def backward(grad_output):
|
|
return grad_output._grad_sum_to_size(self_size), None, None
|
|
|
|
return result, backward
|
|
|
|
def expand_as(self, other):
|
|
result = torch.expand_as(self, other)
|
|
self_size = torch._size_if_not_equal(self.size(), result.size())
|
|
|
|
def backward(grad_output):
|
|
return grad_output._grad_sum_to_size(self_size), None
|
|
|
|
return result, backward
|
|
|
|
def full_like(self,
|
|
fill_value: float):
|
|
def backward(grad_output):
|
|
return None, None
|
|
|
|
return torch.full_like(self, fill_value, memory_format=1), backward
|
|
|
|
def lerp_0(self,
|
|
end,
|
|
weight: number):
|
|
result = torch.lerp(self, end, weight)
|
|
self_size = torch._size_if_not_equal(self.size(), result.size())
|
|
end_size = torch._size_if_not_equal(end.size(), result.size())
|
|
|
|
def backward(grad_output):
|
|
grad_self = (grad_output * (1 - float(weight)))._grad_sum_to_size(self_size)
|
|
grad_end = (grad_output * float(weight))._grad_sum_to_size(end_size)
|
|
return grad_self, grad_end, None
|
|
return result, backward
|
|
|
|
def lerp_1(self,
|
|
end,
|
|
weight):
|
|
result = torch.lerp(self, end, weight)
|
|
self_size = torch._size_if_not_equal(self.size(), result.size())
|
|
end_size = torch._size_if_not_equal(end.size(), result.size())
|
|
weight_size = torch._size_if_not_equal(weight.size(), result.size())
|
|
|
|
def backward(grad_output):
|
|
grad_self = (grad_output * (1 - weight))._grad_sum_to_size(self_size)
|
|
grad_end = (grad_output * weight)._grad_sum_to_size(end_size)
|
|
grad_weight = (grad_output * (end - self))._grad_sum_to_size(weight_size)
|
|
return grad_self, grad_end, grad_weight
|
|
|
|
return result, backward
|
|
|
|
def reshape(self,
|
|
shape: List[int]):
|
|
self_size = self.size()
|
|
|
|
def backward(grad_output):
|
|
return grad_output.reshape(self_size), None
|
|
|
|
return torch.reshape(self, shape), backward
|
|
|
|
def split(self,
|
|
split_size: int,
|
|
dim: int):
|
|
def backward(grad_outputs: List[Tensor]):
|
|
grad_self = torch.cat(grad_outputs, dim)
|
|
return grad_self, None, None
|
|
|
|
return torch.split(self, split_size, dim), backward
|
|
|
|
def split_with_sizes(self,
|
|
split_sizes: List[int],
|
|
dim: int):
|
|
def backward(grad_outputs: List[Tensor]):
|
|
size = len(grad_outputs)
|
|
grad_self = torch.cat(grad_outputs, dim)
|
|
return grad_self, None, None
|
|
|
|
return torch.split_with_sizes(self, split_sizes, dim), backward
|
|
|
|
def stack(tensors: List[Tensor],
|
|
dim: int=0):
|
|
def backward(grad_output):
|
|
grad_tensors = torch.unbind(grad_output, dim)
|
|
return grad_tensors, None
|
|
|
|
return torch.stack(tensors, dim), backward
|
|
|
|
def unbind(self,
|
|
dim: int):
|
|
def backward(grad_outputs: List[Tensor]):
|
|
grad_self = torch.stack(grad_outputs, dim)
|
|
return grad_self, None
|
|
|
|
return torch.unbind(self, dim), backward
|
|
|
|
def cat(tensors: List[Tensor],
|
|
dim: int):
|
|
size = len(tensors)
|
|
split_sizes = [0] * size
|
|
for i in range(size):
|
|
if tensors[i].size() != [0]:
|
|
split_sizes[i] = tensors[i].size()[dim]
|
|
|
|
def backward(grad_output):
|
|
grad_tensors = torch.split_with_sizes(grad_output, split_sizes, dim)
|
|
return grad_tensors, None
|
|
|
|
return torch.cat(tensors, dim), backward
|
|
|
|
def index(self,
|
|
indices: List[Tensor]):
|
|
def backward(grad_output):
|
|
grad_self = torch.zeros_like(self, memory_format=1).index_put_(indices, grad_output, True)
|
|
return grad_self, None
|
|
|
|
return torch.index(self, indices), backward
|
|
|
|
def meshgrid(tensors: List[Tensor]):
|
|
size = len(tensors)
|
|
sizes = [0] * size
|
|
for i in range(size):
|
|
if tensors[i].dim() != 0:
|
|
sizes[i] = tensors[i].size()[0]
|
|
def backward(grad_outputs: List[Tensor]):
|
|
grads_tensors = []
|
|
for i in range(size):
|
|
view_shape = [1] * size
|
|
if sizes[i] == 0:
|
|
view_shape[i] = 1
|
|
grads_tensors.append((grad_outputs[i]._grad_sum_to_size(view_shape)).reshape(()))
|
|
else:
|
|
view_shape[i] = sizes[i]
|
|
grads_tensors.append((grad_outputs[i]._grad_sum_to_size(view_shape)).reshape([sizes[i]]))
|
|
return grads_tensors
|
|
return torch.meshgrid(tensors), backward
|
|
|
|
def mv(self, vec):
|
|
def backward(grad_output):
|
|
return grad_output.ger(vec), self.t().mv(grad_output)
|
|
|
|
return torch.mv(self, vec), backward
|
|
|
|
def nonzero(self):
|
|
def backward(grad_output):
|
|
return None
|
|
|
|
return torch.nonzero(self), backward
|
|
|
|
def ones_like(self):
|
|
def backward(grad_output):
|
|
return None
|
|
|
|
return torch.ones_like(self, memory_format=1), backward
|
|
|
|
def pow_0(self,
|
|
exponent: number):
|
|
def backward(grad_output):
|
|
if float(exponent) == 0.0:
|
|
grad_self = torch.zeros_like(self, memory_format=1)
|
|
else:
|
|
grad_self = grad_output * exponent * torch.pow(self, float(exponent) - 1)
|
|
return grad_self, None
|
|
|
|
return torch.pow(self, exponent), backward
|
|
|
|
def pow_1(self, exponent):
|
|
result = torch.pow(self, exponent)
|
|
self_size = torch._size_if_not_equal(self.size(), result.size())
|
|
exponent_size = torch._size_if_not_equal(exponent.size(), result.size())
|
|
|
|
def backward(grad_output):
|
|
grad_self = torch.where(exponent == 0.0, torch.zeros_like(self, memory_format=1), grad_output * exponent * torch.pow(self, exponent - 1))._grad_sum_to_size(self_size)
|
|
grad_exponent = (grad_output * torch.pow(self, exponent) * torch.log(self))._grad_sum_to_size(exponent_size)
|
|
return grad_self, grad_exponent
|
|
|
|
return result, backward
|
|
|
|
def pow_2(self: number,
|
|
exponent):
|
|
def backward(grad_output):
|
|
grad_exponent = grad_output * torch.pow(self, exponent) * torch.log(float(self))
|
|
return None, grad_exponent
|
|
|
|
return torch.pow(self, exponent), backward
|
|
|
|
def rsub_0(self,
|
|
other,
|
|
alpha: number):
|
|
result = torch.rsub(self, other, alpha=alpha)
|
|
self_size = torch._size_if_not_equal(self.size(), result.size())
|
|
other_size = torch._size_if_not_equal(other.size(), result.size())
|
|
def backward(grad_output):
|
|
grad_self = (- grad_output * alpha)._grad_sum_to_size(self_size)
|
|
grad_other = (grad_output)._grad_sum_to_size(other_size)
|
|
return grad_self, grad_other, None
|
|
|
|
return result, backward
|
|
|
|
def rsub_1(self,
|
|
other: number,
|
|
alpha: number):
|
|
def backward(grad_output):
|
|
grad_self = (- grad_output * alpha)
|
|
return grad_self, None, None
|
|
|
|
return torch.rsub(self, other, alpha), backward
|
|
|
|
def sqrt(self):
|
|
result = torch.sqrt(self)
|
|
def backward(grad_output):
|
|
return grad_output / (2 * result)
|
|
|
|
return result, backward
|
|
|
|
def t(self):
|
|
def backward(grad_output):
|
|
return torch.t(grad_output)
|
|
|
|
return torch.t(self), backward
|
|
|
|
def to_0(self,
|
|
device: Optional[Device],
|
|
dtype: Optional[int],
|
|
non_blocking: bool,
|
|
copy: bool):
|
|
self_device = self.device
|
|
self_dtype = self.dtype
|
|
if device is not None:
|
|
result = self.to(device, dtype=dtype, non_blocking=non_blocking, copy=copy)
|
|
else:
|
|
result = self.to(dtype, non_blocking=non_blocking, copy=copy)
|
|
def backward(grad_output):
|
|
grad_self = grad_output.to(self_device, dtype=self_dtype, non_blocking=non_blocking, copy=copy)
|
|
return grad_self, None, None, None, None
|
|
|
|
return result, backward
|
|
|
|
|
|
def to_1(self,
|
|
dtype: int,
|
|
non_blocking: bool,
|
|
copy: bool):
|
|
self_dtype = self.dtype
|
|
def backward(grad_output):
|
|
grad_self = grad_output.to(self_dtype, non_blocking, copy)
|
|
return grad_self, None, None, None
|
|
|
|
return self.to(dtype=dtype, non_blocking=non_blocking, copy=copy), backward
|
|
|
|
def to_2(self,
|
|
other,
|
|
non_blocking: bool,
|
|
copy: bool):
|
|
def backward(grad_output):
|
|
grad_self = grad_output.to(self, non_blocking, copy)
|
|
return grad_self, None, None, None
|
|
|
|
return self.to(other, non_blocking=non_blocking, copy=copy), backward
|
|
|
|
def transpose(self,
|
|
dim0: int,
|
|
dim1: int):
|
|
def backward(grad_output):
|
|
return torch.transpose(grad_output, dim0, dim1), None, None
|
|
|
|
return torch.transpose(self, dim0, dim1), backward
|
|
|
|
def view(self,
|
|
size: List[int]):
|
|
self_size = self.size()
|
|
def backward(grad_output):
|
|
return grad_output.reshape(self_size), None
|
|
|
|
return torch.view(self, size), backward
|
|
)",
|
|
R"(
|
|
def AD_sizes_if_not_equal_multi_0(t1, t2, res):
|
|
return torch._size_if_not_equal(t1.size(), res.size()), torch._size_if_not_equal(t2.size(), res.size())
|
|
|
|
def mul_0(self, other):
|
|
result = self * other
|
|
self_size, other_size = AD_sizes_if_not_equal_multi_0(self, other, result)
|
|
|
|
def backward(grad_output):
|
|
grad_self = (grad_output * other)._grad_sum_to_size(self_size)
|
|
grad_other = (grad_output * self)._grad_sum_to_size(other_size)
|
|
return grad_self, grad_other
|
|
|
|
return result, backward
|
|
|
|
def mul_1(self, other: number):
|
|
def backward(grad_output):
|
|
return grad_output * other, None
|
|
return self * other, backward
|
|
|
|
def div_0(self, other):
|
|
result = self / other
|
|
self_size, other_size = AD_sizes_if_not_equal_multi_0(self, other, result)
|
|
|
|
def backward(grad_output):
|
|
grad_self = (grad_output / other)._grad_sum_to_size(self_size)
|
|
grad_other = (-grad_output * self / (other * other))._grad_sum_to_size(other_size)
|
|
return grad_self, grad_other
|
|
|
|
return result, backward
|
|
|
|
def div_1(self, other: number):
|
|
def backward(grad_output):
|
|
return grad_output / other, None
|
|
return self / other, backward
|
|
|
|
def div_2(self, other, *, rounding_mode: Optional[str]):
|
|
result = torch.div(self, other, rounding_mode=rounding_mode)
|
|
self_size, other_size = AD_sizes_if_not_equal_multi_0(self, other, result)
|
|
def backward(grad_output):
|
|
if rounding_mode is None:
|
|
grad_self = (grad_output / other)._grad_sum_to_size(self_size)
|
|
grad_other = (-grad_output * self / (other * other))._grad_sum_to_size(other_size)
|
|
else:
|
|
grad_self = torch.zeros_like(self)
|
|
grad_other = torch.zeros_like(other)
|
|
|
|
return grad_self, grad_other, None
|
|
|
|
return result, backward
|
|
|
|
def div_3(self, other: number, *, rounding_mode: Optional[str]):
|
|
result = torch.div(self, other, rounding_mode=rounding_mode)
|
|
def backward(grad_output):
|
|
if rounding_mode is None:
|
|
grad_self = (grad_output / other)
|
|
else:
|
|
grad_self = torch.zeros_like(self, memory_format=1)
|
|
return grad_self, None, None
|
|
return result, backward
|
|
|
|
def max(self, other):
|
|
result = torch.max(self, other)
|
|
self_size, other_size = AD_sizes_if_not_equal_multi_0(self, other, result)
|
|
|
|
def backward(grad_output):
|
|
grad_self = (grad_output * (self > other).type_as(grad_output))._grad_sum_to_size(self_size)
|
|
grad_other = (grad_output * (other > self).type_as(grad_output))._grad_sum_to_size(other_size)
|
|
return grad_self, grad_other
|
|
|
|
return result, backward
|
|
|
|
def min(self, other):
|
|
def backward(grad_output):
|
|
grad_self = (grad_output * (self < other).type_as(grad_output))._grad_sum_to_size(self.size())
|
|
grad_other = (grad_output * (other < self).type_as(grad_output))._grad_sum_to_size(other.size())
|
|
return grad_self, grad_other
|
|
|
|
return torch.min(self, other), backward
|
|
|
|
def sigmoid(self):
|
|
result = torch.sigmoid(self)
|
|
def backward(grad_output):
|
|
return (1 - result) * result * grad_output
|
|
|
|
return result, backward
|
|
|
|
# Share backward with threshold
|
|
def relu(self):
|
|
result = torch.relu(self)
|
|
def backward(grad_output):
|
|
return grad_output * (result > 0).type_as(result)
|
|
|
|
return result, backward
|
|
|
|
def relu6(self):
|
|
result = torch.relu6(self)
|
|
def backward(grad_output):
|
|
return grad_output * ((result > 0) & (result < 6.0))
|
|
|
|
return result, backward
|
|
|
|
def leaky_relu(self, negative_slope: number):
|
|
result = torch.leaky_relu(self, negative_slope)
|
|
def backward(grad_output):
|
|
return grad_output * torch.where(self > 0, 1.0, negative_slope).type_as(result), None
|
|
return result, backward
|
|
|
|
def gelu(self : Tensor, *, approximate : str):
|
|
result = torch.gelu(self, approximate=approximate)
|
|
def backward(grad_output):
|
|
return torch.gelu_backward(grad_output, self, approximate=approximate), None
|
|
return result, backward
|
|
|
|
def silu(self):
|
|
result = torch.silu(self)
|
|
def backward(grad_output):
|
|
input_sigmoid = torch.sigmoid(self)
|
|
return grad_output * (input_sigmoid * (1 + self * (1 - input_sigmoid)))
|
|
return result, backward
|
|
|
|
def hardswish(self):
|
|
result = torch.hardswish(self)
|
|
def backward(grad_output):
|
|
m = (self >= 3.).type_as(result)
|
|
m = torch.where((self > -3.) & (self < 3.), self / 3. + .5, m)
|
|
return grad_output * m
|
|
return result, backward
|
|
|
|
def hardsigmoid(self):
|
|
result = torch.hardsigmoid(self)
|
|
def backward(grad_output):
|
|
m = (self > -3.) & (self < 3.)
|
|
lhs = grad_output * (1.0 / 6.0)
|
|
return torch.where(m, lhs, m.type_as(self))
|
|
return result, backward
|
|
|
|
def erfc(self):
|
|
def backward(grad_output):
|
|
# Precomputed constant C = -2.0 / math.sqrt(math.pi)
|
|
C = -1.1283791670955126
|
|
return C * torch.exp(-self * self) * grad_output
|
|
|
|
return torch.erfc(self), backward
|
|
|
|
def exp(self):
|
|
result = torch.exp(self)
|
|
def backward(grad_output):
|
|
return grad_output * result
|
|
|
|
return result, backward
|
|
|
|
def neg(self):
|
|
def backward(grad_output):
|
|
return grad_output.neg()
|
|
|
|
return torch.neg(self), backward
|
|
|
|
def where(condition, self, other):
|
|
result = torch.where(condition, self, other)
|
|
self_size, other_size = AD_sizes_if_not_equal_multi_0(self, other, result)
|
|
def backward(grad_output):
|
|
grad_self = (grad_output * condition.type_as(grad_output))._grad_sum_to_size(self_size)
|
|
grad_other = (grad_output * (condition.bitwise_not()).type_as(grad_output))._grad_sum_to_size(other_size)
|
|
return None, grad_self, grad_other
|
|
|
|
return result, backward
|
|
|
|
def type_as(self, other):
|
|
def backward(grad_output):
|
|
return grad_output.type_as(self), None
|
|
|
|
return torch.type_as(self, other), backward
|
|
|
|
def unsqueeze(self, dim: int):
|
|
def backward(grad_output):
|
|
return grad_output.squeeze(dim), None
|
|
|
|
return torch.unsqueeze(self, dim), backward
|
|
|
|
def abs(self):
|
|
def backward(grad_output):
|
|
return grad_output * self.sign()
|
|
|
|
return torch.abs(self), backward
|
|
|
|
def acos(self):
|
|
def backward(grad_output):
|
|
return grad_output * -((-self * self + 1).rsqrt())
|
|
|
|
return torch.acos(self), backward
|
|
|
|
def asin(self):
|
|
def backward(grad_output):
|
|
return grad_output * (-self * self + 1).rsqrt()
|
|
|
|
return torch.asin(self), backward
|
|
|
|
def atan(self):
|
|
def backward(grad_output):
|
|
return grad_output / (self * self + 1)
|
|
|
|
return torch.atan(self), backward
|
|
|
|
def ceil(self):
|
|
def backward(grad_output):
|
|
return torch.zeros_like(grad_output, memory_format=1)
|
|
|
|
return torch.ceil(self), backward
|
|
|
|
def cos(self):
|
|
def backward(grad_output):
|
|
return grad_output * -self.sin()
|
|
|
|
return torch.cos(self), backward
|
|
|
|
def cosh(self):
|
|
def backward(grad_output):
|
|
return grad_output * self.sinh()
|
|
|
|
return torch.cosh(self), backward
|
|
|
|
def expm1(self):
|
|
result = torch.expm1(self)
|
|
def backward(grad_output):
|
|
return grad_output * (result + 1)
|
|
|
|
return result, backward
|
|
|
|
def floor(self):
|
|
def backward(grad_output):
|
|
return torch.zeros_like(grad_output, memory_format=1)
|
|
|
|
return torch.floor(self), backward
|
|
|
|
def frac(self):
|
|
def backward(grad_output):
|
|
return grad_output
|
|
|
|
return torch.frac(self), backward
|
|
|
|
def log(self):
|
|
def backward(grad_output):
|
|
return grad_output.div(self)
|
|
|
|
return torch.log(self), backward
|
|
|
|
def log10(self):
|
|
def backward(grad_output):
|
|
return grad_output / (self * 2.3025850929940456)
|
|
|
|
return torch.log10(self), backward
|
|
|
|
def log1p(self):
|
|
def backward(grad_output):
|
|
return grad_output / (self + 1)
|
|
|
|
return torch.log1p(self), backward
|
|
|
|
def log2(self):
|
|
def backward(grad_output):
|
|
return grad_output / (self * 0.6931471805599453)
|
|
|
|
return torch.log2(self), backward
|
|
|
|
# TODO: Fix rand_like to match expected format
|
|
# def rand_like(self, *, memory_format: Optional[int]):
|
|
# def backward(grad_output):
|
|
# return None
|
|
|
|
# return torch.rand_like(self, memory_format=memory_format), backward
|
|
|
|
def reciprocal(self):
|
|
result = torch.reciprocal(self)
|
|
def backward(grad_output):
|
|
return -grad_output * result * result
|
|
|
|
return result, backward
|
|
|
|
def round(self):
|
|
def backward(grad_output):
|
|
return torch.zeros_like(grad_output, memory_format=1)
|
|
|
|
return torch.round(self), backward
|
|
|
|
def rsqrt(self):
|
|
result = torch.rsqrt(self)
|
|
def backward(grad_output):
|
|
return -grad_output * result * result * result / 2
|
|
|
|
return result, backward
|
|
|
|
def sin(self):
|
|
def backward(grad_output):
|
|
return grad_output * self.cos()
|
|
|
|
return torch.sin(self), backward
|
|
|
|
def sinh(self):
|
|
def backward(grad_output):
|
|
return grad_output * self.cosh()
|
|
|
|
return torch.sinh(self), backward
|
|
|
|
def tan(self):
|
|
result = torch.tan(self)
|
|
def backward(grad_output):
|
|
return grad_output * (1. + result * result)
|
|
|
|
return result, backward
|
|
|
|
def trunc(self):
|
|
def backward(grad_output):
|
|
return torch.zeros_like(grad_output, memory_format=1)
|
|
|
|
return torch.trunc(self), backward
|
|
|
|
def _grad_sum_to_size(self,
|
|
size: Optional[List[int]]):
|
|
result = torch._grad_sum_to_size(self, size)
|
|
self_size = torch._size_if_not_equal(self.size(), result.size())
|
|
|
|
def backward(grad_output):
|
|
if self_size is None:
|
|
grad_input = grad_output
|
|
else:
|
|
grad_input = grad_output.expand(self_size)
|
|
return grad_input, None
|
|
|
|
return result, backward
|
|
)",
|
|
R"(
|
|
def batch_norm(input : Tensor,
|
|
weight : Optional[Tensor],
|
|
bias : Optional[Tensor],
|
|
running_mean : Optional[Tensor],
|
|
running_var : Optional[Tensor],
|
|
training : bool,
|
|
momentum : float,
|
|
eps : float,
|
|
cudnn_enabled : bool):
|
|
|
|
output, save1, save2, reserve, impl_idx = torch._batch_norm_impl_index(
|
|
input, weight, bias, running_mean, running_var, training,
|
|
momentum, eps, cudnn_enabled)
|
|
has_weight = weight is not None
|
|
has_bias = bias is not None
|
|
|
|
def backward(grad_output):
|
|
dinput, dweight, dbias = torch._batch_norm_impl_index_backward(
|
|
impl_idx, input, grad_output, weight, running_mean, running_var,
|
|
save1, save2, training, eps, [True, has_weight, has_bias], reserve)
|
|
return dinput, dweight, dbias, None, None, None, None, None, None
|
|
|
|
return output, backward
|
|
|
|
def layer_norm(input : Tensor,
|
|
normalized_shape : List[int],
|
|
weight : Optional[Tensor],
|
|
bias : Optional[Tensor],
|
|
eps : float,
|
|
cudnn_enable : bool):
|
|
|
|
output, mean, rstd = torch.native_layer_norm(input, normalized_shape, weight, bias, eps)
|
|
|
|
def backward(grad_output):
|
|
output_mask = [True, weight is not None, bias is not None]
|
|
grad_input, grad_weight, grad_bias = torch.native_layer_norm_backward(grad_output, input, normalized_shape, mean, rstd, weight, bias, output_mask)
|
|
return grad_input, None, grad_weight, grad_bias, None, None
|
|
return output, backward
|
|
|
|
def dropout(input,
|
|
p: float,
|
|
train: bool):
|
|
# if `train == false` we need to set `p1m` to 0 so `scale == 1`
|
|
p1m = (1. - p) * float(train)
|
|
scale = 1. / (float(p1m == 0.) + p1m)
|
|
res,mask = torch.native_dropout(input, p, train)
|
|
|
|
def backward(grad_output):
|
|
grad_input = torch.native_dropout_backward(grad_output, mask, scale)
|
|
return grad_input, None, None
|
|
return res, backward
|
|
|
|
def embedding(weight,
|
|
indices,
|
|
padding_idx: int,
|
|
scale_grad_by_freq: bool,
|
|
sparse: bool):
|
|
weight_size_0 = weight.size()[0]
|
|
def backward(grad_output):
|
|
grad_weight = torch.embedding_backward(grad_output, indices, weight_size_0, padding_idx, scale_grad_by_freq, sparse)
|
|
return grad_weight, None, None, None, None
|
|
|
|
return torch.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse), backward
|
|
|
|
def log_softmax(self, dim: int, dtype: Optional[int]):
|
|
result = torch.log_softmax(self, dim, dtype)
|
|
def backward(grad_output):
|
|
grad_self = torch._log_softmax_backward_data(grad_output, result, dim, self.dtype)
|
|
return grad_self, None, None
|
|
|
|
return result, backward
|
|
|
|
def nll_loss(self, target, weight: Optional[Tensor], reduction: int, ignore_index: int):
|
|
result, total_weight = torch.nll_loss_forward(self, target, weight, reduction, ignore_index)
|
|
def backward(grad):
|
|
return torch.nll_loss_backward(grad, self, target, weight, reduction, ignore_index, total_weight), None, None, None, None
|
|
return result, backward
|
|
|
|
def softmax(self, dim: int, dtype: Optional[int]):
|
|
result = torch.softmax(self, dim, dtype)
|
|
def backward(grad_output):
|
|
grad_self = torch._softmax_backward_data(grad_output, result, dim, self.dtype)
|
|
return grad_self, None, None
|
|
|
|
return result, backward
|
|
)",
|
|
R"(
|
|
def AD_adaptive_avg_pool3d_backward(grad,
|
|
self,
|
|
output_size: List[int]):
|
|
if output_size[0] == 1 and output_size[1] == 1 and output_size[2] == 1:
|
|
self_size = self.size()
|
|
grad_self = grad.expand(self.size()) / (self_size[-1] * self_size[-2] * self_size[-3])
|
|
else:
|
|
grad_self = torch._adaptive_avg_pool3d_backward(grad, self)
|
|
|
|
return grad_self
|
|
|
|
def AD_adaptive_avg_pool2d_backward(grad,
|
|
self,
|
|
output_size: List[int]):
|
|
if output_size[0] == 1 and output_size[1] == 1:
|
|
self_size = self.size()
|
|
grad_self = grad.expand(self.size()) / (self_size[-1] * self_size[-2])
|
|
else:
|
|
grad_self = torch._adaptive_avg_pool2d_backward(grad, self)
|
|
|
|
return grad_self
|
|
|
|
def AD_adaptive_avg_pool1d_backward(grad,
|
|
input,
|
|
output_size: List[int]):
|
|
output_size_2d = [1, output_size[0]]
|
|
grad_input = AD_adaptive_avg_pool2d_backward(grad.unsqueeze(2), input.unsqueeze(2), output_size_2d).squeeze(2)
|
|
return grad_input
|
|
|
|
def adaptive_avg_pool1d(self,
|
|
output_size: List[int]):
|
|
def backward(grad_output):
|
|
grad_self = AD_adaptive_avg_pool1d_backward(grad_output, self, output_size)
|
|
return grad_self, None
|
|
|
|
return torch.adaptive_avg_pool1d(self, output_size), backward
|
|
|
|
def adaptive_avg_pool2d(self,
|
|
output_size: List[int]):
|
|
def backward(grad_output):
|
|
# self is used in backward, no need to pass in its size explicitly
|
|
grad_self = AD_adaptive_avg_pool2d_backward(grad_output, self, output_size)
|
|
return grad_self, None
|
|
return torch.adaptive_avg_pool2d(self, output_size), backward
|
|
|
|
def adaptive_avg_pool3d(self,
|
|
output_size: List[int]):
|
|
def backward(grad_output):
|
|
grad_self = AD_adaptive_avg_pool3d_backward(grad_output, self, output_size)
|
|
return grad_self, None
|
|
|
|
return torch.adaptive_avg_pool3d(self, output_size), backward
|
|
|
|
def avg_pool2d(self,
|
|
kernel_size: List[int],
|
|
stride: List[int],
|
|
padding: List[int],
|
|
ceil_mode: bool,
|
|
count_include_pad: bool,
|
|
divisor_override: Optional[int]):
|
|
def backward(grad_output):
|
|
grad_self = torch.avg_pool2d_backward(grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
|
|
return grad_self, None, None, None, None, None, None
|
|
|
|
return torch.avg_pool2d(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override), backward
|
|
|
|
def max_pool2d(self,
|
|
kernel_size: List[int],
|
|
stride: List[int],
|
|
padding: List[int],
|
|
dilation: List[int],
|
|
ceil_mode: bool):
|
|
output, indices = torch.max_pool2d_with_indices(self, kernel_size, stride, padding, dilation, ceil_mode)
|
|
def backward(grad_output):
|
|
grad_self = torch.max_pool2d_with_indices_backward(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices)
|
|
return grad_self, None, None, None, None, None
|
|
return output, backward
|
|
|
|
def max_pool2d_with_indices(self,
|
|
kernel_size: List[int],
|
|
stride: List[int],
|
|
padding: List[int],
|
|
dilation: List[int],
|
|
ceil_mode: bool):
|
|
output, indices = torch.max_pool2d_with_indices(self, kernel_size, stride, padding, dilation, ceil_mode)
|
|
def backward(grad_output):
|
|
grad_self = torch.max_pool2d_with_indices_backward(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices)
|
|
return grad_self, None, None, None, None, None
|
|
return output, indices, backward
|
|
)",
|
|
R"(
|
|
def AD_sizes_if_not_equal_multi_1(t1, t2, res):
|
|
return torch._size_if_not_equal(t1.size(), res.size()), torch._size_if_not_equal(t2.size(), res.size())
|
|
|
|
def add_0(self,
|
|
other,
|
|
*,
|
|
alpha: number):
|
|
result = torch.add(self, other, alpha=alpha)
|
|
self_size, other_size = AD_sizes_if_not_equal_multi_1(self, other, result)
|
|
def backward(grad_output):
|
|
grad_other = (grad_output * alpha)._grad_sum_to_size(other_size)
|
|
grad_self = (grad_output)._grad_sum_to_size(self_size)
|
|
return grad_self, grad_other, None
|
|
return result, backward
|
|
|
|
def add_1(self,
|
|
other: number,
|
|
alpha: number):
|
|
def backward(grad_output):
|
|
return grad_output, None, None
|
|
return torch.add(self, other, alpha=alpha), backward
|
|
|
|
def sub_0(self,
|
|
other,
|
|
*,
|
|
alpha: number):
|
|
result = torch.sub(self, other, alpha=alpha)
|
|
self_size, other_size = AD_sizes_if_not_equal_multi_1(self, other, result)
|
|
def backward(grad_output):
|
|
grad_other = (-grad_output * alpha)._grad_sum_to_size(other_size)
|
|
grad_self = (grad_output)._grad_sum_to_size(self_size)
|
|
return grad_self, grad_other, None
|
|
return result , backward
|
|
|
|
def sub_1(self,
|
|
other: number,
|
|
alpha: number):
|
|
def backward(grad_output):
|
|
return grad_output, None, None
|
|
return torch.sub(self, other, alpha=alpha), backward
|
|
|
|
def threshold(self,
|
|
threshold: number,
|
|
value: number):
|
|
def backward(grad_output):
|
|
mask = (self >= threshold).type_as(self)
|
|
return grad_output * mask, None, None
|
|
return torch.threshold(self, threshold, value), backward
|
|
|
|
def softplus(self,
|
|
beta: number,
|
|
threshold: number):
|
|
result = torch.softplus(self, beta, threshold)
|
|
def backward(grad_output):
|
|
z = torch.exp(result * beta)
|
|
return torch.where((result * beta) > threshold, grad_output, grad_output * (z - 1.) / z), None, None
|
|
return result, backward
|
|
|
|
def fmod(self,
|
|
other: number):
|
|
def backward(grad_output):
|
|
return grad_output, None
|
|
return torch.fmod(self, other), backward
|
|
|
|
def remainder(self,
|
|
other: number):
|
|
def backward(grad_output):
|
|
return grad_output, None
|
|
return torch.remainder(self, other), backward
|
|
|
|
def addmm(self,
|
|
mat1,
|
|
mat2,
|
|
*,
|
|
beta: number,
|
|
alpha: number):
|
|
result = torch.addmm(self, mat1, mat2, beta=beta, alpha=alpha)
|
|
self_size = torch._size_if_not_equal(self.size(), result.size())
|
|
def backward(grad_output):
|
|
self_grad = (grad_output * beta)._grad_sum_to_size(self_size)
|
|
mat1_grad = grad_output.mm(mat2.t()) * alpha
|
|
mat2_grad = mat1.t().mm(grad_output) * alpha
|
|
return self_grad, mat1_grad, mat2_grad, None, None
|
|
return result, backward
|
|
|
|
# Comparison operators
|
|
def lt(self, other: number):
|
|
def backward(grad_output):
|
|
return None, None
|
|
return torch.lt(self, other), backward
|
|
|
|
def le(self, other: number):
|
|
def backward(grad_output):
|
|
return None, None
|
|
return torch.le(self, other), backward
|
|
|
|
def gt(self, other: number):
|
|
def backward(grad_output):
|
|
return None, None
|
|
return torch.gt(self, other), backward
|
|
|
|
def ge(self, other: number):
|
|
def backward(grad_output):
|
|
return None, None
|
|
return torch.ge(self, other), backward
|
|
|
|
def eq(self, other: number):
|
|
def backward(grad_output):
|
|
return None, None
|
|
return torch.eq(self, other), backward
|
|
|
|
def ne(self, other: number):
|
|
def backward(grad_output):
|
|
return None, None
|
|
return torch.ne(self, other), backward
|
|
|
|
def hardshrink(self, lambd: number):
|
|
def backward(grad_output):
|
|
mask = ((self > lambd) | (self < -lambd))
|
|
return grad_output * mask, None
|
|
return torch.hardshrink(self, lambd=lambd), backward
|
|
|
|
def hardtanh(self, min_val: number, max_val: number):
|
|
def backward(grad_output):
|
|
mask = ((self >= min_val) * (self <= max_val))
|
|
return grad_output * mask, None, None
|
|
return torch.hardtanh(self, min_val=min_val, max_val=max_val), backward
|
|
|
|
def clamp_1(self,
|
|
min: Optional[number],
|
|
max: Optional[number]):
|
|
def backward(grad_output):
|
|
if min is not None and max is not None:
|
|
mask = ((self >= float(min)) * (self <= float(max))).type_as(self)
|
|
return grad_output * mask, None, None
|
|
elif min is not None:
|
|
mask = (self >= float(min)).type_as(self)
|
|
return grad_output * mask, None, None
|
|
elif max is not None:
|
|
mask = (self <= float(max)).type_as(self)
|
|
return grad_output * mask, None, None
|
|
else: #min is None and max is None
|
|
return grad_output, None, None
|
|
return torch.clamp(self, min=min, max=max), backward
|
|
|
|
def clamp_2(self,
|
|
min: Optional[Tensor],
|
|
max: Optional[Tensor]):
|
|
def backward(grad_output):
|
|
if min is not None and max is not None:
|
|
mask = ((self >= min) * (self <= max)).type_as(self)
|
|
return grad_output * mask, None, None
|
|
elif min is not None:
|
|
mask = (self >= min).type_as(self)
|
|
return grad_output * mask, None, None
|
|
elif max is not None:
|
|
mask = (self <= max).type_as(self)
|
|
return grad_output * mask, None, None
|
|
else: #min is None and max is None
|
|
return grad_output, None, None
|
|
return torch.clamp(self, min=min, max=max), backward
|
|
)"};
|
|
|
|
std::unordered_map<std::string, GradientPair> schema_to_graphs;
|
|
|
|
// This map is a workaround to cache compiled gradient_pairs. Ideally this graph
|
|
// should be compiled only once and saved in Operator structure.
|
|
// This should be done along with merging into native_functions.yaml.
|
|
std::unordered_map<const FunctionSchema*, GradientPair> cached_gradient_pairs;
|
|
|
|
// CompilationUnit that holds all these Functions and keeps them alive.
|
|
CompilationUnit compilation_unit;
|
|
} // anonymous namespace
|
|
|
|
static std::pair<std::shared_ptr<Graph>, Value*> extractClosure(
|
|
Value* closure) {
|
|
TORCH_CHECK(
|
|
closure->node()->kind() == prim::TupleConstruct,
|
|
"closure must be a literal tuple construct");
|
|
Value* fn = closure->node()->inputs().at(0);
|
|
Value* context = closure->node()->inputs().at(1);
|
|
|
|
TORCH_CHECK(
|
|
fn->node()->kind() == prim::Closure,
|
|
"closure tuple must contain a prim::Closure");
|
|
return std::make_pair(fn->node()->g(attr::Subgraph), context);
|
|
}
|
|
|
|
static Argument originalReturnType(const TupleTypePtr& tup) {
|
|
TORCH_CHECK(tup->elements().size() > 1);
|
|
if (tup->elements().size() == 2)
|
|
return Argument("", tup->elements().at(0));
|
|
std::vector<TypePtr> types = tup->elements().vec();
|
|
types.pop_back();
|
|
return Argument("", TupleType::create(std::move(types)));
|
|
}
|
|
|
|
// In torchscript AD formulas, we define {func_0, func_1, ...} as
|
|
// overloaded functions of `func`.
|
|
// Remove the suffix before adding the schema string to map
|
|
// schema_to_graphs.
|
|
static std::string overloadedSchemaString(const FunctionSchema& schema) {
|
|
const auto& schema_name = schema.name();
|
|
auto pos = schema_name.find_last_of('_');
|
|
auto schema_name_suffix = schema_name.substr(pos + 1);
|
|
std::string schema_string = canonicalSchemaString(schema);
|
|
if (!schema_name_suffix.empty() &&
|
|
schema_name_suffix.find_first_not_of("0123456789") == std::string::npos) {
|
|
schema_string.replace(
|
|
schema_string.find(schema_name),
|
|
schema_name.length(),
|
|
schema_name.substr(0, pos));
|
|
}
|
|
|
|
return schema_string;
|
|
}
|
|
|
|
static bool isHelperFunction(const std::string& method_name) {
|
|
std::string helper_prefix = "AD_";
|
|
return method_name.compare(0, helper_prefix.length(), helper_prefix) == 0;
|
|
}
|
|
|
|
static void loadModule(const CompilationUnit& module) {
|
|
for (const auto& method : module.get_functions()) {
|
|
if (isHelperFunction(method->name()))
|
|
continue;
|
|
|
|
GradientPair pair;
|
|
pair.forward = toGraphFunction(*method).graph();
|
|
|
|
// lookup the backward function
|
|
Node* forward_tuple = pair.forward->outputs().at(0)->node();
|
|
|
|
if (forward_tuple->kind() != prim::TupleConstruct) {
|
|
throw(
|
|
ErrorReport(forward_tuple->sourceRange())
|
|
<< "gradient must return literal a tuple");
|
|
}
|
|
|
|
Value* context = nullptr;
|
|
std::tie(pair.backward, context) =
|
|
extractClosure(forward_tuple->inputs().back());
|
|
|
|
// checks that num forward graph inputs equals num backward graph outputs
|
|
TORCH_CHECK(
|
|
pair.forward->inputs().size() ==
|
|
unpackOutputs(pair.backward->outputs().vec()).size(),
|
|
"The autodiff implementation of ",
|
|
method->name(),
|
|
" backward() returns an incorrect number of values: ",
|
|
unpackOutputs(pair.backward->outputs().vec()).size(),
|
|
" instead of ",
|
|
pair.forward->inputs().size());
|
|
|
|
// do surgery on the forward function to remove the closure tuple and
|
|
// replace it with the context variable:
|
|
// backward = (<lambda>, context_tuple)
|
|
// return original, backward
|
|
// -----
|
|
// return original, context_tuple
|
|
std::vector<Value*> new_inputs = forward_tuple->inputs().vec();
|
|
new_inputs.back() = context;
|
|
Value* new_tuple =
|
|
pair.forward->appendNode(pair.forward->createTuple(new_inputs))
|
|
->output();
|
|
pair.forward->eraseOutput(0);
|
|
pair.forward->registerOutput(new_tuple);
|
|
forward_tuple->destroy();
|
|
|
|
// derive schema from original function's schema:
|
|
const FunctionSchema& loaded_schema = method->getSchema();
|
|
FunctionSchema actual_schema(
|
|
Symbol::aten(loaded_schema.name()),
|
|
loaded_schema.overload_name(),
|
|
loaded_schema.arguments(),
|
|
{originalReturnType(new_tuple->type()->expect<TupleType>())});
|
|
|
|
// modify canonical string for function overloading
|
|
// prefer not to modify the schema name
|
|
auto schema_string = overloadedSchemaString(actual_schema);
|
|
|
|
schema_to_graphs[schema_string] = std::move(pair);
|
|
}
|
|
}
|
|
|
|
static void loadFunctions() {
|
|
for (const std::string& str : functions) {
|
|
compilation_unit.define(std::nullopt, str, nativeResolver(), nullptr);
|
|
}
|
|
loadModule(compilation_unit);
|
|
}
|
|
|
|
std::optional<GradientPair> gradientInfoForSchema(
|
|
const FunctionSchema& schema) {
|
|
std::lock_guard<std::mutex> guard(lock);
|
|
if (schema_to_graphs.empty()) {
|
|
loadFunctions();
|
|
}
|
|
auto cache_it = cached_gradient_pairs.find(&schema);
|
|
if (cache_it != cached_gradient_pairs.end()) {
|
|
return cache_it->second;
|
|
} else {
|
|
auto schema_str = canonicalSchemaString(schema);
|
|
// For debugging AD change:
|
|
// std::cout << "Looking for " << schema_str << std::endl;
|
|
auto sym_script_it = schema_to_graphs.find(schema_str);
|
|
|
|
if (sym_script_it != schema_to_graphs.end()) {
|
|
cached_gradient_pairs.emplace_hint(
|
|
cache_it, &schema, sym_script_it->second);
|
|
return sym_script_it->second;
|
|
}
|
|
}
|
|
return std::nullopt;
|
|
}
|
|
|
|
bool hasGradientInfoForSchema(const FunctionSchema& schema) {
|
|
return gradientInfoForSchema(schema).has_value();
|
|
}
|
|
|
|
} // namespace torch::jit
|