Clang-format powerSGD_hook.py (#54839)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54839

ghstack-source-id: 125089465

Test Plan: N/A

Reviewed By: rohan-varma

Differential Revision: D27384796

fbshipit-source-id: 8312059f6a47d60ca29f75041141bb88804e1b32
This commit is contained in:
Yi Wang
2021-03-30 09:25:19 -07:00
committed by Facebook GitHub Bot
parent 6c31f56bf4
commit 7c0941ee63

View File

@ -52,7 +52,11 @@ def _should_compress(
""" # noqa
uncompressed_size = num_rows * num_cols
compressed_size = (num_rows + num_cols) * matrix_approximation_rank
return (compressed_size * min_compression_rate < uncompressed_size, uncompressed_size, compressed_size)
return (
compressed_size * min_compression_rate < uncompressed_size,
uncompressed_size,
compressed_size,
)
class PowerSGDState(object):
@ -126,7 +130,7 @@ class PowerSGDState(object):
use_error_feedback,
warm_start,
random_seed,
compression_stats_logging_frequency
compression_stats_logging_frequency,
)
)
@ -183,7 +187,9 @@ class PowerSGDState(object):
self.total_numel_after_compression = 0
# We'll report compression stats every 'compression_stats_logging_frequency' iterations
# Note that we always report compression stats at least once.
self.compression_stats_logging_frequency = max(1, compression_stats_logging_frequency)
self.compression_stats_logging_frequency = max(
1, compression_stats_logging_frequency
)
self.next_stats_report = 0
def maybe_increase_iter(self, bucket):
@ -207,9 +213,16 @@ class PowerSGDState(object):
numel_after_compression is the total number of elements after compression was applied.
""" # noqa
compress_rate = self.total_numel_before_compression / self.total_numel_after_compression \
if self.total_numel_after_compression > 0 else 0
return (compress_rate, self.total_numel_before_compression, self.total_numel_after_compression)
compress_rate = (
self.total_numel_before_compression / self.total_numel_after_compression
if self.total_numel_after_compression > 0
else 0
)
return (
compress_rate,
self.total_numel_before_compression,
self.total_numel_after_compression,
)
def powerSGD_hook(
@ -320,7 +333,9 @@ def powerSGD_hook(
matrix = tensor.view(tensor.shape[0], -1)
n, m = matrix.shape
matrix_approximation_rank = min(n, m, state.matrix_approximation_rank)
compress_test = _should_compress(n, m, matrix_approximation_rank, state.min_compression_rate)
compress_test = _should_compress(
n, m, matrix_approximation_rank, state.min_compression_rate
)
state.total_numel_before_compression += compress_test[1]
if compress_test[0]:
tensors_to_compress.append(matrix)
@ -332,13 +347,14 @@ def powerSGD_hook(
state.total_numel_after_compression += compress_test[1]
# Accumulate and report stats
if bucket.is_the_last_bucket_to_allreduce() and state.iter >= state.next_stats_report:
if (
bucket.is_the_last_bucket_to_allreduce()
and state.iter >= state.next_stats_report
):
stats = state.compression_stats()
logging.info(
"Compression stats: iter {}, total before compression {}, total after compression {}, "
"rate {}".format(
state.iter, stats[1], stats[2], stats[0]
)
"rate {}".format(state.iter, stats[1], stats[2], stats[0])
)
state.next_stats_report = state.iter + state.compression_stats_logging_frequency
@ -431,7 +447,9 @@ def powerSGD_hook(
uncompressed_tensors_memory = fut.value()[0].div_(world_size)
idx = 0
for tensor in uncompressed_tensors:
tensor.copy_(uncompressed_tensors_memory[idx : idx + tensor.numel()].view_as(tensor))
tensor.copy_(
uncompressed_tensors_memory[idx : idx + tensor.numel()].view_as(tensor)
)
idx += tensor.numel()
# Since these Ps will be orthogonalized later, no need to divide them by world size.