Add PixelUnshuffle (#49334)

Summary:
Adds an implementation of `torch.nn.PixelUnshuffle` as the inverse operation of `torch.nn.PixelShuffle`. This addresses https://github.com/pytorch/pytorch/issues/2456

Pull Request resolved: https://github.com/pytorch/pytorch/pull/49334

Test Plan:
```
# Unit tests.
python test/test_nn.py TestNN.test_pixel_shuffle_unshuffle

# Module test.
python test/test_nn.py TestNN.test_PixelUnshuffle

# C++ API tests.
build/bin/test_api

# C++ / python parity tests.
python test/test_cpp_api_parity.py

# JIT test.
python test/test_jit.py TestJitGeneratedFunctional.test_nn_pixel_unshuffle

# Override tests.
python test/test_overrides.py

# Type hint tests.
python test/test_type_hints.py
```

Screenshots of rendered docs:
<img width="876" alt="Screen Shot 2020-12-18 at 12 19 05 PM" src="https://user-images.githubusercontent.com/75754324/102642255-6b07bb00-412b-11eb-88fa-e53e7e8ba720.png">
<img width="984" alt="Screen Shot 2020-12-18 at 12 19 26 PM" src="https://user-images.githubusercontent.com/75754324/102642276-70fd9c00-412b-11eb-8548-445082a2db02.png">
<img width="932" alt="Screen Shot 2020-12-18 at 12 19 34 PM" src="https://user-images.githubusercontent.com/75754324/102642704-19abfb80-412c-11eb-9546-95bdd1c3cf22.png">
<img width="876" alt="Screen Shot 2020-12-22 at 12 51 36 PM" src="https://user-images.githubusercontent.com/75754324/102918259-986aa680-4454-11eb-99e7-a0b4c8b3e283.png">
<img width="869" alt="Screen Shot 2020-12-22 at 12 51 44 PM" src="https://user-images.githubusercontent.com/75754324/102918274-9ef91e00-4454-11eb-94bb-91b58aff47d3.png">

Reviewed By: mruberry

Differential Revision: D25401439

Pulled By: jbschlosser

fbshipit-source-id: 209d92ce7295e51699e83616d0c62170a7ce75c8
This commit is contained in:
Joel Schlosser
2020-12-22 20:12:40 -08:00
committed by Facebook GitHub Bot
parent 461aafe389
commit 68d438c9da
20 changed files with 371 additions and 56 deletions

View File

@ -553,6 +553,7 @@ _(aten, permute) \
_(aten, pin_memory) \
_(aten, pinverse) \
_(aten, pixel_shuffle) \
_(aten, pixel_unshuffle) \
_(aten, poisson) \
_(aten, polygamma) \
_(aten, pow) \

View File

@ -14,12 +14,16 @@ Tensor pixel_shuffle(const Tensor& self, int64_t upscale_factor) {
TORCH_CHECK(self.dim() >= 3,
"pixel_shuffle expects input to have at least 3 dimensions, but got input with ",
self.dim(), " dimension(s)");
TORCH_CHECK(
upscale_factor > 0,
"pixel_shuffle expects a positive upscale_factor, but got ",
upscale_factor);
// Format: (B1, ..., Bn), C, H, W
int64_t c = self.size(-3);
int64_t h = self.size(-2);
int64_t w = self.size(-1);
const auto NUM_NON_BATCH_DIMS = 3;
const auto last_batch_dim = self.sizes().end() - NUM_NON_BATCH_DIMS;
const auto self_sizes_batch_end = self.sizes().end() - NUM_NON_BATCH_DIMS;
int64_t upscale_factor_squared = upscale_factor * upscale_factor;
TORCH_CHECK(c % upscale_factor_squared == 0,
@ -29,24 +33,81 @@ Tensor pixel_shuffle(const Tensor& self, int64_t upscale_factor) {
int64_t oh = h * upscale_factor;
int64_t ow = w * upscale_factor;
// First, reshape to expand the channels dim from c into 3 separate dims: (oc, upscale_factor, upscale_factor).
// This allows shuffling to be done next by permuting dims.
std::vector<int64_t> expanded_shape(self.sizes().begin(), last_batch_dim);
expanded_shape.insert(expanded_shape.end(), {oc, upscale_factor, upscale_factor, h, w});
const auto input_expanded = self.reshape(expanded_shape);
// First, reshape to split the channels dim from c into 3 separate dims: (oc,
// upscale_factor, upscale_factor). This allows shuffling to be done next by
// permuting dims.
std::vector<int64_t> added_dims_shape(
self.sizes().begin(), self_sizes_batch_end);
added_dims_shape.insert(
added_dims_shape.end(), {oc, upscale_factor, upscale_factor, h, w});
const auto input_reshaped = self.reshape(added_dims_shape);
// Next, shuffle by permuting the new upscale_factor dims alongside the height and width dims.
std::vector<int64_t> permutation(self.sizes().begin(), last_batch_dim);
std::vector<int64_t> permutation(self.sizes().begin(), self_sizes_batch_end);
// std::iota is used to maintain the batch dims within the permutation.
// Since expansion added 2 dims, the correct batch dim offsets are now: -expanded_shape.size(), ..., -7, -6.
std::iota(permutation.begin(), permutation.end(), -expanded_shape.size());
// Since 2 dims were added, the correct batch dim offsets are now:
// -added_dims_shape.size(), ..., -7, -6.
std::iota(permutation.begin(), permutation.end(), -added_dims_shape.size());
permutation.insert(permutation.end(), {-5 /* oc */, -2 /* h */, -4 /* 1st upscale_factor */, -1 /* w */,
-3 /* 2nd upscale_factor */});
const auto input_permuted = input_expanded.permute(permutation);
const auto input_permuted = input_reshaped.permute(permutation);
// Finally, upscale by collapsing (h, upscale_factor) -> a single dim (oh)
// and (w, upscale_factor) -> a single dim (ow).
std::vector<int64_t> final_shape(self.sizes().begin(), last_batch_dim);
std::vector<int64_t> final_shape(self.sizes().begin(), self_sizes_batch_end);
final_shape.insert(final_shape.end(), {oc, oh, ow});
return input_permuted.reshape(final_shape);
}
Tensor pixel_unshuffle(const Tensor& self, int64_t downscale_factor) {
TORCH_CHECK(self.dim() >= 3,
"pixel_unshuffle expects input to have at least 3 dimensions, but got input with ",
self.dim(), " dimension(s)");
TORCH_CHECK(
downscale_factor > 0,
"pixel_unshuffle expects a positive downscale_factor, but got ",
downscale_factor);
// Format: (B1, ..., Bn), C, H, W
int64_t c = self.size(-3);
int64_t h = self.size(-2);
int64_t w = self.size(-1);
constexpr auto NUM_NON_BATCH_DIMS = 3;
const auto self_sizes_batch_end = self.sizes().end() - NUM_NON_BATCH_DIMS;
TORCH_CHECK(h % downscale_factor == 0,
"pixel_unshuffle expects height to be divisible by downscale_factor, but input.size(-2)=", h,
" is not divisible by ", downscale_factor)
TORCH_CHECK(w % downscale_factor == 0,
"pixel_unshuffle expects width to be divisible by downscale_factor, but input.size(-1)=", w,
" is not divisible by ", downscale_factor)
int64_t downscale_factor_squared = downscale_factor * downscale_factor;
int64_t oc = c * downscale_factor_squared;
int64_t oh = h / downscale_factor;
int64_t ow = w / downscale_factor;
// First, reshape to split height dim into (oh, downscale_factor) dims and
// width dim into (ow, downscale_factor) dims. This allows unshuffling to be
// done next by permuting dims.
std::vector<int64_t> added_dims_shape(
self.sizes().begin(), self_sizes_batch_end);
added_dims_shape.insert(
added_dims_shape.end(), {c, oh, downscale_factor, ow, downscale_factor});
const auto input_reshaped = self.reshape(added_dims_shape);
// Next, unshuffle by permuting the downscale_factor dims alongside the channel dim.
std::vector<int64_t> permutation(self.sizes().begin(), self_sizes_batch_end);
// std::iota is used to maintain the batch dims within the permutation.
// Since 2 dims were added, the correct batch dim offsets are now:
// -added_dims_shape.size(), ..., -7, -6.
std::iota(permutation.begin(), permutation.end(), -added_dims_shape.size());
permutation.insert(permutation.end(), {-5 /* c */, -3 /* 1st downscale_factor */, -1 /*2nd downscale_factor */,
-4 /* oh */, -2 /* ow */});
const auto input_permuted = input_reshaped.permute(permutation);
// Finally, downscale by collapsing (c, downscale_factor, downscale_factor) -> a single dim (oc),
// resulting in height=oh and width=ow.
std::vector<int64_t> final_shape(self.sizes().begin(), self_sizes_batch_end);
final_shape.insert(final_shape.end(), {oc, oh, ow});
return input_permuted.reshape(final_shape);
}

View File

@ -3342,6 +3342,9 @@
- func: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor
use_c10_dispatcher: full
- func: pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor
use_c10_dispatcher: full
- func: channel_shuffle(Tensor self, int groups) -> Tensor
use_c10_dispatcher: full
dispatch:

View File

@ -496,6 +496,11 @@ Vision functions
.. autofunction:: pixel_shuffle
:hidden:`pixel_unshuffle`
~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: pixel_unshuffle
:hidden:`pad`
~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -299,6 +299,7 @@ Vision Layers
:template: classtemplate.rst
nn.PixelShuffle
nn.PixelUnshuffle
nn.Upsample
nn.UpsamplingNearest2d
nn.UpsamplingBilinear2d

View File

@ -1487,6 +1487,23 @@ TEST_F(FunctionalTest, PixelShuffle) {
ASSERT_TRUE(y.allclose(y_exp));
}
TEST_F(FunctionalTest, PixelUnshuffle) {
auto x = torch::tensor(
{{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}},
torch::kFloat);
auto y_exp = torch::tensor(
{{{{-17, 19}, {-1, 2}},
{{7, 14}, {-3, 1}},
{{0, -2}, {-12, 14}},
{{-15, 0}, {-3, 9}}}},
torch::kFloat);
auto y = F::pixel_unshuffle(x, 2);
ASSERT_EQ(y.ndimension(), 4);
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 4, 2, 2}));
ASSERT_TRUE(y.allclose(y_exp));
}
TEST_F(FunctionalTest, Softplus) {
const auto size = 3;
for (const auto beta : {0.5, 1.0, 2.0}) {

View File

@ -2761,6 +2761,24 @@ TEST_F(ModulesTest, PixelShuffle) {
ASSERT_TRUE(y.allclose(y_exp));
}
TEST_F(ModulesTest, PixelUnshuffle) {
PixelUnshuffle module(/*downscale_factor=*/2);
auto x = torch::tensor(
{{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}},
torch::kFloat);
auto y_exp = torch::tensor(
{{{{-17, 19}, {-1, 2}},
{{7, 14}, {-3, 1}},
{{0, -2}, {-12, 14}},
{{-15, 0}, {-3, 9}}}},
torch::kFloat);
auto y = module(x);
ASSERT_EQ(y.ndimension(), 4);
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 4, 2, 2}));
ASSERT_TRUE(y.allclose(y_exp));
}
TEST_F(ModulesTest, Softplus) {
const auto size = 3;
for (const auto beta : {0.5, 1.0, 2.0}) {
@ -4764,6 +4782,12 @@ TEST_F(ModulesTest, PrettyPrintPixelShuffle) {
"torch::nn::PixelShuffle(upscale_factor=5)");
}
TEST_F(ModulesTest, PrettyPrintPixelUnshuffle) {
ASSERT_EQ(
c10::str(PixelUnshuffle(PixelUnshuffleOptions(5))),
"torch::nn::PixelUnshuffle(downscale_factor=5)");
}
TEST_F(ModulesTest, PrettyPrintSoftplus) {
ASSERT_EQ(c10::str(Softplus()),
"torch::nn::Softplus(beta=1, threshold=20)");

View File

@ -125,6 +125,7 @@ torch::nn::CosineEmbeddingLoss|Yes|No
torch::nn::MultiMarginLoss|Yes|No
torch::nn::TripletMarginLoss|Yes|No
torch::nn::PixelShuffle|Yes|No
torch::nn::PixelUnshuffle|Yes|No
torch::nn::Upsample|Yes|No
torch::nn::DataParallel|No|No
torch::nn::parallel::DistributedDataParallel|No|No

View File

@ -6897,8 +6897,9 @@ class TestNN(NNTestCase):
output.backward(grad.contiguous())
self.assertEqual(result, input.grad.data, atol=dtype2prec_DONTUSE[dtype], rtol=0)
def test_pixel_shuffle(self):
def _test_pixel_shuffle_helper(num_input_dims, valid_channels_dim=True):
def test_pixel_shuffle_unshuffle(self):
def _test_pixel_shuffle_unshuffle_helper(num_input_dims, valid_channels_dim=True,
upscale_factor=None):
# Function to imperatively ensure pixels are shuffled to the correct locations.
# Used to validate the batch operations in pixel_shuffle.
def _verify_pixel_shuffle(input, output, upscale_factor):
@ -6911,7 +6912,7 @@ class TestNN(NNTestCase):
(c * upscale_factor ** 2)
self.assertEqual(output[..., c, h, w], input[..., channel_idx, height_idx, weight_idx])
upscale_factor = random.randint(2, 5)
upscale_factor = random.randint(2, 5) if upscale_factor is None else upscale_factor
# If valid_channels_dim=False, add 1 to make channels dim indivisible by upscale_factor ** 2.
channels = random.randint(1, 4) * upscale_factor ** 2 + (0 if valid_channels_dim else 1)
height = random.randint(5, 10)
@ -6925,47 +6926,76 @@ class TestNN(NNTestCase):
batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True)
ps = nn.PixelShuffle(upscale_factor)
pus = nn.PixelUnshuffle(downscale_factor=upscale_factor)
if num_input_dims >= 3 and valid_channels_dim:
if num_input_dims >= 3 and valid_channels_dim and upscale_factor > 0:
output = ps(input)
_verify_pixel_shuffle(input, output, upscale_factor)
output.backward(output.data)
self.assertEqual(input.data, input.grad.data)
# Ensure unshuffle properly inverts shuffle.
unshuffle_output = pus(output)
self.assertEqual(input, unshuffle_output)
else:
self.assertRaises(RuntimeError, lambda: ps(input))
def test_pixel_shuffle_1D():
_test_pixel_shuffle_helper(num_input_dims=1)
def _test_pixel_unshuffle_error_case_helper(num_input_dims, valid_height_dim=True, valid_width_dim=True,
downscale_factor=None):
downscale_factor = random.randint(2, 5) if downscale_factor is None else downscale_factor
channels = random.randint(1, 4)
# If valid_height_dim=False, add 1 to make height dim indivisible by downscale_factor.
height = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_height_dim else 1)
# If valid_width_dim=False, add 1 to make width dim indivisible by downscale_factor.
width = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_width_dim else 1)
def test_pixel_shuffle_2D():
_test_pixel_shuffle_helper(num_input_dims=2)
if num_input_dims == 1:
input = torch.rand(channels, requires_grad=True)
elif num_input_dims == 2:
input = torch.rand(height, width, requires_grad=True)
else:
batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True)
def test_pixel_shuffle_3D_with_valid_channels_dim():
_test_pixel_shuffle_helper(num_input_dims=3)
pus = nn.PixelUnshuffle(downscale_factor)
self.assertRaises(RuntimeError, lambda: pus(input))
def test_pixel_shuffle_4D_with_valid_channels_dim():
_test_pixel_shuffle_helper(num_input_dims=4)
def _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims):
# For 1D - 2D, this is an error case.
# For 3D - 5D, this is a success case for pixel_shuffle + pixel_unshuffle.
_test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims)
def test_pixel_shuffle_5D_with_valid_channels_dim():
_test_pixel_shuffle_helper(num_input_dims=5)
# Error cases for pixel_shuffle.
_test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, valid_channels_dim=False)
_test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, upscale_factor=0)
_test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, upscale_factor=-2)
def test_pixel_shuffle_3D_with_invalid_channels_dim():
_test_pixel_shuffle_helper(num_input_dims=3, valid_channels_dim=False)
# Error cases for pixel_unshuffle.
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_height_dim=False)
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_width_dim=False)
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0)
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2)
def test_pixel_shuffle_4D_with_invalid_channels_dim():
_test_pixel_shuffle_helper(num_input_dims=4, valid_channels_dim=False)
def test_pixel_shuffle_unshuffle_1D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1)
def test_pixel_shuffle_5D_with_invalid_channels_dim():
_test_pixel_shuffle_helper(num_input_dims=5, valid_channels_dim=False)
def test_pixel_shuffle_unshuffle_2D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=2)
test_pixel_shuffle_1D()
test_pixel_shuffle_2D()
test_pixel_shuffle_3D_with_valid_channels_dim()
test_pixel_shuffle_4D_with_valid_channels_dim()
test_pixel_shuffle_5D_with_valid_channels_dim()
test_pixel_shuffle_3D_with_invalid_channels_dim()
test_pixel_shuffle_4D_with_invalid_channels_dim()
test_pixel_shuffle_5D_with_invalid_channels_dim()
def test_pixel_shuffle_unshuffle_3D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=3)
def test_pixel_shuffle_unshuffle_4D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=4)
def test_pixel_shuffle_unshuffle_5D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5)
test_pixel_shuffle_unshuffle_1D()
test_pixel_shuffle_unshuffle_2D()
test_pixel_shuffle_unshuffle_3D()
test_pixel_shuffle_unshuffle_4D()
test_pixel_shuffle_unshuffle_5D()
def test_elu_inplace_view(self):
v = torch.tensor([1.0, -1.0, 1.0, -1.0], requires_grad=True)

