torch.registerBinaryOp
function registerBinaryOp<Op extends BinaryOpNames>(config: BinaryOpConfig<Op>): voidRegister a binary operation with minimal boilerplate.
Parameters
configBinaryOpConfig<Op>
Examples
registerBinaryOp({
op: 'mul',
cpu: (a, b) => a * b,
shaderEntry: 'mul_op',
dtypes: ['float32', 'int32'],
backward: (ctx, grad) => ({
a: grad.mul(ctx.saved_tensors[1] as AnyTensor),
b: grad.mul(ctx.saved_tensors[0] as AnyTensor),
}),
});