mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Refactored implementation for upsample_nearest decompostions (#122783)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122783 Approved by: https://github.com/peterbell10
This commit is contained in:
committed by
PyTorch MergeBot
parent
bebdbb63ce
commit
6330acae76
@ -2647,71 +2647,45 @@ def get_scale_value(scales, idx):
|
|||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten.upsample_nearest1d.vec)
|
@register_decomposition(aten.upsample_nearest1d.vec)
|
||||||
|
@register_decomposition(aten.upsample_nearest2d.vec)
|
||||||
|
@register_decomposition(aten.upsample_nearest3d.vec)
|
||||||
@aten.upsample_nearest1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
|
@aten.upsample_nearest1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||||
@aten.upsample_nearest1d.vec.py_impl(DispatchKey.Autograd)
|
@aten.upsample_nearest1d.vec.py_impl(DispatchKey.Autograd)
|
||||||
def upsample_nearest1d_vec(input, output_size, scale_factors):
|
@aten.upsample_nearest2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||||
|
@aten.upsample_nearest2d.vec.py_impl(DispatchKey.Autograd)
|
||||||
|
@aten.upsample_nearest3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||||
|
@aten.upsample_nearest3d.vec.py_impl(DispatchKey.Autograd)
|
||||||
|
def _upsample_nearest_vec(
|
||||||
|
input: Tensor,
|
||||||
|
output_size: Optional[List[int]],
|
||||||
|
scale_factors: Optional[List[float]],
|
||||||
|
) -> Tensor:
|
||||||
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
|
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
|
||||||
scale = get_scale_value(scale_factors, 0)
|
scales = (
|
||||||
|
scale_factors if scale_factors else [None] * len(osize) # type: ignore[list-item]
|
||||||
return aten.upsample_nearest1d.default(input, osize, scale)
|
)
|
||||||
|
return _upsample_nearest(input, osize, scales)
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten._upsample_nearest_exact1d.vec)
|
@register_decomposition(aten._upsample_nearest_exact1d.vec)
|
||||||
|
@register_decomposition(aten._upsample_nearest_exact2d.vec)
|
||||||
|
@register_decomposition(aten._upsample_nearest_exact3d.vec)
|
||||||
@aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
|
@aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||||
@aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.Autograd)
|
@aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.Autograd)
|
||||||
def _upsample_nearest_exact1d_vec(input, output_size, scale_factors):
|
|
||||||
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
|
|
||||||
scale = get_scale_value(scale_factors, 0)
|
|
||||||
|
|
||||||
return aten._upsample_nearest_exact1d.default(input, osize, scale)
|
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten.upsample_nearest2d.vec)
|
|
||||||
@aten.upsample_nearest2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
|
|
||||||
@aten.upsample_nearest2d.vec.py_impl(DispatchKey.Autograd)
|
|
||||||
def upsample_nearest2d_vec(input, output_size, scale_factors):
|
|
||||||
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
|
|
||||||
scale_h = get_scale_value(scale_factors, 0)
|
|
||||||
scale_w = get_scale_value(scale_factors, 1)
|
|
||||||
|
|
||||||
return aten.upsample_nearest2d.default(input, osize, scale_h, scale_w)
|
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten._upsample_nearest_exact2d.vec)
|
|
||||||
@aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
|
@aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||||
@aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.Autograd)
|
@aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.Autograd)
|
||||||
def _upsample_nearest_exact2d_vec(input, output_size, scale_factors):
|
|
||||||
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
|
|
||||||
scale_h = get_scale_value(scale_factors, 0)
|
|
||||||
scale_w = get_scale_value(scale_factors, 1)
|
|
||||||
|
|
||||||
return aten._upsample_nearest_exact2d.default(input, osize, scale_h, scale_w)
|
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten.upsample_nearest3d.vec)
|
|
||||||
@aten.upsample_nearest3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
|
|
||||||
@aten.upsample_nearest3d.vec.py_impl(DispatchKey.Autograd)
|
|
||||||
def upsample_nearest3d_vec(input, output_size, scale_factors):
|
|
||||||
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
|
|
||||||
scale_d = get_scale_value(scale_factors, 0)
|
|
||||||
scale_h = get_scale_value(scale_factors, 1)
|
|
||||||
scale_w = get_scale_value(scale_factors, 2)
|
|
||||||
|
|
||||||
return aten.upsample_nearest3d.default(input, osize, scale_d, scale_h, scale_w)
|
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten._upsample_nearest_exact3d.vec)
|
|
||||||
@aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
|
@aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||||
@aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.Autograd)
|
@aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.Autograd)
|
||||||
def _upsample_nearest_exact3d_vec(input, output_size, scale_factors):
|
def _upsample_nearest_exact_vec(
|
||||||
|
input: Tensor,
|
||||||
|
output_size: Optional[List[int]],
|
||||||
|
scale_factors: Optional[List[float]],
|
||||||
|
) -> Tensor:
|
||||||
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
|
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
|
||||||
scale_d = get_scale_value(scale_factors, 0)
|
scales = (
|
||||||
scale_h = get_scale_value(scale_factors, 1)
|
scale_factors if scale_factors else [None] * len(osize) # type: ignore[list-item]
|
||||||
scale_w = get_scale_value(scale_factors, 2)
|
|
||||||
|
|
||||||
return aten._upsample_nearest_exact3d.default(
|
|
||||||
input, osize, scale_d, scale_h, scale_w
|
|
||||||
)
|
)
|
||||||
|
return _upsample_nearest(input, osize, scales, exact=True)
|
||||||
|
|
||||||
|
|
||||||
def _compute_upsample_nearest_indices(input, output_size, scales, exact=False):
|
def _compute_upsample_nearest_indices(input, output_size, scales, exact=False):
|
||||||
@ -2743,88 +2717,58 @@ def _compute_upsample_nearest_indices(input, output_size, scales, exact=False):
|
|||||||
for _ in range(num_spatial_dims - 1 - d):
|
for _ in range(num_spatial_dims - 1 - d):
|
||||||
input_indices = input_indices.unsqueeze(-1)
|
input_indices = input_indices.unsqueeze(-1)
|
||||||
indices.append(input_indices)
|
indices.append(input_indices)
|
||||||
return tuple(indices)
|
return indices
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten.upsample_nearest1d.default)
|
@register_decomposition(aten.upsample_nearest1d.default)
|
||||||
@aten.upsample_nearest1d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
|
@aten.upsample_nearest1d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||||
@aten.upsample_nearest1d.default.py_impl(DispatchKey.Autograd)
|
@aten.upsample_nearest1d.default.py_impl(DispatchKey.Autograd)
|
||||||
@pw_cast_for_opmath
|
|
||||||
def upsample_nearest1d(
|
def upsample_nearest1d(
|
||||||
input: Tensor,
|
input: Tensor,
|
||||||
output_size: List[int],
|
output_size: List[int],
|
||||||
scales: Optional[float] = None,
|
scales: Optional[float] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
(l_indices,) = _compute_upsample_nearest_indices(input, output_size, (scales,))
|
return _upsample_nearest(input, output_size, [scales])
|
||||||
return aten._unsafe_index(input, (None, None, l_indices))
|
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten._upsample_nearest_exact1d.default)
|
@register_decomposition(aten._upsample_nearest_exact1d.default)
|
||||||
@aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
|
@aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||||
@aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.Autograd)
|
@aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.Autograd)
|
||||||
@pw_cast_for_opmath
|
def upsample_nearest_exact1d(
|
||||||
def _upsample_nearest_exact1d(
|
|
||||||
input: Tensor,
|
input: Tensor,
|
||||||
output_size: List[int],
|
output_size: List[int],
|
||||||
scales: Optional[float] = None,
|
scales: Optional[float] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
(l_indices,) = _compute_upsample_nearest_indices(
|
return _upsample_nearest(input, output_size, [scales], exact=True)
|
||||||
input, output_size, (scales,), exact=True
|
|
||||||
)
|
|
||||||
return aten._unsafe_index(input, (None, None, l_indices))
|
|
||||||
|
|
||||||
|
|
||||||
def _upsample_nearest2d_common(input, h_indices, w_indices):
|
|
||||||
result = aten._unsafe_index(input, (None, None, h_indices, w_indices))
|
|
||||||
|
|
||||||
# convert output to correct memory format, if necessary
|
|
||||||
memory_format = utils.suggest_memory_format(input)
|
|
||||||
|
|
||||||
# following "heuristic: only use channels_last path when it's faster than the contiguous path"
|
|
||||||
_, n_channels, _, _ = input.shape
|
|
||||||
if input.device.type == "cuda" and n_channels < 4:
|
|
||||||
memory_format = torch.contiguous_format
|
|
||||||
|
|
||||||
result = result.contiguous(memory_format=memory_format)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten.upsample_nearest2d.default)
|
@register_decomposition(aten.upsample_nearest2d.default)
|
||||||
@aten.upsample_nearest2d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
|
@aten.upsample_nearest2d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||||
@aten.upsample_nearest2d.default.py_impl(DispatchKey.Autograd)
|
@aten.upsample_nearest2d.default.py_impl(DispatchKey.Autograd)
|
||||||
@pw_cast_for_opmath
|
|
||||||
def upsample_nearest2d(
|
def upsample_nearest2d(
|
||||||
input: Tensor,
|
input: Tensor,
|
||||||
output_size: List[int],
|
output_size: List[int],
|
||||||
scales_h: Optional[float] = None,
|
scales_h: Optional[float] = None,
|
||||||
scales_w: Optional[float] = None,
|
scales_w: Optional[float] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
h_indices, w_indices = _compute_upsample_nearest_indices(
|
return _upsample_nearest(input, output_size, [scales_h, scales_w])
|
||||||
input, output_size, (scales_h, scales_w)
|
|
||||||
)
|
|
||||||
return _upsample_nearest2d_common(input, h_indices, w_indices)
|
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten._upsample_nearest_exact2d.default)
|
@register_decomposition(aten._upsample_nearest_exact2d.default)
|
||||||
@aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
|
@aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||||
@aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.Autograd)
|
@aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.Autograd)
|
||||||
@pw_cast_for_opmath
|
|
||||||
def _upsample_nearest_exact2d(
|
def _upsample_nearest_exact2d(
|
||||||
input: Tensor,
|
input: Tensor,
|
||||||
output_size: List[int],
|
output_size: List[int],
|
||||||
scales_h: Optional[float] = None,
|
scales_h: Optional[float] = None,
|
||||||
scales_w: Optional[float] = None,
|
scales_w: Optional[float] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
h_indices, w_indices = _compute_upsample_nearest_indices(
|
return _upsample_nearest(input, output_size, [scales_h, scales_w], exact=True)
|
||||||
input, output_size, (scales_h, scales_w), exact=True
|
|
||||||
)
|
|
||||||
return _upsample_nearest2d_common(input, h_indices, w_indices)
|
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten.upsample_nearest3d.default)
|
@register_decomposition(aten.upsample_nearest3d.default)
|
||||||
@aten.upsample_nearest3d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
|
@aten.upsample_nearest3d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||||
@aten.upsample_nearest3d.default.py_impl(DispatchKey.Autograd)
|
@aten.upsample_nearest3d.default.py_impl(DispatchKey.Autograd)
|
||||||
@pw_cast_for_opmath
|
|
||||||
def upsample_nearest3d(
|
def upsample_nearest3d(
|
||||||
input: Tensor,
|
input: Tensor,
|
||||||
output_size: List[int],
|
output_size: List[int],
|
||||||
@ -2832,18 +2776,12 @@ def upsample_nearest3d(
|
|||||||
scales_h: Optional[float] = None,
|
scales_h: Optional[float] = None,
|
||||||
scales_w: Optional[float] = None,
|
scales_w: Optional[float] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
d_indices, h_indices, w_indices = _compute_upsample_nearest_indices(
|
return _upsample_nearest(input, output_size, [scales_d, scales_h, scales_w])
|
||||||
input, output_size, (scales_d, scales_h, scales_w)
|
|
||||||
)
|
|
||||||
result = aten._unsafe_index(input, (None, None, d_indices, h_indices, w_indices))
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten._upsample_nearest_exact3d.default)
|
@register_decomposition(aten._upsample_nearest_exact3d.default)
|
||||||
@aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
|
@aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||||
@aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.Autograd)
|
@aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.Autograd)
|
||||||
@pw_cast_for_opmath
|
|
||||||
def _upsample_nearest_exact3d(
|
def _upsample_nearest_exact3d(
|
||||||
input: Tensor,
|
input: Tensor,
|
||||||
output_size: List[int],
|
output_size: List[int],
|
||||||
@ -2851,11 +2789,35 @@ def _upsample_nearest_exact3d(
|
|||||||
scales_h: Optional[float] = None,
|
scales_h: Optional[float] = None,
|
||||||
scales_w: Optional[float] = None,
|
scales_w: Optional[float] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
d_indices, h_indices, w_indices = _compute_upsample_nearest_indices(
|
return _upsample_nearest(
|
||||||
input, output_size, (scales_d, scales_h, scales_w), exact=True
|
input, output_size, [scales_d, scales_h, scales_w], exact=True
|
||||||
)
|
)
|
||||||
result = aten._unsafe_index(input, (None, None, d_indices, h_indices, w_indices))
|
|
||||||
|
|
||||||
|
|
||||||
|
@pw_cast_for_opmath
|
||||||
|
def _upsample_nearest(
|
||||||
|
input: Tensor,
|
||||||
|
output_size: List[int],
|
||||||
|
scales: List[Optional[float]],
|
||||||
|
exact: bool = False,
|
||||||
|
) -> Tensor:
|
||||||
|
spatial_indices = _compute_upsample_nearest_indices(
|
||||||
|
input, output_size, scales, exact=exact
|
||||||
|
)
|
||||||
|
|
||||||
|
indices = [None, None] + spatial_indices
|
||||||
|
result = aten._unsafe_index(input, indices)
|
||||||
|
|
||||||
|
if result.ndim == 4:
|
||||||
|
# convert output to correct memory format, if necessary
|
||||||
|
memory_format = utils.suggest_memory_format(input)
|
||||||
|
|
||||||
|
# following "heuristic: only use channels_last path when it's faster than the contiguous path"
|
||||||
|
n_channels = input.shape[1]
|
||||||
|
if input.device.type == "cuda" and n_channels < 4:
|
||||||
|
memory_format = torch.contiguous_format
|
||||||
|
|
||||||
|
result = result.contiguous(memory_format=memory_format)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user