Refactored implementation for upsample_nearest decompostions (#122783)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122783
Approved by: https://github.com/peterbell10
This commit is contained in:
vfdev-5
2024-04-16 16:08:36 +00:00
committed by PyTorch MergeBot
parent bebdbb63ce
commit 6330acae76

View File

@ -2647,71 +2647,45 @@ def get_scale_value(scales, idx):
@register_decomposition(aten.upsample_nearest1d.vec) @register_decomposition(aten.upsample_nearest1d.vec)
@register_decomposition(aten.upsample_nearest2d.vec)
@register_decomposition(aten.upsample_nearest3d.vec)
@aten.upsample_nearest1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) @aten.upsample_nearest1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.upsample_nearest1d.vec.py_impl(DispatchKey.Autograd) @aten.upsample_nearest1d.vec.py_impl(DispatchKey.Autograd)
def upsample_nearest1d_vec(input, output_size, scale_factors): @aten.upsample_nearest2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.upsample_nearest2d.vec.py_impl(DispatchKey.Autograd)
@aten.upsample_nearest3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.upsample_nearest3d.vec.py_impl(DispatchKey.Autograd)
def _upsample_nearest_vec(
input: Tensor,
output_size: Optional[List[int]],
scale_factors: Optional[List[float]],
) -> Tensor:
osize = upsample_compute_output_size(input.size(), output_size, scale_factors) osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
scale = get_scale_value(scale_factors, 0) scales = (
scale_factors if scale_factors else [None] * len(osize) # type: ignore[list-item]
return aten.upsample_nearest1d.default(input, osize, scale) )
return _upsample_nearest(input, osize, scales)
@register_decomposition(aten._upsample_nearest_exact1d.vec) @register_decomposition(aten._upsample_nearest_exact1d.vec)
@register_decomposition(aten._upsample_nearest_exact2d.vec)
@register_decomposition(aten._upsample_nearest_exact3d.vec)
@aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) @aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.Autograd) @aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.Autograd)
def _upsample_nearest_exact1d_vec(input, output_size, scale_factors):
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
scale = get_scale_value(scale_factors, 0)
return aten._upsample_nearest_exact1d.default(input, osize, scale)
@register_decomposition(aten.upsample_nearest2d.vec)
@aten.upsample_nearest2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.upsample_nearest2d.vec.py_impl(DispatchKey.Autograd)
def upsample_nearest2d_vec(input, output_size, scale_factors):
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
scale_h = get_scale_value(scale_factors, 0)
scale_w = get_scale_value(scale_factors, 1)
return aten.upsample_nearest2d.default(input, osize, scale_h, scale_w)
@register_decomposition(aten._upsample_nearest_exact2d.vec)
@aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) @aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.Autograd) @aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.Autograd)
def _upsample_nearest_exact2d_vec(input, output_size, scale_factors):
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
scale_h = get_scale_value(scale_factors, 0)
scale_w = get_scale_value(scale_factors, 1)
return aten._upsample_nearest_exact2d.default(input, osize, scale_h, scale_w)
@register_decomposition(aten.upsample_nearest3d.vec)
@aten.upsample_nearest3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.upsample_nearest3d.vec.py_impl(DispatchKey.Autograd)
def upsample_nearest3d_vec(input, output_size, scale_factors):
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
scale_d = get_scale_value(scale_factors, 0)
scale_h = get_scale_value(scale_factors, 1)
scale_w = get_scale_value(scale_factors, 2)
return aten.upsample_nearest3d.default(input, osize, scale_d, scale_h, scale_w)
@register_decomposition(aten._upsample_nearest_exact3d.vec)
@aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) @aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.Autograd) @aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.Autograd)
def _upsample_nearest_exact3d_vec(input, output_size, scale_factors): def _upsample_nearest_exact_vec(
input: Tensor,
output_size: Optional[List[int]],
scale_factors: Optional[List[float]],
) -> Tensor:
osize = upsample_compute_output_size(input.size(), output_size, scale_factors) osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
scale_d = get_scale_value(scale_factors, 0) scales = (
scale_h = get_scale_value(scale_factors, 1) scale_factors if scale_factors else [None] * len(osize) # type: ignore[list-item]
scale_w = get_scale_value(scale_factors, 2)
return aten._upsample_nearest_exact3d.default(
input, osize, scale_d, scale_h, scale_w
) )
return _upsample_nearest(input, osize, scales, exact=True)
def _compute_upsample_nearest_indices(input, output_size, scales, exact=False): def _compute_upsample_nearest_indices(input, output_size, scales, exact=False):
@ -2743,88 +2717,58 @@ def _compute_upsample_nearest_indices(input, output_size, scales, exact=False):
for _ in range(num_spatial_dims - 1 - d): for _ in range(num_spatial_dims - 1 - d):
input_indices = input_indices.unsqueeze(-1) input_indices = input_indices.unsqueeze(-1)
indices.append(input_indices) indices.append(input_indices)
return tuple(indices) return indices
@register_decomposition(aten.upsample_nearest1d.default) @register_decomposition(aten.upsample_nearest1d.default)
@aten.upsample_nearest1d.default.py_impl(DispatchKey.CompositeImplicitAutograd) @aten.upsample_nearest1d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.upsample_nearest1d.default.py_impl(DispatchKey.Autograd) @aten.upsample_nearest1d.default.py_impl(DispatchKey.Autograd)
@pw_cast_for_opmath
def upsample_nearest1d( def upsample_nearest1d(
input: Tensor, input: Tensor,
output_size: List[int], output_size: List[int],
scales: Optional[float] = None, scales: Optional[float] = None,
) -> Tensor: ) -> Tensor:
(l_indices,) = _compute_upsample_nearest_indices(input, output_size, (scales,)) return _upsample_nearest(input, output_size, [scales])
return aten._unsafe_index(input, (None, None, l_indices))
@register_decomposition(aten._upsample_nearest_exact1d.default) @register_decomposition(aten._upsample_nearest_exact1d.default)
@aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.CompositeImplicitAutograd) @aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.Autograd) @aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.Autograd)
@pw_cast_for_opmath def upsample_nearest_exact1d(
def _upsample_nearest_exact1d(
input: Tensor, input: Tensor,
output_size: List[int], output_size: List[int],
scales: Optional[float] = None, scales: Optional[float] = None,
) -> Tensor: ) -> Tensor:
(l_indices,) = _compute_upsample_nearest_indices( return _upsample_nearest(input, output_size, [scales], exact=True)
input, output_size, (scales,), exact=True
)
return aten._unsafe_index(input, (None, None, l_indices))
def _upsample_nearest2d_common(input, h_indices, w_indices):
result = aten._unsafe_index(input, (None, None, h_indices, w_indices))
# convert output to correct memory format, if necessary
memory_format = utils.suggest_memory_format(input)
# following "heuristic: only use channels_last path when it's faster than the contiguous path"
_, n_channels, _, _ = input.shape
if input.device.type == "cuda" and n_channels < 4:
memory_format = torch.contiguous_format
result = result.contiguous(memory_format=memory_format)
return result
@register_decomposition(aten.upsample_nearest2d.default) @register_decomposition(aten.upsample_nearest2d.default)
@aten.upsample_nearest2d.default.py_impl(DispatchKey.CompositeImplicitAutograd) @aten.upsample_nearest2d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.upsample_nearest2d.default.py_impl(DispatchKey.Autograd) @aten.upsample_nearest2d.default.py_impl(DispatchKey.Autograd)
@pw_cast_for_opmath
def upsample_nearest2d( def upsample_nearest2d(
input: Tensor, input: Tensor,
output_size: List[int], output_size: List[int],
scales_h: Optional[float] = None, scales_h: Optional[float] = None,
scales_w: Optional[float] = None, scales_w: Optional[float] = None,
) -> Tensor: ) -> Tensor:
h_indices, w_indices = _compute_upsample_nearest_indices( return _upsample_nearest(input, output_size, [scales_h, scales_w])
input, output_size, (scales_h, scales_w)
)
return _upsample_nearest2d_common(input, h_indices, w_indices)
@register_decomposition(aten._upsample_nearest_exact2d.default) @register_decomposition(aten._upsample_nearest_exact2d.default)
@aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.CompositeImplicitAutograd) @aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.Autograd) @aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.Autograd)
@pw_cast_for_opmath
def _upsample_nearest_exact2d( def _upsample_nearest_exact2d(
input: Tensor, input: Tensor,
output_size: List[int], output_size: List[int],
scales_h: Optional[float] = None, scales_h: Optional[float] = None,
scales_w: Optional[float] = None, scales_w: Optional[float] = None,
) -> Tensor: ) -> Tensor:
h_indices, w_indices = _compute_upsample_nearest_indices( return _upsample_nearest(input, output_size, [scales_h, scales_w], exact=True)
input, output_size, (scales_h, scales_w), exact=True
)
return _upsample_nearest2d_common(input, h_indices, w_indices)
@register_decomposition(aten.upsample_nearest3d.default) @register_decomposition(aten.upsample_nearest3d.default)
@aten.upsample_nearest3d.default.py_impl(DispatchKey.CompositeImplicitAutograd) @aten.upsample_nearest3d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.upsample_nearest3d.default.py_impl(DispatchKey.Autograd) @aten.upsample_nearest3d.default.py_impl(DispatchKey.Autograd)
@pw_cast_for_opmath
def upsample_nearest3d( def upsample_nearest3d(
input: Tensor, input: Tensor,
output_size: List[int], output_size: List[int],
@ -2832,18 +2776,12 @@ def upsample_nearest3d(
scales_h: Optional[float] = None, scales_h: Optional[float] = None,
scales_w: Optional[float] = None, scales_w: Optional[float] = None,
) -> Tensor: ) -> Tensor:
d_indices, h_indices, w_indices = _compute_upsample_nearest_indices( return _upsample_nearest(input, output_size, [scales_d, scales_h, scales_w])
input, output_size, (scales_d, scales_h, scales_w)
)
result = aten._unsafe_index(input, (None, None, d_indices, h_indices, w_indices))
return result
@register_decomposition(aten._upsample_nearest_exact3d.default) @register_decomposition(aten._upsample_nearest_exact3d.default)
@aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.CompositeImplicitAutograd) @aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.Autograd) @aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.Autograd)
@pw_cast_for_opmath
def _upsample_nearest_exact3d( def _upsample_nearest_exact3d(
input: Tensor, input: Tensor,
output_size: List[int], output_size: List[int],
@ -2851,11 +2789,35 @@ def _upsample_nearest_exact3d(
scales_h: Optional[float] = None, scales_h: Optional[float] = None,
scales_w: Optional[float] = None, scales_w: Optional[float] = None,
) -> Tensor: ) -> Tensor:
d_indices, h_indices, w_indices = _compute_upsample_nearest_indices( return _upsample_nearest(
input, output_size, (scales_d, scales_h, scales_w), exact=True input, output_size, [scales_d, scales_h, scales_w], exact=True
) )
result = aten._unsafe_index(input, (None, None, d_indices, h_indices, w_indices))
@pw_cast_for_opmath
def _upsample_nearest(
input: Tensor,
output_size: List[int],
scales: List[Optional[float]],
exact: bool = False,
) -> Tensor:
spatial_indices = _compute_upsample_nearest_indices(
input, output_size, scales, exact=exact
)
indices = [None, None] + spatial_indices
result = aten._unsafe_index(input, indices)
if result.ndim == 4:
# convert output to correct memory format, if necessary
memory_format = utils.suggest_memory_format(input)
# following "heuristic: only use channels_last path when it's faster than the contiguous path"
n_channels = input.shape[1]
if input.device.type == "cuda" and n_channels < 4:
memory_format = torch.contiguous_format
result = result.contiguous(memory_format=memory_format)
return result return result