torch.meshgrid
function meshgrid(...args: any[]): Tensor[]Creates coordinate grids from 1-D tensors.
Broadcasts 1-D input tensors to create N-D grids of all combinations of coordinates. Essential for:
- Coordinate generation: Creating 2D/3D grids for sampling
- Image processing: Creating pixel coordinate grids for warping/interpolation
- Function evaluation: Evaluating functions over a grid of points
- Spatial operations: Computing distances, convolutions over grid
- Data visualization: Creating mesh for plotting surfaces
- Attention patterns: Computing attention over spatial grids
Two indexing modes:
- 'ij' (default): index-style, first input varies first dimension
- 'xy': Cartesian-style, first input varies columns (standard for 2D plotting)
Implementation: Efficient broadcasting; no full grid construction in memory.
- 1D input requirement: All inputs must be exactly 1-D
- Broadcasting: Output has shape (n1, n2, ..., nk)
- Memory efficient: Uses broadcasting; no explicit grid storage
- 'xy' vs 'ij': xy is standard for 2D plotting; ij for matrix operations
- First-class tensors: Returned grids are full-fledged tensors
- Differentiable: Gradient flows through grid coordinates
- 1D only: Will error if inputs are not exactly 1-D
- Indexing confusion: Different conventions for 'xy' vs 'ij' - choose carefully
- Memory for 3D+: Large grids can consume significant memory
- Dtype promotion: All inputs should have compatible dtypes
Parameters
argsany[]
Returns
Tensor[]– List of grids with shape (n1, n2, ..., nk) where ni = shape of ith tensorExamples
// 2D Cartesian grid (standard plotting)
const x = torch.linspace(0, 1, 5); // [0.0, 0.25, 0.5, 0.75, 1.0]
const y = torch.linspace(0, 2, 3); // [0, 1, 2]
const [X, Y] = torch.meshgrid(x, y, 'xy'); // Cartesian indexing
// X shape: [3, 5] - x-coordinates (columns vary)
// Y shape: [3, 5] - y-coordinates (rows vary)
// Index-style grid (matrix indexing)
const [X_ij, Y_ij] = torch.meshgrid(x, y, 'ij'); // Index style
// X_ij shape: [5, 3] - x-coordinates (rows vary)
// Y_ij shape: [5, 3] - y-coordinates (columns vary)
// 3D grid for volume operations
const x = torch.linspace(-1, 1, 10);
const y = torch.linspace(-1, 1, 10);
const z = torch.linspace(-1, 1, 10);
const [X, Y, Z] = torch.meshgrid(x, y, z);
// Each shape: [10, 10, 10]
// Can evaluate functions like X.pow(2).add(Y.pow(2)).add(Z.pow(2)) ≤ 1 for sphere
// Function evaluation over grid
const x = torch.linspace(-2, 2, 100);
const y = torch.linspace(-2, 2, 100);
const [X, Y] = torch.meshgrid(x, y, 'xy');
const Z = X.pow(2).add(Y.pow(2)); // f(x,y) = x² + y²
// Z shape: [100, 100] - function values at each grid point
// Pixel coordinates for image warping
const height = 256;
const width = 256;
const y_coords = torch.arange(height).div(height);
const x_coords = torch.arange(width).div(width);
const [X, Y] = torch.meshgrid(x_coords, y_coords, 'xy');
// X, Y: [256, 256] - normalized pixel coordinates
// Distance matrix computation
const points = torch.tensor([[0, 0], [1, 0], [0, 1], [1, 1]]); // 4 points
const x = points.select(1, 0); // x-coordinates
const y = points.select(1, 1); // y-coordinates
const [X, Y] = torch.meshgrid(x, y, 'ij');
const dist = X.sub(Y).pow(2); // Pairwise distance (simplified)
// Spatial attention pattern
const height = 16;
const width = 16;
const h_indices = torch.arange(height);
const w_indices = torch.arange(width);
const [H, W] = torch.meshgrid(h_indices, w_indices, 'xy');
// H, W: [16, 16] - position indices for attention maskingSee Also
- PyTorch torch.meshgrid()
- linspace - Create linearly spaced 1D tensors (often used with meshgrid)
- arange - Create integer-spaced 1D tensors
- stack - Stack tensors (alternative to meshgrid for some cases)
- broadcast_to - Broadcasting operation (used internally)
- cartesian_prod - Cartesian product (related operation)