torch.tx.EvalExprShape
export type EvalExprShape<AST, Shapes extends Record<string, Shape>> =
// Ternary: result is broadcast of then/else branches
AST extends { op: '?:'; cond: infer _C; then: infer T; else: infer E }
? BroadcastShape<EvalExprShape<T, Shapes>, EvalExprShape<E, Shapes>>
: // Matrix multiplication
AST extends { op: '@'; left: infer L; right: infer R }
? MatmulShape<EvalExprShape<L, Shapes>, EvalExprShape<R, Shapes>>
: // All other binary ops: broadcast shapes
AST extends {
op: '+' | '-' | '*' | '/' | '//' | '%' | '**' | '|' | '^' | '&' | '==' | '!=' | '<' | '<=' | '>' | '>=';
left: infer L;
right: infer R;
}
? BroadcastShape<EvalExprShape<L, Shapes>, EvalExprShape<R, Shapes>>
: // Unary ops: preserve shape
AST extends { op: 'neg' | '~'; arg: infer A }
? EvalExprShape<A, Shapes>
: // Function calls with args array - check function category
AST extends { fn: infer FnName extends string; args: infer Args extends readonly unknown[] }
? EvalFunctionShape<FnName, Args, Shapes>
: // Legacy: function calls with single arg (backwards compatibility)
AST extends { fn: infer FnName extends string; arg: infer Arg }
? FnName extends ShapePreservingFunctions
? EvalExprShape<Arg, Shapes>
: DynamicShape
: AST extends { var: infer Name extends string }
? Name extends keyof Shapes
? Shapes[Name]
: DynamicShape
: AST extends { literal: string }
? readonly [] // scalar
: DynamicShape;ASTShapesextends Record<string, Shape>Evaluate a parsed AST to compute the result shape.