mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Refactored implementation for upsample_nearest decompostions (#122783)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122783 Approved by: https://github.com/peterbell10
This commit is contained in:
committed by
PyTorch MergeBot
parent
bebdbb63ce
commit
6330acae76
@ -2647,71 +2647,45 @@ def get_scale_value(scales, idx):
|
||||
|
||||
|
||||
@register_decomposition(aten.upsample_nearest1d.vec)
|
||||
@register_decomposition(aten.upsample_nearest2d.vec)
|
||||
@register_decomposition(aten.upsample_nearest3d.vec)
|
||||
@aten.upsample_nearest1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||
@aten.upsample_nearest1d.vec.py_impl(DispatchKey.Autograd)
|
||||
def upsample_nearest1d_vec(input, output_size, scale_factors):
|
||||
@aten.upsample_nearest2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||
@aten.upsample_nearest2d.vec.py_impl(DispatchKey.Autograd)
|
||||
@aten.upsample_nearest3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||
@aten.upsample_nearest3d.vec.py_impl(DispatchKey.Autograd)
|
||||
def _upsample_nearest_vec(
|
||||
input: Tensor,
|
||||
output_size: Optional[List[int]],
|
||||
scale_factors: Optional[List[float]],
|
||||
) -> Tensor:
|
||||
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
|
||||
scale = get_scale_value(scale_factors, 0)
|
||||
|
||||
return aten.upsample_nearest1d.default(input, osize, scale)
|
||||
scales = (
|
||||
scale_factors if scale_factors else [None] * len(osize) # type: ignore[list-item]
|
||||
)
|
||||
return _upsample_nearest(input, osize, scales)
|
||||
|
||||
|
||||
@register_decomposition(aten._upsample_nearest_exact1d.vec)
|
||||
@register_decomposition(aten._upsample_nearest_exact2d.vec)
|
||||
@register_decomposition(aten._upsample_nearest_exact3d.vec)
|
||||
@aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||
@aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.Autograd)
|
||||
def _upsample_nearest_exact1d_vec(input, output_size, scale_factors):
|
||||
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
|
||||
scale = get_scale_value(scale_factors, 0)
|
||||
|
||||
return aten._upsample_nearest_exact1d.default(input, osize, scale)
|
||||
|
||||
|
||||
@register_decomposition(aten.upsample_nearest2d.vec)
|
||||
@aten.upsample_nearest2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||
@aten.upsample_nearest2d.vec.py_impl(DispatchKey.Autograd)
|
||||
def upsample_nearest2d_vec(input, output_size, scale_factors):
|
||||
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
|
||||
scale_h = get_scale_value(scale_factors, 0)
|
||||
scale_w = get_scale_value(scale_factors, 1)
|
||||
|
||||
return aten.upsample_nearest2d.default(input, osize, scale_h, scale_w)
|
||||
|
||||
|
||||
@register_decomposition(aten._upsample_nearest_exact2d.vec)
|
||||
@aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||
@aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.Autograd)
|
||||
def _upsample_nearest_exact2d_vec(input, output_size, scale_factors):
|
||||
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
|
||||
scale_h = get_scale_value(scale_factors, 0)
|
||||
scale_w = get_scale_value(scale_factors, 1)
|
||||
|
||||
return aten._upsample_nearest_exact2d.default(input, osize, scale_h, scale_w)
|
||||
|
||||
|
||||
@register_decomposition(aten.upsample_nearest3d.vec)
|
||||
@aten.upsample_nearest3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||
@aten.upsample_nearest3d.vec.py_impl(DispatchKey.Autograd)
|
||||
def upsample_nearest3d_vec(input, output_size, scale_factors):
|
||||
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
|
||||
scale_d = get_scale_value(scale_factors, 0)
|
||||
scale_h = get_scale_value(scale_factors, 1)
|
||||
scale_w = get_scale_value(scale_factors, 2)
|
||||
|
||||
return aten.upsample_nearest3d.default(input, osize, scale_d, scale_h, scale_w)
|
||||
|
||||
|
||||
@register_decomposition(aten._upsample_nearest_exact3d.vec)
|
||||
@aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||
@aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.Autograd)
|
||||
def _upsample_nearest_exact3d_vec(input, output_size, scale_factors):
|
||||
def _upsample_nearest_exact_vec(
|
||||
input: Tensor,
|
||||
output_size: Optional[List[int]],
|
||||
scale_factors: Optional[List[float]],
|
||||
) -> Tensor:
|
||||
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
|
||||
scale_d = get_scale_value(scale_factors, 0)
|
||||
scale_h = get_scale_value(scale_factors, 1)
|
||||
scale_w = get_scale_value(scale_factors, 2)
|
||||
|
||||
return aten._upsample_nearest_exact3d.default(
|
||||
input, osize, scale_d, scale_h, scale_w
|
||||
scales = (
|
||||
scale_factors if scale_factors else [None] * len(osize) # type: ignore[list-item]
|
||||
)
|
||||
return _upsample_nearest(input, osize, scales, exact=True)
|
||||
|
||||
|
||||
def _compute_upsample_nearest_indices(input, output_size, scales, exact=False):
|
||||
@ -2743,88 +2717,58 @@ def _compute_upsample_nearest_indices(input, output_size, scales, exact=False):
|
||||
for _ in range(num_spatial_dims - 1 - d):
|
||||
input_indices = input_indices.unsqueeze(-1)
|
||||
indices.append(input_indices)
|
||||
return tuple(indices)
|
||||
return indices
|
||||
|
||||
|
||||
@register_decomposition(aten.upsample_nearest1d.default)
|
||||
@aten.upsample_nearest1d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||
@aten.upsample_nearest1d.default.py_impl(DispatchKey.Autograd)
|
||||
@pw_cast_for_opmath
|
||||
def upsample_nearest1d(
|
||||
input: Tensor,
|
||||
output_size: List[int],
|
||||
scales: Optional[float] = None,
|
||||
) -> Tensor:
|
||||
(l_indices,) = _compute_upsample_nearest_indices(input, output_size, (scales,))
|
||||
return aten._unsafe_index(input, (None, None, l_indices))
|
||||
return _upsample_nearest(input, output_size, [scales])
|
||||
|
||||
|
||||
@register_decomposition(aten._upsample_nearest_exact1d.default)
|
||||
@aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||
@aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.Autograd)
|
||||
@pw_cast_for_opmath
|
||||
def _upsample_nearest_exact1d(
|
||||
def upsample_nearest_exact1d(
|
||||
input: Tensor,
|
||||
output_size: List[int],
|
||||
scales: Optional[float] = None,
|
||||
) -> Tensor:
|
||||
(l_indices,) = _compute_upsample_nearest_indices(
|
||||
input, output_size, (scales,), exact=True
|
||||
)
|
||||
return aten._unsafe_index(input, (None, None, l_indices))
|
||||
|
||||
|
||||
def _upsample_nearest2d_common(input, h_indices, w_indices):
|
||||
result = aten._unsafe_index(input, (None, None, h_indices, w_indices))
|
||||
|
||||
# convert output to correct memory format, if necessary
|
||||
memory_format = utils.suggest_memory_format(input)
|
||||
|
||||
# following "heuristic: only use channels_last path when it's faster than the contiguous path"
|
||||
_, n_channels, _, _ = input.shape
|
||||
if input.device.type == "cuda" and n_channels < 4:
|
||||
memory_format = torch.contiguous_format
|
||||
|
||||
result = result.contiguous(memory_format=memory_format)
|
||||
return result
|
||||
return _upsample_nearest(input, output_size, [scales], exact=True)
|
||||
|
||||
|
||||
@register_decomposition(aten.upsample_nearest2d.default)
|
||||
@aten.upsample_nearest2d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||
@aten.upsample_nearest2d.default.py_impl(DispatchKey.Autograd)
|
||||
@pw_cast_for_opmath
|
||||
def upsample_nearest2d(
|
||||
input: Tensor,
|
||||
output_size: List[int],
|
||||
scales_h: Optional[float] = None,
|
||||
scales_w: Optional[float] = None,
|
||||
) -> Tensor:
|
||||
h_indices, w_indices = _compute_upsample_nearest_indices(
|
||||
input, output_size, (scales_h, scales_w)
|
||||
)
|
||||
return _upsample_nearest2d_common(input, h_indices, w_indices)
|
||||
return _upsample_nearest(input, output_size, [scales_h, scales_w])
|
||||
|
||||
|
||||
@register_decomposition(aten._upsample_nearest_exact2d.default)
|
||||
@aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||
@aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.Autograd)
|
||||
@pw_cast_for_opmath
|
||||
def _upsample_nearest_exact2d(
|
||||
input: Tensor,
|
||||
output_size: List[int],
|
||||
scales_h: Optional[float] = None,
|
||||
scales_w: Optional[float] = None,
|
||||
) -> Tensor:
|
||||
h_indices, w_indices = _compute_upsample_nearest_indices(
|
||||
input, output_size, (scales_h, scales_w), exact=True
|
||||
)
|
||||
return _upsample_nearest2d_common(input, h_indices, w_indices)
|
||||
return _upsample_nearest(input, output_size, [scales_h, scales_w], exact=True)
|
||||
|
||||
|
||||
@register_decomposition(aten.upsample_nearest3d.default)
|
||||
@aten.upsample_nearest3d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||
@aten.upsample_nearest3d.default.py_impl(DispatchKey.Autograd)
|
||||
@pw_cast_for_opmath
|
||||
def upsample_nearest3d(
|
||||
input: Tensor,
|
||||
output_size: List[int],
|
||||
@ -2832,18 +2776,12 @@ def upsample_nearest3d(
|
||||
scales_h: Optional[float] = None,
|
||||
scales_w: Optional[float] = None,
|
||||
) -> Tensor:
|
||||
d_indices, h_indices, w_indices = _compute_upsample_nearest_indices(
|
||||
input, output_size, (scales_d, scales_h, scales_w)
|
||||
)
|
||||
result = aten._unsafe_index(input, (None, None, d_indices, h_indices, w_indices))
|
||||
|
||||
return result
|
||||
return _upsample_nearest(input, output_size, [scales_d, scales_h, scales_w])
|
||||
|
||||
|
||||
@register_decomposition(aten._upsample_nearest_exact3d.default)
|
||||
@aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||
@aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.Autograd)
|
||||
@pw_cast_for_opmath
|
||||
def _upsample_nearest_exact3d(
|
||||
input: Tensor,
|
||||
output_size: List[int],
|
||||
@ -2851,11 +2789,35 @@ def _upsample_nearest_exact3d(
|
||||
scales_h: Optional[float] = None,
|
||||
scales_w: Optional[float] = None,
|
||||
) -> Tensor:
|
||||
d_indices, h_indices, w_indices = _compute_upsample_nearest_indices(
|
||||
input, output_size, (scales_d, scales_h, scales_w), exact=True
|
||||
return _upsample_nearest(
|
||||
input, output_size, [scales_d, scales_h, scales_w], exact=True
|
||||
)
|
||||
result = aten._unsafe_index(input, (None, None, d_indices, h_indices, w_indices))
|
||||
|
||||
|
||||
@pw_cast_for_opmath
|
||||
def _upsample_nearest(
|
||||
input: Tensor,
|
||||
output_size: List[int],
|
||||
scales: List[Optional[float]],
|
||||
exact: bool = False,
|
||||
) -> Tensor:
|
||||
spatial_indices = _compute_upsample_nearest_indices(
|
||||
input, output_size, scales, exact=exact
|
||||
)
|
||||
|
||||
indices = [None, None] + spatial_indices
|
||||
result = aten._unsafe_index(input, indices)
|
||||
|
||||
if result.ndim == 4:
|
||||
# convert output to correct memory format, if necessary
|
||||
memory_format = utils.suggest_memory_format(input)
|
||||
|
||||
# following "heuristic: only use channels_last path when it's faster than the contiguous path"
|
||||
n_channels = input.shape[1]
|
||||
if input.device.type == "cuda" and n_channels < 4:
|
||||
memory_format = torch.contiguous_format
|
||||
|
||||
result = result.contiguous(memory_format=memory_format)
|
||||
return result
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user