# Out of memory, or gradient accumulation for larger models

When the model grows large and does not fit on a single device, and there are no more devices to spare, the common mitigation strategy is to reduce the batch size, thereby allowing more space for the model at the expense of the data. However, smaller batches lead to noisier weight updates, which is undesirable. One solution is gradient accumulation where the weights are updated after evaluating the gradients for several batches at a time. In this article, we show how it can be implemented in practice.

# Solution

Long story short:

```
# Inherit from any optimizer of choice, such as Adam.
class Optimizer(tf.keras.optimizers.Adam):
"""Optimizer that implements gradient accumulation."""
def __init__(self, accumulation: int = 1, **options) -> None:
"""Create an instance.
Arguments:
accumulation: The number of iterations to accumulate gradients over.
If it is set to one, no accumulation is performed, and the gradients
are applied as soon as they are computed. If it is set to a value
greater than one, the gradients will be accumulated for the specified
number of iterations and only then applied, starting a new cycle.
All other arguments are passed to the base optimizer.
"""
super().__init__(**options)
self.accumulation = accumulation
self._accumulation = None
self._gradients = None
def apply_gradients(
self, gradients_variables: list[tuple[tf.Tensor, tf.Tensor]]
) -> tf.Tensor:
"""Apply the gradients according to the accumulation scheme."""
# Split off the gradients from the trainable variables.
gradients, variables = zip(*list(gradients_variables))
# Perform the initialization if needed.
with tf.init_scope():
self.build(variables)
first = self._accumulation % self.accumulation == 0
last = (self._accumulation + 1) % self.accumulation == 0
# Add the new gradients to the old ones with resetting if needed.
for gradient, increment in zip(self._gradients, gradients):
gradient.assign(tf.cast(~first, tf.float32) * gradient + increment)
# Apply the average accumulated gradients to the trainable variables.
gradients = [gradient / self.accumulation for gradient in self._gradients]
super().apply_gradients(zip(gradients, variables))
# Decrement the base counter incremented by the application if needed.
self.iterations.assign_sub(tf.cast(~last, tf.int64))
# Increment the accumulation counter.
self._accumulation.assign_add(1)
return self.iterations
def update_step(self, gradient: tf.Tensor, variable: tf.Tensor) -> None:
"""Update the trainable variable with the gradient."""
update_step = super().update_step
last = (self._accumulation + 1) % self.accumulation == 0
# Allow the update to happen only at the end of each cycle.
tf.cond(last, lambda: update_step(gradient, variable), lambda: None)
def build(self, variables: list[tf.Tensor]) -> None:
"""Initialize the internal state."""
super().build(variables)
if self._gradients is None:
# Create a counter for tracking accumulation.
self._accumulation = self.add_variable(shape=(), dtype=tf.int64)
# Allocate memory for accumulation.
self._gradients = [
self.add_variable_from_reference(
model_variable=variable,
variable_name="gradient",
)
for variable in variables
]
```

It is important to note that the learning rate is *not* held constant during
accumulation. However, since it is not expected to change much from one
iteration to another, it is an adequate simplification.

# Acknowledgments

I would like to thank André Pedersen, Axel Roebel, and Tor-Arne Nordmo for their help with the implementation.