torch.nn.functional.pixel_shuffle
Pixel Shuffle: rearranges channels into spatial dimensions for super-resolution upsampling.
Reorganizes a tensor by converting channel dimension into spatial dimensions, effectively upsampling the image while reducing channels. Converts shape (N, Cr², H, W) to (N, C, Hr, W*r) where r is the upscale factor. Used in super-resolution networks as an efficient upsampling method that preserves all information (no pooling loss). Opposite of pixel_unshuffle. Essential for:
- Super-resolution upsampling (efficient alternative to transposed convolution)
- Image upsampling in generative models (GANs, VAEs)
- Sub-pixel convolution networks (Real-ESRGAN, ESPCN)
- Efficient spatial upsampling (no learnable parameters)
- Information-preserving upsampling (all data retained in channels)
- Feature map reorganization in encoder-decoder networks
Operation detail: Rearranges channels into height/width dimensions. For upscale_factor r: Input has r²×more channels; output has r×more spatial dimensions. All information preserved (total elements unchanged).
When to use Pixel Shuffle:
- Super-resolution networks (efficient upsampling)
- Generative models needing to upsample feature maps
- When you want learnable upsampling (use convolution before pixel shuffle)
- Memory-efficient upsampling (no pooling loss)
- Training very deep networks where efficiency matters
Comparison with alternatives:
- Transposed Conv: Learnable but slower; pixel shuffle is deterministic
- Bilinear/Nearest: Interpolation-based; pixel shuffle reorganizes channels
- Deconvolution: General upsampling; pixel shuffle specific to channel rearrangement
- Upsampling + Conv: Two steps; pixel shuffle does reorganization in one op
- Information preserving: All data retained, just reorganized (no pooling)
- Efficient upsampling: Deterministic operation, much faster than learned methods
- Channel requirement: Input channels must be divisible by (upscale_factor)²
- Inverse of unshuffle: Exactly undoes pixel_unshuffle with same factor
- Deterministic: No learned parameters, purely geometric reorganization
- Used in ESPCN/ESRGAN: Standard technique in modern super-resolution
- Efficient for deep networks: Preserves all information while changing resolution
- Channel divisibility: Input must have C*r² channels exactly
- Not learnable: Deterministic operation; learning happens in surrounding convolutions
- Spatial upsampling only: For temporal/3D data, adaptation needed
- Information density: Output is smaller spatially but higher channel density
Parameters
inputTensor- Input tensor of shape (N, C*r², H, W) or (C*r², H, W) where: - N: batch size (optional) - C: desired output channels (C*r² total input channels) - H, W: height and width - r: upscale_factor
upscale_factornumber- Upscaling factor r (must be positive integer) Input channels must be divisible by (upscale_factor²)
Returns
Tensor– Tensor of shape (N, C, H*r, W*r)Examples
// Basic pixel shuffle: r=2 upscaling
const x = torch.randn(1, 12, 16, 16); // [N=1, C=12 (3*4), H=16, W=16]
const shuffled = torch.nn.functional.pixel_shuffle(x, 2);
// Output shape: [1, 3, 32, 32] (upscaled 2x, channels reduced 4x)
// 3 channels in 32x32 instead of 12 channels in 16x16
// Super-resolution pipeline: upscale via channels then learn features
const lr_image = torch.randn(batch, 3, 64, 64); // Low-res: 64x64
const expanded = conv_layer(lr_image); // Expand channels: [batch, 12, 64, 64]
const upsampled = torch.nn.functional.pixel_shuffle(expanded, 2); // [batch, 3, 128, 128]
const refined = refine_layer(upsampled); // Refine high-res image
// Efficient upsampling: no transposed convolution needed
// Sub-pixel CNN (ESPCN): channels to spatial via convolutions
let x = lr_input; // [batch, 3, H, W]
for (let i = 0; i < num_layers; i++) {
x = torch.nn.functional.relu(conv_layers[i](x));
}
x = final_conv(x); // Output [batch, 3*r², H, W]
const sr_output = torch.nn.functional.pixel_shuffle(x, r); // [batch, 3, H*r, W*r]
// Efficient SR without transposed convolutions
// GAN generator: progressive upsampling
let x = latent_code;
x = linear(x).reshape([batch, channels, h, w]);
for (let i = 0; i < num_upsamples; i++) {
x = conv_block(x); // [batch, c*4, h, w]
x = torch.nn.functional.pixel_shuffle(x, 2); // [batch, c, h*2, w*2]
}
// Progressive upsampling from low-res to high-res
// Inverse of pixel_unshuffle: round-trip identity
const original = torch.randn(1, 3, 8, 8); // [1, 3, 8, 8]
const unshuffled = torch.nn.functional.pixel_unshuffle(original, 2); // [1, 12, 4, 4]
const reshuffled = torch.nn.functional.pixel_shuffle(unshuffled, 2); // [1, 3, 8, 8]
// reshuffled ≈ original (information preserved)See Also
- PyTorch torch.nn.functional.pixel_shuffle
- torch.nn.functional.pixel_unshuffle - Inverse operation (spatial to channels)
- torch.nn.ConvTranspose2d - Learnable upsampling alternative
- torch.nn.Upsample - General upsampling (interpolation-based)
- torch.nn.functional.interpolate - Flexible upsampling with various methods