[TPU] [Perf] Improve Memory Usage Estimation (#15671)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw
2025-03-28 10:37:52 -07:00
committed by GitHub
parent d03308be0c
commit 038bededba

View File

@ -161,7 +161,13 @@ class TPUWorker:
# intermediate activations.
m = xm.get_memory_info(self.device)
total_memory_size = m["bytes_limit"]
profiled = m["peak_bytes_used"] # Weights + intermediate activations.
current_mem = m["bytes_used"]
# Ideally we would use profiled = m["peak_bytes_used"] to
# get weights + activations. But there is memory used during
# compilation / weight loading that impacts the peak and
# there is no way to reset peak memory in XLA, So we
# use the heuristic of 2% of weights.
profiled = current_mem * 1.02
# Calculate the TPU KV cache size based on profiling.
usable_memory_size = int(total_memory_size *