torch.Tensor.Tensor.reshape
Tensor.reshape<A extends number>(d1: A): Tensor<readonly [A], D, Dev>Tensor.reshape(shape: readonly number[]): Tensor<DynamicShape, D, Dev>Reshapes tensor to a new shape without changing its data.
Returns a new tensor with the same total number of elements but a different shape. Does not copy data - just creates a new view of the same buffer (like NumPy reshape). Essential for adapting tensor shapes between layers, flattening batches, etc.
Common use cases:
- Flattening for fully connected layers: (batch, height, width, channels) -> (batch, -1)
- Unflattening back to spatial: (batch, features) -> (batch, height, width)
- Adapting shapes for broadcasting or operations
- Batch/unbatch operations
- Converting between different representation formats
- View semantics: No data copy, just changes how data is indexed. For GPU tensors, this is extremely fast (O(1) operation).
- -1 inference: Exactly one dimension can be -1, which will be inferred. Total elements must match: product of known dims must divide total elements evenly.
- Gradient flow: Gradients flow correctly through reshape operations.
Parameters
d1A
Returns
Tensor<readonly [A], D, Dev>– Tensor with new shape, same dataExamples
// Basic reshape
const x = torch.arange(6); // Shape [6]
const y = x.reshape(2, 3); // Shape [2, 3]
const z = x.reshape([3, 2]); // Alternative array form
// Flatten for FC layer
const image_batch = torch.randn(32, 3, 224, 224); // 32 images
const flattened = image_batch.reshape(32, -1); // [32, 150528] (-1 auto-inferred)
// Unflatten back
const unflattened = flattened.reshape(32, 3, 224, 224);
// Collapse batch dimension
const batch = torch.randn(batch_size, seq_len, hidden_dim);
const collapsed = batch.reshape(-1, hidden_dim); // [batch_size * seq_len, hidden_dim]
const predictions = model(collapsed);
const expanded = predictions.reshape(batch_size, seq_len, -1);
// Add batch dimension
const single_sample = torch.randn(224, 224, 3);
const batched = single_sample.reshape(1, 224, 224, 3);See Also
- PyTorch tensor.reshape()
- view - Alias for reshape (PyTorch name)
- flatten - Reshape to 1D (convenience)
- squeeze - Remove size-1 dimensions
- unsqueeze - Add size-1 dimensions