torch.autograd.FunctionCtx.save_for_backward
FunctionCtx.save_for_backward(...tensors: Tensor[]): voidSaves given tensors for a future call to backward().
save_for_backward should be called at most once, in either the forward() or setup_context() method, and only with tensors.
All tensors intended to be used in the backward pass should be saved with save_for_backward (as opposed to directly on ctx) to prevent incorrect gradients and memory leaks.
Parameters
tensorsTensor[]- Tensors to save for backward
Examples
class MyFunc extends torch.autograd.Function {
static forward(ctx: FunctionCtx, x: Tensor, y: Tensor): Tensor {
ctx.save_for_backward(x, y);
return x.mul(y);
}
static backward(ctx: FunctionCtx, grad_output: Tensor): [Tensor, Tensor] {
const [x, y] = ctx.saved_tensors;
return [grad_output.mul(y), grad_output.mul(x)];
}
}