mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Clang-format powerSGD_hook.py (#54839)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54839 ghstack-source-id: 125089465 Test Plan: N/A Reviewed By: rohan-varma Differential Revision: D27384796 fbshipit-source-id: 8312059f6a47d60ca29f75041141bb88804e1b32
This commit is contained in:
committed by
Facebook GitHub Bot
parent
6c31f56bf4
commit
7c0941ee63
@ -52,7 +52,11 @@ def _should_compress(
|
||||
""" # noqa
|
||||
uncompressed_size = num_rows * num_cols
|
||||
compressed_size = (num_rows + num_cols) * matrix_approximation_rank
|
||||
return (compressed_size * min_compression_rate < uncompressed_size, uncompressed_size, compressed_size)
|
||||
return (
|
||||
compressed_size * min_compression_rate < uncompressed_size,
|
||||
uncompressed_size,
|
||||
compressed_size,
|
||||
)
|
||||
|
||||
|
||||
class PowerSGDState(object):
|
||||
@ -126,7 +130,7 @@ class PowerSGDState(object):
|
||||
use_error_feedback,
|
||||
warm_start,
|
||||
random_seed,
|
||||
compression_stats_logging_frequency
|
||||
compression_stats_logging_frequency,
|
||||
)
|
||||
)
|
||||
|
||||
@ -183,7 +187,9 @@ class PowerSGDState(object):
|
||||
self.total_numel_after_compression = 0
|
||||
# We'll report compression stats every 'compression_stats_logging_frequency' iterations
|
||||
# Note that we always report compression stats at least once.
|
||||
self.compression_stats_logging_frequency = max(1, compression_stats_logging_frequency)
|
||||
self.compression_stats_logging_frequency = max(
|
||||
1, compression_stats_logging_frequency
|
||||
)
|
||||
self.next_stats_report = 0
|
||||
|
||||
def maybe_increase_iter(self, bucket):
|
||||
@ -207,9 +213,16 @@ class PowerSGDState(object):
|
||||
|
||||
numel_after_compression is the total number of elements after compression was applied.
|
||||
""" # noqa
|
||||
compress_rate = self.total_numel_before_compression / self.total_numel_after_compression \
|
||||
if self.total_numel_after_compression > 0 else 0
|
||||
return (compress_rate, self.total_numel_before_compression, self.total_numel_after_compression)
|
||||
compress_rate = (
|
||||
self.total_numel_before_compression / self.total_numel_after_compression
|
||||
if self.total_numel_after_compression > 0
|
||||
else 0
|
||||
)
|
||||
return (
|
||||
compress_rate,
|
||||
self.total_numel_before_compression,
|
||||
self.total_numel_after_compression,
|
||||
)
|
||||
|
||||
|
||||
def powerSGD_hook(
|
||||
@ -320,7 +333,9 @@ def powerSGD_hook(
|
||||
matrix = tensor.view(tensor.shape[0], -1)
|
||||
n, m = matrix.shape
|
||||
matrix_approximation_rank = min(n, m, state.matrix_approximation_rank)
|
||||
compress_test = _should_compress(n, m, matrix_approximation_rank, state.min_compression_rate)
|
||||
compress_test = _should_compress(
|
||||
n, m, matrix_approximation_rank, state.min_compression_rate
|
||||
)
|
||||
state.total_numel_before_compression += compress_test[1]
|
||||
if compress_test[0]:
|
||||
tensors_to_compress.append(matrix)
|
||||
@ -332,13 +347,14 @@ def powerSGD_hook(
|
||||
state.total_numel_after_compression += compress_test[1]
|
||||
|
||||
# Accumulate and report stats
|
||||
if bucket.is_the_last_bucket_to_allreduce() and state.iter >= state.next_stats_report:
|
||||
if (
|
||||
bucket.is_the_last_bucket_to_allreduce()
|
||||
and state.iter >= state.next_stats_report
|
||||
):
|
||||
stats = state.compression_stats()
|
||||
logging.info(
|
||||
"Compression stats: iter {}, total before compression {}, total after compression {}, "
|
||||
"rate {}".format(
|
||||
state.iter, stats[1], stats[2], stats[0]
|
||||
)
|
||||
"rate {}".format(state.iter, stats[1], stats[2], stats[0])
|
||||
)
|
||||
state.next_stats_report = state.iter + state.compression_stats_logging_frequency
|
||||
|
||||
@ -431,7 +447,9 @@ def powerSGD_hook(
|
||||
uncompressed_tensors_memory = fut.value()[0].div_(world_size)
|
||||
idx = 0
|
||||
for tensor in uncompressed_tensors:
|
||||
tensor.copy_(uncompressed_tensors_memory[idx : idx + tensor.numel()].view_as(tensor))
|
||||
tensor.copy_(
|
||||
uncompressed_tensors_memory[idx : idx + tensor.numel()].view_as(tensor)
|
||||
)
|
||||
idx += tensor.numel()
|
||||
|
||||
# Since these Ps will be orthogonalized later, no need to divide them by world size.
|
||||
|
Reference in New Issue
Block a user