torch.nn.functional.max_pool2d_with_indices
function max_pool2d_with_indices(input: Tensor, kernel_size: number | [number, number], options?: MaxPool2dFunctionalOptions): PoolWithIndicesResultApplies 2D max pooling over an input signal and returns both the pooled values and their indices.
This is the functional version of nn.MaxPool2d that also returns the indices of the maximum values. The indices are useful for:
- Implementing max unpooling in encoder-decoder architectures
- Visualizing which positions contributed to pooling
- Skip connections that preserve spatial information
- Indices are flattened indices into each pooling region, useful for max_unpool2d.
- For 3D input (C, H, W), a batch dimension is added and removed automatically.
Parameters
inputTensor- Input tensor of shape (N, C, H, W) or (C, H, W)
kernel_sizenumber | [number, number]- Size of the pooling window. Can be a single number for square kernels or [height, width] for rectangular kernels.
optionsMaxPool2dFunctionalOptionsoptional- Optional pooling settings: -
stride: Stride of the pooling operation (default: kernel_size) -padding: Zero-padding added to both sides (default: 0) -dilation: Spacing between kernel elements (default: 1) -ceil_mode: Use ceil instead of floor to compute output shape (default: false)
Returns
PoolWithIndicesResult– Object with: - values: Pooled output tensor of shape (N, C, H_out, W_out) - indices: Indices of max values, same shape as values, dtype int32Examples
// Basic max pooling with indices
const input = torch.randn(1, 3, 28, 28); // Batch of 3-channel 28x28 images
const { values, indices } = torch.nn.functional.max_pool2d_with_indices(input, 2);
// values: shape [1, 3, 14, 14] - pooled features
// indices: shape [1, 3, 14, 14] - positions of max values
// Custom stride and padding
const { values, indices } = torch.nn.functional.max_pool2d_with_indices(
input, [3, 3], { stride: 2, padding: 1 }
);
// Use indices for max unpooling (encoder-decoder)
const pooled = torch.nn.functional.max_pool2d_with_indices(features, 2);
// ... later in decoder ...
const unpooled = torch.nn.functional.max_unpool2d(pooled.values, pooled.indices, 2);See Also
- [PyTorch torch.nn.functional.max_pool2d (with return_indices=True)](https://pytorch.org/docs/stable/generated/torch.nn.functional.max_pool2d .html)
- max_pool2d - Max pooling without indices
- max_unpool2d - Inverse operation using indices
- avg_pool2d - Average pooling alternative