torch.nn.Bilinear
new Bilinear(in1_features: number, in2_features: number, out_features: number, options?: BilinearOptions)
weight(Parameter)bias(Parameter | null)- readonly
in1_features(number) - readonly
in2_features(number) - readonly
out_features(number)
Bilinear transformation: applies a learned bilinear form to pairs of inputs.
Computes output as: y = x1^T W x2 + b for each of out_features separate weight matrices W. Captures interactions between two input vectors via learnable weights. Essential for:
- Matching/ranking pairs of inputs (e.g., image-text matching, question-answering)
- Relation extraction (finding relationships between entity pairs)
- Similarity learning between two different input spaces
- Attention mechanisms (comparing query-key-value interactions)
- Cross-modal learning (matching between different modalities)
Unlike Linear(x1, x2) which treats inputs independently, Bilinear captures their interaction. For each output dimension, learns a separate weight matrix for the bilinear form.
When to use Bilinear:
- Pair-wise comparisons (e.g., does image match caption?)
- Relation extraction from entity pairs
- Similarity scoring between two vectors
- Learning fine-grained interactions between input spaces
- Cross-modality alignment tasks
Trade-offs:
- vs Linear: Bilinear captures interactions; Linear treats inputs independently
- Parameters: O(in1 * in2 * out) vs O((in1 + in2) * out) for Linear
- Expressiveness: Much more expressive for pair-wise functions
- Computational cost: More expensive than Linear (matrix multiplications per output)
- Use case: Only when you need to model interactions between input pairs
Computation: For each output dimension k:
- output[k] = x1^T @ W[k] @ x2 + b[k] Where W[k] is a [in1_features, in2_features] weight matrix for dimension k.
With batch: output[b, k] = x1[b]^T @ W[k] @ x2[b] + b[k]
- Bilinear form: Captures pairwise interactions between input spaces
- Weight shape: [out, in1, in2] - separate bilinear form for each output
- Computation: More expensive than Linear due to pairwise products
- Parameter count: (in1 * in2 + 1) * out parameters (much larger than Linear)
- Interaction modeling: Essential when you need to model relationships between two inputs
- Symmetry: Not symmetric in general (bilinear(x, y) != bilinear(y, x))
- Low-rank approximation: If in1 and in2 are large, consider decomposing W into UVT
- High parameter count: Can be very expensive for high-dimensional inputs
- Computational cost: Slower than Linear for equivalent input/output dims
- Overfitting risk: Many parameters - may need regularization or dropout
- Memory intensive: Weight matrix is [out, in1, in2] which can be large
Examples
// Basic bilinear transformation
const bilinear = new torch.nn.Bilinear(10, 20, 5);
const x1 = torch.randn([32, 10]); // batch of 32, first input with 10 features
const x2 = torch.randn([32, 20]); // batch of 32, second input with 20 features
const output = bilinear.forward(x1, x2); // [32, 5]
// Each of 5 outputs computed via x1^T @ W @ x2 + b// Image-text matching with bilinear compatibility
class ImageTextMatcher extends torch.nn.Module {
image_encoder: torch.nn.Linear;
text_encoder: torch.nn.Linear;
matcher: torch.nn.Bilinear;
constructor() {
super();
this.image_encoder = new torch.nn.Linear(2048, 256); // ResNet features -> 256D
this.text_encoder = new torch.nn.Linear(768, 256); // BERT features -> 256D
this.matcher = new torch.nn.Bilinear(256, 256, 1); // Output: single compatibility score
}
forward(image_features: torch.Tensor, text_features: torch.Tensor): torch.Tensor {
const img_embedding = this.image_encoder.forward(image_features); // [B, 256]
const txt_embedding = this.text_encoder.forward(text_features); // [B, 256]
const compatibility = this.matcher.forward(img_embedding, txt_embedding); // [B, 1]
return torch.sigmoid(compatibility); // Convert to [0, 1] probability
}
}// Relation extraction: score pair relationships
const question_dim = 768; // BERT question embedding
const passage_dim = 768; // BERT passage embedding
const num_relations = 5; // Number of relation types (same answer, no answer, etc)
const relation_scorer = new torch.nn.Bilinear(question_dim, passage_dim, num_relations);
const question = torch.randn([32, 768]); // Batch of questions
const passages = torch.randn([32, 768]); // Corresponding passages
const relation_scores = relation_scorer.forward(question, passages); // [32, 5]
// Predict relation type from highest score// Attention-like interaction scoring
const seq_len = 10;
const key_dim = 64;
const query_dim = 64;
// Alternative to softmax attention for interaction scoring
const scorer = new torch.nn.Bilinear(query_dim, key_dim, 1);
const queries = torch.randn([1, seq_len, query_dim]); // [1, 10, 64]
const keys = torch.randn([1, seq_len, key_dim]); // [1, 10, 64]
// Score each query-key pair
const scores = torch.zeros([1, seq_len, seq_len]);
for (let i = 0; i < seq_len; i++) {
for (let j = 0; j < seq_len; j++) {
scores[0][i][j] = scorer.forward(queries[0][i].unsqueeze(0), keys[0][j].unsqueeze(0))[0][0];
}
}