View File

@ -210,6 +210,7 @@ def gen_nn_functional(out: str) -> None:
'celu_',
'rrelu_',
'pixel_shuffle',
'pixel_unshuffle',
'channel_shuffle',
'pdist',
'cosine_similarity',

View File

@ -16,6 +16,10 @@ inline Tensor pixel_shuffle(
upscale_factor
);
}
inline Tensor pixel_unshuffle(const Tensor& input, int64_t downscale_factor) {
return torch::pixel_unshuffle(input, downscale_factor);
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
@ -36,6 +40,12 @@ inline Tensor pixel_shuffle(
return detail::pixel_shuffle(input, options.upscale_factor());
}
inline Tensor pixel_unshuffle(
const Tensor& input,
const PixelUnshuffleFuncOptions& options) {
return detail::pixel_unshuffle(input, options.downscale_factor());
}
} // namespace functional
} // namespace nn
} // namespace torch

View File

@ -12,12 +12,13 @@ namespace nn {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PixelShuffle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)`
/// to a tensor of shape :math:`(*, C, H \times r, W \times r)`.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.PixelShuffle to learn
/// about the exact behavior of this module.
/// to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an
/// upscale factor. See
/// https://pytorch.org/docs/master/nn.html#torch.nn.PixelShuffle to learn about
/// the exact behavior of this module.
///
/// See the documentation for `torch::nn::PixelShuffleOptions` class to learn what
/// constructor arguments are supported for this module.
/// See the documentation for `torch::nn::PixelShuffleOptions` class to learn
/// what constructor arguments are supported for this module.
///
/// Example:
/// ```
@ -44,5 +45,42 @@ struct TORCH_API PixelShuffleImpl : public torch::nn::Cloneable<PixelShuffleImpl
/// module storage semantics.
TORCH_MODULE(PixelShuffle);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PixelUnshuffle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Reverses the PixelShuffle operation by rearranging elements in a tensor of
/// shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape :math:`(*,
/// C \times r^2, H, W)`, where r is a downscale factor. See
/// https://pytorch.org/docs/master/nn.html#torch.nn.PixelUnshuffle to learn
/// about the exact behavior of this module.
///
/// See the documentation for `torch::nn::PixelUnshuffleOptions` class to learn
/// what constructor arguments are supported for this module.
///
/// Example:
/// ```
/// PixelUnshuffle model(PixelUnshuffleOptions(5));
/// ```
struct TORCH_API PixelUnshuffleImpl
: public torch::nn::Cloneable<PixelUnshuffleImpl> {
explicit PixelUnshuffleImpl(const PixelUnshuffleOptions& options_);
/// Pretty prints the `PixelUnshuffle` module into the given `stream`.
void pretty_print(std::ostream& stream) const override;
Tensor forward(const Tensor& input);
void reset() override;
/// The options with which this `Module` was constructed.
PixelUnshuffleOptions options;
};
/// A `ModuleHolder` subclass for `PixelUnshuffleImpl`.
/// See the documentation for `PixelUnshuffleImpl` class to learn what methods
/// it provides, and examples of how to use `PixelUnshuffle` with
/// `torch::nn::PixelUnshuffleOptions`. See the documentation for `ModuleHolder`
/// to learn about PyTorch's module storage semantics.
TORCH_MODULE(PixelUnshuffle);
} // namespace nn
} // namespace torch

