**Summary**
Fix circular import in `torch/distributed/utils.py` found when running internal test, see D62901023. Curious why this wasn't causing any issue. Is this relevant code deprecated and no longer used?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136286
Approved by: https://github.com/Skylion007
resolve: https://github.com/pytorch/pytorch/pull/135029
when enabling mixed precision, FSDP cast input args to desired dtype by calling `_apply_to_tensors`. When input args has `dataclass(frozen=True)`, we hit following runtime error, because of using `setattr` in `_apply_to_tensors`
`dataclasses.FrozenInstanceError: cannot assign to field 'some_key'`. The fix is to use dataclasses api `dataclasses.replace`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135067
Approved by: https://github.com/awgu
This PR relaxes `@contract` to allow the 1st argument to be `Sequence[nn.Module]` instead of strictly `nn.Module`. This is required for the next PR, which allows `fully_shard` to take in `List[nn.Module]`.
---
**Changes for reland:**
- The previous PR assumed that any `func` decorated with `@contract` would return the same input `module` as output (which is true for PT-D composable APIs).
- However, TorchRec `shard` returns a different module as output (though that module _does_ satisfy the `@contract` FQN check).
- This PR removes the assumption and instead only enforces the FQN check following the input module order. In other words, if calling `func([x1, ..., xN])` for `N` modules `x1, ..., xN` that returns `[y1, ..., yM]` for `M` modules, we require that `N = M` and that FQNs are preserved coordinate-wise: `xi` and `yi` have same FQNs for all `i = 1, ..., N`.
Differential Revision: [D59863438](https://our.internmc.facebook.com/intern/diff/D59863438)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130947
Approved by: https://github.com/weifengpy, https://github.com/atalman
This PR relaxes `@contract` to allow the 1st argument to be `Sequence[nn.Module]` instead of strictly `nn.Module`. This is required for the next PR, which allows `fully_shard` to take in `List[nn.Module]`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127773
Approved by: https://github.com/weifengpy
Improve Dynamo to support the FSDP2 `use_training_state()` context manager.
Test command:
`
pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_dynamo_trace_use_training_state
`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127854
Approved by: https://github.com/yanboliang
Fixes#112639
```txt
torch/utils/_sympy/value_ranges.py
torch/utils/_sympy/value_ranges.py:60 in public class `ValueRanges`:
D101: Missing docstring in public class
torch/utils/_sympy/value_ranges.py:68 in public method `__init__`:
D107: Missing docstring in __init__
torch/utils/_sympy/value_ranges.py:81 in public method `__contains__`:
D105: Missing docstring in magic method
torch/utils/_sympy/value_ranges.py:86 in public method `tighten`:
D400: First line should end with a period (not 'n')
torch/utils/_sympy/value_ranges.py:90 in public method `__and__`:
D105: Missing docstring in magic method
torch/utils/_sympy/value_ranges.py:103 in public method `__or__`:
D105: Missing docstring in magic method
torch/utils/_sympy/value_ranges.py:113 in public method `is_singleton`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:118 in public method `unknown`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:122 in public method `wrap`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:129 in public method `increasing_map`:
D400: First line should end with a period (not ')')
torch/utils/_sympy/value_ranges.py:135 in public method `decreasing_map`:
D400: First line should end with a period (not ')')
torch/utils/_sympy/value_ranges.py:141 in public method `monotone_map`:
D400: First line should end with a period (not 'g')
torch/utils/_sympy/value_ranges.py:149 in public method `convex_min_zero_map`:
D400: First line should end with a period (not '0')
torch/utils/_sympy/value_ranges.py:149 in public method `convex_min_zero_map`:
D403: First word of the first line should be properly capitalized ('Fn', not 'fn')
torch/utils/_sympy/value_ranges.py:158 in public method `coordinatewise_increasing_map`:
D205: 1 blank line required between summary line and description (found 0)
torch/utils/_sympy/value_ranges.py:158 in public method `coordinatewise_increasing_map`:
D400: First line should end with a period (not ':')
torch/utils/_sympy/value_ranges.py:171 in public method `coordinatewise_monotone_map`:
D400: First line should end with a period (not 'e')
torch/utils/_sympy/value_ranges.py:180 in private class `SymPyValueRangeAnalysis`:
D205: 1 blank line required between summary line and description (found 0)
torch/utils/_sympy/value_ranges.py:180 in private class `SymPyValueRangeAnalysis`:
D400: First line should end with a period (not 's')
torch/utils/_sympy/value_ranges.py:386 in private method `reciprocal`:
D210: No whitespaces allowed surrounding docstring text
torch/utils/_sympy/value_ranges.py:386 in private method `reciprocal`:
D400: First line should end with a period (not 'n')
torch/utils/_sympy/value_ranges.py:488 in public class `ValueRangeAnalysis`:
D101: Missing docstring in public class
torch/utils/_sympy/value_ranges.py:489 in public method `__init__`:
D107: Missing docstring in __init__
torch/utils/_sympy/value_ranges.py:501 in public method `bool_handler`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:506 in public method `default_handler`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:511 in public method `load`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:514 in public method `store`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:517 in public method `reduction`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:520 in public method `index_expr`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:525 in public method `to_dtype`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:558 in public method `square`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:562 in public method `neg`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:566 in public method `truncdiv`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:577 in public method `sub`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:580 in public method `__getattr__`:
D105: Missing docstring in magic method
torch/utils/_sympy/value_ranges.py:585 in public function `bound_sympy`:
D103: Missing docstring in public function
36
torch/utils/_sympy/value_ranges.py:60 in public class `ValueRanges`:
D101: Missing docstring in public class
torch/utils/_sympy/value_ranges.py:68 in public method `__init__`:
D107: Missing docstring in __init__
torch/utils/_sympy/value_ranges.py:81 in public method `__contains__`:
D105: Missing docstring in magic method
torch/utils/_sympy/value_ranges.py:86 in public method `tighten`:
D400: First line should end with a period (not 'n')
torch/utils/_sympy/value_ranges.py:90 in public method `__and__`:
D105: Missing docstring in magic method
torch/utils/_sympy/value_ranges.py:103 in public method `__or__`:
D105: Missing docstring in magic method
torch/utils/_sympy/value_ranges.py:113 in public method `is_singleton`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:118 in public method `unknown`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:122 in public method `wrap`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:182 in private class `SymPyValueRangeAnalysis`:
D205: 1 blank line required between summary line and description (found 0)
torch/utils/_sympy/value_ranges.py:182 in private class `SymPyValueRangeAnalysis`:
D400: First line should end with a period (not 's')
torch/utils/_sympy/value_ranges.py:388 in private method `reciprocal`:
D210: No whitespaces allowed surrounding docstring text
torch/utils/_sympy/value_ranges.py:388 in private method `reciprocal`:
D400: First line should end with a period (not 'n')
torch/utils/_sympy/value_ranges.py:490 in public class `ValueRangeAnalysis`:
D101: Missing docstring in public class
torch/utils/_sympy/value_ranges.py:491 in public method `__init__`:
D107: Missing docstring in __init__
torch/utils/_sympy/value_ranges.py:503 in public method `bool_handler`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:508 in public method `default_handler`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:513 in public method `load`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:516 in public method `store`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:519 in public method `reduction`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:522 in public method `index_expr`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:527 in public method `to_dtype`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:560 in public method `square`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:564 in public method `neg`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:568 in public method `truncdiv`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:579 in public method `sub`:
D102: Missing docstring in public method
torch/utils/_sympy/value_ranges.py:582 in public method `__getattr__`:
D105: Missing docstring in magic method
torch/utils/_sympy/value_ranges.py:587 in public function `bound_sympy`:
D103: Missing docstring in public function
28
torch/utils/viz/_cycles.py
torch/utils/viz/_cycles.py:14 in public function `observe_garbage`:
D103: Missing docstring in public function
torch/utils/viz/_cycles.py:207 in public function `object_annotation`:
D205: 1 blank line required between summary line and description (found 0)
torch/utils/viz/_cycles.py:207 in public function `object_annotation`:
D400: First line should end with a period (not 'g')
torch/utils/viz/_cycles.py:256 in public class `Node`:
D101: Missing docstring in public class
torch/utils/viz/_cycles.py:262 in public function `create_graph`:
D103: Missing docstring in public function
torch/utils/viz/_cycles.py:308 in public function `escape`:
D103: Missing docstring in public function
torch/utils/viz/_cycles.py:312 in public function `is_cuda_tensor`:
D103: Missing docstring in public function
torch/utils/viz/_cycles.py:315 in public function `cuda_allocation_context`:
D103: Missing docstring in public function
torch/utils/viz/_cycles.py:335 in public function `to_dot`:
D103: Missing docstring in public function
torch/utils/viz/_cycles.py:406 in public function `to_html`:
D103: Missing docstring in public function
torch/utils/viz/_cycles.py:416 in public function `observe_tensor_cycles`:
D103: Missing docstring in public function
torch/utils/viz/_cycles.py:429 in public function `warn_tensor_cycles`:
D205: 1 blank line required between summary line and description (found 0)
torch/utils/viz/_cycles.py:429 in public function `warn_tensor_cycles`:
D400: First line should end with a period (not 'p')
torch/utils/viz/_cycles.py:429 in public function `warn_tensor_cycles`:
D401: First line should be in imperative mood; try rephrasing (found 'Reference')
14
torch/utils/viz/_cycles.py:14 in public function `observe_garbage`:
D103: Missing docstring in public function
torch/utils/viz/_cycles.py:256 in public class `Node`:
D101: Missing docstring in public class
torch/utils/viz/_cycles.py:262 in public function `create_graph`:
D103: Missing docstring in public function
torch/utils/viz/_cycles.py:308 in public function `escape`:
D103: Missing docstring in public function
torch/utils/viz/_cycles.py:312 in public function `is_cuda_tensor`:
D103: Missing docstring in public function
torch/utils/viz/_cycles.py:315 in public function `cuda_allocation_context`:
D103: Missing docstring in public function
torch/utils/viz/_cycles.py:335 in public function `to_dot`:
D103: Missing docstring in public function
torch/utils/viz/_cycles.py:406 in public function `to_html`:
D103: Missing docstring in public function
torch/utils/viz/_cycles.py:416 in public function `observe_tensor_cycles`:
D103: Missing docstring in public function
9
torch/distributed/argparse_util.py
torch/distributed/argparse_util.py:1 at module level:
D100: Missing docstring in public module
torch/distributed/argparse_util.py:13 in public class `env`:
D205: 1 blank line required between summary line and description (found 0)
torch/distributed/argparse_util.py:13 in public class `env`:
D400: First line should end with a period (not 'g')
torch/distributed/argparse_util.py:13 in public class `env`:
D412: No blank lines allowed between a section header and its content ('Example')
torch/distributed/argparse_util.py:43 in public method `__init__`:
D107: Missing docstring in __init__
torch/distributed/argparse_util.py:56 in public method `__call__`:
D102: Missing docstring in public method
torch/distributed/argparse_util.py:61 in public class `check_env`:
D205: 1 blank line required between summary line and description (found 0)
torch/distributed/argparse_util.py:61 in public class `check_env`:
D400: First line should end with a period (not 's')
torch/distributed/argparse_util.py:61 in public class `check_env`:
D412: No blank lines allowed between a section header and its content ('Example')
torch/distributed/argparse_util.py:97 in public method `__init__`:
D107: Missing docstring in __init__
torch/distributed/argparse_util.py:102 in public method `__call__`:
D102: Missing docstring in public method
11
torch/distributed/argparse_util.py:1 at module level:
D100: Missing docstring in public module
torch/distributed/argparse_util.py:43 in public method `__init__`:
D107: Missing docstring in __init__
torch/distributed/argparse_util.py:56 in public method `__call__`:
D102: Missing docstring in public method
torch/distributed/argparse_util.py:97 in public method `__init__`:
D107: Missing docstring in __init__
torch/distributed/argparse_util.py:102 in public method `__call__`:
D102: Missing docstring in public method
5
torch/distributed/_composable_state.py
torch/distributed/_composable_state.py:20 in private function `_get_module_state`:
D202: No blank lines allowed after function docstring (found 1)
torch/distributed/_composable_state.py:20 in private function `_get_module_state`:
D205: 1 blank line required between summary line and description (found 0)
torch/distributed/_composable_state.py:20 in private function `_get_module_state`:
D400: First line should end with a period (not '`')
3
0
torch/distributed/launch.py
torch/distributed/launch.py:1 at module level:
D205: 1 blank line required between summary line and description (found 0)
torch/distributed/launch.py:1 at module level:
D400: First line should end with a period (not 'd')
torch/distributed/launch.py:156 in public function `parse_args`:
D103: Missing docstring in public function
torch/distributed/launch.py:171 in public function `launch`:
D103: Missing docstring in public function
torch/distributed/launch.py:180 in public function `main`:
D103: Missing docstring in public function
5
torch/distributed/launch.py:157 in public function `parse_args`:
D103: Missing docstring in public function
torch/distributed/launch.py:172 in public function `launch`:
D103: Missing docstring in public function
torch/distributed/launch.py:181 in public function `main`:
D103: Missing docstring in public function
3
torch/distributed/remote_device.py
torch/distributed/remote_device.py:1 at module level:
D100: Missing docstring in public module
torch/distributed/remote_device.py:81 in private method `worker_name`:
D205: 1 blank line required between summary line and description (found 0)
torch/distributed/remote_device.py:81 in private method `worker_name`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
torch/distributed/remote_device.py:88 in private method `rank`:
D205: 1 blank line required between summary line and description (found 0)
torch/distributed/remote_device.py:88 in private method `rank`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
torch/distributed/remote_device.py:95 in private method `device`:
D200: One-line docstring should fit on one line with quotes (found 3)
torch/distributed/remote_device.py:95 in private method `device`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
7
torch/distributed/remote_device.py:1 at module level:
D100: Missing docstring in public module
torch/distributed/remote_device.py:85 in private method `rank`:
D205: 1 blank line required between summary line and description (found 0)
torch/distributed/remote_device.py:85 in private method `rank`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
3
torch/distributed/rendezvous.py
torch/distributed/rendezvous.py:1 at module level:
D100: Missing docstring in public module
torch/distributed/rendezvous.py:23 in public function `register_rendezvous_handler`:
D401: First line should be in imperative mood (perhaps 'Register', not 'Registers')
torch/distributed/rendezvous.py:88 in public function `rendezvous`:
D103: Missing docstring in public function
torch/distributed/rendezvous.py:147 in private function `_create_c10d_store`:
D205: 1 blank line required between summary line and description (found 0)
torch/distributed/rendezvous.py:147 in private function `_create_c10d_store`:
D400: First line should end with a period (not 'r')
5
torch/distributed/rendezvous.py:1 at module level:
D100: Missing docstring in public module
torch/distributed/rendezvous.py:89 in public function `rendezvous`:
D103: Missing docstring in public function
2
torch/distributed/run.py
torch/distributed/run.py:9 at module level:
D205: 1 blank line required between summary line and description (found 0)
torch/distributed/run.py:9 at module level:
D400: First line should end with a period (not '`')
torch/distributed/run.py:393 in public function `get_args_parser`:
D202: No blank lines allowed after function docstring (found 1)
torch/distributed/run.py:393 in public function `get_args_parser`:
D401: First line should be in imperative mood; try rephrasing (found 'Helper')
torch/distributed/run.py:610 in public function `parse_args`:
D103: Missing docstring in public function
torch/distributed/run.py:615 in public function `parse_min_max_nnodes`:
D103: Missing docstring in public function
torch/distributed/run.py:629 in public function `determine_local_world_size`:
D103: Missing docstring in public function
torch/distributed/run.py:670 in public function `get_rdzv_endpoint`:
D103: Missing docstring in public function
torch/distributed/run.py:677 in public function `get_use_env`:
D205: 1 blank line required between summary line and description (found 0)
torch/distributed/run.py:677 in public function `get_use_env`:
D401: First line should be in imperative mood (perhaps 'Retrieve', not 'Retrieves')
torch/distributed/run.py:689 in public function `config_from_args`:
D103: Missing docstring in public function
torch/distributed/run.py:770 in public function `run_script_path`:
D205: 1 blank line required between summary line and description (found 0)
torch/distributed/run.py:770 in public function `run_script_path`:
D401: First line should be in imperative mood (perhaps 'Run', not 'Runs')
torch/distributed/run.py:781 in public function `run`:
D103: Missing docstring in public function
torch/distributed/run.py:804 in public function `main`:
D103: Missing docstring in public function
15
torch/distributed/run.py:611 in public function `parse_args`:
D103: Missing docstring in public function
torch/distributed/run.py:616 in public function `parse_min_max_nnodes`:
D103: Missing docstring in public function
torch/distributed/run.py:630 in public function `determine_local_world_size`:
D103: Missing docstring in public function
torch/distributed/run.py:671 in public function `get_rdzv_endpoint`:
D103: Missing docstring in public function
torch/distributed/run.py:691 in public function `config_from_args`:
D103: Missing docstring in public function
torch/distributed/run.py:784 in public function `run`:
D103: Missing docstring in public function
torch/distributed/run.py:807 in public function `main`:
D103: Missing docstring in public function
7
torch/distributed/__init__.py
torch/distributed/__init__.py:1 at module level:
D104: Missing docstring in public package
torch/distributed/__init__.py:8 in public function `is_available`:
D205: 1 blank line required between summary line and description (found 0)
torch/distributed/__init__.py:8 in public function `is_available`:
D400: First line should end with a period (not ',')
torch/distributed/__init__.py:8 in public function `is_available`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
4
torch/distributed/__init__.py:1 at module level:
D104: Missing docstring in public package
1
torch/distributed/utils.py:1 at module level:
D100: Missing docstring in public module
torch/distributed/utils.py:16 in private function `_pack_kwargs`:
D205: 1 blank line required between summary line and description (found 0)
torch/distributed/utils.py:16 in private function `_pack_kwargs`:
D400: First line should end with a period (not ')')
torch/distributed/utils.py:47 in private function `_cast_forward_inputs`:
D205: 1 blank line required between summary line and description (found 0)
torch/distributed/utils.py:88 in private function `_recursive_to`:
D200: One-line docstring should fit on one line with quotes (found 3)
torch/distributed/utils.py:141 in private function `_p_assert`:
D205: 1 blank line required between summary line and description (found 0)
torch/distributed/utils.py:141 in private function `_p_assert`:
D209: Multi-line docstring closing quotes should be on a separate line
torch/distributed/utils.py:141 in private function `_p_assert`:
D400: First line should end with a period (not 't')
torch/distributed/utils.py:141 in private function `_p_assert`:
D401: First line should be in imperative mood; try rephrasing (found 'This')
torch/distributed/utils.py:275 in private function `_sync_module_states`:
D205: 1 blank line required between summary line and description (found 0)
torch/distributed/utils.py:275 in private function `_sync_module_states`:
D400: First line should end with a period (not 'n')
torch/distributed/utils.py:275 in private function `_sync_module_states`:
D401: First line should be in imperative mood (perhaps 'Sync', not 'Syncs')
torch/distributed/utils.py:300 in private function `_sync_params_and_buffers`:
D205: 1 blank line required between summary line and description (found 0)
torch/distributed/utils.py:300 in private function `_sync_params_and_buffers`:
D400: First line should end with a period (not 'y')
torch/distributed/utils.py:300 in private function `_sync_params_and_buffers`:
D401: First line should be in imperative mood (perhaps 'Synchronize', not 'Synchronizes')
15
torch/distributed/utils.py:1 at module level:
D100: Missing docstring in public module
1
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112953
Approved by: https://github.com/weifengpy
Summary: Disable buffers sync in _sync_module_states(...) when broadcast_buffers is False. This change will memory usage when a model has huge buffers and does not need broadcast buffers.
Test Plan: .
Differential Revision: D45610709
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100729
Approved by: https://github.com/mrshenli
At present, DDP forward uses `_get_stream` to get a stream,which is cudaStream.
If the custom module already registered to torch, I can use `getattr` to get it and it's stream. Then, the custom stream is used to copy the tensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98723
Approved by: https://github.com/ezyang
Before this PR, if a user runs DDP with `device_ids` specified and with a `PackedSequence` input, then the execution will error with something like:
```
raise ValueError(
ValueError: batch_sizes should always be on CPU. Instances of PackedSequence should never be created manually. They should be instantiated by
functions like pack_sequence and pack_padded_sequences in nn.utils.rnn. https://pytorch.org/docs/stable/nn.html...
```
This is because the DDP forward calls `_to_kwargs()`, which calls `_recursive_to()`, which moves the inputs to GPU. However, `_is_namedtuple(packed_sequence)` returns `True`, leading to the branch `return [type(obj)(*args) for args in zip(*map(to_map, obj))]`, which tries to construct a `PackedSequence` directly via `type(obj)(*args)`, leading to the error.
Repro for `_is_namedtuple(packed_sequence)` returning `True`:
```
import random
import torch
import torch.nn.utils.rnn as rnn_utils
from torch.nn.parallel.scatter_gather import _is_namedtuple
def _ordered_sequence(tensor_type):
seqs = [tensor_type(random.randint(1, 256))
for _ in range(32)]
seqs = [s.random_(-128, 128) for s in seqs]
ordered = sorted(seqs, key=len, reverse=True)
return ordered
def _padded_sequence(tensor_type):
ordered = _ordered_sequence(tensor_type)
lengths = [len(i) for i in ordered]
padded_tensor = rnn_utils.pad_sequence(ordered)
return padded_tensor, lengths
padded, lengths = _padded_sequence(torch.Tensor)
packed = rnn_utils.pack_padded_sequence(
padded, lengths, enforce_sorted=False)
print(type(packed), packed.data.device)
print(_is_namedtuple(packed))
```
Test Plan:
```
python test/distributed/test_c10d_nccl.py -k test_ddp_packed_sequence
```
Without the fix, the added unit test fails with the expected error.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86614
Approved by: https://github.com/rohan-varma
- Uses state dict / load state dict hooks to ensure that modules wrapped with `CheckpointWrapper` can be loaded into non-checkpointed wrapped module.
This is because a training run can use activation checkpointing, then we can recover `state_dict`, and a future run may not want to wrap modules with activation checkpointing or decide to change activation checkpoint wrapping structure. To support this, we add hooks to remove / add the relevant prefix as needed.
Tests are added to ensure we can load into CheckpointWrapper module as well as local module from CheckpointWrapper-wrapped module. state_dict with FSDP is also verified.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77224
Approved by: https://github.com/zhaojuanmao
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62927
As part of the ShardedTensor work, we realized we do need some sort of
_RemoteDevice structure that deals with our format of "workername/device" so
that users don't have to worry about parsing this string directly.
Right now this structure is just the bare minimum and is mostly a container for
describing a remote device. It is currently only used in ShardedTensor,
ShardingSpec and RemoteModule.
Once we actually have a consolidated remote device proposal, this class can be
extended appropriately if needed.
ghstack-source-id: 135534086
Test Plan:
1) unit tests
2) waitforbuildbot
Reviewed By: SciPioneer
Differential Revision: D30170689
fbshipit-source-id: 1ac2e81c7a597dc40bf3fbf2c1168c382c66649f
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55728
Full design: https://github.com/pytorch/pytorch/issues/55207
This PR introduces ChunkShardingSpec (SingleShardingSpec in the design). Used
the name ChunkShardingSpec since it is very similar to `torch.chunk` in terms
of how a Tensor is split up and feels more clear compared to SingleShardingSpec.
ghstack-source-id: 129603318
Test Plan: waitforbuildbot
Reviewed By: SciPioneer
Differential Revision: D27694108
fbshipit-source-id: c8764abe6a4d5fc56d023fda29b74b5af2a73b49