mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Add pyrefly suppressions to torch/distributed (7/n) (#165002)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 One more PR after this one. Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the project-excludes field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: INFO 0 errors (6,884 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165002 Approved by: https://github.com/oulgen
This commit is contained in:
committed by
PyTorch MergeBot
parent
ab94a0d544
commit
7457d139c5
@ -21,7 +21,6 @@ project-excludes = [
|
||||
# ==== below will be enabled directory by directory ====
|
||||
# ==== to test Pyrefly on a specific directory, simply comment it out ====
|
||||
"torch/_inductor/**",
|
||||
"torch/distributed/**",
|
||||
# formatting issues
|
||||
"torch/linalg/__init__.py",
|
||||
"torch/package/importer.py",
|
||||
|
@ -243,6 +243,7 @@ def derived_types(
|
||||
|
||||
|
||||
def get_supported_param_types():
|
||||
# pyrefly: ignore # bad-assignment
|
||||
data: list[tuple[Union[type, typing._SpecialForm], str, bool, bool, bool]] = [
|
||||
# (python type, schema type, type[] variant, type?[] variant, type[]? variant
|
||||
(Tensor, "Tensor", True, True, False),
|
||||
|
@ -155,21 +155,25 @@ class cuBLASModule:
|
||||
if name == "allow_tf32":
|
||||
return torch._C._get_cublas_allow_tf32()
|
||||
elif name == "allow_fp16_reduced_precision_reduction":
|
||||
# pyrefly: ignore # not-iterable
|
||||
allow_reduced_precision, _ = (
|
||||
torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
|
||||
)
|
||||
return allow_reduced_precision
|
||||
elif name == "allow_fp16_reduced_precision_reduction_split_k":
|
||||
# pyrefly: ignore # not-iterable
|
||||
_, allow_splitk = (
|
||||
torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
|
||||
)
|
||||
return allow_splitk
|
||||
elif name == "allow_bf16_reduced_precision_reduction":
|
||||
# pyrefly: ignore # not-iterable
|
||||
allow_reduced_precision, _ = (
|
||||
torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
|
||||
)
|
||||
return allow_reduced_precision
|
||||
elif name == "allow_bf16_reduced_precision_reduction_split_k":
|
||||
# pyrefly: ignore # not-iterable
|
||||
_, allow_splitk = (
|
||||
torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
|
||||
)
|
||||
@ -188,14 +192,20 @@ class cuBLASModule:
|
||||
value, "allow_fp16_reduced_precision_reduction"
|
||||
)
|
||||
return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(
|
||||
allow_reduced_precision, allow_splitk
|
||||
# pyrefly: ignore # bad-argument-count
|
||||
allow_reduced_precision,
|
||||
# pyrefly: ignore # bad-argument-count
|
||||
allow_splitk,
|
||||
)
|
||||
elif name == "allow_bf16_reduced_precision_reduction":
|
||||
allow_reduced_precision, allow_splitk = self._parse_reduction_setting(
|
||||
value, "allow_bf16_reduced_precision_reduction"
|
||||
)
|
||||
return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(
|
||||
allow_reduced_precision, allow_splitk
|
||||
# pyrefly: ignore # bad-argument-count
|
||||
allow_reduced_precision,
|
||||
# pyrefly: ignore # bad-argument-count
|
||||
allow_splitk,
|
||||
)
|
||||
elif name == "allow_fp16_accumulation":
|
||||
return torch._C._set_cublas_allow_fp16_accumulation(value)
|
||||
|
@ -133,8 +133,9 @@ if is_available():
|
||||
# Variables prefixed with underscore are not auto imported
|
||||
# See the comment in `distributed_c10d.py` above `_backend` on why we expose
|
||||
# this.
|
||||
# pyrefly: ignore # deprecated
|
||||
from .distributed_c10d import * # noqa: F403
|
||||
from .distributed_c10d import (
|
||||
from .distributed_c10d import ( # pyrefly: ignore # deprecated
|
||||
_all_gather_base,
|
||||
_coalescing_manager,
|
||||
_CoalescingManager,
|
||||
|
@ -107,6 +107,7 @@ def contract(
|
||||
# If the user passes a sequence of modules, then we assume that
|
||||
# we only need to insert the state object on the root modules
|
||||
# (i.e. those without a parent) among the passed-in modules.
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
modules = _get_root_modules(list(module))
|
||||
state = state_cls() # shared across all modules
|
||||
registry_item = RegistryItem() # shared across all modules
|
||||
@ -118,6 +119,7 @@ def contract(
|
||||
all_orig_named_buffers: list[dict[str, torch.Tensor]] = []
|
||||
all_orig_named_modules: list[dict[str, nn.Module]] = []
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for module in modules:
|
||||
default_all_state: dict[Callable, _State] = OrderedDict()
|
||||
default_registry: dict[str, RegistryItem] = OrderedDict()
|
||||
@ -144,8 +146,11 @@ def contract(
|
||||
all_state.setdefault(func, state)
|
||||
registry.setdefault(func.__name__, registry_item)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
all_orig_named_params.append(OrderedDict(module.named_parameters()))
|
||||
# pyrefly: ignore # missing-attribute
|
||||
all_orig_named_buffers.append(OrderedDict(module.named_buffers()))
|
||||
# pyrefly: ignore # missing-attribute
|
||||
all_orig_named_modules.append(OrderedDict(module.named_modules()))
|
||||
|
||||
updated = func(inp_module, *args, **kwargs)
|
||||
@ -160,9 +165,13 @@ def contract(
|
||||
all_new_named_params: list[dict[str, nn.Parameter]] = []
|
||||
all_new_named_buffers: list[dict[str, torch.Tensor]] = []
|
||||
all_new_named_modules: list[dict[str, nn.Module]] = []
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for module in updated_modules:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
all_new_named_params.append(OrderedDict(module.named_parameters()))
|
||||
# pyrefly: ignore # missing-attribute
|
||||
all_new_named_buffers.append(OrderedDict(module.named_buffers()))
|
||||
# pyrefly: ignore # missing-attribute
|
||||
all_new_named_modules.append(OrderedDict(module.named_modules()))
|
||||
|
||||
num_orig_modules = len(all_orig_named_modules)
|
||||
@ -225,6 +234,7 @@ def contract(
|
||||
# TODO: verify that installed distributed paradigms are compatible with
|
||||
# each other.
|
||||
|
||||
# pyrefly: ignore # bad-return
|
||||
return updated
|
||||
|
||||
def get_state(module: nn.Module) -> _State:
|
||||
|
@ -100,6 +100,7 @@ class _ReplicateState(FSDPState):
|
||||
for module in modules:
|
||||
_insert_module_state(module, self)
|
||||
self._modules = modules
|
||||
# pyrefly: ignore # read-only
|
||||
self._device = device
|
||||
self._device_handle = _get_device_handle(device.type)
|
||||
self._mp_policy = mp_policy
|
||||
@ -150,6 +151,7 @@ class _ReplicateState(FSDPState):
|
||||
)
|
||||
state._is_root = False
|
||||
self._state_ctx.all_states.append(state)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
visited_states.add(state)
|
||||
if self._fsdp_param_group and self._auto_reshard_after_forward:
|
||||
# For the root, do not reshard after forward since for training,
|
||||
|
@ -31,6 +31,7 @@ def _get_module_state(module: nn.Module) -> Optional[_State]:
|
||||
"""
|
||||
global _module_state_mapping
|
||||
if isinstance(module, _State):
|
||||
# pyrefly: ignore # redundant-cast
|
||||
return cast(_State, module)
|
||||
else:
|
||||
# https://github.com/pytorch/pytorch/issues/107054
|
||||
|
@ -633,6 +633,7 @@ class AsyncCollectiveTensor(torch.Tensor):
|
||||
if func == torch.ops.aten.view.default:
|
||||
# Fast handle aten.view as a lot of view related op goes to aten.view
|
||||
# eventually, this avoids pytree slowdown
|
||||
# pyrefly: ignore # index-error
|
||||
res = func(args[0].elem, args[1])
|
||||
wrapper_res = AsyncCollectiveTensor(res)
|
||||
return wrapper_res
|
||||
@ -786,6 +787,7 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str:
|
||||
FutureWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
# pyrefly: ignore # redundant-cast
|
||||
return c10d._resolve_group_name_by_ranks_and_tag(cast(list[int], group), tag)
|
||||
else:
|
||||
raise ValueError(f"Unsupported group type: {type(group)}, {group}")
|
||||
@ -1164,15 +1166,17 @@ def all_gather_inplace(
|
||||
for t in tensor_list:
|
||||
is_scalar = t.dim() == 0
|
||||
t_offset = 1 if is_scalar else t.size(0)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
out = output[offset] if is_scalar else output[offset : offset + t_offset]
|
||||
output_splits.append(out)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
offset += t_offset
|
||||
for dst, src in zip(tensor_list, output_splits):
|
||||
dst.copy_(src)
|
||||
return tensor_list
|
||||
|
||||
|
||||
from torch.distributed.distributed_c10d import (
|
||||
from torch.distributed.distributed_c10d import ( # pyrefly: ignore # deprecated
|
||||
_all_gather_base as legacy_all_gather_base,
|
||||
_reduce_scatter_base as legacy_reduce_scatter_base,
|
||||
all_gather as legacy_all_gather,
|
||||
|
@ -37,7 +37,9 @@ class _MeshLayout(Layout):
|
||||
different from that of PyCute's.
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
shape: IntTuple
|
||||
# pyrefly: ignore # bad-override
|
||||
stride: IntTuple
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
|
@ -43,6 +43,7 @@ def _sharded_op_common(op, early_stop_func, extra_check):
|
||||
def wrapper(types, args=(), kwargs=None, pg=None):
|
||||
_basic_validation(op, args, kwargs)
|
||||
|
||||
# pyrefly: ignore # index-error
|
||||
st = args[0]
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
@ -92,6 +93,7 @@ def _register_sharded_op_on_local_shards(
|
||||
@_sharded_op_impl(op)
|
||||
@_sharded_op_common(op, early_stop_func, extra_check)
|
||||
def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None):
|
||||
# pyrefly: ignore # index-error
|
||||
st = args[0]
|
||||
st_metadata = st.metadata()
|
||||
local_shards = st.local_shards()
|
||||
|
@ -20,10 +20,13 @@ def uniform_(types, args=(), kwargs=None, pg=None):
|
||||
b: the upper bound of the uniform distribution
|
||||
"""
|
||||
validate_param(kwargs, "kwargs")
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
sharded_tensor = kwargs["tensor"]
|
||||
validate_param(sharded_tensor, "tensor")
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
a = kwargs["a"]
|
||||
validate_param(a, "a")
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
b = kwargs["b"]
|
||||
validate_param(b, "b")
|
||||
|
||||
@ -43,10 +46,13 @@ def normal_(types, args=(), kwargs=None, pg=None):
|
||||
std: the standard deviation of the normal distribution
|
||||
"""
|
||||
validate_param(kwargs, "kwargs")
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
sharded_tensor = kwargs["tensor"]
|
||||
validate_param(sharded_tensor, "tensor")
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
mean = kwargs["mean"]
|
||||
validate_param(mean, "mean")
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
std = kwargs["std"]
|
||||
validate_param(std, "std")
|
||||
|
||||
@ -78,12 +84,16 @@ def kaiming_uniform_(types, args=(), kwargs=None, pg=None):
|
||||
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
|
||||
"""
|
||||
validate_param(kwargs, "kwargs")
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
sharded_tensor = kwargs["tensor"]
|
||||
validate_param(sharded_tensor, "tensor")
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
a = kwargs["a"]
|
||||
validate_param(a, "a")
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
mode = kwargs["mode"]
|
||||
validate_param(mode, "mode")
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
nonlinearity = kwargs["nonlinearity"]
|
||||
validate_param(nonlinearity, "nonlinearity")
|
||||
|
||||
@ -103,8 +113,10 @@ def constant_(types, args=(), kwargs=None, pg=None):
|
||||
val: the value to fill the tensor with
|
||||
"""
|
||||
validate_param(kwargs, "kwargs")
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
sharded_tensor = kwargs["tensor"]
|
||||
validate_param(sharded_tensor, "tensor")
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
val = kwargs["val"]
|
||||
validate_param(val, "val")
|
||||
for shard in sharded_tensor.local_shards():
|
||||
@ -137,6 +149,7 @@ def register_tensor_creation_op(op):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
# pyrefly: ignore # index-error
|
||||
st = args[0]
|
||||
|
||||
new_st = creation_op(st.sharding_spec(), st.size(), *args[1:], **kwargs) # type: ignore[operator]
|
||||
|
@ -40,6 +40,7 @@ _register_default_op(torch.Tensor.is_leaf.__get__, _sharded_op_impl) # type: ig
|
||||
# the device property on each rank
|
||||
@_sharded_op_impl(torch.Tensor.device.__get__)
|
||||
def tensor_device(types, args=(), kwargs=None, pg=None):
|
||||
# pyrefly: ignore # index-error
|
||||
self_st = args[0]
|
||||
# Validate types
|
||||
if not isinstance(self_st, ShardedTensor):
|
||||
@ -56,6 +57,7 @@ def tensor_device(types, args=(), kwargs=None, pg=None):
|
||||
|
||||
@_sharded_op_impl(torch.Tensor.is_meta.__get__) # type: ignore[attr-defined]
|
||||
def st_is_meta(types, args=(), kwargs=None, pg=None):
|
||||
# pyrefly: ignore # index-error
|
||||
return args[0].local_tensor().is_meta
|
||||
|
||||
|
||||
@ -196,6 +198,7 @@ _register_sharded_op_on_local_shards(
|
||||
|
||||
@_sharded_op_impl(torch.Tensor.requires_grad_)
|
||||
def tensor_requires_grad_set(types, args=(), kwargs=None, pg=None):
|
||||
# pyrefly: ignore # index-error
|
||||
self_st = args[0]
|
||||
# Validate types
|
||||
if not isinstance(self_st, ShardedTensor):
|
||||
|
@ -299,7 +299,9 @@ class ShardedTensor(ShardedTensorBase):
|
||||
if self._init_rrefs:
|
||||
with _sharded_tensor_lock:
|
||||
global _sharded_tensor_current_id, _sharded_tensor_map
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self._sharded_tensor_id = _sharded_tensor_current_id
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
_sharded_tensor_map[self._sharded_tensor_id] = weakref.ref(self)
|
||||
_sharded_tensor_current_id += 1
|
||||
|
||||
|
@ -208,6 +208,7 @@ def build_global_metadata(
|
||||
global_sharded_tensor_metadata = None
|
||||
global_metadata_rank = 0
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for rank, rank_metadata in enumerate(gathered_metadatas):
|
||||
if rank_metadata is None:
|
||||
continue
|
||||
|
@ -167,6 +167,7 @@ class ChunkShardingSpec(ShardingSpec):
|
||||
)
|
||||
|
||||
tensors_to_scatter[
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
dist.get_group_rank(process_group, remote_global_rank)
|
||||
] = tensor_to_scatter
|
||||
|
||||
|
@ -58,6 +58,7 @@ def _register_sharded_op_on_local_tensor(
|
||||
@custom_sharding_spec_op(ChunkShardingSpec, op)
|
||||
@_sharded_op_common(op, early_stop_func, extra_check)
|
||||
def sharded_tensor_op_on_local_tensor(types, args=(), kwargs=None, pg=None):
|
||||
# pyrefly: ignore # index-error
|
||||
st = args[0]
|
||||
sharding_spec = st.sharding_spec()
|
||||
if len(st.local_shards()) != 1:
|
||||
|
@ -425,7 +425,9 @@ def _handle_row_wise_sharding(
|
||||
else:
|
||||
split_sizes = torch.cat(
|
||||
(
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
offsets[1 : offsets.size(0)] - offsets[0:-1],
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
(input.size(0) - offsets[-1]).unsqueeze(0),
|
||||
),
|
||||
dim=-1,
|
||||
|
@ -195,11 +195,13 @@ def _iterate_state_dict(
|
||||
ret.local_shards()[idx].tensor, non_blocking=non_blocking
|
||||
)
|
||||
else:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
companion_obj.copy_(ret, non_blocking=non_blocking)
|
||||
ret = companion_obj
|
||||
else:
|
||||
ret = {} if isinstance(ret, dict) else None
|
||||
|
||||
# pyrefly: ignore # bad-return
|
||||
return ret
|
||||
|
||||
|
||||
@ -797,6 +799,7 @@ def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None
|
||||
CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val)
|
||||
)
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
extend_list(cur_container, prev_key)
|
||||
if cur_container[prev_key] is None:
|
||||
cur_container[prev_key] = def_val
|
||||
|
@ -1692,6 +1692,7 @@ def empty(
|
||||
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # inconsistent-overload
|
||||
def empty(
|
||||
size: Sequence[_int],
|
||||
*,
|
||||
|
@ -231,6 +231,7 @@ class FSDPMemTracker(MemTracker):
|
||||
" or file a github issue if you need this feature."
|
||||
)
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
args, kwargs = orig_fsdp_state_pre_fw(*args, **kwargs)
|
||||
|
||||
fsdp_state = fsdp_mod._get_fsdp_state()
|
||||
@ -364,6 +365,7 @@ class FSDPMemTracker(MemTracker):
|
||||
# `FSDPParamGroup.post_forward` because during AC these won't be called.
|
||||
# TODO(@sanketpurandare): This will need to be modified after this PR (https://github.com/pytorch/pytorch/pull/127786)
|
||||
# lands. For backward we monkey-patch the `FSDPParamGroup.pre_backward` and `FSDPParamGroup.post_backward`.
|
||||
# pyrefly: ignore # missing-attribute
|
||||
for module in self._root_mod.modules():
|
||||
if isinstance(module, FSDPModule):
|
||||
fsdp_state = module._get_fsdp_state()
|
||||
@ -372,6 +374,7 @@ class FSDPMemTracker(MemTracker):
|
||||
fsdp_state._pre_forward_hook_handle.remove()
|
||||
fsdp_state._post_forward_hook_handle.remove()
|
||||
fsdp_state._pre_forward_hook_handle = (
|
||||
# pyrefly: ignore # missing-attribute
|
||||
module.register_forward_pre_hook(
|
||||
self._fsdp_state_pre_forward(
|
||||
module, fsdp_state._pre_forward
|
||||
@ -380,6 +383,7 @@ class FSDPMemTracker(MemTracker):
|
||||
with_kwargs=True,
|
||||
)
|
||||
)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
fsdp_state._post_forward_hook_handle = module.register_forward_hook(
|
||||
self._fsdp_state_post_forward(module, fsdp_state._post_forward),
|
||||
prepend=False,
|
||||
@ -398,6 +402,7 @@ class FSDPMemTracker(MemTracker):
|
||||
)
|
||||
)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
for buffer in self._root_mod.buffers():
|
||||
self._update_and_maybe_create_winfos(
|
||||
buffer,
|
||||
@ -507,6 +512,7 @@ class FSDPMemTracker(MemTracker):
|
||||
):
|
||||
# N.B: This is a hacky way to override the Meta IMPL of wait_tensor. The original impl returns
|
||||
# a new tensor which does not happen in eager mode, when a wait_tensor is called.
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
res = args[0]
|
||||
else:
|
||||
res = func(*args, **kwargs or {})
|
||||
@ -523,6 +529,7 @@ class FSDPMemTracker(MemTracker):
|
||||
_FSDPState.PRE_FW,
|
||||
_FSDPState.PRE_BW,
|
||||
]:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
output_tensor = args[0]
|
||||
self._update_and_maybe_create_winfos(
|
||||
output_tensor,
|
||||
@ -533,6 +540,7 @@ class FSDPMemTracker(MemTracker):
|
||||
func == c10d._reduce_scatter_base_.default
|
||||
and self._fsdp_state == _FSDPState.POST_BW
|
||||
):
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
input_tensor = args[1]
|
||||
self._update_and_maybe_create_winfos(
|
||||
input_tensor,
|
||||
|
@ -143,6 +143,7 @@ class _WeakRefInfo:
|
||||
self.size = size
|
||||
self.element_size = element_size
|
||||
self.reftype = reftype
|
||||
# pyrefly: ignore # read-only
|
||||
self.device = device
|
||||
self.mem_consumed = self._calculate_mem_consumed()
|
||||
|
||||
@ -404,6 +405,7 @@ class MemTracker(TorchDispatchMode):
|
||||
# Initialize a flag to track if the total memory might drop to zero after updates.
|
||||
maybe_zero = False
|
||||
# Ensure the device entry exists in the current memory snapshot, initializing if necessary.
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
dev_snap = self._curr_mem_snap.setdefault(
|
||||
winfo.device, dict.fromkeys(self._ref_class, 0)
|
||||
)
|
||||
@ -915,6 +917,7 @@ class MemTracker(TorchDispatchMode):
|
||||
self._depth += 1
|
||||
return self
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
self._depth -= 1
|
||||
if self._depth == 0:
|
||||
@ -932,6 +935,7 @@ class MemTracker(TorchDispatchMode):
|
||||
):
|
||||
# N.B: This is a hacky way to override the Meta IMPL of wait_tensor. The original impl returns
|
||||
# a new tensor which does not happen in eager mode, when a wait_tensor is called.
|
||||
# pyrefly: ignore # index-error
|
||||
res = args[0]
|
||||
else:
|
||||
res = func(*args, **kwargs or {})
|
||||
|
@ -232,7 +232,9 @@ class MemoryTracker:
|
||||
def _pre_forward_hook(module: nn.Module, inputs: Any) -> None:
|
||||
self._cur_module_name = f"{name}.forward"
|
||||
if (
|
||||
# pyrefly: ignore # invalid-argument
|
||||
hasattr(module, "_memory_tracker_is_root")
|
||||
# pyrefly: ignore # not-callable
|
||||
and module._memory_tracker_is_root
|
||||
):
|
||||
self._add_marker("fw_start")
|
||||
@ -248,7 +250,9 @@ class MemoryTracker:
|
||||
outputs: Sequence[torch.Tensor],
|
||||
) -> None:
|
||||
if (
|
||||
# pyrefly: ignore # invalid-argument
|
||||
hasattr(module, "_memory_tracker_is_root")
|
||||
# pyrefly: ignore # not-callable
|
||||
and module._memory_tracker_is_root
|
||||
):
|
||||
self._add_marker("fw_bw_boundary")
|
||||
|
@ -178,6 +178,7 @@ class ModTracker:
|
||||
def custom_formatwarning(msg, category, filename, lineno, line=None):
|
||||
return f"{filename}:{lineno}: {category.__name__}: {msg} \n"
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
warnings.formatwarning = custom_formatwarning
|
||||
warnings.warn(
|
||||
"The module hierarchy tracking maybe be messed up."
|
||||
|
@ -519,6 +519,7 @@ class RuntimeEstimator(TorchDispatchMode):
|
||||
super().__enter__()
|
||||
return self
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
print(
|
||||
f"Estimated ({self._estimate_mode_type})"
|
||||
|
@ -429,6 +429,7 @@ class SACEstimator(TorchDispatchMode):
|
||||
# sdpa has non-deterministic seed, but might be deterministic
|
||||
# if no dropout is applied
|
||||
if func.overloadpacket.__name__ == "_scaled_dot_product_flash_attention":
|
||||
# pyrefly: ignore # missing-attribute
|
||||
is_rand_op = kwargs.get("dropout_p", 0) != 0
|
||||
# 5. Create metadata information per active non-leaf module
|
||||
for mod_fqn in self._mod_tracker.parents:
|
||||
|
@ -65,6 +65,7 @@ def _dequantize_tensor(tensor, qtype, quant_loss=None):
|
||||
elif tensor.dtype == torch.float16 and quant_loss is None:
|
||||
return tensor.float()
|
||||
else:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
return tensor.float() / quant_loss
|
||||
elif qtype == DQuantType.BFP16:
|
||||
if tensor.dtype != torch.float16:
|
||||
|
@ -22,6 +22,7 @@ def _allreduce_fut(
|
||||
group_to_use = process_group if process_group is not None else dist.group.WORLD
|
||||
|
||||
# Apply the division first to avoid overflow, especially for FP16.
|
||||
# pyrefly: ignore # missing-attribute
|
||||
tensor.div_(group_to_use.size())
|
||||
|
||||
return (
|
||||
@ -59,6 +60,7 @@ def _compress_hook(
|
||||
bucket: dist.GradBucket,
|
||||
) -> torch.futures.Future[torch.Tensor]:
|
||||
group_to_use = process_group if process_group is not None else dist.group.WORLD
|
||||
# pyrefly: ignore # missing-attribute
|
||||
world_size = group_to_use.size()
|
||||
|
||||
buffer = (
|
||||
@ -78,7 +80,11 @@ def _compress_hook(
|
||||
|
||||
if torch.compiler.is_compiling():
|
||||
grad = dist._functional_collectives.all_reduce(
|
||||
compressed_tensor, "sum", group_to_use
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
compressed_tensor,
|
||||
"sum",
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
group_to_use,
|
||||
)
|
||||
return decompress(grad)
|
||||
else:
|
||||
|
@ -66,6 +66,7 @@ def quantization_pertensor_hook(
|
||||
"""
|
||||
group_to_use = process_group if process_group is not None else dist.group.WORLD
|
||||
rank = process_group.rank() if process_group is not None else dist.get_rank()
|
||||
# pyrefly: ignore # missing-attribute
|
||||
world_size = group_to_use.size()
|
||||
|
||||
tensor = bucket.buffer()
|
||||
@ -147,6 +148,7 @@ def quantization_perchannel_hook(
|
||||
"""
|
||||
group_to_use = process_group if process_group is not None else dist.group.WORLD
|
||||
rank = process_group.rank() if process_group is not None else dist.get_rank()
|
||||
# pyrefly: ignore # missing-attribute
|
||||
world_size = group_to_use.size()
|
||||
|
||||
tensor = bucket.buffer()
|
||||
|
@ -210,6 +210,7 @@ class Join:
|
||||
"""
|
||||
process_group = None
|
||||
device = None
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for joinable in self._joinables:
|
||||
if process_group is None:
|
||||
process_group = joinable.join_process_group
|
||||
|
@ -74,6 +74,7 @@ class HybridModel(torch.nn.Module):
|
||||
assert NUM_PS * EMBEDDING_DIM >= 512
|
||||
dim_normalizer = int(NUM_PS * EMBEDDING_DIM / 512)
|
||||
emb_lookups_reshaped = emb_lookups_cat.reshape( # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # unbound-name
|
||||
[emb_lookups_cat.shape[0] * dim_normalizer, 512]
|
||||
)
|
||||
|
||||
|
@ -45,6 +45,7 @@ def _get_logging_handler(
|
||||
return (log_handler, log_handler_name)
|
||||
|
||||
|
||||
# pyrefly: ignore # unknown-name
|
||||
global _c10d_logger
|
||||
_c10d_logger = _get_or_create_logger()
|
||||
|
||||
|
@ -12,6 +12,10 @@ from .metadata import (
|
||||
from .optimizer import load_sharded_optimizer_state_dict
|
||||
from .planner import LoadPlan, LoadPlanner, ReadItem, SavePlan, SavePlanner, WriteItem
|
||||
from .quantized_hf_storage import QuantizedHuggingFaceStorageReader
|
||||
|
||||
# pyrefly: ignore # deprecated
|
||||
from .state_dict_loader import load, load_state_dict
|
||||
|
||||
# pyrefly: ignore # deprecated
|
||||
from .state_dict_saver import async_save, save, save_state_dict
|
||||
from .storage import StorageReader, StorageWriter
|
||||
|
@ -305,6 +305,7 @@ class _ProcessBasedAsyncCheckpointExecutor(_AsyncCheckpointExecutor):
|
||||
@_dcp_method_logger(**ckpt_kwargs)
|
||||
def create_checkpoint_daemon_process() -> None:
|
||||
global _CHECKPOINT_PROCESS
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
_CHECKPOINT_PROCESS = _AsyncCheckpointProcess(pg_init_info=pg_init_info)
|
||||
|
||||
create_checkpoint_daemon_process()
|
||||
|
@ -322,6 +322,7 @@ class CheckpointProcess:
|
||||
subprocess_pid = self.process.processes[0].pid
|
||||
# send graceful termination to sub process
|
||||
try:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self._parent_end.send(
|
||||
WorkerRequest(
|
||||
request_type=RequestType.TERMINATE_PROCESS,
|
||||
|
@ -175,6 +175,7 @@ class CheckpointReader:
|
||||
# create a new map with all the keys present in source_value
|
||||
target_value = dict.fromkeys(source_value.keys())
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
for key in list(target_value.keys()):
|
||||
current_path = f"{key_path}.{key}" if key_path else key
|
||||
if key in source_value:
|
||||
|
@ -147,12 +147,14 @@ class DefaultStager(CheckpointStager):
|
||||
self._staging_stream = None
|
||||
|
||||
if self._config.use_async_staging:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self._staging_executor = ThreadPoolExecutor(max_workers=1)
|
||||
if torch.accelerator.is_available():
|
||||
# Note: stream needs to be initialized on the main thread after default cuda
|
||||
# stream is setup/used to avoid the risk of accidentally reusing the main
|
||||
# compute stream or in other cases kernels actually launching from the
|
||||
# main thread.
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self._staging_stream = torch.Stream()
|
||||
|
||||
if self._config.use_non_blocking_copy:
|
||||
|
@ -94,6 +94,7 @@ class ZStandard(StreamTransformExtension):
|
||||
return zstandard is not None or pyzstd is not None
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def from_descriptor(version: str) -> "ZStandard":
|
||||
if version.partition(".")[0] != "1":
|
||||
raise ValueError(f"Unknown extension {version=}")
|
||||
@ -216,6 +217,7 @@ class ExtensionRegistry:
|
||||
ext = self.extensions.get(name)
|
||||
if not ext:
|
||||
raise ValueError(f"Unknown extension {name=}")
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return ext.from_descriptor(version)
|
||||
|
||||
return [from_descriptor(desc) for desc in descriptors]
|
||||
|
@ -227,6 +227,7 @@ class PGTransport:
|
||||
self._work: list[Work] = []
|
||||
self._pg = pg
|
||||
self._timeout = timeout
|
||||
# pyrefly: ignore # read-only
|
||||
self._device = device
|
||||
self._state_dict = state_dict
|
||||
|
||||
|
@ -128,6 +128,7 @@ def set_element(
|
||||
CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val)
|
||||
)
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
extend_list(cur_container, prev_key)
|
||||
if cur_container[prev_key] is None:
|
||||
cur_container[prev_key] = def_val
|
||||
@ -154,6 +155,7 @@ def get_element(
|
||||
elif not isinstance(cur_value, Mapping) or part not in cur_value:
|
||||
return default_value
|
||||
|
||||
# pyrefly: ignore # index-error
|
||||
cur_value = cast(CONTAINER_TYPE, cur_value[part])
|
||||
return cast(Optional[T], cur_value)
|
||||
|
||||
|
@ -60,6 +60,7 @@ def _init_model(rank, world_size):
|
||||
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
|
||||
|
||||
_patch_model_state_dict(model)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
_patch_optimizer_state_dict(model, optimizers=optim)
|
||||
|
||||
return model, optim
|
||||
@ -92,6 +93,7 @@ def run(rank, world_size):
|
||||
loss_calc = torch.nn.BCELoss()
|
||||
|
||||
f = None
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
try:
|
||||
torch.manual_seed(epoch)
|
||||
|
@ -64,6 +64,7 @@ class BroadcastingTorchSaveReader(StorageReader):
|
||||
self.checkpoint_id = checkpoint_id
|
||||
self.coordinator_rank = coordinator_rank
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
def read_metadata(self) -> Metadata:
|
||||
"""Extends the default StorageReader to support building the metadata file"""
|
||||
# Metadata is built in planner.set_up_planner, since we are not actually reading metadata from
|
||||
@ -102,6 +103,7 @@ class BroadcastingTorchSaveReader(StorageReader):
|
||||
# Broadcast the tensor from the coordinator rank
|
||||
if self.is_coordinator:
|
||||
pg_device = dist.distributed_c10d._get_pg_default_device()
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
tensor = torch_state_dict[req.storage_index.fqn].to(pg_device)
|
||||
else:
|
||||
tensor = torch.empty_like(planner.state_dict[req.storage_index.fqn])
|
||||
@ -121,6 +123,7 @@ class BroadcastingTorchSaveReader(StorageReader):
|
||||
fut.set_result(None)
|
||||
return fut
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
|
||||
"""Implementation of the StorageReader method"""
|
||||
self.is_coordinator = is_coordinator
|
||||
|
@ -307,6 +307,7 @@ class HuggingFaceStorageReader(FileSystemReader):
|
||||
fut.set_result(None)
|
||||
return fut
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
def read_metadata(self) -> Metadata:
|
||||
from safetensors import safe_open # type: ignore[import]
|
||||
from safetensors.torch import _getdtype # type: ignore[import]
|
||||
|
@ -16,6 +16,7 @@ logger = logging.getLogger()
|
||||
|
||||
__all__: list[str] = []
|
||||
|
||||
# pyrefly: ignore # unknown-name
|
||||
global _dcp_logger
|
||||
_dcp_logger = c10d_logger._get_or_create_logger(DCP_LOGGER_NAME)
|
||||
|
||||
@ -36,9 +37,11 @@ def _msg_dict_from_dcp_method_args(*args, **kwargs) -> dict[str, Any]:
|
||||
|
||||
checkpoint_id = kwargs.get("checkpoint_id", None)
|
||||
if not checkpoint_id and (serializer := storage_writer or storage_reader):
|
||||
# pyrefly: ignore # unbound-name
|
||||
checkpoint_id = getattr(serializer, "checkpoint_id", None)
|
||||
|
||||
msg_dict["checkpoint_id"] = (
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
str(checkpoint_id) if checkpoint_id is not None else checkpoint_id
|
||||
)
|
||||
|
||||
|
@ -29,6 +29,8 @@ from torch.distributed.checkpoint.planner_helpers import (
|
||||
_create_read_items,
|
||||
create_read_items_for_chunk_list,
|
||||
)
|
||||
|
||||
# pyrefly: ignore # deprecated
|
||||
from torch.distributed.checkpoint.state_dict_loader import load_state_dict
|
||||
from torch.distributed.checkpoint.storage import StorageReader
|
||||
from torch.distributed.checkpoint.utils import (
|
||||
@ -157,6 +159,7 @@ def _get_state_dict_2d_layout(
|
||||
class _ReaderWithOffset(DefaultLoadPlanner):
|
||||
translation: dict[MetadataIndex, MetadataIndex]
|
||||
state_dict: STATE_DICT_TYPE
|
||||
# pyrefly: ignore # bad-override
|
||||
metadata: Metadata
|
||||
|
||||
def __init__(self, fqn_to_offset: dict[str, Sequence[int]]) -> None:
|
||||
|
@ -182,12 +182,14 @@ class DefaultStager(AsyncStager):
|
||||
self._staging_executor = None
|
||||
self._staging_stream = None
|
||||
if self._config.use_async_staging:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self._staging_executor = ThreadPoolExecutor(max_workers=1)
|
||||
if torch.accelerator.is_available():
|
||||
# Note: stream needs to be initialized on the main thread after default cuda
|
||||
# stream is setup/used to avoid the risk of accidentally reusing the main
|
||||
# compute stream or in other cases kernels actually launching from the
|
||||
# main thread.
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self._staging_stream = torch.Stream()
|
||||
|
||||
if self._config.use_non_blocking_copy:
|
||||
@ -348,6 +350,7 @@ class _ReplicationStager(AsyncStager):
|
||||
):
|
||||
self._pg = pg
|
||||
self._timeout = timeout
|
||||
# pyrefly: ignore # read-only
|
||||
self._device = device
|
||||
self._transport = PGTransport(pg, timeout, device, None)
|
||||
|
||||
|
@ -199,6 +199,7 @@ def _get_fqns(
|
||||
return {f"{prefix}{fqn}" for fqn in flat_param._fqns}
|
||||
curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE)
|
||||
if curr_obj_name != FSDP_WRAPPED_MODULE:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
fqn_obj_names.append(curr_obj_name)
|
||||
curr_obj = getattr(curr_obj, curr_obj_name)
|
||||
elif isinstance(curr_obj, torch._dynamo.eval_frame.OptimizedModule):
|
||||
@ -215,6 +216,7 @@ def _get_fqns(
|
||||
):
|
||||
if hasattr(curr_obj, removed_fqn):
|
||||
curr_obj = getattr(curr_obj, removed_fqn)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
fqn_obj_names.append(curr_obj_name)
|
||||
if curr_obj_name == nn.modules.module._EXTRA_STATE_KEY_SUFFIX:
|
||||
if i != len(obj_names) - 1:
|
||||
@ -1206,6 +1208,7 @@ def _unflatten_model_state_dict(
|
||||
if not state_dict:
|
||||
return {}
|
||||
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
if isinstance(next(iter(state_dict.keys())), nn.Module):
|
||||
warnings.warn(
|
||||
"Passing model_state_dict as a ``Dict[nn.Module, Dict[str, Any]]``"
|
||||
|
@ -329,6 +329,7 @@ def async_save(
|
||||
upload_future: Future = upload_executor.execute_save(
|
||||
staging_future_or_state_dict,
|
||||
checkpoint_id=checkpoint_id,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
storage_writer=storage_writer,
|
||||
planner=planner,
|
||||
process_group=process_group,
|
||||
|
@ -254,6 +254,7 @@ class _DistWrapper:
|
||||
if len(node_failures) > 0:
|
||||
result = CheckpointException(step, node_failures)
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
final_result = self.broadcast_object(result)
|
||||
if isinstance(final_result, CheckpointException):
|
||||
raise final_result
|
||||
@ -302,6 +303,7 @@ class _DistWrapper:
|
||||
result = map_fun()
|
||||
except BaseException as e: # noqa: B036
|
||||
result = CheckpointException(step, {self.rank: _wrap_exception(e)})
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
final_result = self.broadcast_object(result)
|
||||
if isinstance(final_result, CheckpointException):
|
||||
raise final_result
|
||||
|
@ -114,6 +114,7 @@ def broadcast(
|
||||
error_msg += f": stage {sync_obj.stage_name}"
|
||||
if sync_obj.exception is not None:
|
||||
error_msg += f": exception {sync_obj.exception}"
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
raise RuntimeError(error_msg) from sync_obj.exception
|
||||
|
||||
return cast(T, sync_obj.payload)
|
||||
@ -183,13 +184,16 @@ def all_gather(
|
||||
|
||||
if len(exception_list) > 0:
|
||||
raise RuntimeError( # type: ignore[misc]
|
||||
error_msg, exception_list
|
||||
error_msg,
|
||||
exception_list,
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
) from exception_list[0]
|
||||
return ret_list
|
||||
else:
|
||||
if not sync_obj.success:
|
||||
raise RuntimeError(
|
||||
f"all_gather failed with exception {sync_obj.exception}",
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
) from sync_obj.exception
|
||||
return [sync_obj.payload] # type: ignore[list-item]
|
||||
|
||||
@ -266,10 +270,13 @@ def _summarize_ranks(ranks: Iterable[int]) -> str:
|
||||
result = []
|
||||
for r in ranges:
|
||||
if len(r) == 1:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
result.append(f"{r.start}")
|
||||
elif r.step == 1:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
result.append(f"{r.start}:{r.stop}")
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
result.append(f"{r.start}:{r.stop}:{r.step}")
|
||||
return ",".join(result)
|
||||
|
||||
|
@ -482,6 +482,7 @@ else:
|
||||
self._init_process_groups(backend_override)
|
||||
|
||||
if is_initialized() and get_backend() == "threaded":
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self._thread_id = threading.get_ident()
|
||||
|
||||
if _rank is None:
|
||||
@ -650,6 +651,7 @@ else:
|
||||
# We temporarily revert the reuse subgroup, since it breaks two internal tests.
|
||||
# Temporarily reverting to resolve test timeout while root-causing.
|
||||
# TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists.
|
||||
# pyrefly: ignore # unbound-name
|
||||
if bound_device_id is None or not has_split_group:
|
||||
dim_group = new_group(
|
||||
ranks=subgroup_ranks,
|
||||
|
@ -372,6 +372,7 @@ class BackendConfig:
|
||||
def __init__(self, backend: Backend):
|
||||
"""Init."""
|
||||
self.device_backend_map: dict[str, Backend] = {}
|
||||
# pyrefly: ignore # bad-assignment
|
||||
backend = str(backend)
|
||||
|
||||
if backend == Backend.UNDEFINED:
|
||||
@ -392,6 +393,7 @@ class BackendConfig:
|
||||
# e.g. "nccl", "gloo", "ucc", "mpi"
|
||||
supported_devices = Backend.backend_capability[backend.lower()]
|
||||
backend_val = Backend(backend)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.device_backend_map = dict.fromkeys(supported_devices, backend_val)
|
||||
elif ":" in backend.lower():
|
||||
# Backend specified in "device:backend" format
|
||||
@ -410,6 +412,7 @@ class BackendConfig:
|
||||
f"Invalid device:backend pairing: \
|
||||
{device_backend_pair_str}. {backend_str_error_message}"
|
||||
)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
device, backend = device_backend_pair
|
||||
if device in self.device_backend_map:
|
||||
raise ValueError(
|
||||
@ -1182,6 +1185,7 @@ def _as_iterable(obj) -> collections.abc.Iterable:
|
||||
|
||||
def _ensure_all_tensors_same_dtype(*tensors) -> None:
|
||||
last_dtype = None
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for tensor in itertools.chain.from_iterable(map(_as_iterable, tensors)):
|
||||
tensor_dtype = tensor.dtype
|
||||
# Mixing complex and its element type is allowed
|
||||
@ -1837,6 +1841,7 @@ def _get_split_source(pg):
|
||||
split_from = pg._get_backend(pg.bound_device_id)
|
||||
elif pg is _world.default_pg:
|
||||
try:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
split_from = pg._get_backend(torch.device("cuda"))
|
||||
except RuntimeError:
|
||||
# no cuda device associated with this backend
|
||||
@ -1997,7 +2002,12 @@ def _new_process_group_helper(
|
||||
if not is_gloo_available():
|
||||
raise RuntimeError("Distributed package doesn't have Gloo built in")
|
||||
backend_class = ProcessGroupGloo(
|
||||
backend_prefix_store, group_rank, group_size, timeout=timeout
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
backend_prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
timeout=timeout,
|
||||
)
|
||||
backend_class.options.global_ranks_in_group = global_ranks_in_group
|
||||
backend_class.options.group_name = group_name
|
||||
@ -2018,6 +2028,7 @@ def _new_process_group_helper(
|
||||
# default backend_options for NCCL
|
||||
backend_options = ProcessGroupNCCL.Options()
|
||||
backend_options.is_high_priority_stream = False
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
backend_options._timeout = timeout
|
||||
|
||||
if split_from:
|
||||
@ -2037,7 +2048,12 @@ def _new_process_group_helper(
|
||||
# RuntimeError if is_ucc_available() returns false.
|
||||
|
||||
backend_class = ProcessGroupUCC(
|
||||
backend_prefix_store, group_rank, group_size, timeout=timeout
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
backend_prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
timeout=timeout,
|
||||
)
|
||||
backend_type = ProcessGroup.BackendType.UCC
|
||||
elif backend_str == Backend.XCCL:
|
||||
@ -2046,6 +2062,7 @@ def _new_process_group_helper(
|
||||
backend_options = ProcessGroupXCCL.Options()
|
||||
backend_options.global_ranks_in_group = global_ranks_in_group
|
||||
backend_options.group_name = group_name
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
backend_options._timeout = timeout
|
||||
backend_class = ProcessGroupXCCL(
|
||||
backend_prefix_store, group_rank, group_size, backend_options
|
||||
@ -2070,6 +2087,7 @@ def _new_process_group_helper(
|
||||
dist_backend_opts.store = backend_prefix_store
|
||||
dist_backend_opts.group_rank = group_rank
|
||||
dist_backend_opts.group_size = group_size
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
dist_backend_opts.timeout = timeout
|
||||
dist_backend_opts.group_id = group_name
|
||||
dist_backend_opts.global_ranks_in_group = global_ranks_in_group
|
||||
@ -2113,6 +2131,7 @@ def _new_process_group_helper(
|
||||
store=backend_prefix_store,
|
||||
rank=group_rank,
|
||||
world_size=group_size,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
@ -3322,6 +3341,7 @@ def gather_object(
|
||||
return
|
||||
|
||||
assert object_gather_list is not None, "Must provide object_gather_list on dst rank"
|
||||
# pyrefly: ignore # unbound-name
|
||||
for i, tensor in enumerate(output_tensors):
|
||||
tensor = tensor.type(torch.uint8)
|
||||
tensor_size = object_size_list[i]
|
||||
@ -3698,8 +3718,10 @@ def broadcast_object_list(
|
||||
# has only one element, we can skip the copy.
|
||||
if my_group_rank == group_src:
|
||||
if len(tensor_list) == 1: # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # unbound-name
|
||||
object_tensor = tensor_list[0]
|
||||
else:
|
||||
# pyrefly: ignore # unbound-name
|
||||
object_tensor = torch.cat(tensor_list)
|
||||
else:
|
||||
object_tensor = torch.empty( # type: ignore[call-overload]
|
||||
@ -3828,6 +3850,7 @@ def scatter_object_list(
|
||||
broadcast(max_tensor_size, group_src=group_src, group=group)
|
||||
|
||||
# Scatter actual serialized objects
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
output_tensor = torch.empty(
|
||||
max_tensor_size.item(), dtype=torch.uint8, device=pg_device
|
||||
)
|
||||
@ -4864,16 +4887,19 @@ def barrier(
|
||||
if isinstance(device_ids, list):
|
||||
opts.device_ids = device_ids
|
||||
# use only the first device id
|
||||
# pyrefly: ignore # read-only
|
||||
opts.device = torch.device(device.type, device_ids[0])
|
||||
elif getattr(group, "bound_device_id", None) is not None:
|
||||
# Use device id from `init_process_group(device_id=...)`
|
||||
opts.device = group.bound_device_id # type: ignore[assignment]
|
||||
elif device.type == "cpu" or _get_object_coll_device(group) == "cpu":
|
||||
# pyrefly: ignore # read-only
|
||||
opts.device = torch.device("cpu")
|
||||
else:
|
||||
# Use the current device set by the user. If user did not set any, this
|
||||
# may use default device 0, causing issues like hang or all processes
|
||||
# creating context on device 0.
|
||||
# pyrefly: ignore # read-only
|
||||
opts.device = device
|
||||
if group.rank() == 0:
|
||||
warnings.warn( # warn only once
|
||||
@ -5004,6 +5030,7 @@ def _hash_ranks_to_str(ranks: list[int]) -> str:
|
||||
# Takes a list of ranks and computes an integer color
|
||||
def _process_group_color(ranks: list[int]) -> int:
|
||||
# Convert list to tuple to make it hashable
|
||||
# pyrefly: ignore # bad-assignment
|
||||
ranks = tuple(ranks)
|
||||
hash_value = hash(ranks)
|
||||
# Split color must be:
|
||||
|
@ -333,8 +333,10 @@ class LocalElasticAgent(SimpleElasticAgent):
|
||||
rank=worker.global_rank,
|
||||
local_rank=local_rank,
|
||||
)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
log_line_prefixes[local_rank] = log_line_prefix
|
||||
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
envs[local_rank] = worker_env
|
||||
worker_args = list(spec.args)
|
||||
worker_args = macros.substitute(worker_args, str(local_rank))
|
||||
|
@ -54,6 +54,7 @@ class Event:
|
||||
if isinstance(data, str):
|
||||
data_dict = json.loads(data)
|
||||
data_dict["source"] = EventSource[data_dict["source"]] # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # unbound-name
|
||||
return Event(**data_dict)
|
||||
|
||||
def serialize(self) -> str:
|
||||
@ -108,6 +109,7 @@ class RdzvEvent:
|
||||
if isinstance(data, str):
|
||||
data_dict = json.loads(data)
|
||||
data_dict["node_state"] = NodeState[data_dict["node_state"]] # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # unbound-name
|
||||
return RdzvEvent(**data_dict)
|
||||
|
||||
def serialize(self) -> str:
|
||||
|
@ -142,7 +142,7 @@ Now all metrics in the group ``my_app`` will be printed to stdout as:
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from .api import ( # noqa: F401
|
||||
from .api import ( # noqa: F401; pyrefly: ignore # deprecated; pyrefly: ignore # deprecated
|
||||
configure,
|
||||
ConsoleMetricHandler,
|
||||
get_elapsed_time_ms,
|
||||
|
@ -171,12 +171,15 @@ def profile(group=None):
|
||||
try:
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
publish_metric(group, f"{func.__name__}.success", 1)
|
||||
except Exception:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
publish_metric(group, f"{func.__name__}.failure", 1)
|
||||
raise
|
||||
finally:
|
||||
publish_metric(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
group,
|
||||
f"{func.__name__}.duration.ms",
|
||||
get_elapsed_time_ms(start_time), # type: ignore[possibly-undefined]
|
||||
|
@ -97,6 +97,7 @@ class TailLog:
|
||||
n = len(log_files)
|
||||
self._threadpool = None
|
||||
if n > 0:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self._threadpool = ThreadPoolExecutor(
|
||||
max_workers=n,
|
||||
thread_name_prefix=f"{self.__class__.__qualname__}_{name}",
|
||||
|
@ -126,6 +126,7 @@ class EtcdRendezvousBackend(RendezvousBackend):
|
||||
return tmp
|
||||
|
||||
def _decode_state(self, result: etcd.EtcdResult) -> tuple[bytes, Token]:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
base64_state = result.value.encode()
|
||||
|
||||
try:
|
||||
@ -135,6 +136,7 @@ class EtcdRendezvousBackend(RendezvousBackend):
|
||||
"The state object is corrupt. See inner exception for details."
|
||||
) from exc
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return state, result.modifiedIndex
|
||||
|
||||
|
||||
|
@ -53,6 +53,7 @@ class ElasticDistributedSampler(DistributedSampler[T]):
|
||||
raise TypeError("Dataset must be an instance of collections.abc.Sized")
|
||||
|
||||
# Cast to Sized for mypy
|
||||
# pyrefly: ignore # redundant-cast
|
||||
sized_dataset = cast(Sized, dataset)
|
||||
|
||||
if start_index >= len(sized_dataset):
|
||||
|
@ -65,6 +65,7 @@ class _FSDPDeviceHandle:
|
||||
if backend is None:
|
||||
try:
|
||||
self.__backend = getattr(torch, device.type)
|
||||
# pyrefly: ignore # read-only
|
||||
self.__device = device
|
||||
except AttributeError as exc:
|
||||
raise AttributeError(
|
||||
|
@ -539,6 +539,7 @@ class FlatParamHandle:
|
||||
# Only align addresses for `use_orig_params=True` (for now)
|
||||
align_addresses = use_orig_params
|
||||
self._init_get_unflat_views_fn(align_addresses)
|
||||
# pyrefly: ignore # read-only
|
||||
self.device = device
|
||||
self._device_handle = _FSDPDeviceHandle.from_device(self.device)
|
||||
self.process_group = process_group
|
||||
|
@ -220,6 +220,7 @@ def _move_states_to_device(
|
||||
the future.
|
||||
"""
|
||||
# Follow the logic in `nn.Module._apply`
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
for tensor in itertools.chain(params, buffers):
|
||||
if tensor.device == device or tensor.device.type == "meta":
|
||||
# Keep meta-device tensors on meta device for deferred init
|
||||
|
@ -232,6 +232,7 @@ class FSDPParam:
|
||||
self._module_info: ParamModuleInfo = module_info
|
||||
self.mesh_info = mesh_info
|
||||
self.post_forward_mesh_info = post_forward_mesh_info
|
||||
# pyrefly: ignore # read-only
|
||||
self.device = device
|
||||
self.mp_policy = mp_policy
|
||||
self.offload_to_cpu: bool = isinstance(offload_policy, CPUOffloadPolicy)
|
||||
@ -554,6 +555,7 @@ class FSDPParam:
|
||||
f"world size ({shard_world_size})"
|
||||
)
|
||||
shard_rank = self.post_forward_mesh_info.shard_mesh_rank
|
||||
# pyrefly: ignore # unbound-name
|
||||
sharded_numel = numel // shard_world_size
|
||||
self._sharded_post_forward_param_data = (
|
||||
self.all_gather_outputs[0].narrow(
|
||||
@ -684,6 +686,7 @@ class FSDPParam:
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
pre_all_gather_signature = inspect.signature(
|
||||
# pyrefly: ignore # missing-attribute
|
||||
sharded_local_tensor.fsdp_pre_all_gather
|
||||
)
|
||||
num_fn_params = len(pre_all_gather_signature.parameters)
|
||||
@ -701,6 +704,7 @@ class FSDPParam:
|
||||
(
|
||||
all_gather_inputs,
|
||||
self._extensions_data.all_gather_metadata,
|
||||
# pyrefly: ignore # missing-attribute
|
||||
) = sharded_local_tensor.fsdp_pre_all_gather(
|
||||
self.shard_mesh_from_root
|
||||
)
|
||||
@ -708,6 +712,7 @@ class FSDPParam:
|
||||
(
|
||||
all_gather_inputs,
|
||||
self._extensions_data.all_gather_metadata,
|
||||
# pyrefly: ignore # missing-attribute
|
||||
) = sharded_local_tensor.fsdp_pre_all_gather(
|
||||
self.shard_mesh_from_root,
|
||||
self._orig_size,
|
||||
@ -829,6 +834,7 @@ class FSDPParam:
|
||||
f"instead of {self.sharded_param}"
|
||||
)
|
||||
self.sharded_param = new_param
|
||||
# pyrefly: ignore # missing-attribute
|
||||
local_tensor = new_param._local_tensor
|
||||
if local_tensor.is_meta:
|
||||
return
|
||||
|
@ -151,6 +151,7 @@ class FSDPParamGroup:
|
||||
]
|
||||
self.mesh_info = mesh_info
|
||||
self.post_forward_mesh_info = post_forward_mesh_info
|
||||
# pyrefly: ignore # read-only
|
||||
self.device = device
|
||||
self.device_handle = _get_device_handle(device.type)
|
||||
self.mp_policy = mp_policy
|
||||
@ -616,6 +617,7 @@ class FSDPParamGroup:
|
||||
# Prefetch naively using the reverse post-forward order, which may
|
||||
# have mistargeted prefetches if not all modules used in forward
|
||||
# are used in this backward
|
||||
# pyrefly: ignore # unbound-name
|
||||
target_fsdp_param_group = self.comm_ctx.post_forward_order[target_index]
|
||||
self._prefetch_unshard(target_fsdp_param_group, "backward")
|
||||
|
||||
@ -852,6 +854,7 @@ compile the forward part if you want to use Traceable FSDP2."""
|
||||
raise RuntimeError(msg)
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, param_group: FSDPParamGroup, *inputs: torch.Tensor):
|
||||
# All tensors in `inputs` should require gradient
|
||||
RegisterPostBackwardFunction._assert_not_tracing_fsdp()
|
||||
|
@ -96,6 +96,7 @@ class FSDPState(_State):
|
||||
for module in modules:
|
||||
_insert_module_state(module, self)
|
||||
self._modules = modules
|
||||
# pyrefly: ignore # read-only
|
||||
self._device = device
|
||||
self._device_handle = _get_device_handle(device.type)
|
||||
self._mp_policy = mp_policy
|
||||
|
@ -50,6 +50,7 @@ def get_cls_to_fsdp_cls() -> dict[type, type]:
|
||||
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # inconsistent-overload
|
||||
def fully_shard(
|
||||
module: nn.Module,
|
||||
*,
|
||||
@ -63,6 +64,7 @@ def fully_shard(
|
||||
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # inconsistent-overload
|
||||
def fully_shard(
|
||||
module: list[nn.Module],
|
||||
*,
|
||||
|
@ -508,6 +508,7 @@ def _init_prefetching_state(
|
||||
|
||||
|
||||
@no_type_check
|
||||
# pyrefly: ignore # bad-function-definition
|
||||
def _init_extension(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPState:
|
||||
# TODO: we need to add additional check once we support FSDP + PiPPy.
|
||||
# This check is currently sufficient, since we only support FSDP + TP.
|
||||
@ -904,7 +905,10 @@ def _materialize_meta_module(
|
||||
# As a contract to the user, only call `reset_parameters()` if
|
||||
# the module has directly managed parameters/buffers
|
||||
module_state_iter = itertools.chain(
|
||||
module.parameters(recurse=False), module.buffers(recurse=False)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
module.parameters(recurse=False),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
module.buffers(recurse=False),
|
||||
)
|
||||
has_module_states = len(list(module_state_iter)) > 0
|
||||
if has_module_states:
|
||||
|
@ -603,6 +603,7 @@ def _flatten_optim_state(
|
||||
]
|
||||
# Check that the unflattened parameters have the same state names
|
||||
state_names = None
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for unflat_param_state in unflat_param_states:
|
||||
if unflat_param_state is None:
|
||||
continue
|
||||
@ -918,6 +919,7 @@ def _rekey_sharded_optim_state_dict(
|
||||
flat_param_key = unflat_param_names_to_flat_param_key.get(
|
||||
key.unflat_param_names, key.unflat_param_names
|
||||
)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
rekeyed_osd_state[flat_param_key] = param_state
|
||||
|
||||
# Only process param_groups if it exists in sharded_osd
|
||||
@ -980,6 +982,7 @@ def _get_param_id_to_param_from_optim_input(
|
||||
if optim_input is None:
|
||||
return dict(enumerate(model.parameters()))
|
||||
try:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
params = cast(list[nn.Parameter], list(optim_input))
|
||||
except TypeError as e:
|
||||
raise TypeError(
|
||||
|
@ -354,7 +354,9 @@ class _RemoteModule(nn.Module):
|
||||
_raise_not_supported(self.to.__name__)
|
||||
|
||||
def register_backward_hook( # type: ignore[return]
|
||||
self, hook: Callable[[Module, _grad_t, _grad_t], Union[None, _grad_t]]
|
||||
self,
|
||||
hook: Callable[[Module, _grad_t, _grad_t], Union[None, _grad_t]],
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> RemovableHandle:
|
||||
_raise_not_supported(self.register_backward_hook.__name__)
|
||||
|
||||
@ -369,6 +371,7 @@ class _RemoteModule(nn.Module):
|
||||
],
|
||||
prepend: bool = False,
|
||||
with_kwargs: bool = False,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> RemovableHandle:
|
||||
_raise_not_supported(self.register_forward_pre_hook.__name__)
|
||||
|
||||
@ -380,6 +383,7 @@ class _RemoteModule(nn.Module):
|
||||
],
|
||||
prepend: bool = False,
|
||||
with_kwargs: bool = False,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> RemovableHandle:
|
||||
_raise_not_supported(self.register_forward_hook.__name__)
|
||||
|
||||
@ -400,7 +404,11 @@ class _RemoteModule(nn.Module):
|
||||
)
|
||||
|
||||
def named_parameters( # type: ignore[return]
|
||||
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
|
||||
self,
|
||||
prefix: str = "",
|
||||
recurse: bool = True,
|
||||
remove_duplicate: bool = True,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> Iterator[tuple[str, Parameter]]:
|
||||
_raise_not_supported(self.named_parameters.__name__)
|
||||
|
||||
@ -408,7 +416,11 @@ class _RemoteModule(nn.Module):
|
||||
_raise_not_supported(self.buffers.__name__)
|
||||
|
||||
def named_buffers( # type: ignore[return]
|
||||
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
|
||||
self,
|
||||
prefix: str = "",
|
||||
recurse: bool = True,
|
||||
remove_duplicate: bool = True,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> Iterator[tuple[str, Tensor]]:
|
||||
_raise_not_supported(self.named_buffers.__name__)
|
||||
|
||||
@ -572,23 +584,31 @@ class _RemoteModule(nn.Module):
|
||||
|
||||
remote_module = object.__new__(RemoteModule)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
enable_moving_cpu_tensors_to_cuda = remote_module._prepare_init(remote_device)
|
||||
|
||||
if _module_interface_cls is not None:
|
||||
# Users reply on this field to know if this generated RemoteModule is TorchScript-able.
|
||||
# pyrefly: ignore # missing-attribute
|
||||
remote_module.is_scriptable = True
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
remote_module._init_template(
|
||||
_module_interface_cls, enable_moving_cpu_tensors_to_cuda
|
||||
)
|
||||
else:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
remote_module.is_scriptable = False
|
||||
# pyrefly: ignore # missing-attribute
|
||||
remote_module.generated_methods = (
|
||||
_NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods
|
||||
)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
remote_module.module_rref = module_rref
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
remote_module._install_generated_methods()
|
||||
# pyrefly: ignore # missing-attribute
|
||||
remote_module._check_attribute_picklability()
|
||||
|
||||
return remote_module
|
||||
@ -691,9 +711,11 @@ def _remote_module_receiver(
|
||||
m.__dict__.update(serialized_remote_module._asdict())
|
||||
|
||||
# Unpickling the attribute `module_rref` must invoke RRef's `_deserialize()` method.
|
||||
# pyrefly: ignore # missing-attribute
|
||||
m.module_rref = rpc.PyRRef._deserialize(m.module_rref)
|
||||
|
||||
# Install generated methods when unpickled.
|
||||
# pyrefly: ignore # missing-attribute
|
||||
for method in m.generated_methods:
|
||||
method_name = method.__name__
|
||||
method = torch.jit.export(method)
|
||||
|
@ -225,6 +225,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=group.WORLD):
|
||||
|
||||
class _Broadcast(Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, src, group, tensor):
|
||||
ctx.src = src
|
||||
ctx.group = group
|
||||
@ -236,6 +237,7 @@ class _Broadcast(Function):
|
||||
return tensor
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def backward(ctx, grad_output):
|
||||
gx = _Reduce.apply(ctx.src, ReduceOp.SUM, ctx.group, grad_output)
|
||||
if ctx.src != ctx.rank:
|
||||
@ -245,6 +247,7 @@ class _Broadcast(Function):
|
||||
|
||||
class _Gather(Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, dst, group, tensor):
|
||||
ctx.dst = dst
|
||||
ctx.group = group
|
||||
@ -270,6 +273,7 @@ class _Gather(Function):
|
||||
|
||||
class _Scatter(Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, src, group, *tensors):
|
||||
ctx.src = src
|
||||
ctx.group = group
|
||||
@ -282,12 +286,14 @@ class _Scatter(Function):
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def backward(ctx, grad_output):
|
||||
return (None, None) + _Gather.apply(ctx.src, ctx.group, grad_output)
|
||||
|
||||
|
||||
class _Reduce(Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, src, op, group, tensor):
|
||||
ctx.src = src
|
||||
ctx.group = group
|
||||
@ -296,12 +302,14 @@ class _Reduce(Function):
|
||||
return tensor
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def backward(ctx, grad_output):
|
||||
return (None, None, None) + (_Broadcast.apply(ctx.src, ctx.group, grad_output),)
|
||||
|
||||
|
||||
class _Reduce_Scatter(Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, op, group, tensor, *input_tensor_list):
|
||||
ctx.group = group
|
||||
# Need contiguous tensors for collectives.
|
||||
@ -311,12 +319,14 @@ class _Reduce_Scatter(Function):
|
||||
return tensor
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def backward(ctx, grad_output):
|
||||
return (None, None, None) + _AllGather.apply(ctx.group, grad_output)
|
||||
|
||||
|
||||
class _AllGather(Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, group, tensor):
|
||||
# Need contiguous tensors for collectives.
|
||||
tensor = tensor.contiguous()
|
||||
@ -346,12 +356,14 @@ class _AllGather(Function):
|
||||
|
||||
class _AllGatherBase(Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, output_tensor, input_tensor, group):
|
||||
ctx.group = group
|
||||
dist._all_gather_base(output_tensor, input_tensor.contiguous(), group=group)
|
||||
return output_tensor
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def backward(ctx, grad_output):
|
||||
if dist.get_backend(group=ctx.group) is dist.Backend.NCCL:
|
||||
world_size = dist.get_world_size(group=ctx.group)
|
||||
@ -373,6 +385,7 @@ class _AllGatherBase(Function):
|
||||
|
||||
class _AlltoAll(Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, group, out_tensor_list, *tensors):
|
||||
ctx.group = group
|
||||
ctx.input_tensor_size_list = [
|
||||
@ -408,6 +421,7 @@ class _AlltoAll(Function):
|
||||
|
||||
class _AlltoAllSingle(Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, group, output, output_split_sizes, input_split_sizes, input):
|
||||
ctx.group = group
|
||||
ctx.input_size = input.size()
|
||||
@ -423,6 +437,7 @@ class _AlltoAllSingle(Function):
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def backward(ctx, grad_output):
|
||||
tensor = torch.empty(
|
||||
ctx.input_size, device=grad_output.device, dtype=grad_output.dtype
|
||||
@ -440,6 +455,7 @@ class _AlltoAllSingle(Function):
|
||||
|
||||
class _AllReduce(Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, op, group, tensor):
|
||||
ctx.group = group
|
||||
ctx.op = op
|
||||
@ -448,5 +464,6 @@ class _AllReduce(Function):
|
||||
return tensor
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def backward(ctx, grad_output):
|
||||
return (None, None) + (_AllReduce.apply(ctx.op, ctx.group, grad_output),)
|
||||
|
@ -100,15 +100,19 @@ def _broadcast_object(
|
||||
data = bytearray(buffer.getbuffer())
|
||||
length_tensor = torch.LongTensor([len(data)]).to(device)
|
||||
data_send_tensor = torch.ByteTensor(data).to(device)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False)
|
||||
else:
|
||||
# Receive the object
|
||||
length_tensor = torch.LongTensor([0]).to(device)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
|
||||
data_recv_tensor = torch.empty(
|
||||
[int(length_tensor.item())], dtype=torch.uint8, device=device
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False)
|
||||
buffer = io.BytesIO(data_recv_tensor.cpu().numpy())
|
||||
obj = torch.load(buffer, map_location=device, weights_only=False)
|
||||
@ -167,6 +171,7 @@ class _DDPBucketAssignment:
|
||||
if len(self.parameters) == 0:
|
||||
raise ValueError("Empty bucket assignment")
|
||||
# DDP guarantees all parameters in the bucket have the same device
|
||||
# pyrefly: ignore # read-only
|
||||
self.device: torch.device = self.parameters[0].device
|
||||
self.tensor: Optional[torch.Tensor] = None
|
||||
|
||||
@ -415,7 +420,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||
self.world_size: int = dist.get_world_size(self.process_group)
|
||||
self.rank: int = dist.get_rank(self.process_group)
|
||||
self.global_rank: int = dist.distributed_c10d.get_global_rank(
|
||||
self.process_group, self.rank
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.process_group,
|
||||
self.rank,
|
||||
)
|
||||
|
||||
self._overlap_with_ddp: bool = overlap_with_ddp
|
||||
@ -535,7 +542,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||
self._all_state_dicts = []
|
||||
for rank in range(self.world_size):
|
||||
global_rank = dist.distributed_c10d.get_global_rank(
|
||||
self.process_group, rank
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.process_group,
|
||||
rank,
|
||||
)
|
||||
if self.rank == to:
|
||||
# Consolidate all local `state_dict`s on this rank, storing on
|
||||
@ -767,7 +776,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||
for dev_i_buckets in self._buckets:
|
||||
bucket = dev_i_buckets[rank]
|
||||
global_rank = dist.distributed_c10d.get_global_rank(
|
||||
self.process_group, rank
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.process_group,
|
||||
rank,
|
||||
)
|
||||
handles.append(
|
||||
dist.broadcast(
|
||||
@ -780,7 +791,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||
else:
|
||||
param_groups = self._partition_parameters()[rank]
|
||||
global_rank = dist.distributed_c10d.get_global_rank(
|
||||
self.process_group, rank
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.process_group,
|
||||
rank,
|
||||
)
|
||||
for param_group in param_groups:
|
||||
handles.extend(
|
||||
@ -979,11 +992,14 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||
for param_index, param in enumerate(bucket_params):
|
||||
param_numel = param.numel()
|
||||
if (
|
||||
# pyrefly: ignore # unbound-name
|
||||
assignment_size + param_numel >= threshold
|
||||
and param_index > bucket_offset
|
||||
):
|
||||
assigned_rank = self._get_min_index(
|
||||
size_per_rank, assigned_ranks_per_bucket[bucket_index]
|
||||
# pyrefly: ignore # unbound-name
|
||||
size_per_rank,
|
||||
assigned_ranks_per_bucket[bucket_index],
|
||||
)
|
||||
# Include up to but not including the parameter that
|
||||
# exceeded the threshold
|
||||
@ -994,6 +1010,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||
assigned_rank,
|
||||
assigned_ranks_per_bucket,
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
size_per_rank[assigned_rank] += assignment_size
|
||||
bucket_offset = param_index
|
||||
assignment_size = 0
|
||||
@ -1001,7 +1018,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||
# Assign the remainder of the bucket so that no assignment
|
||||
# spans across two buckets
|
||||
assigned_rank = self._get_min_index(
|
||||
size_per_rank, assigned_ranks_per_bucket[bucket_index]
|
||||
# pyrefly: ignore # unbound-name
|
||||
size_per_rank,
|
||||
assigned_ranks_per_bucket[bucket_index],
|
||||
)
|
||||
self._assign_bucket_subset_to_rank(
|
||||
bucket_index,
|
||||
@ -1010,6 +1029,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||
assigned_rank,
|
||||
assigned_ranks_per_bucket,
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
size_per_rank[assigned_rank] += assignment_size
|
||||
|
||||
return self._bucket_assignments_per_rank_cache
|
||||
@ -1088,6 +1108,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||
|
||||
return loss
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
def step(
|
||||
self,
|
||||
closure: Optional[Callable[[], float]] = None,
|
||||
|
@ -282,6 +282,7 @@ class LossWrapper(torch.nn.Module):
|
||||
|
||||
|
||||
class TrivialLossWrapper(LossWrapper):
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(self, x, targets):
|
||||
model_out = self.module(x)
|
||||
return self.loss_fn(model_out, targets)
|
||||
|
@ -245,6 +245,7 @@ def stage_backward_weight(
|
||||
if non_none_grads:
|
||||
summed_grad = sum(non_none_grads)
|
||||
valid_edges.append(GradientEdge(intermediate, 0))
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
valid_grad_outputs.append(summed_grad)
|
||||
|
||||
# Break a reference cycle caused inside stage_backward_input->get_hook->hook
|
||||
|
@ -81,6 +81,7 @@ def get_schedule_ops(
|
||||
raise ValueError(f"Invalid schedule: {schedule_class}")
|
||||
|
||||
# Instantiate the schedule class
|
||||
# pyrefly: ignore # bad-instantiation, bad-argument-type
|
||||
schedule_instance = schedule_class(stages, num_microbatches)
|
||||
assert schedule_instance.pipeline_order is not None
|
||||
|
||||
|
@ -279,6 +279,7 @@ def _shard_dict_of_args(
|
||||
f"Unsupported chunk spec: {spec} and value: {v} combination."
|
||||
)
|
||||
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
for _flat_split_result, _v_split in zip(
|
||||
flat_split_results, v_splits, strict=True
|
||||
):
|
||||
|
@ -246,6 +246,7 @@ def _format_pipeline_order(
|
||||
pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order)
|
||||
]
|
||||
# Transpose the list of lists (rows to columns)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue=""))
|
||||
# Generate column labels for ranks
|
||||
num_ranks = len(pipeline_order)
|
||||
|
@ -155,6 +155,7 @@ class _PipelineStageBase(ABC):
|
||||
self.submod = submodule
|
||||
self.stage_index = stage_index
|
||||
self.num_stages = num_stages
|
||||
# pyrefly: ignore # read-only
|
||||
self.device = device
|
||||
self.group = group
|
||||
|
||||
|
@ -36,12 +36,14 @@ class _remote_device:
|
||||
elif isinstance(remote_device, str):
|
||||
fields = remote_device.split("/")
|
||||
if len(fields) == 2:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self._worker_name, self._device = fields
|
||||
elif len(fields) == 1:
|
||||
# Check if this is a valid device.
|
||||
if _remote_device._is_valid_local_device(fields[0]):
|
||||
self._device = fields[0]
|
||||
else:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self._worker_name = fields[0]
|
||||
self._device = "cpu"
|
||||
else:
|
||||
@ -63,6 +65,7 @@ class _remote_device:
|
||||
# rank:<rank>/device format, extract rank
|
||||
if fields[0] == "rank" and fields[1].isdigit():
|
||||
self._rank = int(fields[1]) # type: ignore[assignment]
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self._worker_name = None
|
||||
else:
|
||||
raise ValueError(PARSE_ERROR)
|
||||
|
@ -93,6 +93,7 @@ def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwa
|
||||
result = result._replace(
|
||||
query=f"{'&'.join([f'{k}={v}' for k, v in query_dict.items()])}"
|
||||
)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
url = urlunparse(result)
|
||||
|
||||
if result.scheme not in _rendezvous_handlers:
|
||||
@ -110,6 +111,7 @@ def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs):
|
||||
if not isinstance(world_size, numbers.Integral):
|
||||
raise RuntimeError(f"`world_size` must be an integer. {world_size}")
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return _rendezvous_helper(url, rank, world_size, **kwargs)
|
||||
|
||||
|
||||
|
@ -473,6 +473,7 @@ def _rref_typeof_on_user(
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
GenericWithOneTypeVar = Generic[T]
|
||||
|
||||
|
||||
@ -719,6 +720,7 @@ def _invoke_rpc(
|
||||
is_async_exec = hasattr(func, "_wrapped_async_rpc_function")
|
||||
|
||||
if is_async_exec:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
wrapped = func._wrapped_async_rpc_function
|
||||
if isinstance(wrapped, torch.jit.ScriptFunction):
|
||||
func = wrapped
|
||||
|
@ -95,6 +95,7 @@ def register_backend(
|
||||
BackendType.__repr__ = _backend_type_repr # type: ignore[assignment]
|
||||
if BackendType.__doc__:
|
||||
BackendType.__doc__ = _backend_type_doc
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
return BackendType[backend_name]
|
||||
|
||||
|
||||
|
@ -48,6 +48,7 @@ else:
|
||||
_TensorPipeRpcBackendOptionsBase = object # type: ignore[assignment, misc]
|
||||
|
||||
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
|
||||
r"""
|
||||
The backend options for
|
||||
|
@ -4,6 +4,8 @@
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
|
||||
# pyrefly: ignore # deprecated
|
||||
from torch.autograd.profiler_legacy import profile
|
||||
|
||||
from . import (
|
||||
@ -174,11 +176,13 @@ class _server_process_global_profile(profile):
|
||||
flattened_function_events = list(
|
||||
itertools.chain.from_iterable(process_global_function_events)
|
||||
)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.function_events = torch.autograd.profiler_util.EventList(
|
||||
flattened_function_events,
|
||||
use_device="cuda" if self.use_cuda else None,
|
||||
profile_memory=self.profile_memory,
|
||||
)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.function_events._build_tree()
|
||||
|
||||
self.process_global_function_events = process_global_function_events
|
||||
|
@ -840,6 +840,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str
|
||||
) from e
|
||||
|
||||
logs_specs_cls: type[LogsSpecs] = _get_logs_specs_class(args.logs_specs)
|
||||
# pyrefly: ignore # bad-instantiation
|
||||
logs_specs = logs_specs_cls(
|
||||
log_dir=args.log_dir,
|
||||
redirects=Std.from_str(args.redirects),
|
||||
|
@ -189,7 +189,10 @@ class OpDispatcher:
|
||||
|
||||
local_tensor_args = (
|
||||
pytree.tree_unflatten(
|
||||
cast(list[object], op_info.local_args), op_info.args_tree_spec
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
cast(list[object], op_info.local_args),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
op_info.args_tree_spec,
|
||||
)
|
||||
if op_info.args_tree_spec
|
||||
else op_info.local_args
|
||||
@ -361,7 +364,11 @@ class OpDispatcher:
|
||||
|
||||
with redistribute_context:
|
||||
resharded_local_tensor = redistribute_local_tensor(
|
||||
local_tensor, arg_spec, reshard_arg_spec
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
local_tensor,
|
||||
arg_spec,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
reshard_arg_spec,
|
||||
)
|
||||
new_local_args.append(resharded_local_tensor)
|
||||
else:
|
||||
@ -431,7 +438,11 @@ class OpDispatcher:
|
||||
op_call, args_list
|
||||
)
|
||||
kwargs_schema[k] = self._try_replicate_spec_for_scalar_tensor(
|
||||
op_call, v, compute_mesh
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
op_call,
|
||||
v,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
compute_mesh,
|
||||
)
|
||||
local_kwargs[k] = v
|
||||
else:
|
||||
@ -447,6 +458,7 @@ class OpDispatcher:
|
||||
OpSchema(
|
||||
op_call,
|
||||
(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
pytree.tree_unflatten(args_schema, args_spec)
|
||||
if args_spec
|
||||
else tuple(args_schema)
|
||||
|
@ -171,6 +171,7 @@ def einop_rule(
|
||||
global_shape, input_spec.mesh, input_spec.placements
|
||||
)
|
||||
cost += prod(local_shape) * input_spec.mesh.size(mesh_dim)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
costs.append(cost)
|
||||
d_to_keep_sharding = dims[costs.index(max(costs))]
|
||||
for d in dims:
|
||||
|
@ -131,6 +131,7 @@ class _NormPartial(Partial):
|
||||
if self.reduce_op == "sum":
|
||||
assert isinstance(self.norm_type, (int, float)), f"{self.norm_type}"
|
||||
if self.norm_type != 0 and self.norm_type != 1:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
return tensor**self.norm_type
|
||||
return tensor
|
||||
|
||||
@ -138,6 +139,7 @@ class _NormPartial(Partial):
|
||||
if self.reduce_op == "sum":
|
||||
assert isinstance(self.norm_type, (int, float)), f"{self.norm_type}"
|
||||
if self.norm_type != 0 and self.norm_type != 1:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
return tensor ** (1.0 / self.norm_type)
|
||||
return tensor
|
||||
|
||||
|
@ -1057,7 +1057,9 @@ def grouped_mm_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||
)
|
||||
return TensorMeta(torch.Size(local_shape), local_stride, meta.dtype)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
mat1_meta = local_meta(mat1_strategy.strategies[0], input_specs[0].placements)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
mat2_meta = local_meta(mat2_strategy.strategies[0], input_specs[1].placements)
|
||||
|
||||
def check_valid_strides(meta: TensorMeta) -> bool:
|
||||
|
@ -336,6 +336,7 @@ def expand_to_full_mesh_op_strategy(
|
||||
for specs in zip(*strategy_comb):
|
||||
if specs[0] is not None:
|
||||
# TODO: we should fill in tensor_meta here. If nothing else, it helps the filter strategy callback
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
spec_list.append(DTensorSpec(mesh, specs))
|
||||
else:
|
||||
spec_list.append(None)
|
||||
|
@ -150,6 +150,7 @@ class _RNGStateTracker:
|
||||
"""
|
||||
|
||||
def __init__(self, device: torch.device):
|
||||
# pyrefly: ignore # read-only
|
||||
self._device = device
|
||||
self._device_handle = _get_device_handle(self._device.type)
|
||||
if not (self._device_handle and self._device_handle.is_available()):
|
||||
|
@ -256,8 +256,10 @@ def convolution_backward_handler(
|
||||
kwargs: dict[str, object],
|
||||
) -> object:
|
||||
# Redistribute grad_output tensor to the same placement as input tensor
|
||||
# pyrefly: ignore # bad-assignment
|
||||
args = list(args)
|
||||
assert isinstance(args[0], dtensor.DTensor) and isinstance(args[1], dtensor.DTensor)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
args[0] = args[0].redistribute(args[1].device_mesh, args[1].placements)
|
||||
args = tuple(args)
|
||||
|
||||
|
@ -594,6 +594,7 @@ class CommDebugMode(TorchDispatchMode):
|
||||
self.advanced_module_tracker.__enter__()
|
||||
return self
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
def __exit__(self, *args):
|
||||
self.advanced_module_tracker.__exit__()
|
||||
super().__exit__(*args)
|
||||
|
@ -90,6 +90,7 @@ def print_op_coverage_summary(model: nn.Module, args, kwargs, *, output_csv=Fals
|
||||
op_infos.sort(key=itemgetter(count_idx), reverse=True)
|
||||
|
||||
headers = ["Operator", "Schema", "Total Count", "Supported"]
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
print(tabulate(op_infos, headers=headers))
|
||||
|
||||
if output_csv:
|
||||
@ -101,4 +102,5 @@ def print_op_coverage_summary(model: nn.Module, args, kwargs, *, output_csv=Fals
|
||||
csv_writer.writerow(headers)
|
||||
# Write each table row to the CSV file
|
||||
for row in op_infos:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
csv_writer.writerow(row)
|
||||
|
@ -90,8 +90,12 @@ class LocalShardsWrapper(torch.Tensor):
|
||||
# TODO: we shall continually extend this function to support more ops if needed
|
||||
if func in supported_ops:
|
||||
res_shards_list = [
|
||||
func(shard, *args[1:], **kwargs) for shard in args[0].shards
|
||||
# pyrefly: ignore # index-error
|
||||
func(shard, *args[1:], **kwargs)
|
||||
# pyrefly: ignore # index-error
|
||||
for shard in args[0].shards
|
||||
]
|
||||
# pyrefly: ignore # index-error
|
||||
return LocalShardsWrapper(res_shards_list, args[0].shard_offsets)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
@ -141,6 +145,7 @@ def run_torchrec_row_wise_even_sharding_example(rank, world_size):
|
||||
local_tensor = torch.randn(local_shard_shape, device=device)
|
||||
# row-wise sharding: one shard per rank
|
||||
# create the local shards wrapper
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
local_shards_wrapper = LocalShardsWrapper(
|
||||
local_shards=[local_tensor],
|
||||
offsets=[local_shard_offset],
|
||||
@ -219,6 +224,7 @@ def run_torchrec_row_wise_uneven_sharding_example(rank, world_size):
|
||||
# local shards
|
||||
# row-wise sharding: one shard per rank
|
||||
# create the local shards wrapper
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
local_shards_wrapper = LocalShardsWrapper(
|
||||
local_shards=[local_tensor],
|
||||
offsets=[local_shard_offset],
|
||||
@ -297,6 +303,7 @@ def run_torchrec_table_wise_sharding_example(rank, world_size):
|
||||
local_shard_offset = torch.Size((0, 0))
|
||||
# wrap local shards into a wrapper
|
||||
local_shards_wrapper = (
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
LocalShardsWrapper(
|
||||
local_shards=[local_tensor],
|
||||
offsets=[local_shard_offset],
|
||||
|
@ -475,6 +475,7 @@ def _templated_ring_attention(
|
||||
)
|
||||
sdpa_merger.step(out, logsumexp, partial)
|
||||
|
||||
# pyrefly: ignore # unbound-name
|
||||
return *sdpa_merger.results(), *rest
|
||||
|
||||
|
||||
@ -641,6 +642,7 @@ def _templated_ring_attention_backward(
|
||||
grad_query,
|
||||
grad_key,
|
||||
grad_value,
|
||||
# pyrefly: ignore # unbound-name
|
||||
*rest,
|
||||
)
|
||||
|
||||
@ -988,7 +990,9 @@ def _distribute_function(
|
||||
|
||||
def _restore_function(fn: Callable, fn_module: types.ModuleType) -> None:
|
||||
"""Restore the function that is replaced by _distribute_function."""
|
||||
# pyrefly: ignore # unknown-name
|
||||
global _original_functions
|
||||
# pyrefly: ignore # unknown-name
|
||||
global _wrapper_functions
|
||||
|
||||
if fn not in _replaced_functions:
|
||||
@ -1021,6 +1025,7 @@ def _context_parallel_dispatcher(
|
||||
placement = [Shard(seq_dim)]
|
||||
all_args = []
|
||||
|
||||
# pyrefly: ignore # bad-assignment, bad-argument-type
|
||||
for arg in itertools.chain(args, kwargs.values()):
|
||||
if isinstance(arg, torch.Tensor) and not isinstance(arg, DTensor):
|
||||
arg = DTensor.from_local(arg, mesh, placement, run_check=False)
|
||||
|
@ -238,6 +238,7 @@ def _local_map_wrapped(
|
||||
|
||||
flat_local_args.append(arg)
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
local_args = pytree.tree_unflatten(flat_local_args, args_spec)
|
||||
|
||||
out = func(*local_args, **kwargs)
|
||||
@ -271,6 +272,7 @@ def _local_map_wrapped(
|
||||
|
||||
flat_dist_out.append(out)
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return pytree.tree_unflatten(flat_dist_out, out_spec)
|
||||
else:
|
||||
return out
|
||||
|
@ -237,8 +237,11 @@ def _mark_sharding(
|
||||
op_schema,
|
||||
)
|
||||
placement_strategies[node] = OpSpec(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
output_specs=_get_output_spec_from_output_sharding(output_sharding),
|
||||
# pyrefly: ignore # missing-attribute
|
||||
input_specs=output_sharding.redistribute_schema.args_spec
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if output_sharding.redistribute_schema is not None
|
||||
else _get_input_node_specs(node, placement_strategies),
|
||||
)
|
||||
|
@ -135,9 +135,11 @@ def _rewrite_spec_if_needed(
|
||||
break
|
||||
if rewrite:
|
||||
spec = copy.deepcopy(spec)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
for i, placement in enumerate(spec.placements):
|
||||
placement = cast(_remote_device, placement)
|
||||
if placement.rank() == rank and placement.device() != tensor.device:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
spec.placements[i] = _remote_device(f"rank:{rank}/{tensor.device}")
|
||||
|
||||
return spec
|
||||
|
@ -251,6 +251,7 @@ def _nll_loss_forward(
|
||||
if weight is not None:
|
||||
new_shape = list(x.shape)
|
||||
new_shape[channel_dim] = -1
|
||||
# pyrefly: ignore # unbound-name
|
||||
w = w.expand(new_shape)
|
||||
wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim)
|
||||
wsum = torch.where(target != ignore_index, wsum, 0)
|
||||
@ -308,7 +309,9 @@ def _nll_loss_forward_handler(
|
||||
output_placements = all_replicate_placements
|
||||
|
||||
# tensor inputs to _propagate_tensor_meta need to be DTensors
|
||||
# pyrefly: ignore # bad-assignment
|
||||
args = list(args)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
args[1], args[2] = target, weight
|
||||
output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs)
|
||||
|
||||
@ -439,8 +442,11 @@ def _nll_loss_backward_handler(
|
||||
weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh)
|
||||
|
||||
# tensor inputs to _propagate_tensor_meta need to be DTensors
|
||||
# pyrefly: ignore # bad-assignment
|
||||
args = list(args)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
args[2], args[3] = target, weight
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
args[6] = _cast_to_dtensor(total_weight, all_replicate_placements, spec.mesh)
|
||||
output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs)
|
||||
|
||||
|
@ -548,6 +548,7 @@ class PrepareModuleInput(ParallelStyle):
|
||||
assert self.desired_input_layouts is not None, (
|
||||
"desired module inputs should not be None!"
|
||||
)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
for inp, input_layout, desired_layout in zip(
|
||||
inputs, self.input_layouts, self.desired_input_layouts
|
||||
):
|
||||
@ -663,6 +664,7 @@ class PrepareModuleOutput(ParallelStyle):
|
||||
raise ValueError(
|
||||
"module outputs and output_layouts should have same length!"
|
||||
)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
for out, out_layout, desired_out_layout in zip(
|
||||
outputs, self.output_layouts, self.desired_output_layouts
|
||||
):
|
||||
|
@ -59,6 +59,7 @@ def _cast_forward_inputs(
|
||||
def cast_fn(x: torch.Tensor) -> torch.Tensor:
|
||||
if not torch.is_floating_point(x) or x.dtype == dtype:
|
||||
return x
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return x.to(dtype)
|
||||
|
||||
return (_apply_to_tensors(cast_fn, args), _apply_to_tensors(cast_fn, kwargs))
|
||||
@ -133,12 +134,16 @@ def _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies):
|
||||
from torch.nn.parallel.scatter_gather import _is_namedtuple
|
||||
|
||||
if _is_namedtuple(obj):
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return [type(obj)(*args) for args in zip(*map(to_map, obj))]
|
||||
if isinstance(obj, tuple) and len(obj) > 0:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return list(zip(*map(to_map, obj)))
|
||||
if isinstance(obj, list) and len(obj) > 0:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return [list(i) for i in zip(*map(to_map, obj))]
|
||||
if isinstance(obj, dict) and len(obj) > 0:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return [type(obj)(i) for i in zip(*map(to_map, obj.items()))]
|
||||
return [obj]
|
||||
|
||||
|
Reference in New Issue
Block a user