Compare commits

...

15 Commits

Author SHA1 Message Date
e105c47187 remove more lines 2025-09-15 15:51:32 -07:00
bb9d89d32b update 2025-09-15 15:15:47 -07:00
259bb6f535 update 2025-09-15 15:09:42 -07:00
9c85b9f05b fix instance logger 2025-09-15 15:07:24 -07:00
ff61b90cde update 2025-09-15 14:58:56 -07:00
d1a640b146 update 2025-09-15 14:42:20 -07:00
074af6f329 update 2025-09-15 14:41:05 -07:00
cb0624f175 update 2025-09-15 14:33:54 -07:00
da969718ad update 2025-09-15 14:30:34 -07:00
c94cca7b59 update 2025-09-15 14:21:44 -07:00
fbdb9ba538 update 2025-09-15 14:16:49 -07:00
607f616b3e update 2025-09-15 14:15:07 -07:00
eca1bf31b7 update 2025-09-15 14:12:54 -07:00
606c46aee0 update 2025-09-15 14:07:28 -07:00
d093a0afd4 No Distributed Log Spew 2025-09-15 13:57:09 -07:00
4 changed files with 109 additions and 7 deletions

38
demo_no_spew.py Normal file
View File

@ -0,0 +1,38 @@
#!/usr/bin/env python3
"""
Demo: No log spew with distributed logging patch.
Run with: torchrun --nproc_per_node=2 demo_no_spew.py
"""
import os
import warnings
import logging
import torch
import torch.distributed as dist
# Initialize distributed
if 'RANK' in os.environ:
dist.init_process_group('gloo')
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
print(f"=== Process {rank}/{world_size} ===")
# Test warnings
warnings.warn("This warning should only appear ONCE (from rank 0)")
# Test logging
logging.warning("This logging should only appear ONCE (from rank 0)")
# Test the original cpp_extension case
logging.getLogger('torch.utils.cpp_extension').setLevel(logging.DEBUG)
from torch.utils.cpp_extension import _get_cuda_arch_flags
_get_cuda_arch_flags()
print(f"Process {rank} completed")
if 'RANK' in os.environ:
dist.destroy_process_group()

View File

@ -0,0 +1,62 @@
"""
Minimal distributed logging to prevent log spew across ranks.
Only warnings and logging statements on rank 0 are emitted.
"""
import logging
import warnings
import torch
def _is_non_rank_zero():
"""Check if we should suppress output (non-rank-0 in distributed mode)."""
return (torch.distributed.is_available() and
torch.distributed.is_initialized() and
torch.distributed.get_rank() != 0)
def _make_rank_zero_only(original_func):
"""Create a wrapper that only executes on rank 0."""
def wrapper(*args, **kwargs):
if _is_non_rank_zero():
return
return original_func(*args, **kwargs)
return wrapper
def patch_logging_for_distributed(patch_print=False):
"""
Patch warnings and logging to only emit on rank 0 in distributed mode.
Args:
patch_print: If True, also patch print() to only emit on rank 0.
Default False since print is often used for debugging.
"""
original_warn = warnings.warn
def distributed_safe_warn(message, category=None, stacklevel=1, source=None):
if _is_non_rank_zero():
return
return original_warn(message, category, stacklevel + 1, source)
# Patch warnings.warn
warnings.warn = distributed_safe_warn
# Patch logging module functions
for method_name in ['debug', 'info', 'warning', 'warn', 'error', 'critical']:
if hasattr(logging, method_name):
original_func = getattr(logging, method_name)
setattr(logging, method_name, _make_rank_zero_only(original_func))
# Patch Logger class methods to catch logger instances
for method_name in ['debug', 'info', 'warning', 'warn', 'error', 'critical']:
if hasattr(logging.Logger, method_name):
original_method = getattr(logging.Logger, method_name)
setattr(logging.Logger, method_name, _make_rank_zero_only(original_method))
# Optionally patch print
if patch_print:
import builtins
original_print = builtins.print
builtins.print = _make_rank_zero_only(original_print)

View File

@ -1768,6 +1768,10 @@ def init_process_group(
)
_update_default_pg(default_pg)
# Enable distributed-safe logging to prevent log spew from non-rank-0 processes
from torch.distributed._distributed_logging import patch_logging_for_distributed
patch_logging_for_distributed()
_world.pg_group_ranks[GroupMember.WORLD] = { # type: ignore[index]
i: i
for i in range(GroupMember.WORLD.size()) # type: ignore[attr-defined]

View File

@ -2448,13 +2448,11 @@ def _get_cuda_arch_flags(cflags: Optional[list[str]] = None) -> list[str]:
arch_list[-1] += '+PTX'
if not _arch_list:
# Only log on rank 0 in distributed settings to avoid spam
if not torch.distributed.is_available() or not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
arch_list_str = ';'.join(arch_list)
logger.debug(
"TORCH_CUDA_ARCH_LIST is not set, using TORCH_CUDA_ARCH_LIST='%s' "
"for visible GPU architectures. Set os.environ['TORCH_CUDA_ARCH_LIST'] to override.",
arch_list_str)
arch_list_str = ';'.join(arch_list)
logger.debug(
"TORCH_CUDA_ARCH_LIST is not set, using TORCH_CUDA_ARCH_LIST='%s' "
"for visible GPU architectures. Set os.environ['TORCH_CUDA_ARCH_LIST'] to override.",
arch_list_str)
else:
# Deal with lists that are ' ' separated (only deal with ';' after)
_arch_list = _arch_list.replace(' ', ';')