View File

@ -21,6 +21,20 @@ struct TORCH_API PixelShuffleOptions {
TORCH_ARG(int64_t, upscale_factor);
};
/// Options for the `PixelUnshuffle` module.
///
/// Example:
/// ```
/// PixelUnshuffle model(PixelUnshuffleOptions(5));
/// ```
struct TORCH_API PixelUnshuffleOptions {
/* implicit */ PixelUnshuffleOptions(int64_t downscale_factor)
: downscale_factor_(downscale_factor) {}
/// Factor to decrease spatial resolution by
TORCH_ARG(int64_t, downscale_factor);
};
namespace functional {
/// Options for `torch::nn::functional::pixel_shuffle`.
///
@ -33,6 +47,18 @@ namespace functional {
/// F::pixel_shuffle(x, F::PixelShuffleFuncOptions(2));
/// ```
using PixelShuffleFuncOptions = PixelShuffleOptions;
/// Options for `torch::nn::functional::pixel_unshuffle`.
///
/// See the documentation for `torch::nn::PixelUnshuffleOptions` class to learn
/// what arguments are supported.
///
/// Example:
/// ```
/// namespace F = torch::nn::functional;
/// F::pixel_unshuffle(x, F::PixelUnshuffleFuncOptions(2));
/// ```
using PixelUnshuffleFuncOptions = PixelUnshuffleOptions;
} // namespace functional
} // namespace nn

