Add pyrefly suppressions to torch/distributed (7/n) (#165002)

Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

One more PR after this one.

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:
INFO 0 errors (6,884 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165002
Approved by: https://github.com/oulgen
This commit is contained in:
Maggie Moss
2025-10-09 04:08:21 +00:00
committed by PyTorch MergeBot
parent ab94a0d544
commit 7457d139c5
100 changed files with 354 additions and 24 deletions

View File

@ -21,7 +21,6 @@ project-excludes = [
# ==== below will be enabled directory by directory ====
# ==== to test Pyrefly on a specific directory, simply comment it out ====
"torch/_inductor/**",
"torch/distributed/**",
# formatting issues
"torch/linalg/__init__.py",
"torch/package/importer.py",

View File

@ -243,6 +243,7 @@ def derived_types(
def get_supported_param_types():
# pyrefly: ignore # bad-assignment
data: list[tuple[Union[type, typing._SpecialForm], str, bool, bool, bool]] = [
# (python type, schema type, type[] variant, type?[] variant, type[]? variant
(Tensor, "Tensor", True, True, False),

View File

@ -155,21 +155,25 @@ class cuBLASModule:
if name == "allow_tf32":
return torch._C._get_cublas_allow_tf32()
elif name == "allow_fp16_reduced_precision_reduction":
# pyrefly: ignore # not-iterable
allow_reduced_precision, _ = (
torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
)
return allow_reduced_precision
elif name == "allow_fp16_reduced_precision_reduction_split_k":
# pyrefly: ignore # not-iterable
_, allow_splitk = (
torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
)
return allow_splitk
elif name == "allow_bf16_reduced_precision_reduction":
# pyrefly: ignore # not-iterable
allow_reduced_precision, _ = (
torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
)
return allow_reduced_precision
elif name == "allow_bf16_reduced_precision_reduction_split_k":
# pyrefly: ignore # not-iterable
_, allow_splitk = (
torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
)
@ -188,14 +192,20 @@ class cuBLASModule:
value, "allow_fp16_reduced_precision_reduction"
)
return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(
allow_reduced_precision, allow_splitk
# pyrefly: ignore # bad-argument-count
allow_reduced_precision,
# pyrefly: ignore # bad-argument-count
allow_splitk,
)
elif name == "allow_bf16_reduced_precision_reduction":
allow_reduced_precision, allow_splitk = self._parse_reduction_setting(
value, "allow_bf16_reduced_precision_reduction"
)
return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(
allow_reduced_precision, allow_splitk
# pyrefly: ignore # bad-argument-count
allow_reduced_precision,
# pyrefly: ignore # bad-argument-count
allow_splitk,
)
elif name == "allow_fp16_accumulation":
return torch._C._set_cublas_allow_fp16_accumulation(value)

View File

@ -133,8 +133,9 @@ if is_available():
# Variables prefixed with underscore are not auto imported
# See the comment in `distributed_c10d.py` above `_backend` on why we expose
# this.
# pyrefly: ignore # deprecated
from .distributed_c10d import * # noqa: F403
from .distributed_c10d import (
from .distributed_c10d import ( # pyrefly: ignore # deprecated
_all_gather_base,
_coalescing_manager,
_CoalescingManager,

View File

@ -107,6 +107,7 @@ def contract(
# If the user passes a sequence of modules, then we assume that
# we only need to insert the state object on the root modules
# (i.e. those without a parent) among the passed-in modules.
# pyrefly: ignore # no-matching-overload
modules = _get_root_modules(list(module))
state = state_cls() # shared across all modules
registry_item = RegistryItem() # shared across all modules
@ -118,6 +119,7 @@ def contract(
all_orig_named_buffers: list[dict[str, torch.Tensor]] = []
all_orig_named_modules: list[dict[str, nn.Module]] = []
# pyrefly: ignore # bad-assignment
for module in modules:
default_all_state: dict[Callable, _State] = OrderedDict()
default_registry: dict[str, RegistryItem] = OrderedDict()
@ -144,8 +146,11 @@ def contract(
all_state.setdefault(func, state)
registry.setdefault(func.__name__, registry_item)
# pyrefly: ignore # missing-attribute
all_orig_named_params.append(OrderedDict(module.named_parameters()))
# pyrefly: ignore # missing-attribute
all_orig_named_buffers.append(OrderedDict(module.named_buffers()))
# pyrefly: ignore # missing-attribute
all_orig_named_modules.append(OrderedDict(module.named_modules()))
updated = func(inp_module, *args, **kwargs)
@ -160,9 +165,13 @@ def contract(
all_new_named_params: list[dict[str, nn.Parameter]] = []
all_new_named_buffers: list[dict[str, torch.Tensor]] = []
all_new_named_modules: list[dict[str, nn.Module]] = []
# pyrefly: ignore # bad-assignment
for module in updated_modules:
# pyrefly: ignore # missing-attribute
all_new_named_params.append(OrderedDict(module.named_parameters()))
# pyrefly: ignore # missing-attribute
all_new_named_buffers.append(OrderedDict(module.named_buffers()))
# pyrefly: ignore # missing-attribute
all_new_named_modules.append(OrderedDict(module.named_modules()))
num_orig_modules = len(all_orig_named_modules)
@ -225,6 +234,7 @@ def contract(
# TODO: verify that installed distributed paradigms are compatible with
# each other.
# pyrefly: ignore # bad-return
return updated
def get_state(module: nn.Module) -> _State:

View File

@ -100,6 +100,7 @@ class _ReplicateState(FSDPState):
for module in modules:
_insert_module_state(module, self)
self._modules = modules
# pyrefly: ignore # read-only
self._device = device
self._device_handle = _get_device_handle(device.type)
self._mp_policy = mp_policy
@ -150,6 +151,7 @@ class _ReplicateState(FSDPState):
)
state._is_root = False
self._state_ctx.all_states.append(state)
# pyrefly: ignore # bad-argument-type
visited_states.add(state)
if self._fsdp_param_group and self._auto_reshard_after_forward:
# For the root, do not reshard after forward since for training,

View File

@ -31,6 +31,7 @@ def _get_module_state(module: nn.Module) -> Optional[_State]:
"""
global _module_state_mapping
if isinstance(module, _State):
# pyrefly: ignore # redundant-cast
return cast(_State, module)
else:
# https://github.com/pytorch/pytorch/issues/107054

View File

@ -633,6 +633,7 @@ class AsyncCollectiveTensor(torch.Tensor):
if func == torch.ops.aten.view.default:
# Fast handle aten.view as a lot of view related op goes to aten.view
# eventually, this avoids pytree slowdown
# pyrefly: ignore # index-error
res = func(args[0].elem, args[1])
wrapper_res = AsyncCollectiveTensor(res)
return wrapper_res
@ -786,6 +787,7 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str:
FutureWarning,
stacklevel=3,
)
# pyrefly: ignore # redundant-cast
return c10d._resolve_group_name_by_ranks_and_tag(cast(list[int], group), tag)
else:
raise ValueError(f"Unsupported group type: {type(group)}, {group}")
@ -1164,15 +1166,17 @@ def all_gather_inplace(
for t in tensor_list:
is_scalar = t.dim() == 0
t_offset = 1 if is_scalar else t.size(0)
# pyrefly: ignore # unsupported-operation
out = output[offset] if is_scalar else output[offset : offset + t_offset]
output_splits.append(out)
# pyrefly: ignore # unsupported-operation
offset += t_offset
for dst, src in zip(tensor_list, output_splits):
dst.copy_(src)
return tensor_list
from torch.distributed.distributed_c10d import (
from torch.distributed.distributed_c10d import ( # pyrefly: ignore # deprecated
_all_gather_base as legacy_all_gather_base,
_reduce_scatter_base as legacy_reduce_scatter_base,
all_gather as legacy_all_gather,

View File

@ -37,7 +37,9 @@ class _MeshLayout(Layout):
different from that of PyCute's.
"""
# pyrefly: ignore # bad-override
shape: IntTuple
# pyrefly: ignore # bad-override
stride: IntTuple
def __post_init__(self) -> None:

View File

@ -43,6 +43,7 @@ def _sharded_op_common(op, early_stop_func, extra_check):
def wrapper(types, args=(), kwargs=None, pg=None):
_basic_validation(op, args, kwargs)
# pyrefly: ignore # index-error
st = args[0]
if kwargs is None:
kwargs = {}
@ -92,6 +93,7 @@ def _register_sharded_op_on_local_shards(
@_sharded_op_impl(op)
@_sharded_op_common(op, early_stop_func, extra_check)
def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None):
# pyrefly: ignore # index-error
st = args[0]
st_metadata = st.metadata()
local_shards = st.local_shards()

View File

@ -20,10 +20,13 @@ def uniform_(types, args=(), kwargs=None, pg=None):
b: the upper bound of the uniform distribution
"""
validate_param(kwargs, "kwargs")
# pyrefly: ignore # unsupported-operation
sharded_tensor = kwargs["tensor"]
validate_param(sharded_tensor, "tensor")
# pyrefly: ignore # unsupported-operation
a = kwargs["a"]
validate_param(a, "a")
# pyrefly: ignore # unsupported-operation
b = kwargs["b"]
validate_param(b, "b")
@ -43,10 +46,13 @@ def normal_(types, args=(), kwargs=None, pg=None):
std: the standard deviation of the normal distribution
"""
validate_param(kwargs, "kwargs")
# pyrefly: ignore # unsupported-operation
sharded_tensor = kwargs["tensor"]
validate_param(sharded_tensor, "tensor")
# pyrefly: ignore # unsupported-operation
mean = kwargs["mean"]
validate_param(mean, "mean")
# pyrefly: ignore # unsupported-operation
std = kwargs["std"]
validate_param(std, "std")
@ -78,12 +84,16 @@ def kaiming_uniform_(types, args=(), kwargs=None, pg=None):
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
"""
validate_param(kwargs, "kwargs")
# pyrefly: ignore # unsupported-operation
sharded_tensor = kwargs["tensor"]
validate_param(sharded_tensor, "tensor")
# pyrefly: ignore # unsupported-operation
a = kwargs["a"]
validate_param(a, "a")
# pyrefly: ignore # unsupported-operation
mode = kwargs["mode"]
validate_param(mode, "mode")
# pyrefly: ignore # unsupported-operation
nonlinearity = kwargs["nonlinearity"]
validate_param(nonlinearity, "nonlinearity")
@ -103,8 +113,10 @@ def constant_(types, args=(), kwargs=None, pg=None):
val: the value to fill the tensor with
"""
validate_param(kwargs, "kwargs")
# pyrefly: ignore # unsupported-operation
sharded_tensor = kwargs["tensor"]
validate_param(sharded_tensor, "tensor")
# pyrefly: ignore # unsupported-operation
val = kwargs["val"]
validate_param(val, "val")
for shard in sharded_tensor.local_shards():
@ -137,6 +149,7 @@ def register_tensor_creation_op(op):
if kwargs is None:
kwargs = {}
# pyrefly: ignore # index-error
st = args[0]
new_st = creation_op(st.sharding_spec(), st.size(), *args[1:], **kwargs) # type: ignore[operator]

View File

@ -40,6 +40,7 @@ _register_default_op(torch.Tensor.is_leaf.__get__, _sharded_op_impl) # type: ig
# the device property on each rank
@_sharded_op_impl(torch.Tensor.device.__get__)
def tensor_device(types, args=(), kwargs=None, pg=None):
# pyrefly: ignore # index-error
self_st = args[0]
# Validate types
if not isinstance(self_st, ShardedTensor):
@ -56,6 +57,7 @@ def tensor_device(types, args=(), kwargs=None, pg=None):
@_sharded_op_impl(torch.Tensor.is_meta.__get__) # type: ignore[attr-defined]
def st_is_meta(types, args=(), kwargs=None, pg=None):
# pyrefly: ignore # index-error
return args[0].local_tensor().is_meta
@ -196,6 +198,7 @@ _register_sharded_op_on_local_shards(
@_sharded_op_impl(torch.Tensor.requires_grad_)
def tensor_requires_grad_set(types, args=(), kwargs=None, pg=None):
# pyrefly: ignore # index-error
self_st = args[0]
# Validate types
if not isinstance(self_st, ShardedTensor):

View File

@ -299,7 +299,9 @@ class ShardedTensor(ShardedTensorBase):
if self._init_rrefs:
with _sharded_tensor_lock:
global _sharded_tensor_current_id, _sharded_tensor_map
# pyrefly: ignore # bad-assignment
self._sharded_tensor_id = _sharded_tensor_current_id
# pyrefly: ignore # unsupported-operation
_sharded_tensor_map[self._sharded_tensor_id] = weakref.ref(self)
_sharded_tensor_current_id += 1

View File

@ -208,6 +208,7 @@ def build_global_metadata(
global_sharded_tensor_metadata = None
global_metadata_rank = 0
# pyrefly: ignore # bad-assignment
for rank, rank_metadata in enumerate(gathered_metadatas):
if rank_metadata is None:
continue

View File

@ -167,6 +167,7 @@ class ChunkShardingSpec(ShardingSpec):
)
tensors_to_scatter[
# pyrefly: ignore # bad-argument-type
dist.get_group_rank(process_group, remote_global_rank)
] = tensor_to_scatter

View File

@ -58,6 +58,7 @@ def _register_sharded_op_on_local_tensor(
@custom_sharding_spec_op(ChunkShardingSpec, op)
@_sharded_op_common(op, early_stop_func, extra_check)
def sharded_tensor_op_on_local_tensor(types, args=(), kwargs=None, pg=None):
# pyrefly: ignore # index-error
st = args[0]
sharding_spec = st.sharding_spec()
if len(st.local_shards()) != 1:

View File

@ -425,7 +425,9 @@ def _handle_row_wise_sharding(
else:
split_sizes = torch.cat(
(
# pyrefly: ignore # unsupported-operation
offsets[1 : offsets.size(0)] - offsets[0:-1],
# pyrefly: ignore # unsupported-operation
(input.size(0) - offsets[-1]).unsqueeze(0),
),
dim=-1,

View File

@ -195,11 +195,13 @@ def _iterate_state_dict(
ret.local_shards()[idx].tensor, non_blocking=non_blocking
)
else:
# pyrefly: ignore # missing-attribute
companion_obj.copy_(ret, non_blocking=non_blocking)
ret = companion_obj
else:
ret = {} if isinstance(ret, dict) else None
# pyrefly: ignore # bad-return
return ret
@ -797,6 +799,7 @@ def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None
CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val)
)
else:
# pyrefly: ignore # bad-argument-type
extend_list(cur_container, prev_key)
if cur_container[prev_key] is None:
cur_container[prev_key] = def_val

View File

@ -1692,6 +1692,7 @@ def empty(
@overload
# pyrefly: ignore # inconsistent-overload
def empty(
size: Sequence[_int],
*,

View File

@ -231,6 +231,7 @@ class FSDPMemTracker(MemTracker):
" or file a github issue if you need this feature."
)
# pyrefly: ignore # bad-assignment
args, kwargs = orig_fsdp_state_pre_fw(*args, **kwargs)
fsdp_state = fsdp_mod._get_fsdp_state()
@ -364,6 +365,7 @@ class FSDPMemTracker(MemTracker):
# `FSDPParamGroup.post_forward` because during AC these won't be called.
# TODO(@sanketpurandare): This will need to be modified after this PR (https://github.com/pytorch/pytorch/pull/127786)
# lands. For backward we monkey-patch the `FSDPParamGroup.pre_backward` and `FSDPParamGroup.post_backward`.
# pyrefly: ignore # missing-attribute
for module in self._root_mod.modules():
if isinstance(module, FSDPModule):
fsdp_state = module._get_fsdp_state()
@ -372,6 +374,7 @@ class FSDPMemTracker(MemTracker):
fsdp_state._pre_forward_hook_handle.remove()
fsdp_state._post_forward_hook_handle.remove()
fsdp_state._pre_forward_hook_handle = (
# pyrefly: ignore # missing-attribute
module.register_forward_pre_hook(
self._fsdp_state_pre_forward(
module, fsdp_state._pre_forward
@ -380,6 +383,7 @@ class FSDPMemTracker(MemTracker):
with_kwargs=True,
)
)
# pyrefly: ignore # missing-attribute
fsdp_state._post_forward_hook_handle = module.register_forward_hook(
self._fsdp_state_post_forward(module, fsdp_state._post_forward),
prepend=False,
@ -398,6 +402,7 @@ class FSDPMemTracker(MemTracker):
)
)
# pyrefly: ignore # missing-attribute
for buffer in self._root_mod.buffers():
self._update_and_maybe_create_winfos(
buffer,
@ -507,6 +512,7 @@ class FSDPMemTracker(MemTracker):
):
# N.B: This is a hacky way to override the Meta IMPL of wait_tensor. The original impl returns
# a new tensor which does not happen in eager mode, when a wait_tensor is called.
# pyrefly: ignore # unsupported-operation
res = args[0]
else:
res = func(*args, **kwargs or {})
@ -523,6 +529,7 @@ class FSDPMemTracker(MemTracker):
_FSDPState.PRE_FW,
_FSDPState.PRE_BW,
]:
# pyrefly: ignore # unsupported-operation
output_tensor = args[0]
self._update_and_maybe_create_winfos(
output_tensor,
@ -533,6 +540,7 @@ class FSDPMemTracker(MemTracker):
func == c10d._reduce_scatter_base_.default
and self._fsdp_state == _FSDPState.POST_BW
):
# pyrefly: ignore # unsupported-operation
input_tensor = args[1]
self._update_and_maybe_create_winfos(
input_tensor,

View File

@ -143,6 +143,7 @@ class _WeakRefInfo:
self.size = size
self.element_size = element_size
self.reftype = reftype
# pyrefly: ignore # read-only
self.device = device
self.mem_consumed = self._calculate_mem_consumed()
@ -404,6 +405,7 @@ class MemTracker(TorchDispatchMode):
# Initialize a flag to track if the total memory might drop to zero after updates.
maybe_zero = False
# Ensure the device entry exists in the current memory snapshot, initializing if necessary.
# pyrefly: ignore # no-matching-overload
dev_snap = self._curr_mem_snap.setdefault(
winfo.device, dict.fromkeys(self._ref_class, 0)
)
@ -915,6 +917,7 @@ class MemTracker(TorchDispatchMode):
self._depth += 1
return self
# pyrefly: ignore # bad-override
def __exit__(self, *args: Any) -> None:
self._depth -= 1
if self._depth == 0:
@ -932,6 +935,7 @@ class MemTracker(TorchDispatchMode):
):
# N.B: This is a hacky way to override the Meta IMPL of wait_tensor. The original impl returns
# a new tensor which does not happen in eager mode, when a wait_tensor is called.
# pyrefly: ignore # index-error
res = args[0]
else:
res = func(*args, **kwargs or {})

View File

@ -232,7 +232,9 @@ class MemoryTracker:
def _pre_forward_hook(module: nn.Module, inputs: Any) -> None:
self._cur_module_name = f"{name}.forward"
if (
# pyrefly: ignore # invalid-argument
hasattr(module, "_memory_tracker_is_root")
# pyrefly: ignore # not-callable
and module._memory_tracker_is_root
):
self._add_marker("fw_start")
@ -248,7 +250,9 @@ class MemoryTracker:
outputs: Sequence[torch.Tensor],
) -> None:
if (
# pyrefly: ignore # invalid-argument
hasattr(module, "_memory_tracker_is_root")
# pyrefly: ignore # not-callable
and module._memory_tracker_is_root
):
self._add_marker("fw_bw_boundary")

View File

@ -178,6 +178,7 @@ class ModTracker:
def custom_formatwarning(msg, category, filename, lineno, line=None):
return f"{filename}:{lineno}: {category.__name__}: {msg} \n"
# pyrefly: ignore # bad-assignment
warnings.formatwarning = custom_formatwarning
warnings.warn(
"The module hierarchy tracking maybe be messed up."

View File

@ -519,6 +519,7 @@ class RuntimeEstimator(TorchDispatchMode):
super().__enter__()
return self
# pyrefly: ignore # bad-override
def __exit__(self, *args: Any) -> None:
print(
f"Estimated ({self._estimate_mode_type})"

View File

@ -429,6 +429,7 @@ class SACEstimator(TorchDispatchMode):
# sdpa has non-deterministic seed, but might be deterministic
# if no dropout is applied
if func.overloadpacket.__name__ == "_scaled_dot_product_flash_attention":
# pyrefly: ignore # missing-attribute
is_rand_op = kwargs.get("dropout_p", 0) != 0
# 5. Create metadata information per active non-leaf module
for mod_fqn in self._mod_tracker.parents:

View File

@ -65,6 +65,7 @@ def _dequantize_tensor(tensor, qtype, quant_loss=None):
elif tensor.dtype == torch.float16 and quant_loss is None:
return tensor.float()
else:
# pyrefly: ignore # unsupported-operation
return tensor.float() / quant_loss
elif qtype == DQuantType.BFP16:
if tensor.dtype != torch.float16:

View File

@ -22,6 +22,7 @@ def _allreduce_fut(
group_to_use = process_group if process_group is not None else dist.group.WORLD
# Apply the division first to avoid overflow, especially for FP16.
# pyrefly: ignore # missing-attribute
tensor.div_(group_to_use.size())
return (
@ -59,6 +60,7 @@ def _compress_hook(
bucket: dist.GradBucket,
) -> torch.futures.Future[torch.Tensor]:
group_to_use = process_group if process_group is not None else dist.group.WORLD
# pyrefly: ignore # missing-attribute
world_size = group_to_use.size()
buffer = (
@ -78,7 +80,11 @@ def _compress_hook(
if torch.compiler.is_compiling():
grad = dist._functional_collectives.all_reduce(
compressed_tensor, "sum", group_to_use
# pyrefly: ignore # bad-argument-type
compressed_tensor,
"sum",
# pyrefly: ignore # bad-argument-type
group_to_use,
)
return decompress(grad)
else:

View File

@ -66,6 +66,7 @@ def quantization_pertensor_hook(
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
rank = process_group.rank() if process_group is not None else dist.get_rank()
# pyrefly: ignore # missing-attribute
world_size = group_to_use.size()
tensor = bucket.buffer()
@ -147,6 +148,7 @@ def quantization_perchannel_hook(
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
rank = process_group.rank() if process_group is not None else dist.get_rank()
# pyrefly: ignore # missing-attribute
world_size = group_to_use.size()
tensor = bucket.buffer()

View File

@ -210,6 +210,7 @@ class Join:
"""
process_group = None
device = None
# pyrefly: ignore # bad-assignment
for joinable in self._joinables:
if process_group is None:
process_group = joinable.join_process_group

View File

@ -74,6 +74,7 @@ class HybridModel(torch.nn.Module):
assert NUM_PS * EMBEDDING_DIM >= 512
dim_normalizer = int(NUM_PS * EMBEDDING_DIM / 512)
emb_lookups_reshaped = emb_lookups_cat.reshape( # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name
[emb_lookups_cat.shape[0] * dim_normalizer, 512]
)

View File

@ -45,6 +45,7 @@ def _get_logging_handler(
return (log_handler, log_handler_name)
# pyrefly: ignore # unknown-name
global _c10d_logger
_c10d_logger = _get_or_create_logger()

View File

@ -12,6 +12,10 @@ from .metadata import (
from .optimizer import load_sharded_optimizer_state_dict
from .planner import LoadPlan, LoadPlanner, ReadItem, SavePlan, SavePlanner, WriteItem
from .quantized_hf_storage import QuantizedHuggingFaceStorageReader
# pyrefly: ignore # deprecated
from .state_dict_loader import load, load_state_dict
# pyrefly: ignore # deprecated
from .state_dict_saver import async_save, save, save_state_dict
from .storage import StorageReader, StorageWriter

View File

@ -305,6 +305,7 @@ class _ProcessBasedAsyncCheckpointExecutor(_AsyncCheckpointExecutor):
@_dcp_method_logger(**ckpt_kwargs)
def create_checkpoint_daemon_process() -> None:
global _CHECKPOINT_PROCESS
# pyrefly: ignore # bad-argument-type
_CHECKPOINT_PROCESS = _AsyncCheckpointProcess(pg_init_info=pg_init_info)
create_checkpoint_daemon_process()

View File

@ -322,6 +322,7 @@ class CheckpointProcess:
subprocess_pid = self.process.processes[0].pid
# send graceful termination to sub process
try:
# pyrefly: ignore # missing-attribute
self._parent_end.send(
WorkerRequest(
request_type=RequestType.TERMINATE_PROCESS,

View File

@ -175,6 +175,7 @@ class CheckpointReader:
# create a new map with all the keys present in source_value
target_value = dict.fromkeys(source_value.keys())
# pyrefly: ignore # missing-attribute
for key in list(target_value.keys()):
current_path = f"{key_path}.{key}" if key_path else key
if key in source_value:

View File

@ -147,12 +147,14 @@ class DefaultStager(CheckpointStager):
self._staging_stream = None
if self._config.use_async_staging:
# pyrefly: ignore # bad-assignment
self._staging_executor = ThreadPoolExecutor(max_workers=1)
if torch.accelerator.is_available():
# Note: stream needs to be initialized on the main thread after default cuda
# stream is setup/used to avoid the risk of accidentally reusing the main
# compute stream or in other cases kernels actually launching from the
# main thread.
# pyrefly: ignore # bad-assignment
self._staging_stream = torch.Stream()
if self._config.use_non_blocking_copy:

View File

@ -94,6 +94,7 @@ class ZStandard(StreamTransformExtension):
return zstandard is not None or pyzstd is not None
@staticmethod
# pyrefly: ignore # bad-override
def from_descriptor(version: str) -> "ZStandard":
if version.partition(".")[0] != "1":
raise ValueError(f"Unknown extension {version=}")
@ -216,6 +217,7 @@ class ExtensionRegistry:
ext = self.extensions.get(name)
if not ext:
raise ValueError(f"Unknown extension {name=}")
# pyrefly: ignore # bad-argument-type
return ext.from_descriptor(version)
return [from_descriptor(desc) for desc in descriptors]

View File

@ -227,6 +227,7 @@ class PGTransport:
self._work: list[Work] = []
self._pg = pg
self._timeout = timeout
# pyrefly: ignore # read-only
self._device = device
self._state_dict = state_dict

View File

@ -128,6 +128,7 @@ def set_element(
CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val)
)
else:
# pyrefly: ignore # bad-argument-type
extend_list(cur_container, prev_key)
if cur_container[prev_key] is None:
cur_container[prev_key] = def_val
@ -154,6 +155,7 @@ def get_element(
elif not isinstance(cur_value, Mapping) or part not in cur_value:
return default_value
# pyrefly: ignore # index-error
cur_value = cast(CONTAINER_TYPE, cur_value[part])
return cast(Optional[T], cur_value)

View File

@ -60,6 +60,7 @@ def _init_model(rank, world_size):
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
_patch_model_state_dict(model)
# pyrefly: ignore # bad-argument-type
_patch_optimizer_state_dict(model, optimizers=optim)
return model, optim
@ -92,6 +93,7 @@ def run(rank, world_size):
loss_calc = torch.nn.BCELoss()
f = None
# pyrefly: ignore # bad-assignment
for epoch in range(NUM_EPOCHS):
try:
torch.manual_seed(epoch)

View File

@ -64,6 +64,7 @@ class BroadcastingTorchSaveReader(StorageReader):
self.checkpoint_id = checkpoint_id
self.coordinator_rank = coordinator_rank
# pyrefly: ignore # bad-override
def read_metadata(self) -> Metadata:
"""Extends the default StorageReader to support building the metadata file"""
# Metadata is built in planner.set_up_planner, since we are not actually reading metadata from
@ -102,6 +103,7 @@ class BroadcastingTorchSaveReader(StorageReader):
# Broadcast the tensor from the coordinator rank
if self.is_coordinator:
pg_device = dist.distributed_c10d._get_pg_default_device()
# pyrefly: ignore # unsupported-operation
tensor = torch_state_dict[req.storage_index.fqn].to(pg_device)
else:
tensor = torch.empty_like(planner.state_dict[req.storage_index.fqn])
@ -121,6 +123,7 @@ class BroadcastingTorchSaveReader(StorageReader):
fut.set_result(None)
return fut
# pyrefly: ignore # bad-override
def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
"""Implementation of the StorageReader method"""
self.is_coordinator = is_coordinator

View File

@ -307,6 +307,7 @@ class HuggingFaceStorageReader(FileSystemReader):
fut.set_result(None)
return fut
# pyrefly: ignore # bad-override
def read_metadata(self) -> Metadata:
from safetensors import safe_open # type: ignore[import]
from safetensors.torch import _getdtype # type: ignore[import]

View File

@ -16,6 +16,7 @@ logger = logging.getLogger()
__all__: list[str] = []
# pyrefly: ignore # unknown-name
global _dcp_logger
_dcp_logger = c10d_logger._get_or_create_logger(DCP_LOGGER_NAME)
@ -36,9 +37,11 @@ def _msg_dict_from_dcp_method_args(*args, **kwargs) -> dict[str, Any]:
checkpoint_id = kwargs.get("checkpoint_id", None)
if not checkpoint_id and (serializer := storage_writer or storage_reader):
# pyrefly: ignore # unbound-name
checkpoint_id = getattr(serializer, "checkpoint_id", None)
msg_dict["checkpoint_id"] = (
# pyrefly: ignore # unsupported-operation
str(checkpoint_id) if checkpoint_id is not None else checkpoint_id
)

View File

@ -29,6 +29,8 @@ from torch.distributed.checkpoint.planner_helpers import (
_create_read_items,
create_read_items_for_chunk_list,
)
# pyrefly: ignore # deprecated
from torch.distributed.checkpoint.state_dict_loader import load_state_dict
from torch.distributed.checkpoint.storage import StorageReader
from torch.distributed.checkpoint.utils import (
@ -157,6 +159,7 @@ def _get_state_dict_2d_layout(
class _ReaderWithOffset(DefaultLoadPlanner):
translation: dict[MetadataIndex, MetadataIndex]
state_dict: STATE_DICT_TYPE
# pyrefly: ignore # bad-override
metadata: Metadata
def __init__(self, fqn_to_offset: dict[str, Sequence[int]]) -> None:

View File

@ -182,12 +182,14 @@ class DefaultStager(AsyncStager):
self._staging_executor = None
self._staging_stream = None
if self._config.use_async_staging:
# pyrefly: ignore # bad-assignment
self._staging_executor = ThreadPoolExecutor(max_workers=1)
if torch.accelerator.is_available():
# Note: stream needs to be initialized on the main thread after default cuda
# stream is setup/used to avoid the risk of accidentally reusing the main
# compute stream or in other cases kernels actually launching from the
# main thread.
# pyrefly: ignore # bad-assignment
self._staging_stream = torch.Stream()
if self._config.use_non_blocking_copy:
@ -348,6 +350,7 @@ class _ReplicationStager(AsyncStager):
):
self._pg = pg
self._timeout = timeout
# pyrefly: ignore # read-only
self._device = device
self._transport = PGTransport(pg, timeout, device, None)

View File

@ -199,6 +199,7 @@ def _get_fqns(
return {f"{prefix}{fqn}" for fqn in flat_param._fqns}
curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE)
if curr_obj_name != FSDP_WRAPPED_MODULE:
# pyrefly: ignore # bad-argument-type
fqn_obj_names.append(curr_obj_name)
curr_obj = getattr(curr_obj, curr_obj_name)
elif isinstance(curr_obj, torch._dynamo.eval_frame.OptimizedModule):
@ -215,6 +216,7 @@ def _get_fqns(
):
if hasattr(curr_obj, removed_fqn):
curr_obj = getattr(curr_obj, removed_fqn)
# pyrefly: ignore # bad-argument-type
fqn_obj_names.append(curr_obj_name)
if curr_obj_name == nn.modules.module._EXTRA_STATE_KEY_SUFFIX:
if i != len(obj_names) - 1:
@ -1206,6 +1208,7 @@ def _unflatten_model_state_dict(
if not state_dict:
return {}
# pyrefly: ignore # no-matching-overload
if isinstance(next(iter(state_dict.keys())), nn.Module):
warnings.warn(
"Passing model_state_dict as a ``Dict[nn.Module, Dict[str, Any]]``"

View File

@ -329,6 +329,7 @@ def async_save(
upload_future: Future = upload_executor.execute_save(
staging_future_or_state_dict,
checkpoint_id=checkpoint_id,
# pyrefly: ignore # bad-argument-type
storage_writer=storage_writer,
planner=planner,
process_group=process_group,

View File

@ -254,6 +254,7 @@ class _DistWrapper:
if len(node_failures) > 0:
result = CheckpointException(step, node_failures)
# pyrefly: ignore # bad-argument-type
final_result = self.broadcast_object(result)
if isinstance(final_result, CheckpointException):
raise final_result
@ -302,6 +303,7 @@ class _DistWrapper:
result = map_fun()
except BaseException as e: # noqa: B036
result = CheckpointException(step, {self.rank: _wrap_exception(e)})
# pyrefly: ignore # bad-argument-type
final_result = self.broadcast_object(result)
if isinstance(final_result, CheckpointException):
raise final_result

View File

@ -114,6 +114,7 @@ def broadcast(
error_msg += f": stage {sync_obj.stage_name}"
if sync_obj.exception is not None:
error_msg += f": exception {sync_obj.exception}"
# pyrefly: ignore # invalid-inheritance
raise RuntimeError(error_msg) from sync_obj.exception
return cast(T, sync_obj.payload)
@ -183,13 +184,16 @@ def all_gather(
if len(exception_list) > 0:
raise RuntimeError( # type: ignore[misc]
error_msg, exception_list
error_msg,
exception_list,
# pyrefly: ignore # invalid-inheritance
) from exception_list[0]
return ret_list
else:
if not sync_obj.success:
raise RuntimeError(
f"all_gather failed with exception {sync_obj.exception}",
# pyrefly: ignore # invalid-inheritance
) from sync_obj.exception
return [sync_obj.payload] # type: ignore[list-item]
@ -266,10 +270,13 @@ def _summarize_ranks(ranks: Iterable[int]) -> str:
result = []
for r in ranges:
if len(r) == 1:
# pyrefly: ignore # bad-argument-type
result.append(f"{r.start}")
elif r.step == 1:
# pyrefly: ignore # bad-argument-type
result.append(f"{r.start}:{r.stop}")
else:
# pyrefly: ignore # bad-argument-type
result.append(f"{r.start}:{r.stop}:{r.step}")
return ",".join(result)

View File

@ -482,6 +482,7 @@ else:
self._init_process_groups(backend_override)
if is_initialized() and get_backend() == "threaded":
# pyrefly: ignore # bad-assignment
self._thread_id = threading.get_ident()
if _rank is None:
@ -650,6 +651,7 @@ else:
# We temporarily revert the reuse subgroup, since it breaks two internal tests.
# Temporarily reverting to resolve test timeout while root-causing.
# TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists.
# pyrefly: ignore # unbound-name
if bound_device_id is None or not has_split_group:
dim_group = new_group(
ranks=subgroup_ranks,

View File

@ -372,6 +372,7 @@ class BackendConfig:
def __init__(self, backend: Backend):
"""Init."""
self.device_backend_map: dict[str, Backend] = {}
# pyrefly: ignore # bad-assignment
backend = str(backend)
if backend == Backend.UNDEFINED:
@ -392,6 +393,7 @@ class BackendConfig:
# e.g. "nccl", "gloo", "ucc", "mpi"
supported_devices = Backend.backend_capability[backend.lower()]
backend_val = Backend(backend)
# pyrefly: ignore # bad-assignment
self.device_backend_map = dict.fromkeys(supported_devices, backend_val)
elif ":" in backend.lower():
# Backend specified in "device:backend" format
@ -410,6 +412,7 @@ class BackendConfig:
f"Invalid device:backend pairing: \
{device_backend_pair_str}. {backend_str_error_message}"
)
# pyrefly: ignore # bad-assignment
device, backend = device_backend_pair
if device in self.device_backend_map:
raise ValueError(
@ -1182,6 +1185,7 @@ def _as_iterable(obj) -> collections.abc.Iterable:
def _ensure_all_tensors_same_dtype(*tensors) -> None:
last_dtype = None
# pyrefly: ignore # bad-assignment
for tensor in itertools.chain.from_iterable(map(_as_iterable, tensors)):
tensor_dtype = tensor.dtype
# Mixing complex and its element type is allowed
@ -1837,6 +1841,7 @@ def _get_split_source(pg):
split_from = pg._get_backend(pg.bound_device_id)
elif pg is _world.default_pg:
try:
# pyrefly: ignore # missing-attribute
split_from = pg._get_backend(torch.device("cuda"))
except RuntimeError:
# no cuda device associated with this backend
@ -1997,7 +2002,12 @@ def _new_process_group_helper(
if not is_gloo_available():
raise RuntimeError("Distributed package doesn't have Gloo built in")
backend_class = ProcessGroupGloo(
backend_prefix_store, group_rank, group_size, timeout=timeout
# pyrefly: ignore # bad-argument-type
backend_prefix_store,
group_rank,
group_size,
# pyrefly: ignore # bad-argument-type
timeout=timeout,
)
backend_class.options.global_ranks_in_group = global_ranks_in_group
backend_class.options.group_name = group_name
@ -2018,6 +2028,7 @@ def _new_process_group_helper(
# default backend_options for NCCL
backend_options = ProcessGroupNCCL.Options()
backend_options.is_high_priority_stream = False
# pyrefly: ignore # bad-argument-type
backend_options._timeout = timeout
if split_from:
@ -2037,7 +2048,12 @@ def _new_process_group_helper(
# RuntimeError if is_ucc_available() returns false.
backend_class = ProcessGroupUCC(
backend_prefix_store, group_rank, group_size, timeout=timeout
# pyrefly: ignore # bad-argument-type
backend_prefix_store,
group_rank,
group_size,
# pyrefly: ignore # bad-argument-type
timeout=timeout,
)
backend_type = ProcessGroup.BackendType.UCC
elif backend_str == Backend.XCCL:
@ -2046,6 +2062,7 @@ def _new_process_group_helper(
backend_options = ProcessGroupXCCL.Options()
backend_options.global_ranks_in_group = global_ranks_in_group
backend_options.group_name = group_name
# pyrefly: ignore # bad-argument-type
backend_options._timeout = timeout
backend_class = ProcessGroupXCCL(
backend_prefix_store, group_rank, group_size, backend_options
@ -2070,6 +2087,7 @@ def _new_process_group_helper(
dist_backend_opts.store = backend_prefix_store
dist_backend_opts.group_rank = group_rank
dist_backend_opts.group_size = group_size
# pyrefly: ignore # bad-argument-type
dist_backend_opts.timeout = timeout
dist_backend_opts.group_id = group_name
dist_backend_opts.global_ranks_in_group = global_ranks_in_group
@ -2113,6 +2131,7 @@ def _new_process_group_helper(
store=backend_prefix_store,
rank=group_rank,
world_size=group_size,
# pyrefly: ignore # bad-argument-type
timeout=timeout,
)
@ -3322,6 +3341,7 @@ def gather_object(
return
assert object_gather_list is not None, "Must provide object_gather_list on dst rank"
# pyrefly: ignore # unbound-name
for i, tensor in enumerate(output_tensors):
tensor = tensor.type(torch.uint8)
tensor_size = object_size_list[i]
@ -3698,8 +3718,10 @@ def broadcast_object_list(
# has only one element, we can skip the copy.
if my_group_rank == group_src:
if len(tensor_list) == 1: # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name
object_tensor = tensor_list[0]
else:
# pyrefly: ignore # unbound-name
object_tensor = torch.cat(tensor_list)
else:
object_tensor = torch.empty( # type: ignore[call-overload]
@ -3828,6 +3850,7 @@ def scatter_object_list(
broadcast(max_tensor_size, group_src=group_src, group=group)
# Scatter actual serialized objects
# pyrefly: ignore # no-matching-overload
output_tensor = torch.empty(
max_tensor_size.item(), dtype=torch.uint8, device=pg_device
)
@ -4864,16 +4887,19 @@ def barrier(
if isinstance(device_ids, list):
opts.device_ids = device_ids
# use only the first device id
# pyrefly: ignore # read-only
opts.device = torch.device(device.type, device_ids[0])
elif getattr(group, "bound_device_id", None) is not None:
# Use device id from `init_process_group(device_id=...)`
opts.device = group.bound_device_id # type: ignore[assignment]
elif device.type == "cpu" or _get_object_coll_device(group) == "cpu":
# pyrefly: ignore # read-only
opts.device = torch.device("cpu")
else:
# Use the current device set by the user. If user did not set any, this
# may use default device 0, causing issues like hang or all processes
# creating context on device 0.
# pyrefly: ignore # read-only
opts.device = device
if group.rank() == 0:
warnings.warn( # warn only once
@ -5004,6 +5030,7 @@ def _hash_ranks_to_str(ranks: list[int]) -> str:
# Takes a list of ranks and computes an integer color
def _process_group_color(ranks: list[int]) -> int:
# Convert list to tuple to make it hashable
# pyrefly: ignore # bad-assignment
ranks = tuple(ranks)
hash_value = hash(ranks)
# Split color must be:

View File

@ -333,8 +333,10 @@ class LocalElasticAgent(SimpleElasticAgent):
rank=worker.global_rank,
local_rank=local_rank,
)
# pyrefly: ignore # unsupported-operation
log_line_prefixes[local_rank] = log_line_prefix
# pyrefly: ignore # unsupported-operation
envs[local_rank] = worker_env
worker_args = list(spec.args)
worker_args = macros.substitute(worker_args, str(local_rank))

View File

@ -54,6 +54,7 @@ class Event:
if isinstance(data, str):
data_dict = json.loads(data)
data_dict["source"] = EventSource[data_dict["source"]] # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name
return Event(**data_dict)
def serialize(self) -> str:
@ -108,6 +109,7 @@ class RdzvEvent:
if isinstance(data, str):
data_dict = json.loads(data)
data_dict["node_state"] = NodeState[data_dict["node_state"]] # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name
return RdzvEvent(**data_dict)
def serialize(self) -> str:

View File

@ -142,7 +142,7 @@ Now all metrics in the group ``my_app`` will be printed to stdout as:
from typing import Optional
from .api import ( # noqa: F401
from .api import ( # noqa: F401; pyrefly: ignore # deprecated; pyrefly: ignore # deprecated
configure,
ConsoleMetricHandler,
get_elapsed_time_ms,

View File

@ -171,12 +171,15 @@ def profile(group=None):
try:
start_time = time.time()
result = func(*args, **kwargs)
# pyrefly: ignore # bad-argument-type
publish_metric(group, f"{func.__name__}.success", 1)
except Exception:
# pyrefly: ignore # bad-argument-type
publish_metric(group, f"{func.__name__}.failure", 1)
raise
finally:
publish_metric(
# pyrefly: ignore # bad-argument-type
group,
f"{func.__name__}.duration.ms",
get_elapsed_time_ms(start_time), # type: ignore[possibly-undefined]

View File

@ -97,6 +97,7 @@ class TailLog:
n = len(log_files)
self._threadpool = None
if n > 0:
# pyrefly: ignore # bad-assignment
self._threadpool = ThreadPoolExecutor(
max_workers=n,
thread_name_prefix=f"{self.__class__.__qualname__}_{name}",

View File

@ -126,6 +126,7 @@ class EtcdRendezvousBackend(RendezvousBackend):
return tmp
def _decode_state(self, result: etcd.EtcdResult) -> tuple[bytes, Token]:
# pyrefly: ignore # missing-attribute
base64_state = result.value.encode()
try:
@ -135,6 +136,7 @@ class EtcdRendezvousBackend(RendezvousBackend):
"The state object is corrupt. See inner exception for details."
) from exc
# pyrefly: ignore # missing-attribute
return state, result.modifiedIndex

View File

@ -53,6 +53,7 @@ class ElasticDistributedSampler(DistributedSampler[T]):
raise TypeError("Dataset must be an instance of collections.abc.Sized")
# Cast to Sized for mypy
# pyrefly: ignore # redundant-cast
sized_dataset = cast(Sized, dataset)
if start_index >= len(sized_dataset):

View File

@ -65,6 +65,7 @@ class _FSDPDeviceHandle:
if backend is None:
try:
self.__backend = getattr(torch, device.type)
# pyrefly: ignore # read-only
self.__device = device
except AttributeError as exc:
raise AttributeError(

View File

@ -539,6 +539,7 @@ class FlatParamHandle:
# Only align addresses for `use_orig_params=True` (for now)
align_addresses = use_orig_params
self._init_get_unflat_views_fn(align_addresses)
# pyrefly: ignore # read-only
self.device = device
self._device_handle = _FSDPDeviceHandle.from_device(self.device)
self.process_group = process_group

View File

@ -220,6 +220,7 @@ def _move_states_to_device(
the future.
"""
# Follow the logic in `nn.Module._apply`
# pyrefly: ignore # bad-argument-type
for tensor in itertools.chain(params, buffers):
if tensor.device == device or tensor.device.type == "meta":
# Keep meta-device tensors on meta device for deferred init

View File

@ -232,6 +232,7 @@ class FSDPParam:
self._module_info: ParamModuleInfo = module_info
self.mesh_info = mesh_info
self.post_forward_mesh_info = post_forward_mesh_info
# pyrefly: ignore # read-only
self.device = device
self.mp_policy = mp_policy
self.offload_to_cpu: bool = isinstance(offload_policy, CPUOffloadPolicy)
@ -554,6 +555,7 @@ class FSDPParam:
f"world size ({shard_world_size})"
)
shard_rank = self.post_forward_mesh_info.shard_mesh_rank
# pyrefly: ignore # unbound-name
sharded_numel = numel // shard_world_size
self._sharded_post_forward_param_data = (
self.all_gather_outputs[0].narrow(
@ -684,6 +686,7 @@ class FSDPParam:
self.device, non_blocking=True
)
pre_all_gather_signature = inspect.signature(
# pyrefly: ignore # missing-attribute
sharded_local_tensor.fsdp_pre_all_gather
)
num_fn_params = len(pre_all_gather_signature.parameters)
@ -701,6 +704,7 @@ class FSDPParam:
(
all_gather_inputs,
self._extensions_data.all_gather_metadata,
# pyrefly: ignore # missing-attribute
) = sharded_local_tensor.fsdp_pre_all_gather(
self.shard_mesh_from_root
)
@ -708,6 +712,7 @@ class FSDPParam:
(
all_gather_inputs,
self._extensions_data.all_gather_metadata,
# pyrefly: ignore # missing-attribute
) = sharded_local_tensor.fsdp_pre_all_gather(
self.shard_mesh_from_root,
self._orig_size,
@ -829,6 +834,7 @@ class FSDPParam:
f"instead of {self.sharded_param}"
)
self.sharded_param = new_param
# pyrefly: ignore # missing-attribute
local_tensor = new_param._local_tensor
if local_tensor.is_meta:
return

View File

@ -151,6 +151,7 @@ class FSDPParamGroup:
]
self.mesh_info = mesh_info
self.post_forward_mesh_info = post_forward_mesh_info
# pyrefly: ignore # read-only
self.device = device
self.device_handle = _get_device_handle(device.type)
self.mp_policy = mp_policy
@ -616,6 +617,7 @@ class FSDPParamGroup:
# Prefetch naively using the reverse post-forward order, which may
# have mistargeted prefetches if not all modules used in forward
# are used in this backward
# pyrefly: ignore # unbound-name
target_fsdp_param_group = self.comm_ctx.post_forward_order[target_index]
self._prefetch_unshard(target_fsdp_param_group, "backward")
@ -852,6 +854,7 @@ compile the forward part if you want to use Traceable FSDP2."""
raise RuntimeError(msg)
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, param_group: FSDPParamGroup, *inputs: torch.Tensor):
# All tensors in `inputs` should require gradient
RegisterPostBackwardFunction._assert_not_tracing_fsdp()

View File

@ -96,6 +96,7 @@ class FSDPState(_State):
for module in modules:
_insert_module_state(module, self)
self._modules = modules
# pyrefly: ignore # read-only
self._device = device
self._device_handle = _get_device_handle(device.type)
self._mp_policy = mp_policy

View File

@ -50,6 +50,7 @@ def get_cls_to_fsdp_cls() -> dict[type, type]:
@overload
# pyrefly: ignore # inconsistent-overload
def fully_shard(
module: nn.Module,
*,
@ -63,6 +64,7 @@ def fully_shard(
@overload
# pyrefly: ignore # inconsistent-overload
def fully_shard(
module: list[nn.Module],
*,

View File

@ -508,6 +508,7 @@ def _init_prefetching_state(
@no_type_check
# pyrefly: ignore # bad-function-definition
def _init_extension(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPState:
# TODO: we need to add additional check once we support FSDP + PiPPy.
# This check is currently sufficient, since we only support FSDP + TP.
@ -904,7 +905,10 @@ def _materialize_meta_module(
# As a contract to the user, only call `reset_parameters()` if
# the module has directly managed parameters/buffers
module_state_iter = itertools.chain(
module.parameters(recurse=False), module.buffers(recurse=False)
# pyrefly: ignore # bad-argument-type
module.parameters(recurse=False),
# pyrefly: ignore # bad-argument-type
module.buffers(recurse=False),
)
has_module_states = len(list(module_state_iter)) > 0
if has_module_states:

View File

@ -603,6 +603,7 @@ def _flatten_optim_state(
]
# Check that the unflattened parameters have the same state names
state_names = None
# pyrefly: ignore # bad-assignment
for unflat_param_state in unflat_param_states:
if unflat_param_state is None:
continue
@ -918,6 +919,7 @@ def _rekey_sharded_optim_state_dict(
flat_param_key = unflat_param_names_to_flat_param_key.get(
key.unflat_param_names, key.unflat_param_names
)
# pyrefly: ignore # unsupported-operation
rekeyed_osd_state[flat_param_key] = param_state
# Only process param_groups if it exists in sharded_osd
@ -980,6 +982,7 @@ def _get_param_id_to_param_from_optim_input(
if optim_input is None:
return dict(enumerate(model.parameters()))
try:
# pyrefly: ignore # no-matching-overload
params = cast(list[nn.Parameter], list(optim_input))
except TypeError as e:
raise TypeError(

View File

@ -354,7 +354,9 @@ class _RemoteModule(nn.Module):
_raise_not_supported(self.to.__name__)
def register_backward_hook( # type: ignore[return]
self, hook: Callable[[Module, _grad_t, _grad_t], Union[None, _grad_t]]
self,
hook: Callable[[Module, _grad_t, _grad_t], Union[None, _grad_t]],
# pyrefly: ignore # bad-return
) -> RemovableHandle:
_raise_not_supported(self.register_backward_hook.__name__)
@ -369,6 +371,7 @@ class _RemoteModule(nn.Module):
],
prepend: bool = False,
with_kwargs: bool = False,
# pyrefly: ignore # bad-return
) -> RemovableHandle:
_raise_not_supported(self.register_forward_pre_hook.__name__)
@ -380,6 +383,7 @@ class _RemoteModule(nn.Module):
],
prepend: bool = False,
with_kwargs: bool = False,
# pyrefly: ignore # bad-return
) -> RemovableHandle:
_raise_not_supported(self.register_forward_hook.__name__)
@ -400,7 +404,11 @@ class _RemoteModule(nn.Module):
)
def named_parameters( # type: ignore[return]
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
self,
prefix: str = "",
recurse: bool = True,
remove_duplicate: bool = True,
# pyrefly: ignore # bad-return
) -> Iterator[tuple[str, Parameter]]:
_raise_not_supported(self.named_parameters.__name__)
@ -408,7 +416,11 @@ class _RemoteModule(nn.Module):
_raise_not_supported(self.buffers.__name__)
def named_buffers( # type: ignore[return]
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
self,
prefix: str = "",
recurse: bool = True,
remove_duplicate: bool = True,
# pyrefly: ignore # bad-return
) -> Iterator[tuple[str, Tensor]]:
_raise_not_supported(self.named_buffers.__name__)
@ -572,23 +584,31 @@ class _RemoteModule(nn.Module):
remote_module = object.__new__(RemoteModule)
# pyrefly: ignore # missing-attribute
enable_moving_cpu_tensors_to_cuda = remote_module._prepare_init(remote_device)
if _module_interface_cls is not None:
# Users reply on this field to know if this generated RemoteModule is TorchScript-able.
# pyrefly: ignore # missing-attribute
remote_module.is_scriptable = True
# pyrefly: ignore # missing-attribute
remote_module._init_template(
_module_interface_cls, enable_moving_cpu_tensors_to_cuda
)
else:
# pyrefly: ignore # missing-attribute
remote_module.is_scriptable = False
# pyrefly: ignore # missing-attribute
remote_module.generated_methods = (
_NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods
)
# pyrefly: ignore # missing-attribute
remote_module.module_rref = module_rref
# pyrefly: ignore # missing-attribute
remote_module._install_generated_methods()
# pyrefly: ignore # missing-attribute
remote_module._check_attribute_picklability()
return remote_module
@ -691,9 +711,11 @@ def _remote_module_receiver(
m.__dict__.update(serialized_remote_module._asdict())
# Unpickling the attribute `module_rref` must invoke RRef's `_deserialize()` method.
# pyrefly: ignore # missing-attribute
m.module_rref = rpc.PyRRef._deserialize(m.module_rref)
# Install generated methods when unpickled.
# pyrefly: ignore # missing-attribute
for method in m.generated_methods:
method_name = method.__name__
method = torch.jit.export(method)

View File

@ -225,6 +225,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=group.WORLD):
class _Broadcast(Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, src, group, tensor):
ctx.src = src
ctx.group = group
@ -236,6 +237,7 @@ class _Broadcast(Function):
return tensor
@staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output):
gx = _Reduce.apply(ctx.src, ReduceOp.SUM, ctx.group, grad_output)
if ctx.src != ctx.rank:
@ -245,6 +247,7 @@ class _Broadcast(Function):
class _Gather(Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, dst, group, tensor):
ctx.dst = dst
ctx.group = group
@ -270,6 +273,7 @@ class _Gather(Function):
class _Scatter(Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, src, group, *tensors):
ctx.src = src
ctx.group = group
@ -282,12 +286,14 @@ class _Scatter(Function):
return output
@staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output):
return (None, None) + _Gather.apply(ctx.src, ctx.group, grad_output)
class _Reduce(Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, src, op, group, tensor):
ctx.src = src
ctx.group = group
@ -296,12 +302,14 @@ class _Reduce(Function):
return tensor
@staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output):
return (None, None, None) + (_Broadcast.apply(ctx.src, ctx.group, grad_output),)
class _Reduce_Scatter(Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, op, group, tensor, *input_tensor_list):
ctx.group = group
# Need contiguous tensors for collectives.
@ -311,12 +319,14 @@ class _Reduce_Scatter(Function):
return tensor
@staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output):
return (None, None, None) + _AllGather.apply(ctx.group, grad_output)
class _AllGather(Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, group, tensor):
# Need contiguous tensors for collectives.
tensor = tensor.contiguous()
@ -346,12 +356,14 @@ class _AllGather(Function):
class _AllGatherBase(Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, output_tensor, input_tensor, group):
ctx.group = group
dist._all_gather_base(output_tensor, input_tensor.contiguous(), group=group)
return output_tensor
@staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output):
if dist.get_backend(group=ctx.group) is dist.Backend.NCCL:
world_size = dist.get_world_size(group=ctx.group)
@ -373,6 +385,7 @@ class _AllGatherBase(Function):
class _AlltoAll(Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, group, out_tensor_list, *tensors):
ctx.group = group
ctx.input_tensor_size_list = [
@ -408,6 +421,7 @@ class _AlltoAll(Function):
class _AlltoAllSingle(Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, group, output, output_split_sizes, input_split_sizes, input):
ctx.group = group
ctx.input_size = input.size()
@ -423,6 +437,7 @@ class _AlltoAllSingle(Function):
return output
@staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output):
tensor = torch.empty(
ctx.input_size, device=grad_output.device, dtype=grad_output.dtype
@ -440,6 +455,7 @@ class _AlltoAllSingle(Function):
class _AllReduce(Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, op, group, tensor):
ctx.group = group
ctx.op = op
@ -448,5 +464,6 @@ class _AllReduce(Function):
return tensor
@staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output):
return (None, None) + (_AllReduce.apply(ctx.op, ctx.group, grad_output),)

View File

@ -100,15 +100,19 @@ def _broadcast_object(
data = bytearray(buffer.getbuffer())
length_tensor = torch.LongTensor([len(data)]).to(device)
data_send_tensor = torch.ByteTensor(data).to(device)
# pyrefly: ignore # bad-argument-type
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
# pyrefly: ignore # bad-argument-type
dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False)
else:
# Receive the object
length_tensor = torch.LongTensor([0]).to(device)
# pyrefly: ignore # bad-argument-type
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
data_recv_tensor = torch.empty(
[int(length_tensor.item())], dtype=torch.uint8, device=device
)
# pyrefly: ignore # bad-argument-type
dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False)
buffer = io.BytesIO(data_recv_tensor.cpu().numpy())
obj = torch.load(buffer, map_location=device, weights_only=False)
@ -167,6 +171,7 @@ class _DDPBucketAssignment:
if len(self.parameters) == 0:
raise ValueError("Empty bucket assignment")
# DDP guarantees all parameters in the bucket have the same device
# pyrefly: ignore # read-only
self.device: torch.device = self.parameters[0].device
self.tensor: Optional[torch.Tensor] = None
@ -415,7 +420,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
self.world_size: int = dist.get_world_size(self.process_group)
self.rank: int = dist.get_rank(self.process_group)
self.global_rank: int = dist.distributed_c10d.get_global_rank(
self.process_group, self.rank
# pyrefly: ignore # bad-argument-type
self.process_group,
self.rank,
)
self._overlap_with_ddp: bool = overlap_with_ddp
@ -535,7 +542,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
self._all_state_dicts = []
for rank in range(self.world_size):
global_rank = dist.distributed_c10d.get_global_rank(
self.process_group, rank
# pyrefly: ignore # bad-argument-type
self.process_group,
rank,
)
if self.rank == to:
# Consolidate all local `state_dict`s on this rank, storing on
@ -767,7 +776,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
for dev_i_buckets in self._buckets:
bucket = dev_i_buckets[rank]
global_rank = dist.distributed_c10d.get_global_rank(
self.process_group, rank
# pyrefly: ignore # bad-argument-type
self.process_group,
rank,
)
handles.append(
dist.broadcast(
@ -780,7 +791,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
else:
param_groups = self._partition_parameters()[rank]
global_rank = dist.distributed_c10d.get_global_rank(
self.process_group, rank
# pyrefly: ignore # bad-argument-type
self.process_group,
rank,
)
for param_group in param_groups:
handles.extend(
@ -979,11 +992,14 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
for param_index, param in enumerate(bucket_params):
param_numel = param.numel()
if (
# pyrefly: ignore # unbound-name
assignment_size + param_numel >= threshold
and param_index > bucket_offset
):
assigned_rank = self._get_min_index(
size_per_rank, assigned_ranks_per_bucket[bucket_index]
# pyrefly: ignore # unbound-name
size_per_rank,
assigned_ranks_per_bucket[bucket_index],
)
# Include up to but not including the parameter that
# exceeded the threshold
@ -994,6 +1010,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
assigned_rank,
assigned_ranks_per_bucket,
)
# pyrefly: ignore # unbound-name
size_per_rank[assigned_rank] += assignment_size
bucket_offset = param_index
assignment_size = 0
@ -1001,7 +1018,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
# Assign the remainder of the bucket so that no assignment
# spans across two buckets
assigned_rank = self._get_min_index(
size_per_rank, assigned_ranks_per_bucket[bucket_index]
# pyrefly: ignore # unbound-name
size_per_rank,
assigned_ranks_per_bucket[bucket_index],
)
self._assign_bucket_subset_to_rank(
bucket_index,
@ -1010,6 +1029,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
assigned_rank,
assigned_ranks_per_bucket,
)
# pyrefly: ignore # unbound-name
size_per_rank[assigned_rank] += assignment_size
return self._bucket_assignments_per_rank_cache
@ -1088,6 +1108,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
return loss
# pyrefly: ignore # bad-override
def step(
self,
closure: Optional[Callable[[], float]] = None,

View File

@ -282,6 +282,7 @@ class LossWrapper(torch.nn.Module):
class TrivialLossWrapper(LossWrapper):
# pyrefly: ignore # bad-override
def forward(self, x, targets):
model_out = self.module(x)
return self.loss_fn(model_out, targets)

View File

@ -245,6 +245,7 @@ def stage_backward_weight(
if non_none_grads:
summed_grad = sum(non_none_grads)
valid_edges.append(GradientEdge(intermediate, 0))
# pyrefly: ignore # bad-argument-type
valid_grad_outputs.append(summed_grad)
# Break a reference cycle caused inside stage_backward_input->get_hook->hook

View File

@ -81,6 +81,7 @@ def get_schedule_ops(
raise ValueError(f"Invalid schedule: {schedule_class}")
# Instantiate the schedule class
# pyrefly: ignore # bad-instantiation, bad-argument-type
schedule_instance = schedule_class(stages, num_microbatches)
assert schedule_instance.pipeline_order is not None

View File

@ -279,6 +279,7 @@ def _shard_dict_of_args(
f"Unsupported chunk spec: {spec} and value: {v} combination."
)
# pyrefly: ignore # no-matching-overload
for _flat_split_result, _v_split in zip(
flat_split_results, v_splits, strict=True
):

View File

@ -246,6 +246,7 @@ def _format_pipeline_order(
pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order)
]
# Transpose the list of lists (rows to columns)
# pyrefly: ignore # no-matching-overload
transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue=""))
# Generate column labels for ranks
num_ranks = len(pipeline_order)

View File

@ -155,6 +155,7 @@ class _PipelineStageBase(ABC):
self.submod = submodule
self.stage_index = stage_index
self.num_stages = num_stages
# pyrefly: ignore # read-only
self.device = device
self.group = group

View File

@ -36,12 +36,14 @@ class _remote_device:
elif isinstance(remote_device, str):
fields = remote_device.split("/")
if len(fields) == 2:
# pyrefly: ignore # bad-assignment
self._worker_name, self._device = fields
elif len(fields) == 1:
# Check if this is a valid device.
if _remote_device._is_valid_local_device(fields[0]):
self._device = fields[0]
else:
# pyrefly: ignore # bad-assignment
self._worker_name = fields[0]
self._device = "cpu"
else:
@ -63,6 +65,7 @@ class _remote_device:
# rank:<rank>/device format, extract rank
if fields[0] == "rank" and fields[1].isdigit():
self._rank = int(fields[1]) # type: ignore[assignment]
# pyrefly: ignore # bad-assignment
self._worker_name = None
else:
raise ValueError(PARSE_ERROR)

View File

@ -93,6 +93,7 @@ def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwa
result = result._replace(
query=f"{'&'.join([f'{k}={v}' for k, v in query_dict.items()])}"
)
# pyrefly: ignore # bad-assignment
url = urlunparse(result)
if result.scheme not in _rendezvous_handlers:
@ -110,6 +111,7 @@ def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs):
if not isinstance(world_size, numbers.Integral):
raise RuntimeError(f"`world_size` must be an integer. {world_size}")
# pyrefly: ignore # bad-argument-type
return _rendezvous_helper(url, rank, world_size, **kwargs)

View File

@ -473,6 +473,7 @@ def _rref_typeof_on_user(
T = TypeVar("T")
# pyrefly: ignore # invalid-annotation
GenericWithOneTypeVar = Generic[T]
@ -719,6 +720,7 @@ def _invoke_rpc(
is_async_exec = hasattr(func, "_wrapped_async_rpc_function")
if is_async_exec:
# pyrefly: ignore # missing-attribute
wrapped = func._wrapped_async_rpc_function
if isinstance(wrapped, torch.jit.ScriptFunction):
func = wrapped

View File

@ -95,6 +95,7 @@ def register_backend(
BackendType.__repr__ = _backend_type_repr # type: ignore[assignment]
if BackendType.__doc__:
BackendType.__doc__ = _backend_type_doc
# pyrefly: ignore # unsupported-operation
return BackendType[backend_name]

View File

@ -48,6 +48,7 @@ else:
_TensorPipeRpcBackendOptionsBase = object # type: ignore[assignment, misc]
# pyrefly: ignore # invalid-inheritance
class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
r"""
The backend options for

View File

@ -4,6 +4,8 @@
import itertools
import torch
# pyrefly: ignore # deprecated
from torch.autograd.profiler_legacy import profile
from . import (
@ -174,11 +176,13 @@ class _server_process_global_profile(profile):
flattened_function_events = list(
itertools.chain.from_iterable(process_global_function_events)
)
# pyrefly: ignore # bad-assignment
self.function_events = torch.autograd.profiler_util.EventList(
flattened_function_events,
use_device="cuda" if self.use_cuda else None,
profile_memory=self.profile_memory,
)
# pyrefly: ignore # missing-attribute
self.function_events._build_tree()
self.process_global_function_events = process_global_function_events

View File

@ -840,6 +840,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str
) from e
logs_specs_cls: type[LogsSpecs] = _get_logs_specs_class(args.logs_specs)
# pyrefly: ignore # bad-instantiation
logs_specs = logs_specs_cls(
log_dir=args.log_dir,
redirects=Std.from_str(args.redirects),

View File

@ -189,7 +189,10 @@ class OpDispatcher:
local_tensor_args = (
pytree.tree_unflatten(
cast(list[object], op_info.local_args), op_info.args_tree_spec
# pyrefly: ignore # bad-argument-type
cast(list[object], op_info.local_args),
# pyrefly: ignore # bad-argument-type
op_info.args_tree_spec,
)
if op_info.args_tree_spec
else op_info.local_args
@ -361,7 +364,11 @@ class OpDispatcher:
with redistribute_context:
resharded_local_tensor = redistribute_local_tensor(
local_tensor, arg_spec, reshard_arg_spec
# pyrefly: ignore # bad-argument-type
local_tensor,
arg_spec,
# pyrefly: ignore # bad-argument-type
reshard_arg_spec,
)
new_local_args.append(resharded_local_tensor)
else:
@ -431,7 +438,11 @@ class OpDispatcher:
op_call, args_list
)
kwargs_schema[k] = self._try_replicate_spec_for_scalar_tensor(
op_call, v, compute_mesh
# pyrefly: ignore # bad-argument-type
op_call,
v,
# pyrefly: ignore # bad-argument-type
compute_mesh,
)
local_kwargs[k] = v
else:
@ -447,6 +458,7 @@ class OpDispatcher:
OpSchema(
op_call,
(
# pyrefly: ignore # bad-argument-type
pytree.tree_unflatten(args_schema, args_spec)
if args_spec
else tuple(args_schema)

View File

@ -171,6 +171,7 @@ def einop_rule(
global_shape, input_spec.mesh, input_spec.placements
)
cost += prod(local_shape) * input_spec.mesh.size(mesh_dim)
# pyrefly: ignore # bad-argument-type
costs.append(cost)
d_to_keep_sharding = dims[costs.index(max(costs))]
for d in dims:

View File

@ -131,6 +131,7 @@ class _NormPartial(Partial):
if self.reduce_op == "sum":
assert isinstance(self.norm_type, (int, float)), f"{self.norm_type}"
if self.norm_type != 0 and self.norm_type != 1:
# pyrefly: ignore # unsupported-operation
return tensor**self.norm_type
return tensor
@ -138,6 +139,7 @@ class _NormPartial(Partial):
if self.reduce_op == "sum":
assert isinstance(self.norm_type, (int, float)), f"{self.norm_type}"
if self.norm_type != 0 and self.norm_type != 1:
# pyrefly: ignore # unsupported-operation
return tensor ** (1.0 / self.norm_type)
return tensor

View File

@ -1057,7 +1057,9 @@ def grouped_mm_strategy(op_schema: OpSchema) -> OpStrategy:
)
return TensorMeta(torch.Size(local_shape), local_stride, meta.dtype)
# pyrefly: ignore # missing-attribute
mat1_meta = local_meta(mat1_strategy.strategies[0], input_specs[0].placements)
# pyrefly: ignore # missing-attribute
mat2_meta = local_meta(mat2_strategy.strategies[0], input_specs[1].placements)
def check_valid_strides(meta: TensorMeta) -> bool:

View File

@ -336,6 +336,7 @@ def expand_to_full_mesh_op_strategy(
for specs in zip(*strategy_comb):
if specs[0] is not None:
# TODO: we should fill in tensor_meta here. If nothing else, it helps the filter strategy callback
# pyrefly: ignore # bad-argument-type
spec_list.append(DTensorSpec(mesh, specs))
else:
spec_list.append(None)

View File

@ -150,6 +150,7 @@ class _RNGStateTracker:
"""
def __init__(self, device: torch.device):
# pyrefly: ignore # read-only
self._device = device
self._device_handle = _get_device_handle(self._device.type)
if not (self._device_handle and self._device_handle.is_available()):

View File

@ -256,8 +256,10 @@ def convolution_backward_handler(
kwargs: dict[str, object],
) -> object:
# Redistribute grad_output tensor to the same placement as input tensor
# pyrefly: ignore # bad-assignment
args = list(args)
assert isinstance(args[0], dtensor.DTensor) and isinstance(args[1], dtensor.DTensor)
# pyrefly: ignore # unsupported-operation
args[0] = args[0].redistribute(args[1].device_mesh, args[1].placements)
args = tuple(args)

View File

@ -594,6 +594,7 @@ class CommDebugMode(TorchDispatchMode):
self.advanced_module_tracker.__enter__()
return self
# pyrefly: ignore # bad-override
def __exit__(self, *args):
self.advanced_module_tracker.__exit__()
super().__exit__(*args)

View File

@ -90,6 +90,7 @@ def print_op_coverage_summary(model: nn.Module, args, kwargs, *, output_csv=Fals
op_infos.sort(key=itemgetter(count_idx), reverse=True)
headers = ["Operator", "Schema", "Total Count", "Supported"]
# pyrefly: ignore # bad-argument-type
print(tabulate(op_infos, headers=headers))
if output_csv:
@ -101,4 +102,5 @@ def print_op_coverage_summary(model: nn.Module, args, kwargs, *, output_csv=Fals
csv_writer.writerow(headers)
# Write each table row to the CSV file
for row in op_infos:
# pyrefly: ignore # bad-argument-type
csv_writer.writerow(row)

View File

@ -90,8 +90,12 @@ class LocalShardsWrapper(torch.Tensor):
# TODO: we shall continually extend this function to support more ops if needed
if func in supported_ops:
res_shards_list = [
func(shard, *args[1:], **kwargs) for shard in args[0].shards
# pyrefly: ignore # index-error
func(shard, *args[1:], **kwargs)
# pyrefly: ignore # index-error
for shard in args[0].shards
]
# pyrefly: ignore # index-error
return LocalShardsWrapper(res_shards_list, args[0].shard_offsets)
else:
raise NotImplementedError(
@ -141,6 +145,7 @@ def run_torchrec_row_wise_even_sharding_example(rank, world_size):
local_tensor = torch.randn(local_shard_shape, device=device)
# row-wise sharding: one shard per rank
# create the local shards wrapper
# pyrefly: ignore # no-matching-overload
local_shards_wrapper = LocalShardsWrapper(
local_shards=[local_tensor],
offsets=[local_shard_offset],
@ -219,6 +224,7 @@ def run_torchrec_row_wise_uneven_sharding_example(rank, world_size):
# local shards
# row-wise sharding: one shard per rank
# create the local shards wrapper
# pyrefly: ignore # no-matching-overload
local_shards_wrapper = LocalShardsWrapper(
local_shards=[local_tensor],
offsets=[local_shard_offset],
@ -297,6 +303,7 @@ def run_torchrec_table_wise_sharding_example(rank, world_size):
local_shard_offset = torch.Size((0, 0))
# wrap local shards into a wrapper
local_shards_wrapper = (
# pyrefly: ignore # no-matching-overload
LocalShardsWrapper(
local_shards=[local_tensor],
offsets=[local_shard_offset],

View File

@ -475,6 +475,7 @@ def _templated_ring_attention(
)
sdpa_merger.step(out, logsumexp, partial)
# pyrefly: ignore # unbound-name
return *sdpa_merger.results(), *rest
@ -641,6 +642,7 @@ def _templated_ring_attention_backward(
grad_query,
grad_key,
grad_value,
# pyrefly: ignore # unbound-name
*rest,
)
@ -988,7 +990,9 @@ def _distribute_function(
def _restore_function(fn: Callable, fn_module: types.ModuleType) -> None:
"""Restore the function that is replaced by _distribute_function."""
# pyrefly: ignore # unknown-name
global _original_functions
# pyrefly: ignore # unknown-name
global _wrapper_functions
if fn not in _replaced_functions:
@ -1021,6 +1025,7 @@ def _context_parallel_dispatcher(
placement = [Shard(seq_dim)]
all_args = []
# pyrefly: ignore # bad-assignment, bad-argument-type
for arg in itertools.chain(args, kwargs.values()):
if isinstance(arg, torch.Tensor) and not isinstance(arg, DTensor):
arg = DTensor.from_local(arg, mesh, placement, run_check=False)

View File

@ -238,6 +238,7 @@ def _local_map_wrapped(
flat_local_args.append(arg)
# pyrefly: ignore # bad-argument-type
local_args = pytree.tree_unflatten(flat_local_args, args_spec)
out = func(*local_args, **kwargs)
@ -271,6 +272,7 @@ def _local_map_wrapped(
flat_dist_out.append(out)
# pyrefly: ignore # bad-argument-type
return pytree.tree_unflatten(flat_dist_out, out_spec)
else:
return out

View File

@ -237,8 +237,11 @@ def _mark_sharding(
op_schema,
)
placement_strategies[node] = OpSpec(
# pyrefly: ignore # bad-argument-type
output_specs=_get_output_spec_from_output_sharding(output_sharding),
# pyrefly: ignore # missing-attribute
input_specs=output_sharding.redistribute_schema.args_spec
# pyrefly: ignore # missing-attribute
if output_sharding.redistribute_schema is not None
else _get_input_node_specs(node, placement_strategies),
)

View File

@ -135,9 +135,11 @@ def _rewrite_spec_if_needed(
break
if rewrite:
spec = copy.deepcopy(spec)
# pyrefly: ignore # missing-attribute
for i, placement in enumerate(spec.placements):
placement = cast(_remote_device, placement)
if placement.rank() == rank and placement.device() != tensor.device:
# pyrefly: ignore # missing-attribute
spec.placements[i] = _remote_device(f"rank:{rank}/{tensor.device}")
return spec

View File

@ -251,6 +251,7 @@ def _nll_loss_forward(
if weight is not None:
new_shape = list(x.shape)
new_shape[channel_dim] = -1
# pyrefly: ignore # unbound-name
w = w.expand(new_shape)
wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim)
wsum = torch.where(target != ignore_index, wsum, 0)
@ -308,7 +309,9 @@ def _nll_loss_forward_handler(
output_placements = all_replicate_placements
# tensor inputs to _propagate_tensor_meta need to be DTensors
# pyrefly: ignore # bad-assignment
args = list(args)
# pyrefly: ignore # unsupported-operation
args[1], args[2] = target, weight
output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs)
@ -439,8 +442,11 @@ def _nll_loss_backward_handler(
weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh)
# tensor inputs to _propagate_tensor_meta need to be DTensors
# pyrefly: ignore # bad-assignment
args = list(args)
# pyrefly: ignore # unsupported-operation
args[2], args[3] = target, weight
# pyrefly: ignore # unsupported-operation
args[6] = _cast_to_dtensor(total_weight, all_replicate_placements, spec.mesh)
output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs)

View File

@ -548,6 +548,7 @@ class PrepareModuleInput(ParallelStyle):
assert self.desired_input_layouts is not None, (
"desired module inputs should not be None!"
)
# pyrefly: ignore # no-matching-overload
for inp, input_layout, desired_layout in zip(
inputs, self.input_layouts, self.desired_input_layouts
):
@ -663,6 +664,7 @@ class PrepareModuleOutput(ParallelStyle):
raise ValueError(
"module outputs and output_layouts should have same length!"
)
# pyrefly: ignore # no-matching-overload
for out, out_layout, desired_out_layout in zip(
outputs, self.output_layouts, self.desired_output_layouts
):

View File

@ -59,6 +59,7 @@ def _cast_forward_inputs(
def cast_fn(x: torch.Tensor) -> torch.Tensor:
if not torch.is_floating_point(x) or x.dtype == dtype:
return x
# pyrefly: ignore # no-matching-overload
return x.to(dtype)
return (_apply_to_tensors(cast_fn, args), _apply_to_tensors(cast_fn, kwargs))
@ -133,12 +134,16 @@ def _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies):
from torch.nn.parallel.scatter_gather import _is_namedtuple
if _is_namedtuple(obj):
# pyrefly: ignore # no-matching-overload
return [type(obj)(*args) for args in zip(*map(to_map, obj))]
if isinstance(obj, tuple) and len(obj) > 0:
# pyrefly: ignore # no-matching-overload
return list(zip(*map(to_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
# pyrefly: ignore # no-matching-overload
return [list(i) for i in zip(*map(to_map, obj))]
if isinstance(obj, dict) and len(obj) > 0:
# pyrefly: ignore # no-matching-overload
return [type(obj)(i) for i in zip(*map(to_map, obj.items()))]
return [obj]