Fix G001,G002,G003 in logs to % syntax (#97812)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97812
Approved by: https://github.com/Skylion007, https://github.com/kiukchung, https://github.com/malfet, https://github.com/mlazos
This commit is contained in:
Edward Z. Yang
2023-03-31 12:53:36 -04:00
committed by PyTorch MergeBot
parent 7f9533e224
commit 5df59f957f
12 changed files with 70 additions and 77 deletions

View File

@ -16,7 +16,7 @@ ignore =
# these ignores are from flake8-comprehensions; please fix! # these ignores are from flake8-comprehensions; please fix!
C407 C407
# these ignores are from flake8-logging-format; please fix! # these ignores are from flake8-logging-format; please fix!
G001,G002,G003,G004,G100,G101,G200,G201,G202, G004,G100,G101,G200,G201,G202
# these ignores are from flake8-simplify. please fix or ignore with commented reason # these ignores are from flake8-simplify. please fix or ignore with commented reason
SIM105,SIM108,SIM109,SIM110,SIM111,SIM113,SIM114,SIM115,SIM116,SIM117,SIM118,SIM119,SIM12, SIM105,SIM108,SIM109,SIM110,SIM111,SIM113,SIM114,SIM115,SIM116,SIM117,SIM118,SIM119,SIM12,
# flake8-simplify code styles # flake8-simplify code styles

View File

@ -891,7 +891,7 @@ def read_batch_size_from_file(args, filename, model_name):
if model_name == cur_name: if model_name == cur_name:
batch_size = int(b) batch_size = int(b)
if batch_size is None: if batch_size is None:
log.warning("Could not find batch size for {}".format(model_name)) log.warning("Could not find batch size for %s", model_name)
elif batch_size == -1: elif batch_size == -1:
raise RuntimeError( raise RuntimeError(
f"Batch size is unset for {model_name} in {args.batch_size_file}" f"Batch size is unset for {model_name} in {args.batch_size_file}"

View File

@ -31,7 +31,7 @@ if __name__ == "__main__":
work = process_group.allreduce(torch.rand(10).cuda(rank)) work = process_group.allreduce(torch.rand(10).cuda(rank))
logging.info('Waiting for allreduce to complete...') logging.info('Waiting for allreduce to complete...')
work.wait() work.wait()
logging.info('Second allreduce successful: {}'.format(work.is_success())) logging.info('Second allreduce successful: %s', work.is_success())
else: else:
logging.info('Aborting all other ranks.') logging.info('Aborting all other ranks.')
os.abort() os.abort()

View File

@ -45,9 +45,11 @@ def pretty_print_buckets(buckets: List[Bucket]):
try: try:
from tabulate import tabulate from tabulate import tabulate
# TODO: Do you really want to log.info this? It would get
# suppressed if log level is too low
log.info( log.info(
"\nDDPOptimizer bucket assignments\n" "\nDDPOptimizer bucket assignments\n%s",
+ tabulate(rows, headers=headers, tablefmt="simple_grid") tabulate(rows, headers=headers, tablefmt="simple_grid"),
) )
except ImportError: except ImportError:
log.info( log.info(
@ -318,9 +320,7 @@ class DDPOptimizer:
else: else:
curr_submod = real_mod curr_submod = real_mod
log.debug( log.debug(f"\n---{n.target} graph---\n{curr_submod.graph}")
f"\n---{n.target} graph---\n" + str(curr_submod.graph)
)
# When calling the compiler on the submod, inputs (new_args) are expected to # When calling the compiler on the submod, inputs (new_args) are expected to
# be FakeTensors already since Dynamo would have made them FakeTensors in the # be FakeTensors already since Dynamo would have made them FakeTensors in the
@ -348,5 +348,5 @@ class DDPOptimizer:
submod_compiler.run(*example_inputs) submod_compiler.run(*example_inputs)
split_gm.recompile() split_gm.recompile()
log.debug("\n---final graph---\n" + str(split_gm.graph) + "\n---------------\n") log.debug(f"\n---final graph---\n{split_gm.graph}\n---------------\n")
return split_gm return split_gm

View File

@ -262,9 +262,9 @@ def convert_frame_assert(
assert code in guard_failures, "TODO(whc) any other recompile reasons?" assert code in guard_failures, "TODO(whc) any other recompile reasons?"
log.warning( log.warning(
f"torch._dynamo hit config.cache_size_limit ({config.cache_size_limit})\n" f"torch._dynamo hit config.cache_size_limit ({config.cache_size_limit})\n"
+ f" function: {format_func_info(code)}\n" f" function: {format_func_info(code)}\n"
+ f" reasons: {format_guard_failures(code)}\n" f" reasons: {format_guard_failures(code)}\n"
+ f"to diagnose recompilation issues, see {troubleshooting_url}." f"to diagnose recompilation issues, see {troubleshooting_url}."
) )
unimplemented("cache_size_limit reached") unimplemented("cache_size_limit reached")

View File

@ -307,7 +307,7 @@ or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib6
find_tc = self.add_lib_preload(lib_type="tcmalloc") find_tc = self.add_lib_preload(lib_type="tcmalloc")
if not find_tc: if not find_tc:
msg = f"{self.msg_lib_notfound} you can use \"conda install -c conda-forge gperftools\" to install {{0}}" msg = f"{self.msg_lib_notfound} you can use \"conda install -c conda-forge gperftools\" to install {{0}}"
logger.warning(msg.format("TCmalloc", "tcmalloc")) logger.warning(msg.format("TCmalloc", "tcmalloc")) # noqa: G001
else: else:
logger.info("Use TCMalloc memory allocator") logger.info("Use TCMalloc memory allocator")
@ -315,7 +315,7 @@ or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib6
find_je = self.add_lib_preload(lib_type="jemalloc") find_je = self.add_lib_preload(lib_type="jemalloc")
if not find_je: if not find_je:
msg = f"{self.msg_lib_notfound} you can use \"conda install -c conda-forge jemalloc\" to install {{0}}" msg = f"{self.msg_lib_notfound} you can use \"conda install -c conda-forge jemalloc\" to install {{0}}"
logger.warning(msg.format("Jemalloc", "jemalloc")) logger.warning(msg.format("Jemalloc", "jemalloc")) # noqa: G001
else: else:
logger.info("Use JeMalloc memory allocator") logger.info("Use JeMalloc memory allocator")
self.set_env("MALLOC_CONF", "oversize_threshold:1,background_thread:true,metadata_thp:auto") self.set_env("MALLOC_CONF", "oversize_threshold:1,background_thread:true,metadata_thp:auto")
@ -371,7 +371,7 @@ Value applied: {os.environ[env_name]}. Value ignored: {env_value}")
find_iomp = self.add_lib_preload(lib_type="iomp5") find_iomp = self.add_lib_preload(lib_type="iomp5")
if not find_iomp: if not find_iomp:
msg = f"{self.msg_lib_notfound} you can use \"conda install mkl\" to install {{0}}" msg = f"{self.msg_lib_notfound} you can use \"conda install mkl\" to install {{0}}"
logger.warning(msg.format("iomp", "iomp5")) logger.warning(msg.format("iomp", "iomp5")) # noqa: G001
else: else:
logger.info("Using Intel OpenMP") logger.info("Using Intel OpenMP")
if set_kmp_affinity: if set_kmp_affinity:
@ -429,9 +429,9 @@ please make sure ninstances <= total_cores)")
num_leftover_cores = ncore_per_node % args.ncores_per_instance num_leftover_cores = ncore_per_node % args.ncores_per_instance
if args.ncores_per_instance > ncore_per_node: if args.ncores_per_instance > ncore_per_node:
# too many ncores_per_instance to skip cross-node cores # too many ncores_per_instance to skip cross-node cores
logger.warning("there are {} core(s) per socket, but you specify {} ncores_per_instance and \ logger.warning("there are %s core(s) per socket, but you specify %s ncores_per_instance and \
skip_cross_node_cores. Please make sure --ncores-per-instance < core(s) per \ skip_cross_node_cores. Please make sure --ncores-per-instance < core(s) per \
socket".format(ncore_per_node, args.ncores_per_instance)) socket", ncore_per_node, args.ncores_per_instance)
exit(-1) exit(-1)
elif num_leftover_cores == 0: elif num_leftover_cores == 0:
# aren't any cross-node cores # aren't any cross-node cores

View File

@ -36,7 +36,7 @@ class PostLocalSGDState:
post_local_gradient_allreduce=True, post_local_gradient_allreduce=True,
): ):
logger.info( logger.info(
"Local SGD will be started after {} iterations".format(start_localSGD_iter) "Local SGD will be started after %s iterations", start_localSGD_iter
) )
# The group used for all-reducing gradients globally. # The group used for all-reducing gradients globally.
@ -58,7 +58,7 @@ class PostLocalSGDState:
if self.iter == self.start_localSGD_iter: if self.iter == self.start_localSGD_iter:
logger.info( logger.info(
"Start to apply local SGD after {} iterations.".format(self.iter) "Start to apply local SGD after %s iterations.", self.iter
) )

View File

@ -106,8 +106,8 @@ def _report_compression_stats(bucket, state):
): ):
stats = state.compression_stats() stats = state.compression_stats()
logger.info( logger.info(
"Compression stats: iter {}, total before compression {}, total after compression {}, " "Compression stats: iter %s, total before compression %s, total after compression %s, "
"rate {}".format(state.iter, stats[1], stats[2], stats[0]) "rate %s", state.iter, stats[1], stats[2], stats[0]
) )
state.next_stats_report = state.iter + state.compression_stats_logging_frequency state.next_stats_report = state.iter + state.compression_stats_logging_frequency
@ -183,19 +183,18 @@ class PowerSGDState:
batch_tensors_with_same_shape: bool = False, batch_tensors_with_same_shape: bool = False,
): ):
logger.info( logger.info(
"PowerSGD config: matrix_approximation_rank = {}; start_powerSGD_iter = {}; " "PowerSGD config: matrix_approximation_rank = %s; start_powerSGD_iter = %s; "
"min_compression_rate = {}; orthogonalization_epsilon = {}; use_error_feedback = {}; warm_start = {}; " "min_compression_rate = %s; orthogonalization_epsilon = %s; use_error_feedback = %s; warm_start = %s; "
"random_seed = {}; compression_stats_logging_frequency = {}; batch_tensors_with_same_shape = {}".format( "random_seed = %s; compression_stats_logging_frequency = %s; batch_tensors_with_same_shape = %s",
matrix_approximation_rank, matrix_approximation_rank,
start_powerSGD_iter, start_powerSGD_iter,
min_compression_rate, min_compression_rate,
orthogonalization_epsilon, orthogonalization_epsilon,
use_error_feedback, use_error_feedback,
warm_start, warm_start,
random_seed, random_seed,
compression_stats_logging_frequency, compression_stats_logging_frequency,
batch_tensors_with_same_shape, batch_tensors_with_same_shape,
)
) )
self.process_group = process_group self.process_group = process_group
@ -300,7 +299,7 @@ class PowerSGDState:
if self.iter == self.start_powerSGD_iter: if self.iter == self.start_powerSGD_iter:
logger.info( logger.info(
"Start to apply PowerSGD after {} iterations.".format(self.iter) "Start to apply PowerSGD after %s iterations.", self.iter
) )
def compression_stats(self): def compression_stats(self):
@ -409,9 +408,8 @@ def powerSGD_hook(
input_tensor.add_(state.error_dict[bucket_index]) input_tensor.add_(state.error_dict[bucket_index])
else: else:
logger.info( logger.info(
"A zero tensor of length {} that represents local error is created.".format( "A zero tensor of length %s that represents local error is created.",
total_length total_length
)
) )
state.error_dict[bucket_index] = torch.zeros( state.error_dict[bucket_index] = torch.zeros(
total_length, device=device, dtype=dtype total_length, device=device, dtype=dtype
@ -468,9 +466,8 @@ def powerSGD_hook(
# Only log this if warm-start to avoid spamming. # Only log this if warm-start to avoid spamming.
if state.warm_start: if state.warm_start:
logger.info( logger.info(
"Allocating contiguous memory of length {} for Ps, and of length {} for Qs, respectively.".format( "Allocating contiguous memory of length %s for Ps, and of length %s for Qs, respectively.",
total_Ps_size, total_Qs_size total_Ps_size, total_Qs_size
)
) )
state.p_memory_dict[bucket_index] = torch.empty( state.p_memory_dict[bucket_index] = torch.empty(
total_Ps_size, device=device, dtype=dtype total_Ps_size, device=device, dtype=dtype
@ -728,9 +725,8 @@ def batched_powerSGD_hook(
input_tensor.add_(state.error_dict[bucket_index]) input_tensor.add_(state.error_dict[bucket_index])
else: else:
logger.info( logger.info(
"A zero tensor of length {} that represents local error is created.".format( "A zero tensor of length %s that represents local error is created.",
padded_total_length padded_total_length
)
) )
state.error_dict[bucket_index] = torch.zeros( state.error_dict[bucket_index] = torch.zeros(
padded_total_length, device=device, dtype=input_tensor.dtype padded_total_length, device=device, dtype=input_tensor.dtype
@ -749,9 +745,8 @@ def batched_powerSGD_hook(
# Only log this if warm-start to avoid spamming. # Only log this if warm-start to avoid spamming.
if state.warm_start: if state.warm_start:
logger.info( logger.info(
"Initializing low-rank tensors P and Q, each of which has a shape of {} x {}.".format( "Initializing low-rank tensors P and Q, each of which has a shape of %s x %s.",
square_side_length, state.matrix_approximation_rank square_side_length, state.matrix_approximation_rank
)
) )
def create_low_rank_tensor(fill_random_values, rng): def create_low_rank_tensor(fill_random_values, rng):

View File

@ -476,7 +476,7 @@ def _store_based_barrier(rank, store, timeout):
""" """
store_key = "{}:{}".format(STORE_BASED_BARRIER_PREFIX, _world.group_count) store_key = "{}:{}".format(STORE_BASED_BARRIER_PREFIX, _world.group_count)
store.add(store_key, 1) store.add(store_key, 1)
logger.info("Added key: {} to store for rank: {}".format(store_key, rank)) logger.info("Added key: %s to store for rank: %s", store_key, rank)
# Now wait for all workers to check in with the store. # Now wait for all workers to check in with the store.
world_size = get_world_size() world_size = get_world_size()
@ -496,9 +496,8 @@ def _store_based_barrier(rank, store, timeout):
if timedelta(seconds=(time.time() - log_time)) > timedelta(seconds=10): if timedelta(seconds=(time.time() - log_time)) > timedelta(seconds=10):
logger.info( logger.info(
"Waiting in store based barrier to initialize process group for " "Waiting in store based barrier to initialize process group for "
"rank: {}, key: {} (world_size={}, worker_count={}, timeout={})".format( "rank: %s, key: %s (world_size=%s, worker_count=%s, timeout=%s)",
rank, store_key, world_size, worker_count, timeout rank, store_key, world_size, worker_count, timeout
)
) )
log_time = time.time() log_time = time.time()
@ -3716,7 +3715,8 @@ def new_subgroups(
if rank in ranks_in_subgroup: if rank in ranks_in_subgroup:
cur_subgroup = subgroup cur_subgroup = subgroup
logger.info( logger.info(
"Rank {} is assigned to subgroup {}".format(rank, ranks_in_subgroup) "Rank %s is assigned to subgroup %s",
rank, ranks_in_subgroup
) )
return cur_subgroup, subgroups return cur_subgroup, subgroups
@ -3828,7 +3828,7 @@ def new_subgroups_by_enumeration(
rank_to_ranks_dict[rank] = ranks rank_to_ranks_dict[rank] = ranks
if my_rank == rank: if my_rank == rank:
cur_subgroup = subgroup cur_subgroup = subgroup
logger.info("Rank {} is assigned to subgroup {}".format(rank, ranks)) logger.info("Rank %s is assigned to subgroup %s", rank, ranks)
return cur_subgroup, subgroups return cur_subgroup, subgroups

View File

@ -211,7 +211,7 @@ class EtcdRendezvous:
last_call_timeout, last_call_timeout,
): ):
self.client = client self.client = client
log.info("Etcd machines: " + str(self.client.machines)) log.info("Etcd machines: %s", self.client.machines)
self._prefix = prefix self._prefix = prefix
self._run_id = run_id self._run_id = run_id
@ -310,7 +310,7 @@ class EtcdRendezvous:
# to avoid spamming etcd # to avoid spamming etcd
# FIXME: there are a few things that fall under this like # FIXME: there are a few things that fall under this like
# etcd.EtcdKeyNotFound, etc, which could be handled more explicitly. # etcd.EtcdKeyNotFound, etc, which could be handled more explicitly.
log.info("Rendezvous attempt failed, will retry. Reason: " + str(e)) log.info("Rendezvous attempt failed, will retry. Reason: %s", e)
time.sleep(1) time.sleep(1)
def init_phase(self): def init_phase(self):
@ -335,12 +335,12 @@ class EtcdRendezvous:
try: try:
active_version = self.try_create_rendezvous() active_version = self.try_create_rendezvous()
state = json.loads(active_version.value) state = json.loads(active_version.value)
log.info("New rendezvous state created: " + str(state)) log.info("New rendezvous state created: %s", state)
except etcd.EtcdAlreadyExist: except etcd.EtcdAlreadyExist:
active_version, state = self.get_rdzv_state() active_version, state = self.get_rdzv_state()
# Note: it is possible for above query to fail (etcd.EtcdKeyNotFound), # Note: it is possible for above query to fail (etcd.EtcdKeyNotFound),
# but this is ok for us - just means we'll restart from beginning. # but this is ok for us - just means we'll restart from beginning.
log.info("Observed existing rendezvous state: " + str(state)) log.info("Observed existing rendezvous state: %s", state)
if state["status"] == "closed": if state["status"] == "closed":
raise RendezvousClosedError() raise RendezvousClosedError()
@ -365,9 +365,8 @@ class EtcdRendezvous:
active_version, this_rank = self.join_rendezvous(expected_version) active_version, this_rank = self.join_rendezvous(expected_version)
state = json.loads(active_version.value) state = json.loads(active_version.value)
log.info( log.info(
"Joined rendezvous version {} as rank {}. Full state: {}".format( "Joined rendezvous version %s as rank %s. Full state: %s",
state["version"], this_rank, state state["version"], this_rank, state
)
) )
# If this worker was first to reach num_min_workers requirement, # If this worker was first to reach num_min_workers requirement,
@ -380,10 +379,10 @@ class EtcdRendezvous:
# when min_num_workers is reached. # when min_num_workers is reached.
if this_rank == self._num_min_workers - 1 and state["status"] == "joinable": if this_rank == self._num_min_workers - 1 and state["status"] == "joinable":
log.info("Rank {} is responsible for join last call.".format(this_rank)) log.info("Rank %s is responsible for join last call.", this_rank)
last_call_deadline = time.time() + self._last_call_timeout last_call_deadline = time.time() + self._last_call_timeout
self.handle_join_last_call(expected_version, last_call_deadline) self.handle_join_last_call(expected_version, last_call_deadline)
log.info("Rank {} finished join last call.".format(this_rank)) log.info("Rank %s finished join last call.", this_rank)
# Wait for rendezvous state to be frozen, which means a fixed set of peers # Wait for rendezvous state to be frozen, which means a fixed set of peers
log.info("Waiting for remaining peers.") log.info("Waiting for remaining peers.")
@ -412,9 +411,8 @@ class EtcdRendezvous:
state = json.loads(active_version.value) state = json.loads(active_version.value)
log.info( log.info(
"Rendezvous version {} is complete. Final state: {}".format( "Rendezvous version %s is complete. Final state: %s",
state["version"], state state["version"], state
)
) )
# Rendezvous version number; our rank in it; world size # Rendezvous version number; our rank in it; world size
@ -433,9 +431,8 @@ class EtcdRendezvous:
# 2. if keep alives are missing, destroy it and bail out. # 2. if keep alives are missing, destroy it and bail out.
active_state = self.announce_self_waiting(expected_version) active_state = self.announce_self_waiting(expected_version)
log.info( log.info(
"Added self to waiting list. Rendezvous full state: {}".format( "Added self to waiting list. Rendezvous full state: %s",
active_state.value active_state.value
)
) )
self.wait_for_rendezvous_to_free(expected_version) self.wait_for_rendezvous_to_free(expected_version)
@ -698,9 +695,10 @@ class EtcdRendezvous:
if key not in keep_alive_keys: if key not in keep_alive_keys:
# This participant didn't renew their lease. We'll declare this # This participant didn't renew their lease. We'll declare this
# rendezvous version as dead (but only if it hadn't changed) # rendezvous version as dead (but only if it hadn't changed)
log.info("Keep-alive key {} is not renewed.".format(key)) log.info("Keep-alive key %s is not renewed.", key)
log.info( log.info(
"Rendevous version {} is incomplete. ".format(expected_version) "Rendevous version %s is incomplete. ",
expected_version
) )
log.info("Attempting to destroy it.") log.info("Attempting to destroy it.")
@ -713,9 +711,8 @@ class EtcdRendezvous:
) )
log.info( log.info(
"Destroyed rendezvous version {} successfully.".format( "Destroyed rendezvous version %s successfully.",
expected_version expected_version
)
) )
# We can return (and retry) immediately # We can return (and retry) immediately

View File

@ -161,8 +161,9 @@ class ShardedGradScaler(GradScaler):
for tensor in grad: for tensor in grad:
if tensor.device != expected_device: if tensor.device != expected_device:
log.error( log.error(
"tensor device is %s and expected device is %s" "tensor device is %s and expected device is %s",
% (tensor.device, expected_device) tensor.device,
expected_device,
) )
raise ValueError("Gradients must be on the same device.") raise ValueError("Gradients must be on the same device.")

View File

@ -73,10 +73,10 @@ def _write(out_path, text):
old_text = None old_text = None
if old_text != text: if old_text != text:
with open(out_path, "w") as f: with open(out_path, "w") as f:
logger.info("Writing {}".format(out_path)) logger.info("Writing %s", out_path)
f.write(text) f.write(text)
else: else:
logger.info("Skipped writing {}".format(out_path)) logger.info("Skipped writing %s", out_path)
def _do_instantiate_remote_module_template( def _do_instantiate_remote_module_template(