torch.nn.functional.channel_shuffle
Channel Shuffle: rearranges channels by dividing into groups and shuffling order.
Reorganizes channels by splitting into groups and reordering them for better feature mixing. For input with C channels divided into G groups of C/G channels each, reshapes to (N, G, C/G, H, W) then permutes to (N, C/G, G, H, W) and reshapes back to (N, C, H, W). Efficient operation that increases feature diversity without learnable parameters. Essential for:
- ShuffleNet architectures (efficient mobile networks)
- Improving feature mixing between layers without computation cost
- Channel-wise cross-group information flow
- Group convolution layers (after grouped conv, shuffle channels between groups)
- Lightweight neural networks with limited computation budgets
- Mobile and embedded deployment scenarios
- Efficient feature extraction with minimal memory footprint
How Channel Shuffle works:
- Reshape (N, C, H, W) → (N, G, C/G, H, W) dividing channels into G groups
- Permute to (N, C/G, G, H, W) to interleave groups
- Reshape back to (N, C, H, W) with shuffled channel order Result: channels from different groups are now mixed in the output
When to use Channel Shuffle:
- After grouped convolutions (to mix features between groups)
- In efficient architectures (ShuffleNet, MobileNet variants)
- When you want feature mixing without computation cost
- Mobile/embedded models where efficiency is critical
- Between blocks to increase feature diversity
- Replacing expensive layers in ultra-lightweight networks
Comparison with alternatives:
- Pointwise convolution: Learnable feature mixing; shuffle is deterministic
- Depthwise separable conv: More computation; shuffle is cheaper feature mixing
- 1x1 convolution: Learning overhead; shuffle has zero learnable parameters
- No shuffle: Groups are isolated; shuffle enables cross-group information
- Deterministic operation: No randomness; same input always produces same output
- No learnable parameters: Pure rearrangement, zero computational overhead
- Efficient: O(1) operation (just permutation, no data copying in theory)
- Reversible: Can shuffle back if group order is known
- Channel divisibility: Groups must divide channels evenly (C % groups == 0)
- Feature mixing: Allows features from different groups to interact later
- ShuffleNet standard: Key operation in ShuffleNet architecture for efficiency
- Groups must divide channels: C % groups must equal 0
- Minimum 3D input: Requires at least (C, H, W) dimensions; batch is optional
- Channel ordering changes: Channels are reordered; don't rely on original order
- Limited effect with groups=1: No shuffling occurs (single group is all channels)
- Memory layout: May not be cache-optimal after shuffle (but usually negligible)
- Numerical precision: Shuffle itself doesn't affect values, but downstream ops might
Parameters
inputTensor- Input tensor of shape [N, C, H, W] or [N, C, D, H, W] - N: batch size - C: number of channels (must be divisible by groups) - H, W (,D): spatial dimensions (height, width, depth for 3D)
groupsnumber- Number of groups to divide channels into (default: typical use is 2 or 8) - C must be divisible by groups - Each group has C/groups channels - groups=1 is no-op (no shuffling); groups=C is per-channel shuffle
Returns
Tensor– Tensor of same shape as input with channels shuffled Internal channel order changed, but total information preservedExamples
// Basic channel shuffle: 2 groups
const x = torch.randn(8, 16, 28, 28); // [N=8, C=16, H=28, W=28]
const shuffled = torch.nn.functional.channel_shuffle(x, 2);
// Output: [8, 16, 28, 28] with channels shuffled between 2 groups of 8
// Original: [0-7, 8-15] → Shuffled: channels interleaved from both groups// ShuffleNet block: group conv → shuffle → group conv
let x = torch.randn(batch, channels, height, width);
// Grouped convolution (e.g., groups=8, reduces computation)
const grouped_conv = new torch.nn.Conv2d(channels, out_channels, 1, 1, 0, 8);
x = grouped_conv.forward(x); // [batch, out_channels, height, width]
// Channel shuffle to mix features between groups
x = torch.nn.functional.channel_shuffle(x, 8);
// Next grouped convolution operates on mixed features
const grouped_conv2 = new torch.nn.Conv2d(out_channels, channels2, 1, 1, 0, 8);
x = grouped_conv2.forward(x);
// Result: ShuffleNet block pattern enables feature mixing with efficiency// Mobile network: efficient feature extraction
class EfficientBlock extends torch.nn.Module {
private pw_conv1: torch.nn.Conv2d; // 1x1 pointwise
private gconv: torch.nn.Conv2d; // 3x3 grouped
private pw_conv2: torch.nn.Conv2d; // 1x1 pointwise
private groups: number = 8; // Group size for efficiency
forward(x: Tensor): Tensor {
x = torch.nn.functional.relu(this.pw_conv1.forward(x));
x = torch.nn.functional.relu(this.gconv.forward(x));
x = torch.nn.functional.channel_shuffle(x, this.groups); // Mix between groups
x = this.pw_conv2.forward(x);
return x;
}
}
// Efficient architecture using group convolutions + shuffle// Comparison: shuffle effect on channel ordering
const x = torch.arange(16).reshape(1, 16, 1, 1).to('float32');
// Channels: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
const shuffled = torch.nn.functional.channel_shuffle(x, 2);
// Groups: [0-7] and [8-15]
// After shuffle: channels interleaved → [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15]
// Or similar interleaving pattern depending on implementation// 3D convolution with channel shuffle for video/volumetric data
const video_features = torch.randn(batch, channels, depth, height, width);
const shuffled_3d = torch.nn.functional.channel_shuffle(video_features, 4);
// Works with any spatial dimensions (2D, 3D, or higher)See Also
- PyTorch torch.nn.functional.channel_shuffle
- torch.nn.functional.grouped_mm - Grouped operations for efficiency
- torch.nn.Conv2d - With groups parameter for grouped convolution
- torch.nn.functional.permute - Generic permutation for any tensor
- torch.nn.ShuffleNet - Module using channel shuffle