torch.rearrange
function rearrange<S extends Shape, P extends string, A extends AxesRecord = Record<string, never>>(tensor: Tensor<S>, pattern: P, options?: RearrangeOptions<A>): Tensor<ValidatedRearrangeShape<P, S, A>>function rearrange<S extends Shape, P extends string, A extends AxesRecord = Record<string, never>>(tensor: Tensor<S>, pattern: P, axes: A, options?: RearrangeOptions<A>): Tensor<ValidatedRearrangeShape<P, S, A>>Rearranges tensor dimensions using einops-style pattern notation.
The rearrange operation provides a powerful, readable way to reshape, transpose, and decompose tensors using pattern strings. It's especially useful for:
- Reshaping tensors for attention mechanisms
- Converting between image formats (BCHW ↔ BHWC)
- Splitting and merging dimensions
- Making tensor manipulations self-documenting
- Pattern syntax: Use spaces to separate axes, parentheses to group axes for merging/splitting, and - to separate input from output pattern.
- Ellipsis: Use ... to match zero or more batch dimensions that should be preserved in the same position.
- Axis sizes: When decomposing an axis (e.g., '(h w)'), you must provide all but one size in the axes parameter; the last is inferred.
Parameters
tensorTensor<S>- Input tensor to rearrange
patternP- Einops pattern string in the form "input_pattern - output_pattern". Axes can be named (b c h w), grouped (b (h w)), or use ellipsis (...) for batch dims.
optionsRearrangeOptions<A>optional- Optional settings: -
axes: Same as the axes parameter, for options-style API
Returns
Tensor<ValidatedRearrangeShape<P, S, A>>– Tensor rearranged according to the patternExamples
// Transpose: swap height and width
const x = torch.randn(2, 3, 4); // (batch, height, width)
const y = torch.rearrange(x, 'b h w -> b w h'); // (2, 4, 3)
// Flatten spatial dimensions
const img = torch.randn(8, 3, 32, 32); // (batch, channels, H, W)
const flat = torch.rearrange(img, 'b c h w -> b (c h w)'); // (8, 3072)
// Split a dimension (must specify size)
const seq = torch.randn(4, 64); // (batch, features)
const heads = torch.rearrange(seq, 'b (h d) -> b h d', { h: 8 }); // (4, 8, 8)
// Attention: reshape for multi-head attention
const qkv = torch.randn(2, 100, 512); // (batch, seq_len, embed_dim)
const multihead = torch.rearrange(qkv, 'b n (h d) -> b h n d', { h: 8 });
// Shape: (2, 8, 100, 64) - 8 heads, 64 dim each
// Image format conversion
const bchw = torch.randn(4, 3, 224, 224);
const bhwc = torch.rearrange(bchw, 'b c h w -> b h w c'); // For display
// Patch embedding for Vision Transformer
const image = torch.randn(1, 3, 224, 224);
const patches = torch.rearrange(image, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', { p1: 16, p2: 16 });
// Shape: (1, 196, 768) - 196 patches of 768 features eachSee Also
- [PyTorch einops.rearrange (third-party library commonly used with PyTorch)](https://pytorch.org/docs/stable/generated/einops.rearrange .html)
- reshape - Simple reshape without pattern notation
- permute - Dimension permutation without reshape
- einsum - Einstein summation for contractions