torch.autograd.Function
class Functionforward(ForwardFn | ((...inputs: any[]) => Tensor | Tensor[]))- – Performs the operation. This method should be overridden by all subclasses. Combined style: forward(ctx, *args) Separate style: forward(*args) (without ctx)
backward(BackwardFn)- – Defines the formula for differentiating the operation. This method should be overridden by all subclasses. It must accept a context ctx as the first argument, followed by as many outputs as forward() returned. It should return as many tensors as there were inputs to forward().
setup_context(SetupContextFn)optional- – Sets up the context for backward (separate style). Optional. If provided, forward() should NOT accept ctx as first argument.
Base class to create custom autograd Functions.
To create a custom autograd.Function, subclass this class and implement the forward() and backward() static methods.
There are two ways to define forward:
- Combined style (traditional): forward(ctx, *args) where ctx is modified inline
- Separate style (PyTorch 2.0+): forward(*args) + setup_context(ctx, inputs, output)
The separate style is recommended as it composes better with torch.func transforms.
Examples
Combined style
```typescript
class Exp extends torch.autograd.Function {
static forward(ctx: FunctionCtx, x: Tensor): Tensor {
const result = x.exp();
ctx.save_for_backward(result);
return result;
}
static backward(ctx: FunctionCtx, grad_output: Tensor): [Tensor] {
const [result] = ctx.saved_tensors;
return [grad_output.mul(result)];
}
}
const y = Exp.apply(x);
```typescript
Separate style
```typescript
class Exp extends torch.autograd.Function {
static forward(x: Tensor): Tensor {
return x.exp();
}
static setup_context(ctx: FunctionCtx, inputs: [Tensor], output: Tensor): void {
ctx.save_for_backward(output);
}
static backward(ctx: FunctionCtx, grad_output: Tensor): [Tensor] {
const [result] = ctx.saved_tensors;
return [grad_output.mul(result)];
}
}
const y = Exp.apply(x);
## See Also
- [PyTorch torch.autograd.Function](https://pytorch.org/docs/stable/generated/torch.autograd.Function.html)