Hmm… Maybe:
What that and self.training guard is doing in Transformers
In GradientCheckpointingLayer.__call__, Transformers only routes through checkpointing when:
if self.gradient_checkpointing and self.training:
and, when it does, it also forcibly disables cache-related kwargs (use_cache=False, past_key_values=None, etc.) and logs a warning once. (GitHub)
That pairing (“only in training” + “disable caching”) exists for two practical reasons:
1) Checkpointing is a backward-time optimization
PyTorch checkpointing saves memory by not storing activations in forward and instead recomputing forward during backward. If you are not going to run backward (typical eval/inference), there is little to gain. (PyTorch Docs)
2) Caching (past_key_values) is often critical in eval/generation
Transformers treats KV caching as incompatible with checkpointing, and explicitly rewrites cache args when checkpointing is active. (GitHub)
If you allowed checkpointing in eval mode, you’d also be enabling this “turn cache off” behavior in eval, which can break assumptions and/or slow generation paths. A concrete example is a reported generate() misbehavior triggered by gradient_checkpointing=True → use_cache=False → past_key_values=None. (GitHub)
What would happen if you change it to if self.gradient_checkpointing:?
You’d be enabling checkpointing in eval mode too, which has different outcomes depending on whether gradients are enabled and whether anything requires grad.
Case A — eval() + torch.no_grad() / inference_mode() (most inference)
Likely outcome: mostly downside
- No meaningful memory savings, because there is no backward graph to save activations for. (PyTorch Docs)
- Potential overhead/noise: you may see warnings of the form “None of the inputs have requires_grad=True. Gradients will be None” if checkpoint is invoked when nothing participates in autograd (this is a common complaint in PyTorch land). (PyTorch Forums)
- Behavior change: any call passing
use_cache=True / past_key_values would get those rewritten (disabled) even in eval. (GitHub)
This is particularly harmful for autoregressive decoding, where caching is the main performance lever (and can even affect code paths, as in the generate() issue). (GitHub)
Case B — eval() + grads enabled, but all params frozen and inputs don’t require grad
Likely outcome: checkpointing is ineffective and may block gradients
- With all parameters
requires_grad=False, nothing inside the checkpointed region may participate in autograd.
- For PyTorch’s reentrant checkpoint variant (
use_reentrant=True), PyTorch documents a hard requirement: at least one input and one output must have requires_grad=True or the checkpointed part will not produce gradients. (PyTorch Docs)
- Result: you can get warnings / “grads will be None” behavior and you still don’t get a useful gradient signal. (PyTorch Forums)
Case C — eval() + grads enabled, and you do want gradients (e.g., wrt inputs/embeddings or a small unfrozen subset)
This is the one scenario where “eval + checkpointing” can make sense
-
There’s an active Transformers issue describing exactly this need: keep eval() (disable dropout for deterministic gradients) while using checkpointing to fit large models for gradient-based metrics. (GitHub)
-
But you must ensure something requires grad through the checkpointed region:
- If you want gradients w.r.t. input embeddings, you typically need to make them require grad (Transformers provides
enable_input_require_grads() for this use case). (Hugging Face)
- If you are using reentrant checkpointing, Transformers’ own docstring warns that inputs requiring gradients (e.g., hidden states) must be passed as positional args (not kwargs) for gradients to propagate correctly. (GitHub)
-
You still have to accept that KV cache will be disabled when checkpointing is active. (GitHub)
Why binding to training is a blunt but pragmatic choice
self.training is being used as a proxy for “this is a backward-producing regime where checkpointing is worth it, and we’re okay disabling caches.”
There are multiple ecosystem reports asking for a different proxy: gate checkpointing on “grad mode” instead of “training mode.” For example, a Diffusers issue proposes if torch.is_grad_enabled() and self.gradient_checkpointing: specifically to enable checkpointing in eval for LoRA/frozen-model scenarios while avoiding inference/no_grad() paths. (GitHub)
What you should do instead (recommended approach)
If your goal is eval-mode forward behavior + gradients, the more semantically correct condition is:
- Enable checkpointing when gradients are enabled, not when
module.training is true.
Conceptually: self.gradient_checkpointing and torch.is_grad_enabled() (often also combined with “any tensor input requires grad” to avoid warnings).
Also, strongly consider using the non-reentrant checkpoint variant where possible, because it removes the “input/output must require grad” requirement that bites frozen-model setups. PyTorch explicitly states that this requirement applies to reentrant but not non-reentrant. (PyTorch Docs)
Transformers supports passing checkpoint kwargs via gradient_checkpointing_enable(gradient_checkpointing_kwargs=...). (Hugging Face)
Bottom line for your exact stated setup (eval + all params frozen)
- If you are not computing gradients (typical eval/inference): changing the condition will mostly add risk (cache disabled, possible warnings) with little benefit. (GitHub)
- If you are computing gradients (e.g., wrt inputs/embeddings): it can work, but only if you ensure something requires grad through the checkpointed region (often
enable_input_require_grads()), and you accept that caching will be forced off. (Hugging Face)