torch.tx.EvalExprShape
export type EvalExprShape<AST, Shapes extends Record<string, Shape>> =
// Ternary: result is broadcast of then/else branches
AST extends { op: '?:'; cond: infer _C; then: infer T; else: infer E }
? BroadcastShape<EvalExprShape<T, Shapes>, EvalExprShape<E, Shapes>>
: // Matrix multiplication
AST extends { op: '@'; left: infer L; right: infer R }
? MatmulShape<EvalExprShape<L, Shapes>, EvalExprShape<R, Shapes>>
: // All other binary ops: broadcast shapes
AST extends {
op:
| '+'
| '-'
| '*'
| '/'
| '//'
| '%'
| '**'
| '|'
| '^'
| '&'
| '=='
| '!='
| '<'
| '<='
| '>'
| '>=';
left: infer L;
right: infer R;
}
? BroadcastShape<EvalExprShape<L, Shapes>, EvalExprShape<R, Shapes>>
: // Unary ops: preserve shape
AST extends { op: 'neg' | '~'; arg: infer A }
? EvalExprShape<A, Shapes>
: // Function calls with args array - check function category
AST extends {
fn: infer FnName extends string;
args: infer Args extends readonly unknown[];
}
? EvalFunctionShape<FnName, Args, Shapes>
: // Legacy: function calls with single arg (backwards compatibility)
AST extends { fn: infer FnName extends string; arg: infer Arg }
? FnName extends ShapePreservingFunctions
? EvalExprShape<Arg, Shapes>
: DynamicShape
: AST extends { var: infer Name extends string }
? Name extends keyof Shapes
? Shapes[Name]
: DynamicShape
: AST extends { literal: string }
? readonly [] // scalar
: DynamicShape;ASTShapesextends Record<string, Shape>Evaluate a parsed AST to compute the result shape.