torch.nn.functional.grid_sample
function grid_sample(input: Tensor, grid: Tensor, options?: GridSampleFunctionalOptions): Tensorfunction grid_sample(input: Tensor, grid: Tensor, mode: 'bilinear' | 'nearest' | 'bicubic', padding_mode: 'zeros' | 'border' | 'reflection', align_corners: boolean, options?: GridSampleFunctionalOptions): TensorSpatial Transformer Grid Sampling: samples input using learned spatial transformation grids.
Samples input feature map at arbitrary (floating-point) coordinates specified by a grid. Core operation for spatial transformer networks (STNs) enabling learning of geometric transformations. Given normalized (-1 to 1) coordinate grid, interpolates values from input feature map at those locations. Enables networks to learn spatial transformations (rotation, translation, scale, shear) end-to-end. Essential for:
- Spatial Transformer Networks (STNs) - learning geometric invariance
- Image augmentation and data augmentation during training
- Learnable geometric transformations (learned rotations, zooms, translations)
- Deformable convolutions and attention mechanisms
- Differentiable image warping and geometric correction
- Pose estimation and object alignment networks
- Object tracking and detection with spatial adaptation
- Fine-grained recognition with attended regions
Normalized Coordinates: Grid uses normalized (-1, 1) coordinates:
- (-1, -1) = top-left corner of input
- (0, 0) = center of input
- (1, 1) = bottom-right corner
- Out-of-bounds coordinates handled by padding_mode
Differentiation: Bilinear/nearest interpolation is fully differentiable, enabling backprop through sampling. Gradients flow from output back to input feature map and grid coordinates. This allows learning transformations end-to-end with the rest of the network.
Interpolation Modes:
- 'bilinear': Smooth bilinear interpolation (differentiable, standard choice)
- 'nearest': Nearest neighbor (non-smooth at pixel boundaries)
- 'bicubic': Cubic interpolation (smoother, only for 2D)
Padding Modes (for out-of-bounds access):
- 'zeros': Fill out-of-bounds with 0
- 'border': Clamp to edge pixels
- 'reflection': Mirror at boundaries
- Normalized coordinates crucial: Grid must use (-1, 1) range; otherwise sampling is wrong
- Bilinear standard choice: Default mode for most applications (smooth, differentiable)
- Spatial Transformer Networks: Primary use case for learning geometric invariance
- Differentiable pipeline: Entire transformation learnable end-to-end with network
- Composition possible: Multiple grid_sample calls can be composed for complex transformations
- Batch processing: Supports different transformation per sample in batch
- Grid interpolation: Grid coordinates themselves can be interpolated (continuous transformations)
- Out-of-bounds handling: Different padding modes produce different results; choose carefully
- Coordinate system: (-1, -1) is top-left, (1, 1) is bottom-right (y-axis inverted from math)
- Non-differentiable nearest mode: Nearest neighbor not differentiable at pixel boundaries
- Grid shape mismatch: Grid and input batch sizes must match; will error otherwise
- Computational cost: Quadratic in output resolution (samples each output location)
- align_corners confusion: Different align_corners affects coordinate mapping; use consistently
Parameters
inputTensor- Feature map to sample from, shape [N, C, H_in, W_in] (2D) or [N, C, D_in, H_in, W_in] (3D) - N: batch size - C: number of channels - H_in, W_in (D_in): spatial dimensions of input
gridTensor- Sampling coordinates, shape [N, H_out, W_out, 2] (2D) or [N, D_out, H_out, W_out, 3] (3D) - Each coordinate in [-1, 1] normalized space - Last dimension contains x, y (and z for 3D) coordinates
optionsGridSampleFunctionalOptionsoptional
Returns
Tensor– Sampled feature map of shape [N, C, H_out, W_out] (2D) or [N, C, D_out, H_out, W_out] (3D) Same channels and batch size as input, with spatial size from gridExamples
// Basic spatial transformer: learn to rotate/translate image
const image = torch.randn(8, 3, 32, 32); // [batch=8, channels=3, height=32, width=32]
// Create transformation grid: identity (no transformation)
const grid = torch.zeros(8, 32, 32, 2); // [batch, H, W, 2]
for (let h = 0; h < 32; h++) {
for (let w = 0; w < 32; w++) {
grid[0, h, w, 0] = -1 + 2 * w / 31; // x coordinate from -1 to 1
grid[0, h, w, 1] = -1 + 2 * h / 31; // y coordinate from -1 to 1
}
}
// Broadcast to batch
const grid_batch = grid.unsqueeze(0).expand([8, 32, 32, 2]);
const sampled = torch.nn.functional.grid_sample(image, grid_batch, 'bilinear');
// Output: [8, 3, 32, 32] - same as input (identity transformation)// Spatial Transformer Network: learn geometric transformation
class SpatialTransformer extends torch.nn.Module {
private localization: torch.nn.Sequential;
private grid: Tensor;
forward(x: Tensor): Tensor {
// 1. Localization network: predict affine transformation
const theta = this.localization.forward(x); // [batch, 6] for 2D affine
theta = theta.reshape([-1, 2, 3]); // [batch, 2, 3]
// 2. Grid generation: create sampling grid from transformation
const grid = torch.nn.functional.affine_grid(
theta,
x.shape,
false
); // [batch, H, W, 2]
// 3. Grid sampling: apply transformation to input
const sampled = torch.nn.functional.grid_sample(
x, grid, 'bilinear', 'zeros', false
); // [batch, channels, H, W]
return sampled;
}
}
// STN learns to apply learned geometric transformations// Image warping: sample with custom coordinates
const source_image = torch.randn(1, 3, 256, 256); // Source image
// Create grid that samples only the center region (zoom in)
const grid_zoom = torch.zeros(1, 256, 256, 2);
for (let h = 0; h < 256; h++) {
for (let w = 0; w < 256; w++) {
const x = -0.5 + 1.0 * w / 255; // Scale by 2 in the center
const y = -0.5 + 1.0 * h / 255;
grid_zoom[0, h, w, 0] = x;
grid_zoom[0, h, w, 1] = y;
}
}
const zoomed = torch.nn.functional.grid_sample(
source_image, grid_zoom, 'bilinear', 'border', false
);
// Output: zoomed version of center region// Different padding modes: handling out-of-bounds
const input = torch.randn(1, 3, 32, 32);
// Create grid with some out-of-bounds coordinates
const grid = torch.randn(1, 40, 40, 2); // Some values outside [-1, 1]
// Padding mode: zeros - fill black for out-of-bounds
const padded_zeros = torch.nn.functional.grid_sample(
input, grid, 'bilinear', 'zeros', false
); // [1, 3, 40, 40]
// Padding mode: border - repeat edge pixels
const padded_border = torch.nn.functional.grid_sample(
input, grid, 'bilinear', 'border', false
); // Seamless at boundaries
// Padding mode: reflection - mirror at boundaries
const padded_reflect = torch.nn.functional.grid_sample(
input, grid, 'bilinear', 'reflection', false
); // Natural-looking continuation// Differentiable geometric transformation learning
const optimizer = new torch.optim.SGD(model.parameters(), { lr: 0.01 });
const input = torch.randn(32, 3, 64, 64);
const target = torch.randn(32, 3, 64, 64); // Target after transformation
// Forward: apply learned spatial transformation
const theta = localization_net(input); // Predict transformation
const grid = torch.nn.functional.affine_grid(theta, input.shape, false);
const transformed = torch.nn.functional.grid_sample(
input, grid, 'bilinear', 'zeros', false
);
// Loss and backprop
const loss = torch.nn.functional.mse_loss(transformed, target);
loss.backward();
optimizer.step();
// Network learns to apply transformations that match targetSee Also
- PyTorch torch.nn.functional.grid_sample
- torch.nn.functional.affine_grid - Generate grid from affine transformation matrix
- torch.nn.functional.pad - Padding (different operation, handles tensor edges)
- torch.nn.functional.interpolate - Direct resampling without grid
- torch.nn.modules.spatial_transformer.SpatialTransformer - Full STN module