Log when using fallback layer (#121)

This commit is contained in:
Daniël de Kok
2025-07-31 17:18:00 +02:00
committed by GitHub
parent 6fbff7a9cb
commit bcc29915f9

View File

@ -920,13 +920,19 @@ def _conditionally_replace_forward(
# layers registered with the FALLBACK mode never get rejected by
# _validate_layer_has_mode. For such layers, we want to fall back in
# case the layer does not support the given mode.
needs_fallback = Mode.TORCH_COMPILE in mode and not getattr(
needs_fallback_for_compile = Mode.TORCH_COMPILE in mode and not getattr(
layer, "can_torch_compile", False
)
needs_fallback |= Mode.TRAINING in mode and not getattr(layer, "has_backward", True)
needs_fallback_for_backward = Mode.TRAINING in mode and not getattr(
layer, "has_backward", True
)
if needs_fallback:
if needs_fallback_for_compile or needs_fallback_for_backward:
if use_fallback:
if needs_fallback_for_compile:
logging.info("Layer does not support torch.compile, using fallback")
if needs_fallback_for_backward:
logging.info("Layer does not support backward, using fallback")
_replace_forward(module, module_class)
else:
raise ValueError(f"Available kernel does not support mode: {mode}")