mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add PixelUnshuffle (#49334)
Summary: Adds an implementation of `torch.nn.PixelUnshuffle` as the inverse operation of `torch.nn.PixelShuffle`. This addresses https://github.com/pytorch/pytorch/issues/2456 Pull Request resolved: https://github.com/pytorch/pytorch/pull/49334 Test Plan: ``` # Unit tests. python test/test_nn.py TestNN.test_pixel_shuffle_unshuffle # Module test. python test/test_nn.py TestNN.test_PixelUnshuffle # C++ API tests. build/bin/test_api # C++ / python parity tests. python test/test_cpp_api_parity.py # JIT test. python test/test_jit.py TestJitGeneratedFunctional.test_nn_pixel_unshuffle # Override tests. python test/test_overrides.py # Type hint tests. python test/test_type_hints.py ``` Screenshots of rendered docs: <img width="876" alt="Screen Shot 2020-12-18 at 12 19 05 PM" src="https://user-images.githubusercontent.com/75754324/102642255-6b07bb00-412b-11eb-88fa-e53e7e8ba720.png"> <img width="984" alt="Screen Shot 2020-12-18 at 12 19 26 PM" src="https://user-images.githubusercontent.com/75754324/102642276-70fd9c00-412b-11eb-8548-445082a2db02.png"> <img width="932" alt="Screen Shot 2020-12-18 at 12 19 34 PM" src="https://user-images.githubusercontent.com/75754324/102642704-19abfb80-412c-11eb-9546-95bdd1c3cf22.png"> <img width="876" alt="Screen Shot 2020-12-22 at 12 51 36 PM" src="https://user-images.githubusercontent.com/75754324/102918259-986aa680-4454-11eb-99e7-a0b4c8b3e283.png"> <img width="869" alt="Screen Shot 2020-12-22 at 12 51 44 PM" src="https://user-images.githubusercontent.com/75754324/102918274-9ef91e00-4454-11eb-94bb-91b58aff47d3.png"> Reviewed By: mruberry Differential Revision: D25401439 Pulled By: jbschlosser fbshipit-source-id: 209d92ce7295e51699e83616d0c62170a7ce75c8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
461aafe389
commit
68d438c9da
@ -553,6 +553,7 @@ _(aten, permute) \
|
||||
_(aten, pin_memory) \
|
||||
_(aten, pinverse) \
|
||||
_(aten, pixel_shuffle) \
|
||||
_(aten, pixel_unshuffle) \
|
||||
_(aten, poisson) \
|
||||
_(aten, polygamma) \
|
||||
_(aten, pow) \
|
||||
|
@ -14,12 +14,16 @@ Tensor pixel_shuffle(const Tensor& self, int64_t upscale_factor) {
|
||||
TORCH_CHECK(self.dim() >= 3,
|
||||
"pixel_shuffle expects input to have at least 3 dimensions, but got input with ",
|
||||
self.dim(), " dimension(s)");
|
||||
TORCH_CHECK(
|
||||
upscale_factor > 0,
|
||||
"pixel_shuffle expects a positive upscale_factor, but got ",
|
||||
upscale_factor);
|
||||
// Format: (B1, ..., Bn), C, H, W
|
||||
int64_t c = self.size(-3);
|
||||
int64_t h = self.size(-2);
|
||||
int64_t w = self.size(-1);
|
||||
const auto NUM_NON_BATCH_DIMS = 3;
|
||||
const auto last_batch_dim = self.sizes().end() - NUM_NON_BATCH_DIMS;
|
||||
const auto self_sizes_batch_end = self.sizes().end() - NUM_NON_BATCH_DIMS;
|
||||
|
||||
int64_t upscale_factor_squared = upscale_factor * upscale_factor;
|
||||
TORCH_CHECK(c % upscale_factor_squared == 0,
|
||||
@ -29,24 +33,81 @@ Tensor pixel_shuffle(const Tensor& self, int64_t upscale_factor) {
|
||||
int64_t oh = h * upscale_factor;
|
||||
int64_t ow = w * upscale_factor;
|
||||
|
||||
// First, reshape to expand the channels dim from c into 3 separate dims: (oc, upscale_factor, upscale_factor).
|
||||
// This allows shuffling to be done next by permuting dims.
|
||||
std::vector<int64_t> expanded_shape(self.sizes().begin(), last_batch_dim);
|
||||
expanded_shape.insert(expanded_shape.end(), {oc, upscale_factor, upscale_factor, h, w});
|
||||
const auto input_expanded = self.reshape(expanded_shape);
|
||||
// First, reshape to split the channels dim from c into 3 separate dims: (oc,
|
||||
// upscale_factor, upscale_factor). This allows shuffling to be done next by
|
||||
// permuting dims.
|
||||
std::vector<int64_t> added_dims_shape(
|
||||
self.sizes().begin(), self_sizes_batch_end);
|
||||
added_dims_shape.insert(
|
||||
added_dims_shape.end(), {oc, upscale_factor, upscale_factor, h, w});
|
||||
const auto input_reshaped = self.reshape(added_dims_shape);
|
||||
|
||||
// Next, shuffle by permuting the new upscale_factor dims alongside the height and width dims.
|
||||
std::vector<int64_t> permutation(self.sizes().begin(), last_batch_dim);
|
||||
std::vector<int64_t> permutation(self.sizes().begin(), self_sizes_batch_end);
|
||||
// std::iota is used to maintain the batch dims within the permutation.
|
||||
// Since expansion added 2 dims, the correct batch dim offsets are now: -expanded_shape.size(), ..., -7, -6.
|
||||
std::iota(permutation.begin(), permutation.end(), -expanded_shape.size());
|
||||
// Since 2 dims were added, the correct batch dim offsets are now:
|
||||
// -added_dims_shape.size(), ..., -7, -6.
|
||||
std::iota(permutation.begin(), permutation.end(), -added_dims_shape.size());
|
||||
permutation.insert(permutation.end(), {-5 /* oc */, -2 /* h */, -4 /* 1st upscale_factor */, -1 /* w */,
|
||||
-3 /* 2nd upscale_factor */});
|
||||
const auto input_permuted = input_expanded.permute(permutation);
|
||||
const auto input_permuted = input_reshaped.permute(permutation);
|
||||
|
||||
// Finally, upscale by collapsing (h, upscale_factor) -> a single dim (oh)
|
||||
// and (w, upscale_factor) -> a single dim (ow).
|
||||
std::vector<int64_t> final_shape(self.sizes().begin(), last_batch_dim);
|
||||
std::vector<int64_t> final_shape(self.sizes().begin(), self_sizes_batch_end);
|
||||
final_shape.insert(final_shape.end(), {oc, oh, ow});
|
||||
return input_permuted.reshape(final_shape);
|
||||
}
|
||||
|
||||
|
||||
Tensor pixel_unshuffle(const Tensor& self, int64_t downscale_factor) {
|
||||
TORCH_CHECK(self.dim() >= 3,
|
||||
"pixel_unshuffle expects input to have at least 3 dimensions, but got input with ",
|
||||
self.dim(), " dimension(s)");
|
||||
TORCH_CHECK(
|
||||
downscale_factor > 0,
|
||||
"pixel_unshuffle expects a positive downscale_factor, but got ",
|
||||
downscale_factor);
|
||||
// Format: (B1, ..., Bn), C, H, W
|
||||
int64_t c = self.size(-3);
|
||||
int64_t h = self.size(-2);
|
||||
int64_t w = self.size(-1);
|
||||
constexpr auto NUM_NON_BATCH_DIMS = 3;
|
||||
const auto self_sizes_batch_end = self.sizes().end() - NUM_NON_BATCH_DIMS;
|
||||
|
||||
TORCH_CHECK(h % downscale_factor == 0,
|
||||
"pixel_unshuffle expects height to be divisible by downscale_factor, but input.size(-2)=", h,
|
||||
" is not divisible by ", downscale_factor)
|
||||
TORCH_CHECK(w % downscale_factor == 0,
|
||||
"pixel_unshuffle expects width to be divisible by downscale_factor, but input.size(-1)=", w,
|
||||
" is not divisible by ", downscale_factor)
|
||||
int64_t downscale_factor_squared = downscale_factor * downscale_factor;
|
||||
int64_t oc = c * downscale_factor_squared;
|
||||
int64_t oh = h / downscale_factor;
|
||||
int64_t ow = w / downscale_factor;
|
||||
|
||||
// First, reshape to split height dim into (oh, downscale_factor) dims and
|
||||
// width dim into (ow, downscale_factor) dims. This allows unshuffling to be
|
||||
// done next by permuting dims.
|
||||
std::vector<int64_t> added_dims_shape(
|
||||
self.sizes().begin(), self_sizes_batch_end);
|
||||
added_dims_shape.insert(
|
||||
added_dims_shape.end(), {c, oh, downscale_factor, ow, downscale_factor});
|
||||
const auto input_reshaped = self.reshape(added_dims_shape);
|
||||
|
||||
// Next, unshuffle by permuting the downscale_factor dims alongside the channel dim.
|
||||
std::vector<int64_t> permutation(self.sizes().begin(), self_sizes_batch_end);
|
||||
// std::iota is used to maintain the batch dims within the permutation.
|
||||
// Since 2 dims were added, the correct batch dim offsets are now:
|
||||
// -added_dims_shape.size(), ..., -7, -6.
|
||||
std::iota(permutation.begin(), permutation.end(), -added_dims_shape.size());
|
||||
permutation.insert(permutation.end(), {-5 /* c */, -3 /* 1st downscale_factor */, -1 /*2nd downscale_factor */,
|
||||
-4 /* oh */, -2 /* ow */});
|
||||
const auto input_permuted = input_reshaped.permute(permutation);
|
||||
|
||||
// Finally, downscale by collapsing (c, downscale_factor, downscale_factor) -> a single dim (oc),
|
||||
// resulting in height=oh and width=ow.
|
||||
std::vector<int64_t> final_shape(self.sizes().begin(), self_sizes_batch_end);
|
||||
final_shape.insert(final_shape.end(), {oc, oh, ow});
|
||||
return input_permuted.reshape(final_shape);
|
||||
}
|
||||
|
@ -3342,6 +3342,9 @@
|
||||
- func: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
|
||||
- func: pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
|
||||
- func: channel_shuffle(Tensor self, int groups) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
dispatch:
|
||||
|
@ -496,6 +496,11 @@ Vision functions
|
||||
|
||||
.. autofunction:: pixel_shuffle
|
||||
|
||||
:hidden:`pixel_unshuffle`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: pixel_unshuffle
|
||||
|
||||
:hidden:`pad`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@ -299,6 +299,7 @@ Vision Layers
|
||||
:template: classtemplate.rst
|
||||
|
||||
nn.PixelShuffle
|
||||
nn.PixelUnshuffle
|
||||
nn.Upsample
|
||||
nn.UpsamplingNearest2d
|
||||
nn.UpsamplingBilinear2d
|
||||
|
@ -1487,6 +1487,23 @@ TEST_F(FunctionalTest, PixelShuffle) {
|
||||
ASSERT_TRUE(y.allclose(y_exp));
|
||||
}
|
||||
|
||||
TEST_F(FunctionalTest, PixelUnshuffle) {
|
||||
auto x = torch::tensor(
|
||||
{{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}},
|
||||
torch::kFloat);
|
||||
auto y_exp = torch::tensor(
|
||||
{{{{-17, 19}, {-1, 2}},
|
||||
{{7, 14}, {-3, 1}},
|
||||
{{0, -2}, {-12, 14}},
|
||||
{{-15, 0}, {-3, 9}}}},
|
||||
torch::kFloat);
|
||||
auto y = F::pixel_unshuffle(x, 2);
|
||||
|
||||
ASSERT_EQ(y.ndimension(), 4);
|
||||
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 4, 2, 2}));
|
||||
ASSERT_TRUE(y.allclose(y_exp));
|
||||
}
|
||||
|
||||
TEST_F(FunctionalTest, Softplus) {
|
||||
const auto size = 3;
|
||||
for (const auto beta : {0.5, 1.0, 2.0}) {
|
||||
|
@ -2761,6 +2761,24 @@ TEST_F(ModulesTest, PixelShuffle) {
|
||||
ASSERT_TRUE(y.allclose(y_exp));
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, PixelUnshuffle) {
|
||||
PixelUnshuffle module(/*downscale_factor=*/2);
|
||||
auto x = torch::tensor(
|
||||
{{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}},
|
||||
torch::kFloat);
|
||||
auto y_exp = torch::tensor(
|
||||
{{{{-17, 19}, {-1, 2}},
|
||||
{{7, 14}, {-3, 1}},
|
||||
{{0, -2}, {-12, 14}},
|
||||
{{-15, 0}, {-3, 9}}}},
|
||||
torch::kFloat);
|
||||
auto y = module(x);
|
||||
|
||||
ASSERT_EQ(y.ndimension(), 4);
|
||||
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 4, 2, 2}));
|
||||
ASSERT_TRUE(y.allclose(y_exp));
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, Softplus) {
|
||||
const auto size = 3;
|
||||
for (const auto beta : {0.5, 1.0, 2.0}) {
|
||||
@ -4764,6 +4782,12 @@ TEST_F(ModulesTest, PrettyPrintPixelShuffle) {
|
||||
"torch::nn::PixelShuffle(upscale_factor=5)");
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, PrettyPrintPixelUnshuffle) {
|
||||
ASSERT_EQ(
|
||||
c10::str(PixelUnshuffle(PixelUnshuffleOptions(5))),
|
||||
"torch::nn::PixelUnshuffle(downscale_factor=5)");
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, PrettyPrintSoftplus) {
|
||||
ASSERT_EQ(c10::str(Softplus()),
|
||||
"torch::nn::Softplus(beta=1, threshold=20)");
|
||||
|
@ -125,6 +125,7 @@ torch::nn::CosineEmbeddingLoss|Yes|No
|
||||
torch::nn::MultiMarginLoss|Yes|No
|
||||
torch::nn::TripletMarginLoss|Yes|No
|
||||
torch::nn::PixelShuffle|Yes|No
|
||||
torch::nn::PixelUnshuffle|Yes|No
|
||||
torch::nn::Upsample|Yes|No
|
||||
torch::nn::DataParallel|No|No
|
||||
torch::nn::parallel::DistributedDataParallel|No|No
|
||||
|
@ -6897,8 +6897,9 @@ class TestNN(NNTestCase):
|
||||
output.backward(grad.contiguous())
|
||||
self.assertEqual(result, input.grad.data, atol=dtype2prec_DONTUSE[dtype], rtol=0)
|
||||
|
||||
def test_pixel_shuffle(self):
|
||||
def _test_pixel_shuffle_helper(num_input_dims, valid_channels_dim=True):
|
||||
def test_pixel_shuffle_unshuffle(self):
|
||||
def _test_pixel_shuffle_unshuffle_helper(num_input_dims, valid_channels_dim=True,
|
||||
upscale_factor=None):
|
||||
# Function to imperatively ensure pixels are shuffled to the correct locations.
|
||||
# Used to validate the batch operations in pixel_shuffle.
|
||||
def _verify_pixel_shuffle(input, output, upscale_factor):
|
||||
@ -6911,7 +6912,7 @@ class TestNN(NNTestCase):
|
||||
(c * upscale_factor ** 2)
|
||||
self.assertEqual(output[..., c, h, w], input[..., channel_idx, height_idx, weight_idx])
|
||||
|
||||
upscale_factor = random.randint(2, 5)
|
||||
upscale_factor = random.randint(2, 5) if upscale_factor is None else upscale_factor
|
||||
# If valid_channels_dim=False, add 1 to make channels dim indivisible by upscale_factor ** 2.
|
||||
channels = random.randint(1, 4) * upscale_factor ** 2 + (0 if valid_channels_dim else 1)
|
||||
height = random.randint(5, 10)
|
||||
@ -6925,47 +6926,76 @@ class TestNN(NNTestCase):
|
||||
batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
|
||||
input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True)
|
||||
ps = nn.PixelShuffle(upscale_factor)
|
||||
pus = nn.PixelUnshuffle(downscale_factor=upscale_factor)
|
||||
|
||||
if num_input_dims >= 3 and valid_channels_dim:
|
||||
if num_input_dims >= 3 and valid_channels_dim and upscale_factor > 0:
|
||||
output = ps(input)
|
||||
_verify_pixel_shuffle(input, output, upscale_factor)
|
||||
output.backward(output.data)
|
||||
self.assertEqual(input.data, input.grad.data)
|
||||
|
||||
# Ensure unshuffle properly inverts shuffle.
|
||||
unshuffle_output = pus(output)
|
||||
self.assertEqual(input, unshuffle_output)
|
||||
else:
|
||||
self.assertRaises(RuntimeError, lambda: ps(input))
|
||||
|
||||
def test_pixel_shuffle_1D():
|
||||
_test_pixel_shuffle_helper(num_input_dims=1)
|
||||
def _test_pixel_unshuffle_error_case_helper(num_input_dims, valid_height_dim=True, valid_width_dim=True,
|
||||
downscale_factor=None):
|
||||
downscale_factor = random.randint(2, 5) if downscale_factor is None else downscale_factor
|
||||
channels = random.randint(1, 4)
|
||||
# If valid_height_dim=False, add 1 to make height dim indivisible by downscale_factor.
|
||||
height = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_height_dim else 1)
|
||||
# If valid_width_dim=False, add 1 to make width dim indivisible by downscale_factor.
|
||||
width = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_width_dim else 1)
|
||||
|
||||
def test_pixel_shuffle_2D():
|
||||
_test_pixel_shuffle_helper(num_input_dims=2)
|
||||
if num_input_dims == 1:
|
||||
input = torch.rand(channels, requires_grad=True)
|
||||
elif num_input_dims == 2:
|
||||
input = torch.rand(height, width, requires_grad=True)
|
||||
else:
|
||||
batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
|
||||
input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True)
|
||||
|
||||
def test_pixel_shuffle_3D_with_valid_channels_dim():
|
||||
_test_pixel_shuffle_helper(num_input_dims=3)
|
||||
pus = nn.PixelUnshuffle(downscale_factor)
|
||||
self.assertRaises(RuntimeError, lambda: pus(input))
|
||||
|
||||
def test_pixel_shuffle_4D_with_valid_channels_dim():
|
||||
_test_pixel_shuffle_helper(num_input_dims=4)
|
||||
def _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims):
|
||||
# For 1D - 2D, this is an error case.
|
||||
# For 3D - 5D, this is a success case for pixel_shuffle + pixel_unshuffle.
|
||||
_test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims)
|
||||
|
||||
def test_pixel_shuffle_5D_with_valid_channels_dim():
|
||||
_test_pixel_shuffle_helper(num_input_dims=5)
|
||||
# Error cases for pixel_shuffle.
|
||||
_test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, valid_channels_dim=False)
|
||||
_test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, upscale_factor=0)
|
||||
_test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, upscale_factor=-2)
|
||||
|
||||
def test_pixel_shuffle_3D_with_invalid_channels_dim():
|
||||
_test_pixel_shuffle_helper(num_input_dims=3, valid_channels_dim=False)
|
||||
# Error cases for pixel_unshuffle.
|
||||
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_height_dim=False)
|
||||
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_width_dim=False)
|
||||
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0)
|
||||
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2)
|
||||
|
||||
def test_pixel_shuffle_4D_with_invalid_channels_dim():
|
||||
_test_pixel_shuffle_helper(num_input_dims=4, valid_channels_dim=False)
|
||||
def test_pixel_shuffle_unshuffle_1D():
|
||||
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1)
|
||||
|
||||
def test_pixel_shuffle_5D_with_invalid_channels_dim():
|
||||
_test_pixel_shuffle_helper(num_input_dims=5, valid_channels_dim=False)
|
||||
def test_pixel_shuffle_unshuffle_2D():
|
||||
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=2)
|
||||
|
||||
test_pixel_shuffle_1D()
|
||||
test_pixel_shuffle_2D()
|
||||
test_pixel_shuffle_3D_with_valid_channels_dim()
|
||||
test_pixel_shuffle_4D_with_valid_channels_dim()
|
||||
test_pixel_shuffle_5D_with_valid_channels_dim()
|
||||
test_pixel_shuffle_3D_with_invalid_channels_dim()
|
||||
test_pixel_shuffle_4D_with_invalid_channels_dim()
|
||||
test_pixel_shuffle_5D_with_invalid_channels_dim()
|
||||
def test_pixel_shuffle_unshuffle_3D():
|
||||
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=3)
|
||||
|
||||
def test_pixel_shuffle_unshuffle_4D():
|
||||
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=4)
|
||||
|
||||
def test_pixel_shuffle_unshuffle_5D():
|
||||
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5)
|
||||
|
||||
test_pixel_shuffle_unshuffle_1D()
|
||||
test_pixel_shuffle_unshuffle_2D()
|
||||
test_pixel_shuffle_unshuffle_3D()
|
||||
test_pixel_shuffle_unshuffle_4D()
|
||||
test_pixel_shuffle_unshuffle_5D()
|
||||
|
||||
def test_elu_inplace_view(self):
|
||||
v = torch.tensor([1.0, -1.0, 1.0, -1.0], requires_grad=True)
|
||||
|
@ -210,6 +210,7 @@ def gen_nn_functional(out: str) -> None:
|
||||
'celu_',
|
||||
'rrelu_',
|
||||
'pixel_shuffle',
|
||||
'pixel_unshuffle',
|
||||
'channel_shuffle',
|
||||
'pdist',
|
||||
'cosine_similarity',
|
||||
|
@ -16,6 +16,10 @@ inline Tensor pixel_shuffle(
|
||||
upscale_factor
|
||||
);
|
||||
}
|
||||
|
||||
inline Tensor pixel_unshuffle(const Tensor& input, int64_t downscale_factor) {
|
||||
return torch::pixel_unshuffle(input, downscale_factor);
|
||||
}
|
||||
} // namespace detail
|
||||
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
|
||||
|
||||
@ -36,6 +40,12 @@ inline Tensor pixel_shuffle(
|
||||
return detail::pixel_shuffle(input, options.upscale_factor());
|
||||
}
|
||||
|
||||
inline Tensor pixel_unshuffle(
|
||||
const Tensor& input,
|
||||
const PixelUnshuffleFuncOptions& options) {
|
||||
return detail::pixel_unshuffle(input, options.downscale_factor());
|
||||
}
|
||||
|
||||
} // namespace functional
|
||||
} // namespace nn
|
||||
} // namespace torch
|
||||
|
@ -12,12 +12,13 @@ namespace nn {
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PixelShuffle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
/// Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)`
|
||||
/// to a tensor of shape :math:`(*, C, H \times r, W \times r)`.
|
||||
/// See https://pytorch.org/docs/master/nn.html#torch.nn.PixelShuffle to learn
|
||||
/// about the exact behavior of this module.
|
||||
/// to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an
|
||||
/// upscale factor. See
|
||||
/// https://pytorch.org/docs/master/nn.html#torch.nn.PixelShuffle to learn about
|
||||
/// the exact behavior of this module.
|
||||
///
|
||||
/// See the documentation for `torch::nn::PixelShuffleOptions` class to learn what
|
||||
/// constructor arguments are supported for this module.
|
||||
/// See the documentation for `torch::nn::PixelShuffleOptions` class to learn
|
||||
/// what constructor arguments are supported for this module.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
@ -44,5 +45,42 @@ struct TORCH_API PixelShuffleImpl : public torch::nn::Cloneable<PixelShuffleImpl
|
||||
/// module storage semantics.
|
||||
TORCH_MODULE(PixelShuffle);
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PixelUnshuffle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
/// Reverses the PixelShuffle operation by rearranging elements in a tensor of
|
||||
/// shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape :math:`(*,
|
||||
/// C \times r^2, H, W)`, where r is a downscale factor. See
|
||||
/// https://pytorch.org/docs/master/nn.html#torch.nn.PixelUnshuffle to learn
|
||||
/// about the exact behavior of this module.
|
||||
///
|
||||
/// See the documentation for `torch::nn::PixelUnshuffleOptions` class to learn
|
||||
/// what constructor arguments are supported for this module.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// PixelUnshuffle model(PixelUnshuffleOptions(5));
|
||||
/// ```
|
||||
struct TORCH_API PixelUnshuffleImpl
|
||||
: public torch::nn::Cloneable<PixelUnshuffleImpl> {
|
||||
explicit PixelUnshuffleImpl(const PixelUnshuffleOptions& options_);
|
||||
|
||||
/// Pretty prints the `PixelUnshuffle` module into the given `stream`.
|
||||
void pretty_print(std::ostream& stream) const override;
|
||||
|
||||
Tensor forward(const Tensor& input);
|
||||
|
||||
void reset() override;
|
||||
|
||||
/// The options with which this `Module` was constructed.
|
||||
PixelUnshuffleOptions options;
|
||||
};
|
||||
|
||||
/// A `ModuleHolder` subclass for `PixelUnshuffleImpl`.
|
||||
/// See the documentation for `PixelUnshuffleImpl` class to learn what methods
|
||||
/// it provides, and examples of how to use `PixelUnshuffle` with
|
||||
/// `torch::nn::PixelUnshuffleOptions`. See the documentation for `ModuleHolder`
|
||||
/// to learn about PyTorch's module storage semantics.
|
||||
TORCH_MODULE(PixelUnshuffle);
|
||||
|
||||
} // namespace nn
|
||||
} // namespace torch
|
||||
|
@ -21,6 +21,20 @@ struct TORCH_API PixelShuffleOptions {
|
||||
TORCH_ARG(int64_t, upscale_factor);
|
||||
};
|
||||
|
||||
/// Options for the `PixelUnshuffle` module.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// PixelUnshuffle model(PixelUnshuffleOptions(5));
|
||||
/// ```
|
||||
struct TORCH_API PixelUnshuffleOptions {
|
||||
/* implicit */ PixelUnshuffleOptions(int64_t downscale_factor)
|
||||
: downscale_factor_(downscale_factor) {}
|
||||
|
||||
/// Factor to decrease spatial resolution by
|
||||
TORCH_ARG(int64_t, downscale_factor);
|
||||
};
|
||||
|
||||
namespace functional {
|
||||
/// Options for `torch::nn::functional::pixel_shuffle`.
|
||||
///
|
||||
@ -33,6 +47,18 @@ namespace functional {
|
||||
/// F::pixel_shuffle(x, F::PixelShuffleFuncOptions(2));
|
||||
/// ```
|
||||
using PixelShuffleFuncOptions = PixelShuffleOptions;
|
||||
|
||||
/// Options for `torch::nn::functional::pixel_unshuffle`.
|
||||
///
|
||||
/// See the documentation for `torch::nn::PixelUnshuffleOptions` class to learn
|
||||
/// what arguments are supported.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// namespace F = torch::nn::functional;
|
||||
/// F::pixel_unshuffle(x, F::PixelUnshuffleFuncOptions(2));
|
||||
/// ```
|
||||
using PixelUnshuffleFuncOptions = PixelUnshuffleOptions;
|
||||
} // namespace functional
|
||||
|
||||
} // namespace nn
|
||||
|
@ -21,5 +21,19 @@ Tensor PixelShuffleImpl::forward(
|
||||
return F::detail::pixel_shuffle(input, options.upscale_factor());
|
||||
}
|
||||
|
||||
PixelUnshuffleImpl::PixelUnshuffleImpl(const PixelUnshuffleOptions& options_)
|
||||
: options(options_) {}
|
||||
|
||||
void PixelUnshuffleImpl::pretty_print(std::ostream& stream) const {
|
||||
stream << "torch::nn::PixelUnshuffle(downscale_factor="
|
||||
<< options.downscale_factor() << ")";
|
||||
}
|
||||
|
||||
void PixelUnshuffleImpl::reset() {}
|
||||
|
||||
Tensor PixelUnshuffleImpl::forward(const Tensor& input) {
|
||||
return F::detail::pixel_unshuffle(input, options.downscale_factor());
|
||||
}
|
||||
|
||||
} // namespace nn
|
||||
} // namespace torch
|
||||
|
@ -2799,7 +2799,7 @@ pixel_shuffle = _add_docstr(torch.pixel_shuffle, r"""
|
||||
pixel_shuffle(input, upscale_factor) -> Tensor
|
||||
|
||||
Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` to a
|
||||
tensor of shape :math:`(*, C, H \times r, W \times r)`.
|
||||
tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is the :attr:`upscale_factor`.
|
||||
|
||||
See :class:`~torch.nn.PixelShuffle` for details.
|
||||
|
||||
@ -2815,6 +2815,27 @@ Examples::
|
||||
torch.Size([1, 1, 12, 12])
|
||||
""")
|
||||
|
||||
pixel_unshuffle = _add_docstr(torch.pixel_unshuffle, r"""
|
||||
pixel_unshuffle(input, downscale_factor) -> Tensor
|
||||
|
||||
Reverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements in a
|
||||
tensor of shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape
|
||||
:math:`(*, C \times r^2, H, W)`, where r is the :attr:`downscale_factor`.
|
||||
|
||||
See :class:`~torch.nn.PixelUnshuffle` for details.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
downscale_factor (int): factor to increase spatial resolution by
|
||||
|
||||
Examples::
|
||||
|
||||
>>> input = torch.randn(1, 1, 12, 12)
|
||||
>>> output = torch.nn.functional.pixel_unshuffle(input, 3)
|
||||
>>> print(output.size())
|
||||
torch.Size([1, 9, 4, 4])
|
||||
""")
|
||||
|
||||
channel_shuffle = _add_docstr(torch.channel_shuffle, r"""
|
||||
channel_shuffle(input, groups) -> Tensor
|
||||
|
||||
|
@ -24,7 +24,7 @@ from .padding import ReflectionPad1d, ReflectionPad2d, ReplicationPad1d, Replica
|
||||
from .sparse import Embedding, EmbeddingBag
|
||||
from .rnn import RNNBase, RNN, LSTM, GRU, \
|
||||
RNNCellBase, RNNCell, LSTMCell, GRUCell
|
||||
from .pixelshuffle import PixelShuffle
|
||||
from .pixelshuffle import PixelShuffle, PixelUnshuffle
|
||||
from .upsampling import UpsamplingNearest2d, UpsamplingBilinear2d, Upsample
|
||||
from .distance import PairwiseDistance, CosineSimilarity
|
||||
from .fold import Fold, Unfold
|
||||
@ -50,7 +50,7 @@ __all__ = [
|
||||
'Dropout', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout',
|
||||
'ReflectionPad1d', 'ReflectionPad2d', 'ReplicationPad2d', 'ReplicationPad1d', 'ReplicationPad3d',
|
||||
'CrossMapLRN2d', 'Embedding', 'EmbeddingBag', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell',
|
||||
'LSTMCell', 'GRUCell', 'PixelShuffle', 'Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d',
|
||||
'LSTMCell', 'GRUCell', 'PixelShuffle', 'PixelUnshuffle', 'Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d',
|
||||
'PairwiseDistance', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d',
|
||||
'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d', 'TripletMarginLoss', 'ZeroPad2d', 'ConstantPad1d', 'ConstantPad2d',
|
||||
'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold',
|
||||
|
@ -6,26 +6,30 @@ from torch import Tensor
|
||||
|
||||
class PixelShuffle(Module):
|
||||
r"""Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)`
|
||||
to a tensor of shape :math:`(*, C, H \times r, W \times r)`.
|
||||
to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an upscale factor.
|
||||
|
||||
This is useful for implementing efficient sub-pixel convolution
|
||||
with a stride of :math:`1/r`.
|
||||
|
||||
Look at the paper:
|
||||
See the paper:
|
||||
`Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_
|
||||
by Shi et. al (2016) for more details.
|
||||
|
||||
Note that this function can take inputs with any number of batch dimensions:
|
||||
:math:`(L, H_{in}, W_{in})`, :math:`(N, L, H_{in}, W_{in})`, :math:`(N_1, N_2, L, H_{in}, W_{in})`, etc.
|
||||
|
||||
Args:
|
||||
upscale_factor (int): factor to increase spatial resolution by
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(*, L, H_{in}, W_{in})` where :math:`L=C \times \text{upscale\_factor}^2`
|
||||
- Output: :math:`(*, C, H_{out}, W_{out})` where
|
||||
:math:`H_{out} = H_{in} \times \text{upscale\_factor}`
|
||||
and :math:`W_{out} = W_{in} \times \text{upscale\_factor}`
|
||||
- Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions
|
||||
- Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where
|
||||
|
||||
.. math::
|
||||
C_{out} = C_{in} \div \text{upscale\_factor}^2
|
||||
|
||||
.. math::
|
||||
H_{out} = H_{in} \times \text{upscale\_factor}
|
||||
|
||||
.. math::
|
||||
W_{out} = W_{in} \times \text{upscale\_factor}
|
||||
|
||||
Examples::
|
||||
|
||||
@ -50,3 +54,53 @@ class PixelShuffle(Module):
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'upscale_factor={}'.format(self.upscale_factor)
|
||||
|
||||
|
||||
class PixelUnshuffle(Module):
|
||||
r"""Reverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements
|
||||
in a tensor of shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape
|
||||
:math:`(*, C \times r^2, H, W)`, where r is a downscale factor.
|
||||
|
||||
See the paper:
|
||||
`Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_
|
||||
by Shi et. al (2016) for more details.
|
||||
|
||||
Args:
|
||||
downscale_factor (int): factor to decrease spatial resolution by
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions
|
||||
- Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where
|
||||
|
||||
.. math::
|
||||
C_{out} = C_{in} \times \text{downscale\_factor}^2
|
||||
|
||||
.. math::
|
||||
H_{out} = H_{in} \div \text{downscale\_factor}
|
||||
|
||||
.. math::
|
||||
W_{out} = W_{in} \div \text{downscale\_factor}
|
||||
|
||||
Examples::
|
||||
|
||||
>>> pixel_unshuffle = nn.PixelUnshuffle(3)
|
||||
>>> input = torch.randn(1, 1, 12, 12)
|
||||
>>> output = pixel_unshuffle(input)
|
||||
>>> print(output.size())
|
||||
torch.Size([1, 9, 4, 4])
|
||||
|
||||
.. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network:
|
||||
https://arxiv.org/abs/1609.05158
|
||||
"""
|
||||
__constants__ = ['downscale_factor']
|
||||
downscale_factor: int
|
||||
|
||||
def __init__(self, downscale_factor: int) -> None:
|
||||
super(PixelUnshuffle, self).__init__()
|
||||
self.downscale_factor = downscale_factor
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return F.pixel_unshuffle(input, self.downscale_factor)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'downscale_factor={}'.format(self.downscale_factor)
|
||||
|
@ -706,6 +706,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
torch.pdist: lambda input, p=2: -1,
|
||||
torch.pinverse: lambda input, rcond=1e-15: -1,
|
||||
torch.pixel_shuffle: lambda input, upscale_factor: -1,
|
||||
torch.pixel_unshuffle: lambda input, downscale_factor: -1,
|
||||
torch.poisson: lambda input, generator=None: -1,
|
||||
torch.poisson_nll_loss: lambda input, target, log_input, full, eps, reduction: -1,
|
||||
torch.polygamma: lambda input, n, out=None: -1,
|
||||
|
@ -2516,6 +2516,12 @@ new_module_tests = [
|
||||
cpp_constructor_args='torch::nn::PixelShuffleOptions(3)',
|
||||
input_size=(1, 9, 4, 4),
|
||||
),
|
||||
dict(
|
||||
module_name='PixelUnshuffle',
|
||||
constructor_args=(3,),
|
||||
cpp_constructor_args='torch::nn::PixelUnshuffleOptions(3)',
|
||||
input_size=(1, 1, 12, 12),
|
||||
),
|
||||
dict(
|
||||
constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
|
||||
cpp_options_args='''F::InterpolateFuncOptions()
|
||||
|
@ -140,6 +140,7 @@ nn_functional_tests = [
|
||||
('multilabel_soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
||||
('cosine_embedding_loss', (S, S), ((S, S), non_differentiable(torch.rand(S,))),),
|
||||
('pixel_shuffle', (1, 9, 4, 4), (3,),),
|
||||
('pixel_unshuffle', (1, 1, 12, 12), (3,),),
|
||||
('affine_grid', (S, 2, 3), (torch.Size([S, 1, 7, 7]),),),
|
||||
('pad', (3, 3, 4, 2), ([1, 1],),),
|
||||
('pairwise_distance', (S, S), ((S, S),),),
|
||||
|
Reference in New Issue
Block a user