View File

@ -21,5 +21,19 @@ Tensor PixelShuffleImpl::forward(
return F::detail::pixel_shuffle(input, options.upscale_factor());
}
PixelUnshuffleImpl::PixelUnshuffleImpl(const PixelUnshuffleOptions& options_)
: options(options_) {}
void PixelUnshuffleImpl::pretty_print(std::ostream& stream) const {
stream << "torch::nn::PixelUnshuffle(downscale_factor="
<< options.downscale_factor() << ")";
}
void PixelUnshuffleImpl::reset() {}
Tensor PixelUnshuffleImpl::forward(const Tensor& input) {
return F::detail::pixel_unshuffle(input, options.downscale_factor());
}
} // namespace nn
} // namespace torch

View File

@ -2799,7 +2799,7 @@ pixel_shuffle = _add_docstr(torch.pixel_shuffle, r"""
pixel_shuffle(input, upscale_factor) -> Tensor
Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` to a
tensor of shape :math:`(*, C, H \times r, W \times r)`.
tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is the :attr:`upscale_factor`.
See :class:`~torch.nn.PixelShuffle` for details.
@ -2815,6 +2815,27 @@ Examples::
torch.Size([1, 1, 12, 12])
""")
pixel_unshuffle = _add_docstr(torch.pixel_unshuffle, r"""
pixel_unshuffle(input, downscale_factor) -> Tensor
Reverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements in a
tensor of shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape
:math:`(*, C \times r^2, H, W)`, where r is the :attr:`downscale_factor`.
See :class:`~torch.nn.PixelUnshuffle` for details.
Args:
input (Tensor): the input tensor
downscale_factor (int): factor to increase spatial resolution by
Examples::
>>> input = torch.randn(1, 1, 12, 12)
>>> output = torch.nn.functional.pixel_unshuffle(input, 3)
>>> print(output.size())
torch.Size([1, 9, 4, 4])
""")
channel_shuffle = _add_docstr(torch.channel_shuffle, r"""
channel_shuffle(input, groups) -> Tensor

View File

@ -24,7 +24,7 @@ from .padding import ReflectionPad1d, ReflectionPad2d, ReplicationPad1d, Replica
from .sparse import Embedding, EmbeddingBag
from .rnn import RNNBase, RNN, LSTM, GRU, \
RNNCellBase, RNNCell, LSTMCell, GRUCell
from .pixelshuffle import PixelShuffle
from .pixelshuffle import PixelShuffle, PixelUnshuffle
from .upsampling import UpsamplingNearest2d, UpsamplingBilinear2d, Upsample
from .distance import PairwiseDistance, CosineSimilarity
from .fold import Fold, Unfold
@ -50,7 +50,7 @@ __all__ = [
'Dropout', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout',
'ReflectionPad1d', 'ReflectionPad2d', 'ReplicationPad2d', 'ReplicationPad1d', 'ReplicationPad3d',
'CrossMapLRN2d', 'Embedding', 'EmbeddingBag', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell',
'LSTMCell', 'GRUCell', 'PixelShuffle', 'Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d',
'LSTMCell', 'GRUCell', 'PixelShuffle', 'PixelUnshuffle', 'Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d',
'PairwiseDistance', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d',
'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d', 'TripletMarginLoss', 'ZeroPad2d', 'ConstantPad1d', 'ConstantPad2d',
'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold',

View File

@ -6,26 +6,30 @@ from torch import Tensor
class PixelShuffle(Module):
r"""Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)`
to a tensor of shape :math:`(*, C, H \times r, W \times r)`.
to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an upscale factor.
This is useful for implementing efficient sub-pixel convolution
with a stride of :math:`1/r`.
Look at the paper:
See the paper:
`Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_
by Shi et. al (2016) for more details.
Note that this function can take inputs with any number of batch dimensions:
:math:`(L, H_{in}, W_{in})`, :math:`(N, L, H_{in}, W_{in})`, :math:`(N_1, N_2, L, H_{in}, W_{in})`, etc.
Args:
upscale_factor (int): factor to increase spatial resolution by
Shape:
- Input: :math:`(*, L, H_{in}, W_{in})` where :math:`L=C \times \text{upscale\_factor}^2`
- Output: :math:`(*, C, H_{out}, W_{out})` where
:math:`H_{out} = H_{in} \times \text{upscale\_factor}`
and :math:`W_{out} = W_{in} \times \text{upscale\_factor}`
- Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions
- Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where
.. math::
C_{out} = C_{in} \div \text{upscale\_factor}^2
.. math::
H_{out} = H_{in} \times \text{upscale\_factor}
.. math::
W_{out} = W_{in} \times \text{upscale\_factor}
Examples::
@ -50,3 +54,53 @@ class PixelShuffle(Module):
def extra_repr(self) -> str:
return 'upscale_factor={}'.format(self.upscale_factor)
class PixelUnshuffle(Module):
r"""Reverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements
in a tensor of shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape
:math:`(*, C \times r^2, H, W)`, where r is a downscale factor.
See the paper:
`Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_
by Shi et. al (2016) for more details.
Args:
downscale_factor (int): factor to decrease spatial resolution by
Shape:
- Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions
- Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where
.. math::
C_{out} = C_{in} \times \text{downscale\_factor}^2
.. math::
H_{out} = H_{in} \div \text{downscale\_factor}
.. math::
W_{out} = W_{in} \div \text{downscale\_factor}
Examples::
>>> pixel_unshuffle = nn.PixelUnshuffle(3)
>>> input = torch.randn(1, 1, 12, 12)
>>> output = pixel_unshuffle(input)
>>> print(output.size())
torch.Size([1, 9, 4, 4])
.. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network:
https://arxiv.org/abs/1609.05158
"""
__constants__ = ['downscale_factor']
downscale_factor: int
def __init__(self, downscale_factor: int) -> None:
super(PixelUnshuffle, self).__init__()
self.downscale_factor = downscale_factor
def forward(self, input: Tensor) -> Tensor:
return F.pixel_unshuffle(input, self.downscale_factor)
def extra_repr(self) -> str:
return 'downscale_factor={}'.format(self.downscale_factor)

