torch.nn.functional.pixel_unshuffle
Pixel Unshuffle: inverse of pixel shuffle, reorganizing channels into spatial dimensions.
Reverses the pixel shuffle operation by rearranging channels into spatial dimensions. Converts a tensor from shape (N, C, H, W) to (N, C*r², H/r, W/r) where r is the downscale factor. Opposite of pixel_shuffle: moves information from channels to spatial locations. Used in super-resolution and image generation as part of upsampling pipelines, in generative models before downsampling, or for reducing channel dimension before processing. Essential for:
- Super-resolution networks (feature extraction before upsampling)
- Generative models (reducing channels in discriminators)
- Multi-scale feature extraction (spatial to channel reorganization)
- Efficient feature encoding (trading channels for spatial resolution)
- Image-to-image translation networks
- Feature pyramid networks with channel reduction
Operation detail: Inverse of pixel_shuffle. If pixel_shuffle(H,W,Cr²) → (Hr, Wr, C), then pixel_unshuffle(Hr, Wr, C) → (H, W, Cr²). Useful in downsampling stages.
When to use Pixel Unshuffle:
- Super-resolution networks (extract spatial features to channels)
- Discriminator in GANs (downsampling with feature stacking)
- Reducing spatial resolution before feature processing
- Efficient pooling alternative (maintains information via channels)
- Feature extraction with spatial-to-channel reorganization
Comparison with alternatives:
- Max/Avg pooling: Loses information; unshuffle preserves all information
- Strided convolution: Learnable downsampling; unshuffle is deterministic reorganization
- Pixel shuffle: Opposite operation (channels to spatial); unshuffle is spatial to channels
- Reshape/permute: Manual approach; pixel_unshuffle is optimized convenience operation
- Channel requirement: Input channels must be divisible by (downscale_factor)²
- Spatial requirement: H and W must be divisible by downscale_factor
- Information preserving: All spatial information retained in channels (no loss)
- Inverse operation: Exactly undoes pixel_shuffle with same factor
- Deterministic: No learned parameters, purely geometric reorganization
- Efficient implementation: Optimized to avoid redundant data copies
- 4D input standard: Typically used with batched image data (N, C, H, W) format
- Channel divisibility: C must be divisible by (downscale_factor)²
- Spatial divisibility: H and W must be divisible by downscale_factor
- Not learnable: Deterministic operation; use stride or convolution for learned downsampling
- Information expands in channels: Output has more channels but less spatial resolution
Parameters
inputTensor- Input tensor of shape (N, C, H, W) where: - N: batch size - C: number of channels (should be divisible by downscale_factor²) - H, W: height and width (should be divisible by downscale_factor)
downscale_factornumber- Downscaling factor r (must be positive integer) Output spatial dimensions are divided by r; channels multiplied by r²
Returns
Tensor– Tensor of shape (N, C*r², H/r, W/r)Examples
// Basic pixel unshuffle: r=2 downscaling
const x = torch.randn(1, 3, 4, 4); // [N=1, C=3, H=4, W=4]
const unshuffled = torch.nn.functional.pixel_unshuffle(x, 2);
// Output shape: [1, 3*4, 4/2, 4/2] = [1, 12, 2, 2]
// Reorganizes 3-channel 4x4 image into 12-channel 2x2 image
// Super-resolution discriminator: downsampling with unshuffle
const hr_image = torch.randn(batch, 3, 256, 256); // High-res input
const downsampled = torch.nn.functional.pixel_unshuffle(hr_image, 2); // [batch, 12, 128, 128]
const features = conv_layer(downsampled); // Process downsampled with more channels
// Discriminator efficiently downsamples while preserving information in channels
// GAN discriminator progressive downsampling
let x = input_image; // [batch, 3, 256, 256]
x = torch.nn.functional.pixel_unshuffle(x, 2); // [batch, 12, 128, 128]
x = conv_block(x);
x = torch.nn.functional.pixel_unshuffle(x, 2); // [batch, 48, 64, 64]
x = conv_block(x);
// Progressive downsampling: spatial info moved to channels
// Inverse of pixel_shuffle: round-trip identity
const original = torch.randn(1, 4, 8, 8); // [1, 4 (must be 4*r²), 8, 8]
const shuffled = torch.nn.functional.pixel_shuffle(original, 2); // [1, 1, 16, 16]
const unshuffled = torch.nn.functional.pixel_unshuffle(shuffled, 2); // [1, 4, 8, 8]
// unshuffled ≈ original (up to floating point precision)
// Efficient downsampling for feature maps
const features = torch.randn(batch, 64, 32, 32); // Large spatial, small channels
const downsampled = torch.nn.functional.pixel_unshuffle(features, 4); // [batch, 1024, 8, 8]
// Converts low-channel large-spatial to high-channel small-spatialSee Also
- PyTorch torch.nn.functional.pixel_unshuffle
- torch.nn.functional.pixel_shuffle - Inverse operation (channels to spatial)
- torch.nn.MaxPool2d - Learnable downsampling alternative (with information loss)
- torch.nn.AvgPool2d - Averaged downsampling alternative
- torch.nn.Conv2d - Strided convolution for learned downsampling