torch.transpose
function transpose<S extends Shape, D0 extends number, D1 extends number, D extends DType = DType, Dev extends DeviceType = DeviceType>(input: Tensor<S, D, Dev>, dim0: D0, dim1: D1): Tensor<TransposeDimsShapeChecked<S, D0, D1>, D, Dev>Swaps two dimensions of a tensor and returns a new tensor.
Exchanges the positions of two specified dimensions. All other dimensions remain unchanged. Essential for reshaping data to match different operation requirements. Useful for:
- Converting between data formats (batch-first vs channel-first in images)
- Preparing matrices for multiplication (transpose before matmul)
- Reorganizing tensor layout for broadcasting
- Aligning tensor dimensions for operations
- Swapping batch and sequence dimensions in NLP tasks
Note: swapaxes() is an alias for this function.
- Lightweight operation: Only swaps two dimensions, not a full rearrangement
- Rank preserved: Output has same rank as input
- Order preserved: Non-swapped dimensions stay in same positions
- Invertible: transpose(x, a, b) then transpose(x, a, b) returns original
- 2-D special case: For matrices, this is the standard matrix transpose
- Same dimensions: Using same value for dim0 and dim1 is a no-op
- Invalid dimensions: Both dims must be in valid range [0, rank)
- Not permute: This swaps exactly 2 dims; use permute() for arbitrary rearrangement
Parameters
inputTensor<S, D, Dev>- The input tensor (rank ≥ 2)
dim0D0- The first dimension to swap (0 to rank-1)
dim1D1- The second dimension to swap (0 to rank-1)
Returns
Tensor<TransposeDimsShapeChecked<S, D0, D1>, D, Dev>– Tensor with dimensions dim0 and dim1 swapped, same rank as inputExamples
// Transpose matrix
const x = torch.randn(2, 3);
torch.transpose(x, 0, 1).shape; // [3, 2]
// Swap dimensions in 3D tensor
const y = torch.randn(2, 3, 4);
torch.transpose(y, 0, 2).shape; // [4, 3, 2]
torch.transpose(y, 1, 2).shape; // [2, 4, 3]
// Convert batch-first to channel-first (images)
const images = torch.randn(32, 224, 224, 3); // [batch, height, width, channels]
const channels_first = torch.transpose(images, 1, 3); // After permute typically
// Prepare for matrix multiplication
const a = torch.randn(32, 64, 128); // [batch, seq_len, embedding]
const b = torch.randn(128, 256); // [embedding, output]
const a_t = torch.transpose(a, 1, 2); // [batch, embedding, seq_len]
// Now can matmul with proper shapes
// Double transpose returns to original
const original = torch.randn(5, 10);
const transposed = torch.transpose(original, 0, 1); // [10, 5]
const back = torch.transpose(transposed, 0, 1); // [5, 10]