torch.MatmulShape
export type MatmulShape<A extends Shape, B extends Shape> =
// Handle dynamic shapes
number extends A['length']
? DynamicShape
: number extends B['length']
? DynamicShape
: // 1D @ 1D: dot product
A['length'] extends 1
? B['length'] extends 1
? A[0] extends B[0]
? readonly []
: matmul_error_inner_dimensions_do_not_match<A[0] & number, B[0] & number, A, B>
: // 1D @ 2D
B['length'] extends 2
? A[0] extends B[0]
? readonly [B[1] & number]
: matmul_error_inner_dimensions_do_not_match<A[0] & number, B[0] & number, A, B>
: DynamicShape
: // 2D @ 1D
A['length'] extends 2
? B['length'] extends 1
? A[1] extends B[0]
? readonly [A[0] & number]
: matmul_error_inner_dimensions_do_not_match<A[1] & number, B[0] & number, A, B>
: // 2D @ 2D
B['length'] extends 2
? A[1] extends B[0]
? readonly [A[0] & number, B[1] & number]
: matmul_error_inner_dimensions_do_not_match<A[1] & number, B[0] & number, A, B>
: DynamicShape
: // ND @ ND (batched) - use helper
MatmulBatchedShape<A, B>;Aextends ShapeBextends ShapeCompute output shape of matmul with various input dimensions.
Handles:
- 2D @ 2D: [M, K] @ [K, N] → [M, N]
- 1D @ 1D: [K] @ [K] → [] (dot product)
- 2D @ 1D: [M, K] @ [K] → [M]
- 1D @ 2D: [K] @ [K, N] → [N]
- Batched: [...B, M, K] @ [...B, K, N] → [...B, M, N] with batch broadcasting