torch.Tensor.Tensor.masked_select
Tensor.masked_select(mask: Tensor): Tensor<DynamicShape, D, Dev>Selects elements where a boolean mask is true, returning a 1D flattened result.
Extracts all elements where the corresponding mask value is true and packs them into a 1D tensor. Useful for conditional element selection, filtering, and masking operations in neural networks. Works element-wise across the entire tensor.
Important limitation: Due to WebGPU's lack of dynamic sizing, this returns a tensor
with shape [inputSize] where inputSize is the total number of elements. The actual
selected elements are packed at the beginning. Use masked_select_async() to get the
exact shape (requires dynamic sizing).
Common use cases:
- Filter elements based on conditions (e.g., values > threshold)
- Extract non-zero elements
- Conditional loss computation (mask out certain samples)
- Attention weights masking
- Pruning and sparsification
- Exact shape unavailable: For GPU tensors, can't know true count synchronously. Result shape is always [inputSize], but only first N elements are valid.
- 1D output: Always flattens selected elements to 1D regardless of input shape.
- Mask shape: Mask must be broadcastable with input (usually same shape).
- Async alternative: Use
masked_select_async()to get exact shape (slower). - Gradient flow: Fully differentiable for backpropagation.
- WebGPU limitation: Must synchronously return fixed size. Actual count unknown.
- Garbage elements: Use only first N elements from result; rest are undefined.
- Memory overhead: For sparsely selected tensors, significant memory wasted.
Parameters
maskTensor- Boolean tensor same shape as input. true=include, false=exclude.
Returns
Tensor<DynamicShape, D, Dev>– 1D tensor with selected elements packed at beginning (may include garbage after true count)Examples
// Filter values greater than threshold
const x = torch.tensor([[1, 2, 3], [4, 5, 6]]);
const mask = x.gt(3); // [[false, false, false], [true, true, true]]
const selected = x.masked_select(mask); // [4, 5, 6, ?, ?, ?] (last 3 are garbage)
// Extract non-zero elements
const nonzero = x.masked_select(x.ne(0));
// Conditional loss masking - zero out padding tokens
const logits = torch.randn(batch_size, seq_len, vocab_size);
const padding_mask = torch.ones(batch_size, seq_len); // 1=real token, 0=padding
const valid_logits = logits.masked_select(padding_mask.unsqueeze(-1));
// Attention mask pattern: mask out future positions
const attention = torch.randn(batch, seq_len, seq_len);
const future_mask = torch.tril(torch.ones(seq_len, seq_len)); // Lower triangular
const masked_attn = attention.masked_select(future_mask.unsqueeze(0));
// Efficient outlier removal (keep values in 5th to 95th percentile)
const data = torch.randn(1000);
const q05 = torch.kthvalue(data, 50).values; // 5th percentile
const q95 = torch.kthvalue(data, 950).values; // 95th percentile
const inliers = data.masked_select(data.ge(q05).logical_and(data.le(q95)));See Also
- PyTorch torch.masked_select()
- masked_select_async - Async variant that returns exact shape
- where - Conditional selection returning full tensor size
- nonzero - Returns indices of non-zero elements
- gather - Index-based element selection
- scatter - Index-based element placement