torch.Tensor.Tensor.transpose
Tensor.transpose<D0 extends number, D1 extends number>(dim0: D0, dim1: D1): Tensor<TransposeDimsShapeChecked<S, D0, D1>, D, Dev>Swaps two dimensions of the tensor, creating a new tensor with reordered axes.
Transposes (swaps) the specified pair of dimensions while keeping all other dimensions unchanged.
This is a general-purpose transpose for any two dimensions of an N-dimensional tensor, unlike
t() which only works on 2D tensors. Commonly used for:
- Moving batch dimensions to different positions
- Converting between channel-first and channel-last tensor layouts
- Preparing tensors for specific operations (batch matrix multiplication, etc.)
- Implementing attention mechanisms with specific tensor layouts
- Handling time-series data (moving time/sequence dimension)
- Efficient: Returns a view-like result without copying data (when possible on GPU). The underlying memory is not rearranged, just how axes are interpreted.
- Supports negative indexing: Use negative indices to count from the end (-1 is last dimension).
- Commutative:
transpose(0, 1)andtranspose(1, 0)produce identical results (swapping is symmetric). - Gradient flow: Fully differentiable. Gradients automatically flow back through the transpose.
- Transposing same dimension with itself is a no-op (returns a copy).
- Dimension indices must be valid. Out-of-range dimensions throw an error.
Parameters
dim0D0- First dimension to swap (supports negative indexing, e.g., -1 for last dimension)
dim1D1- Second dimension to swap (supports negative indexing)
Returns
Tensor<TransposeDimsShapeChecked<S, D0, D1>, D, Dev>– New tensor with dims dim0 and dim1 swappedExamples
// Basic 3D transpose - swap dimensions 0 and 2
const x = torch.zeros(2, 3, 4);
x.transpose(0, 2); // Shape: [4, 3, 2]
// Channel format conversion: NHWC to NCHW
const image_nhwc = torch.randn(batch, height, width, channels); // [32, 224, 224, 3]
const image_nchw = image_nhwc.transpose(1, 3); // [32, 3, 224, 224]
// Or more clearly: permute(0, 3, 1, 2) moves channels to position 1
// Preparing for batch matrix multiplication
const matrices = torch.randn(batch_size, m, n);
const transposed = matrices.transpose(-2, -1); // [batch_size, n, m]
const result = torch.matmul(transposed, matrices); // [batch_size, n, m] @ [batch_size, m, n]
// Moving sequence dimension in transformers
const sequence = torch.randn(batch, seq_len, d_model); // [32, 100, 768]
const swapped = sequence.transpose(0, 1); // [100, 32, 768] - sequence is batch dim
// Negative indexing (same operations as above, more readable)
const matrices = torch.randn(2, 3, 4);
matrices.transpose(-2, -1); // Same as transpose(1, 2)See Also
- PyTorch tensor.transpose(dim0, dim1)
- permute - More general: reorder all dimensions in any order
- t - Shorthand for 2D transpose (faster for matrices)
- transpose_ - In-place version
- swapaxes - Alias for transpose with different parameter order