View File

@ -706,6 +706,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.pdist: lambda input, p=2: -1,
torch.pinverse: lambda input, rcond=1e-15: -1,
torch.pixel_shuffle: lambda input, upscale_factor: -1,
torch.pixel_unshuffle: lambda input, downscale_factor: -1,
torch.poisson: lambda input, generator=None: -1,
torch.poisson_nll_loss: lambda input, target, log_input, full, eps, reduction: -1,
torch.polygamma: lambda input, n, out=None: -1,

View File

@ -2516,6 +2516,12 @@ new_module_tests = [
cpp_constructor_args='torch::nn::PixelShuffleOptions(3)',
input_size=(1, 9, 4, 4),
),
dict(
module_name='PixelUnshuffle',
constructor_args=(3,),
cpp_constructor_args='torch::nn::PixelUnshuffleOptions(3)',
input_size=(1, 1, 12, 12),
),
dict(
constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
cpp_options_args='''F::InterpolateFuncOptions()

View File

@ -140,6 +140,7 @@ nn_functional_tests = [
('multilabel_soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
('cosine_embedding_loss', (S, S), ((S, S), non_differentiable(torch.rand(S,))),),
('pixel_shuffle', (1, 9, 4, 4), (3,),),
('pixel_unshuffle', (1, 1, 12, 12), (3,),),
('affine_grid', (S, 2, 3), (torch.Size([S, 1, 7, 7]),),),
('pad', (3, 3, 4, 2), ([1, 1],),),
('pairwise_distance', (S, S), ((S, S),),),