torch.nn.functional.affine_grid
function affine_grid(theta: Tensor, size: readonly number[], options?: AffineGridFunctionalOptions): Tensorfunction affine_grid(theta: Tensor, size: readonly number[], align_corners: boolean, options?: AffineGridFunctionalOptions): TensorAffine Grid Generation: converts affine transformation matrices to coordinate grids for grid_sample.
Generates a sampling grid from a batch of affine transformation matrices. For each affine transformation matrix, creates a grid of output pixel coordinates that maps back to input image coordinates. Essential paired operation with grid_sample for Spatial Transformer Networks (STNs) to apply learned geometric transformations. Essential for:
- Spatial Transformer Networks (STNs) - learning geometric transformations
- Converting affine/perspective transformations to sampling grids
- Creating learnable geometric augmentation pipelines
- Image registration and alignment
- Differentiable geometric transformations in neural networks
- Data augmentation with learned transformations
- Pose estimation and object alignment
- Image correction and distortion removal
Workflow with grid_sample:
- Localization Network: Predict affine parameters (6 values for 2D: rotation, translation, scale, shear)
- Reshape to Matrix: Convert 6D vector to 2×3 affine matrix
- affine_grid: Generate output sampling grid from transformation matrix
- grid_sample: Sample input feature map using the grid
Affine Transformation Representation: 2D: 2×3 matrix = [a, b, c; d, e, f] (applied as [a b; d e] @ x + [c, f]) 3D: 3×4 matrix = [a, b, c, d; e, f, g, h; i, j, k, l] (3D affine)
Output Grid Coordinates: Grid maps output positions to input positions in normalized (-1, 1) space. Each position in output grid contains x, y (and z for 3D) normalized coordinates that tell grid_sample where to sample from the input feature map.
Align Corners Parameter:
- False (default): Pixel centers aligned (standard in modern PyTorch)
- True: Corner-to-corner alignment (older behavior)
- Inverse transformation: Matrix represents input→output mapping (not output→input)
- Identity initialization: theta = [[1,0,0], [0,1,0]] for identity (no transformation)
- Paired with grid_sample: Always used together; generates grid for grid_sample to use
- Coordinate normalization: Output grid in [-1, 1] space, not pixel coordinates
- Learnable transformation: Localization network learns theta values during training
- Batch support: Different transformation per sample in batch
- Differentiable: Grid generation is fully differentiable for backprop
- Inverse transformation: Matrix is inverse mapping (where to sample from input)
- align_corners consistency: Must use same align_corners in grid_sample as affine_grid
- Normalized output: Grid is in [-1, 1] space; don't use raw pixel coordinates
- 2D vs 3D shapes: theta must be [N, 2, 3] for 2D or [N, 3, 4] for 3D
- Size parameter: Only spatial dimensions matter; batch/channel ignored
- Numerical issues: Very large transformations can cause extrapolation beyond [-1, 1]
Parameters
thetaTensor- Affine transformation matrices of shape [N, 2, 3] (2D) or [N, 3, 4] (3D) - 2D: theta[i] = [[a, b, c], [d, e, f]] applies transform (a*x + b*y + c, d*x + e*y + f) - 3D: theta[i] = [[a, b, c, d], [e, f, g, h], [i, j, k, l]] for 3D affine - Typically learned by a localization network
sizereadonly number[]- Output spatial size, shape [N, C, H, W] (2D) or [N, C, D, H, W] (3D) - Only uses spatial dimensions (from index 2 onward) - Determines output grid size and coordinate range
optionsAffineGridFunctionalOptionsoptional
Returns
Tensor– Sampling grid of shape [N, H, W, 2] (2D) or [N, D, H, W, 3] (3D) Coordinates in [-1, 1] normalized space for use with grid_sampleExamples
// Simple 2D affine: identity transformation (no change)
const N = 8; // batch size
const theta = torch.zeros(N, 2, 3);
theta[0, 0, 0] = 1; // x scaling = 1
theta[0, 1, 1] = 1; // y scaling = 1
// Off-diagonal and last column remain 0 (no rotation/translation)
// theta = [[1, 0, 0], [0, 1, 0]] - identity matrix
const size = [N, 3, 32, 32]; // [batch, channels, height, width]
const grid = torch.nn.functional.affine_grid(theta, size, false);
// Output shape: [8, 32, 32, 2] - grid for bilinear sampling// 2D affine: scaling by 0.5 (zoom out)
const theta = torch.zeros(1, 2, 3);
theta[0, 0, 0] = 0.5; // x scale down by 2
theta[0, 1, 1] = 0.5; // y scale down by 2
// theta = [[0.5, 0, 0], [0, 0.5, 0]]
const size = [1, 3, 64, 64];
const grid = torch.nn.functional.affine_grid(theta, size, false);
// Grid samples from 2x larger region (zoomed out image)// Spatial Transformer Network: learn affine transformation
class STN extends torch.nn.Module {
private fc_loc: torch.nn.Linear; // Localization network
constructor() {
super();
// Localization: input features → 6 affine parameters (2D)
this.fc_loc = new torch.nn.Linear(128, 6);
// Initialize with identity transformation
this.fc_loc.weight.data.fill(0);
this.fc_loc.bias.data = torch.tensor([1, 0, 0, 0, 1, 0]).to('float32');
}
forward(x: Tensor): Tensor {
// 1. Localization: predict 6 affine parameters
const affine_params = this.fc_loc.forward(x); // [batch, 6]
// 2. Reshape to 2×3 affine matrix
const theta = affine_params.reshape([-1, 2, 3]); // [batch, 2, 3]
// 3. Generate grid
const grid = torch.nn.functional.affine_grid(theta, x.shape, false);
// 4. Sample input with grid
const sampled = torch.nn.functional.grid_sample(
x, grid, 'bilinear', 'zeros', false
);
return sampled;
}
}
// STN learns transformation parameters end-to-end// 2D affine: rotation by angle θ
const angle = Math.PI / 4; // 45 degrees
const cos_a = Math.cos(angle);
const sin_a = Math.sin(angle);
const theta = torch.tensor([[
[cos_a, -sin_a, 0],
[sin_a, cos_a, 0]
]]).to('float32'); // [1, 2, 3] - rotation matrix
const size = [1, 3, 64, 64];
const grid = torch.nn.functional.affine_grid(theta, size, false);
// Can now use with grid_sample to rotate images
const rotated = torch.nn.functional.grid_sample(image, grid, 'bilinear', 'zeros', false);// 2D affine: translation (shift)
const tx = 0.1; // shift x by 10% of width
const ty = 0.2; // shift y by 20% of height
const theta = torch.tensor([[
[1, 0, tx], // x: scale 1, translate tx
[0, 1, ty] // y: scale 1, translate ty
]]).to('float32');
const size = [1, 3, 64, 64];
const grid = torch.nn.functional.affine_grid(theta, size, false);
// Grid shifts sampling positions (translates image)// 3D affine: for volumetric/video data
const theta_3d = torch.zeros(2, 3, 4); // [batch=2, 3D affine]
// Identity 3D transformation
theta_3d[0, 0, 0] = 1; // x rotation/scale
theta_3d[0, 1, 1] = 1; // y rotation/scale
theta_3d[0, 2, 2] = 1; // z rotation/scale
const size = [2, 3, 32, 64, 64]; // [batch, channels, depth, height, width]
const grid_3d = torch.nn.functional.affine_grid(theta_3d, size, false);
// Output shape: [2, 32, 64, 64, 3] - 3D sampling gridSee Also
- PyTorch torch.nn.functional.affine_grid
- torch.nn.functional.grid_sample - Apply grid to sample input
- torch.nn.modules.spatial_transformer.SpatialTransformer - Complete STN module
- torch.nn.functional.pad - Alternative for geometric ops (padding)
- torch.nn.functional.interpolate - Direct resampling without transformations