Skip to content

Conversation

@yashwantbezawada
Copy link
Contributor

@yashwantbezawada yashwantbezawada commented Dec 25, 2025

Fixes #43010

What does this PR do?

Adds @torch.no_grad() decorator to cache layer update() methods that use in-place operations, preventing torch.func.grad from failing with "in-place operation would mutate a captured Tensor" errors.

Why only StaticLayer and StaticSlidingWindowLayer?

After investigation, I found that only these two classes need the decorator:

Cache Layer Operation In-Place? Needs @torch.no_grad()?
DynamicLayer torch.cat() No No - breaks gradient flow
DynamicSlidingWindowLayer torch.cat() No No - breaks gradient flow
StaticLayer index_copy_() Yes Yes
StaticSlidingWindowLayer copy_(), index_copy_() Yes Yes
QuantizedLayer torch.cat() No No - breaks gradient flow

Adding the decorator to DynamicLayer (and subclasses) would break gradient flow because torch.cat() creates new tensors that participate in the computation graph. Models like T5 use DynamicCache and need gradients to flow through cached key/values.

Changes

Added @torch.no_grad() decorator to:

  • StaticLayer.update()
  • StaticSlidingWindowLayer.update()

Testing

  • All 9 cache unit tests pass
  • Verified torch.func.grad works with StaticCache
  • Verified gradient flow preserved for DynamicCache

Decorates all cache layer update() methods with @torch.no_grad() to prevent
PyTorch autograd from complaining about tensor version changes when computing
gradients with respect to model inputs.

This follows the same pattern used by optimizer.step() methods and is safe
because cache updates are only used during inference/generation, never during
training.

Methods decorated:
- DynamicLayer.update()
- DynamicSlidingWindowLayer.update()
- StaticLayer.update()
- StaticSlidingWindowLayer.update()
- QuantizedLayer.update()

Fixes huggingface#43010
Remove @torch.no_grad() from DynamicLayer, DynamicSlidingWindowLayer,
and QuantizedLayer since they use torch.cat() (not in-place) and need
gradient flow preserved.

Keep @torch.no_grad() only on StaticLayer and StaticSlidingWindowLayer
which use index_copy_() and copy_() (in-place operations) that cause
torch.func.grad to fail.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Cache's (and Layer's) update(...) method to be decorated with @torch.no_grad

1 participant