61 Commits

Author SHA1 Message Date
995df34b19 [BE][PYFMT] migrate PYFMT for torch.{distributed,distributions} to ruff format (#144547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547
Approved by: https://github.com/kwen2501
2025-02-28 07:35:56 +00:00
94dc3253a0 [BE][Easy] enable UFMT for torch/distributed/ (#128870)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128870
Approved by: https://github.com/fegin, https://github.com/wconstab
2024-06-22 18:53:28 +00:00
9c929f6ce9 Revert "[BE][Easy] enable UFMT for torch/distributed/ (#128870)"
This reverts commit a0e1e20c4157bb3e537fc784a51d7aef1e754157.

Reverted https://github.com/pytorch/pytorch/pull/128870 on behalf of https://github.com/fbgheith due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/128870#issuecomment-2181780356))
2024-06-21 00:38:28 +00:00
a0e1e20c41 [BE][Easy] enable UFMT for torch/distributed/ (#128870)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128870
Approved by: https://github.com/fegin
ghstack dependencies: #128868, #128869
2024-06-18 21:49:08 +00:00
67ef2683d9 [BE] wrap deprecated function/class with typing_extensions.deprecated (#127689)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

Resolves #126888

- #126888

This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
2024-06-02 12:30:43 +00:00
033e733021 Revert "[BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)"
This reverts commit 749a132fb0a8325cbad4734a563aa459ca611991.

Reverted https://github.com/pytorch/pytorch/pull/126898 on behalf of https://github.com/fbgheith due to switching typing-extensions=4.3.0 to 4.9.0 causes internal failure ([comment](https://github.com/pytorch/pytorch/pull/126898#issuecomment-2142884456))
2024-05-31 19:47:24 +00:00
749a132fb0 [BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.

Resolves #126888

- #126888

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
2024-05-29 12:09:27 +00:00
64670e414e [reland] Create torch.distributed._shard package. (#72141)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72141

We have many sharding components currently:
torch.distributed._sharded_tensor, torch.distributed._sharding_spec,
torch.distributed._sharded_optimizer and more coming.

As a result, organizing all of this under the `torch.distributed._shard`
package. For BC reasons, I'm still keeping the old packages and have them just
reference the new package.
ghstack-source-id: 148150861
ghstack-source-id: 148150861

Test Plan: waitforbuildbot

Reviewed By: fduwjj

Differential Revision: D33904585

fbshipit-source-id: 057e847eb7521b536a3ee4e0f94871aacc752062
(cherry picked from commit 29a70dd7afde6083bab942081020a13278f38e52)
2022-02-02 06:58:20 +00:00
34494e6252 Back out "Create torch.distributed.shard package." (#72062)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72062

Original commit changeset: dc692b31e260

Original Phabricator Diff: D33755913 (87bbcf70f7)

Test Plan: CI

Reviewed By: pbelevich

Differential Revision: D33891115

fbshipit-source-id: 37286e03d743d8691319f07c95e9561d54f3d6d0
(cherry picked from commit 0c1b3fe00848a275d44d8c91fba91d3df6d4927f)
2022-01-31 18:29:27 +00:00
87bbcf70f7 Create torch.distributed.shard package. (#71742)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71742

We have many sharding components currently:
torch.distributed._sharded_tensor, torch.distributed._sharding_spec,
torch.distributed._sharded_optimizer and more coming.

As a result, organizing all of this under the `torch.distributed.shard`
package. For BC reasons, I'm still keeping the old packages and have them just
reference the new package.
ghstack-source-id: 147899768

Test Plan: waitforbuildbot

Reviewed By: fduwjj, wanchaol

Differential Revision: D33755913

fbshipit-source-id: dc692b31e2607063d55dfcb3db33ec53961d5a5b
(cherry picked from commit 5b6885f3587786217f8ce143f2329ceec618404e)
2022-01-29 00:48:06 +00:00
53b3904115 Fix memory leak in ShardedTensor. (#71445)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71445

A reference to the ShardedTensor was always added to the global map
`_sharded_tensor_map`, that never got cleaned up since the map always held a
reference to the ShardedTensor.

A couple of fixes for this:
1) Add to the global map only for `init_rrefs=True` since only this codepath
requires this.
2) Add a `weakref` to the global map to avoid having a reference to the
ShardedTensor forever that never gets cleaned up.
ghstack-source-id: 147299580

Test Plan: waitforbuildbot

Reviewed By: fduwjj

Differential Revision: D33641013

fbshipit-source-id: c552fa3359186514445fd5715bec93f67dc2262d
(cherry picked from commit d25f1a645313dcbf8c37158d80c42c983262cec2)
2022-01-20 19:38:41 +00:00
f5b19ba683 Additional unit test for sharded linear. (#70476)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70476

1) Support a single dimension for inputs
2) Test several error cases

Partially addresses https://github.com/pytorch/pytorch/issues/65638
ghstack-source-id: 146307607

Test Plan: waitforbuildbot

Reviewed By: fduwjj

Differential Revision: D33344357

fbshipit-source-id: 4de7a7177452951dbcce76f27441703447609e6f
(cherry picked from commit 96dfded5697e451b54f113f99b6d0da6f6af500d)
2022-01-20 01:23:44 +00:00
b56ba296b1 Support multiple input dims for sharded linear. (#70266)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70266

Addresses some of the issues mentioned in
https://github.com/pytorch/pytorch/issues/65638. ShardedLinear implementation
only support 2D inputs.

On the other hand `nn.Linear` supports arbitrary dimensions for inputs and
outputs. As a result, in this PR I've added support to ensure that
ShardedLinear supports arbitrary input dims as well.
ghstack-source-id: 147206607

Test Plan: waitforbuildbot

Reviewed By: wanchaol

Differential Revision: D33267630

fbshipit-source-id: 0460994c3aa33348b80547d9274206ef90cb29b6
(cherry picked from commit 7c289e1dbf491008e091ed0a49f98f2ebcfb4175)
2022-01-19 08:07:14 +00:00
2378421340 Implement torch.allclose for sharded tensor. (#70331)
Summary:
Implement torch.allclose op for sharded tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/70331

Test Plan:
Automated test added.
pritamdamania87
Fixes https://github.com/pytorch/pytorch/issues/67112

cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang

Reviewed By: pritamdamania87

Differential Revision: D33339137

Pulled By: kumpera

fbshipit-source-id: 4263e468eaa117317b190f69877bf3f8bbac5658
2022-01-07 08:37:04 -08:00
0544f975e1 [reland] Support torch.equal for ShardedTensor. (#70145)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70145

Added support for torch.equal to ShardedTensor. This is really
helpful in terms of comparing two ShardedTensors.
ghstack-source-id: 146066939

Test Plan: waitforbuildbot

Reviewed By: wanchaol

Differential Revision: D33201714

fbshipit-source-id: 56adfc36e345d512c9901c56c07759bf658c745b
2021-12-21 13:22:52 -08:00
3e8ef9a272 Add return type annotation for ShardedTensor (#69945)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69945

Test Plan: CI

Reviewed By: wanchaol

Differential Revision: D32502393

fbshipit-source-id: 7bea08762446b211d8ea028d024d2acdabe45479
2021-12-20 17:15:44 -08:00
92320dfe6e [shard] remove set device for nccl (#69946)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69946

This PR remove the implicit set_device for nccl pg according to the proposal of https://github.com/pytorch/pytorch/issues/69731
ghstack-source-id: 145847504

Test Plan: wait for ci

Reviewed By: pritamdamania87

Differential Revision: D33099095

fbshipit-source-id: 3fe9f6a0facf5ea513c267e9f32c6a7fd56cc8a2
2021-12-16 17:16:42 -08:00
b199e3c842 Provide functionality to write custom ShardedTensor ops. (#69874)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69874

We have a handful of ops supported for ShardedTensor via
``__torch_function__`` dispatch. However, we currently can't cover all torch
operators and having a way for users to extend this functionality will make
this functionality much more general.

In this PR, I've introduced a custom_sharded_op decorator which can be used to
register a custom sharded op implementation.
ghstack-source-id: 145841141

Test Plan: waitforbuildbot

Reviewed By: wanchaol

Differential Revision: D33078587

fbshipit-source-id: 5936b7ac25582e613653c19afa559219719ee54b
2021-12-16 12:40:13 -08:00
dc18048dd8 [PT-D][Fix] Broken sharded embedding and embedding bag test fix (#69725)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69725

We have added a `no_grad` cx manager in the tensor sharding to ensure that the local_shard is the root node. But it turns out for embedding and embedding_bag, when the `max_norm` is specified, it will complain for row-wise sharding. We use the original `max_norm` of the operators.

Error traces:
```
  File "/data/sandcastle/boxes/fbsource/fbcode/buck-out/dev/gen/caffe2/test/distributed/_sharded_tensor/sharded_embedding#binary,link-tree/torch/overrides.py", line 1389, in handle_torch_function
    result = torch_func_method(public_api, types, args, kwargs)
  File "/data/sandcastle/boxes/fbsource/fbcode/buck-out/dev/gen/caffe2/test/distributed/_sharded_tensor/sharded_embedding#binary,link-tree/torch/distributed/_sharded_tensor/api.py", line 554, in __torch_function__
    return sharded_embedding(types, args, kwargs, self._process_group)
  File "/data/sandcastle/boxes/fbsource/fbcode/buck-out/dev/gen/caffe2/test/distributed/_sharded_tensor/sharded_embedding#binary,link-tree/torch/distributed/_sharded_tensor/ops/embedding.py", line 115, in sharded_embedding
    return _handle_row_wise_sharding(
  File "/data/sandcastle/boxes/fbsource/fbcode/buck-out/dev/gen/caffe2/test/distributed/_sharded_tensor/sharded_embedding#binary,link-tree/torch/distributed/_sharded_tensor/ops/embedding.py", line 309, in _handle_row_wise_sharding
    gathered_input_embeddings = torch.nn.functional.embedding(
  File "/data/sandcastle/boxes/fbsource/fbcode/buck-out/dev/gen/caffe2/test/distributed/_sharded_tensor/sharded_embedding#binary,link-tree/torch/nn/functional.py", line 2153, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: A view was created in no_grad mode and its base or another view of its base has been modified inplace with grad mode enabled. Given that this use case is ambiguous and error-prone, it is forbidden. You can clarify your code by moving both the view and the inplace either both inside the no_grad block (if you don't want the inplace to be tracked) or both outside (if you want the inplace to be tracked).
 exiting process 2 with exit code: 10
```

As a fix, we clone, detach the local shard from the narrow result without using the context manager.
ghstack-source-id: 145773748

Test Plan: CI + Unit test.

Reviewed By: pritamdamania87, wanchaol

Differential Revision: D33000927

fbshipit-source-id: 4d5a93120675e90d4d6d6225a51c4a481d18d159
2021-12-15 17:53:49 -08:00
a406a427ae Revert D33004315: Support torch.equal for ShardedTensor.
Test Plan: revert-hammer

Differential Revision:
D33004315 (1c4c81622c)

Original commit changeset: 786fe26baf82

Original Phabricator Diff: D33004315 (1c4c81622c)

fbshipit-source-id: e1dda70fea656834fdf0f2a9f874415f7b460c6e
2021-12-15 14:14:06 -08:00
1c4c81622c Support torch.equal for ShardedTensor. (#69734)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69734

Added support for `torch.equal` to ShardedTensor. This is really
helpful in terms of comparing two ShardedTensors.

Will implement `allclose` in a follow PR.
ghstack-source-id: 145301451

Test Plan: waitforbuildbot

Reviewed By: fduwjj, wanchaol

Differential Revision: D33004315

fbshipit-source-id: 786fe26baf82e1bb4fecfdbfc9ad4b64e704877f
2021-12-15 13:07:36 -08:00
800a457b6f [shard] add ShardedOptimizer (#68607)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68607

This PR added ShardedOptimizer and a API to get module parameters along with ShardedTensor param, it allows user to use this Optimizer Wrapper to construct a optimizer that involves ShardedTensor

The state_dict support will be a follow up diff
ghstack-source-id: 145532834

Test Plan: python test_sharded_optim.py

Reviewed By: pritamdamania87

Differential Revision: D32539994

fbshipit-source-id: a3313c6870d1f1817fc3e08dc2fc27dc43bef743
2021-12-14 12:15:20 -08:00
603a1de871 Fix inefficient recursive update in ShardedTensor.state_dict hook (#68806)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/68805

The bug is described in the linked issue. This PR is an attempt to make the functions `_recurse_update_dict` and `_recurse_update_module` more efficient in how they iterate over the submodules. The previous implementation was suboptimal, as it recursively called the update method on the submodules returned by `module.named_modules()`, while `module.named_modules()` already returned all submodules including nested ones.

cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang

Pull Request resolved: https://github.com/pytorch/pytorch/pull/68806

Reviewed By: pritamdamania87

Differential Revision: D33053940

Pulled By: wanchaol

fbshipit-source-id: 3e72822f65a641939fec40daef29c806af725df6
2021-12-13 19:22:55 -08:00
5374d5d8c9 [shard] fix with_comms wrapper (#69493)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69493

When added `with_comms` decorator with arguments, we added an `with_comms_decorator` inner function, `with_comms()` will refer to a function object, the added parentheses was necessary to use in test cases.

This PR fixes the `with_comms` wrapper behavior, to allow we both specify with/without arguments in test cases:
```
with_comms
def test_case:
    ...
```
or
```
with_comms(backend="gloo")
def test_case:
    ...
```
ghstack-source-id: 145327066

Test Plan: test_sharded_tensor

Reviewed By: pritamdamania87

Differential Revision: D32897555

fbshipit-source-id: 2f3504630df4f6ad1ea73b8084fb781f21604110
2021-12-10 10:25:54 -08:00
65b0f389d2 [PyTorch][Distributed] Use auto-grad enabled collections for the shared linear op to enable backward grad calculation (#68096)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68096

We replace all c10d APIs with the Auto-grad collection in the shareded linear op. So that we can enable the backward propagation (grad calculation for sharded linear).
ghstack-source-id: 144882914

Test Plan: Unit test + CI

Reviewed By: pritamdamania87

Differential Revision: D32177341

fbshipit-source-id: 1919e8ca877bdc79f4cdb0dc2a82ddaf6881b9f1
2021-12-06 15:17:08 -08:00
c60232d89a [shard] add back init_from_local_shard_and_global_metadata API (#69226)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69226

This add back the previous init_from_local_shards API, but renamed it to init_from_local_shard_and_global_metadata. It's a partial revert of D32147888 (35712a8eb4). We now provide two APIs:
1. `init_from_local_shards`: user don't need to provide global metadata and we do all_gather under the hood, the other that
2. `init_from_local_shards_and_global_metadata`: user need to explicitly construct ShardedTensorMetadata to use this API, need to ensure correctness on all ranks, as there's no cross-rank communication/validations.

All of these two APIs stay private until it stablizes and proof of UX. The second one can only be called on `ShardedTensor` class directly, not included as a package API for now.

Test Plan:
test_init_from_local_shards_and_global_metadata
test_init_from_local_shards_and_global_metadata_invalid_shards

Reviewed By: dstaay-fb, pritamdamania87

Differential Revision: D32746882

fbshipit-source-id: bafd26ce16c02e2095907f9e59984a5d775c7df5
2021-12-02 01:02:56 -08:00
84047ff342 Add API usage logging to ShardedTensor and fix a few tests. (#68771)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68771

ghstack-source-id: 143974518

Test Plan: waitforbuildbot

Reviewed By: fduwjj, wanchaol

Differential Revision: D32601562

fbshipit-source-id: ed624137efab94fbe556609bb40cca14e69d9bac
2021-11-23 13:30:59 -08:00
d6a68e0b8d [PyTorch][3/N] Enable the rest forward spec options for ShardedEmbedding and ShardedEmbeddingBag. (#67799)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67799

We have enabled the sharding embedding and embedding bag in https://github.com/pytorch/pytorch/pull/67188 and https://github.com/pytorch/pytorch/pull/66604. We now want to enable as many parameters as defined in doc as possible: https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding_bag.html, https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html.

For the ones that we don't support we just throw exception.

Last but not least, we use get to get params instead of directly using the key.
ghstack-source-id: 143987066

Test Plan: Unit test & CI

Reviewed By: pritamdamania87

Differential Revision: D31985333

fbshipit-source-id: 3794241b81eecc815bc4390679d0bb0323f4ae72
2021-11-22 20:33:03 -08:00
c41d8290b3 Rename shard_lengths to shard_sizes to be more inline with Tensor sizes. (#66464)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66464

Dimension sizes are referred to as `size` in general in PyTorch and
hence rename shard_lengths to shard_sizes.

#Closes: https://github.com/pytorch/pytorch/issues/65794
ghstack-source-id: 143866449

Test Plan: waitforbuildbot

Reviewed By: fduwjj, wanchaol

Differential Revision: D31564153

fbshipit-source-id: 6273426c4b0e079358806070d0d9644740adb257
2021-11-19 16:30:00 -08:00
35712a8eb4 [reland] simplify init_from_local_shards API (#68021)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68021

reland PR of https://github.com/pytorch/pytorch/pull/64481 as the previous one have some internal failures that didn't get captured when first landed.

This simplifies `init_from_local_shards` API in sharded tensor, to only require user pass in a list of `Shard` and `overall_size`, instead of ShardedTensorMetadata. We will do the all_gather inside to form a valid ShardedTensorMetadata instead.

TODO: add more test cases to improve coverage.
ghstack-source-id: 143661119
ghstack-source-id: 143661119

Test Plan: TestShardedTensorFromLocalShards

Reviewed By: pritamdamania87

Differential Revision: D32147888

fbshipit-source-id: 897128b75224f4b9644471a04a64079f51e0d5fe
2021-11-17 23:20:37 -08:00
2766662ca9 [PyTorch][2/N] Basic implementation of ShardedEmbeddingBag using ShardedTensor. (#67188)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67188

This diff/PR is trying to implement the ShardedEmbeddingBag using the ShardedTensor.

We support both row-wise and column-wise sharding of the embedding bag. The detailed logic can be found in the comment.

Several caveats:
1. Only the sharding of one weight is supported now.
1. We support limited input params for the op. To support more params are on the way.
2. We only support chuck sharding for now.
3. We only support a single local shard per rank for now.

Some other changes include:
1. Refactor the ShardedEmbedding code so that the common logic can be reused.
2. Fix tiny typos and corner cases in API `get_chunked_dim_size`. Where it will return -1 if the we set the dim_size = 5, split_size = 2, idx = 3. (This is a valid case because when chunks = 4, dim_size = 5, then the split_size = 2)
ghstack-source-id: 142325915

Test Plan: Unit test and CI

Reviewed By: pritamdamania87

Differential Revision: D31749458

fbshipit-source-id: ed77e05e4ec94ef1a01b1feda8bbf32dc5d5da1b
2021-11-03 17:39:18 -07:00
ba74b03b0d Back out "[sharded_tensor] simplify init_from_local_shards API"
Summary: Original commit changeset: 6e97d95ffafd

Test Plan: unit test

Reviewed By: wanchaol

Differential Revision: D32023341

fbshipit-source-id: 2a9f7b637c0ff18700bcc3e44466fffcff861698
2021-10-29 14:01:07 -07:00
fc664ac272 [sharded_tensor] easier initialization for Shard (#66351)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66351

This add the ability for use to just provide shard_offsets and optionally rank, to construct a local shard, instead of knowing there's a ShardedMetadata. Under the hood, we will construct the ShardedMetadata by inferring shard_lengths and device from the local tensor.
ghstack-source-id: 141742410

Test Plan: test_local_shards

Reviewed By: pritamdamania87

Differential Revision: D31519919

fbshipit-source-id: 8f3b4682ffc74b79b41076f3f4b832f4cacda49d
2021-10-27 22:20:37 -07:00
71a67d0ce9 [sharded_tensor] simplify init_from_local_shards API (#64481)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64481

This simplifies `init_from_local_shards` API in sharded tensor, to only require user pass in a list of `Shard` and `overall_size`, instead of ShardedTensorMetadata. We will do the all_gather inside to form a valid ShardedTensorMetadata instead.

TODO: add more test cases to improve coverage.
ghstack-source-id: 141742350

Test Plan: TestShardedTensorFromLocalShards

Reviewed By: pritamdamania87

Differential Revision: D30748504

fbshipit-source-id: 6e97d95ffafde6b5f3970e2c2ba33b76cabd8d8a
2021-10-27 22:19:20 -07:00
31bcfa3760 [sharded_tensor] refactor sharded_tensor file structure (#67199)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67199

This PR refactors _sharded_tensor package so that it splits from api.py, and add different components to make it more modularized, this will also help us resolve circular dependency due to increasing code size and better organize the package:

* api.py: sharded tensor APIs
* metadata.py: Metadata definition for ShardedTensors
* shard.py: Shard definition for ShardedTensor
* utils.py: utility functions for validation, etc.
ghstack-source-id: 141533618

Test Plan: test_sharded_tensor.py

Reviewed By: pritamdamania87

Differential Revision: D31904249

fbshipit-source-id: c747d96e131a1d4731991ec4ac090f639dcb369b
2021-10-26 00:36:23 -07:00
3596e13d45 Add torch.nn.init.normal_ and torch.nn.init.kaiming_uniform_ ops to ShardedTensor (#67057)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67057

Extend ShardedTensor with torch.nn.init.[normal_, and kaiming_uniform_] ops
Follow up from https://github.com/pytorch/pytorch/pull/63997

Test Plan:
a) Unit Test
(pytorch) ... $ python test/distributed/_sharded_tensor/ops/test_init.py TestShardedTensorNNInit --v

or b) Manual run: Instruction here: https://docs.google.com/document/d/1_m1Hdo5w51-hhPlZ_F8Y6PIWrN7UgJZqiSpARYvhsaE/edit#
s/uniform_/normal_ or kaiming_uniform_

Imported from OSS

Reviewed By: pritamdamania87

Differential Revision: D31845654

fbshipit-source-id: e7aedc0972539da59f7b84bbbf617caf6b206d52
2021-10-25 19:14:30 -07:00
b6df043f1f Add torch.nn.init.uniform_ operator to ShardedTensor. (#63997)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63997

Use torch_function to extend torch.nn.init.uniform_
The Init is done in SPMD fashion. Note that ideally we want to aggregate sharded tensors into a global tensor, init it and reshard. It's fine to run it SPMD since uniform is I.I.D indepenent and identifically distributed.
Also enable unit test for test_linear.py for OSS test

Test Plan:
a) Unit Test
(pytorch) ... $ python test/distributed/_sharded_tensor/ops/test_init.py TestShardedTensorNNInit --v
(pytorch) ... $ python test/distributed/_sharded_tensor/ops/test_linear.py --v (before runs this command is no-op)

or b) Manual run: Instruction here: https://docs.google.com/document/d/1_m1Hdo5w51-hhPlZ_F8Y6PIWrN7UgJZqiSpARYvhsaE/edit#

Imported from OSS

Reviewed By: pritamdamania87, anjali411

Differential Revision: D30563017

fbshipit-source-id: d1859f7682235bcb44515efc69ca92bc5e34fce1
2021-10-21 00:17:13 -07:00
08cb31a03e [PyTorch][1/N] Basic implementation of ShardedEmbedding using ShardedTensor. (#66604)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66604

This diff/PR is trying to implement the ShardedEmbedding and ShardedEmbedding using the ShardedTensor.

Several caveats:
1. We support limited input params for the op. To support more params are on the way.
2. We only support chuck sharding for now.
3. We only support a single local shard per rank for now.

ghstack-source-id: 141056130

Test Plan: Unit test and CI

Reviewed By: pritamdamania87

Differential Revision: D31544556

fbshipit-source-id: cc867dcba8c11e6f4c7c3722488908f5108cc67f
2021-10-20 15:16:49 -07:00
14ee608791 [PyTorch] Make rearragement in sharded linear work as expected. (#66603)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66603

Found the issue here: https://github.com/pytorch/pytorch/issues/66281 by make the test cases more complicated.

By closely reading the code again, it turns out my original understanding is also wrong. Let's use the example mentioned in the issue to explain:

If the placement is like:
```
"rank:3/cuda:3",
"rank:0/cuda:0",
"rank:1/cuda:1",
"rank:2/cuda:2",
```

First, we split the column or row by the order of [3, 0, 1, 2].

In the case of column-wise sharding:
We get to reaggrage the result from rank0-4.
Step 1: we split the output based on the original sharding strategy, aka, rank3 gets the 1st shard, rank0 get the 2nd shard, etc.
Step 2: we need to rearrange the result from rank0-4 by ordering them following the order of [3, 0, 1, 2], aka, the result from rank3 needs to be put in the front, and so forth.

In the case of row-wise sharding:
We need to rearrange the input being sent to rank0-4.
Step 1: we reorder the input and follow the map of [3, 0, 1, 2]. For example, the first shard goes to rank 3 so we need to put in the 3rd part, the second shard goes to rank 0, so we put it in the 2nd part, and so on.
Step 2: the size of the sharding for each rank is decided by the original placement: [3, 0, 1, 2], aka, rank 3 gets the first shard and its size, etc.

Update the unit test to reflect this change.

Also, correct some format and comments in the sharded linear.
ghstack-source-id: 141055689

Test Plan: unit test and wait for CI.

Reviewed By: pritamdamania87, bowangbj

Differential Revision: D31634590

fbshipit-source-id: 677a9c2b42da1e2c63220523ed2c004565bbecc7
2021-10-19 23:16:38 -07:00
bc6935ddf5 [PyTorch][Distributed][Easy] Make ShardedTensor.size() equivalent to torch.Tensor.size() (#65087) (#66012)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66012

Test Plan: Imported from OSS

Reviewed By: pritamdamania87

Differential Revision: D31345161

Pulled By: fduwjj

fbshipit-source-id: 10d6b65780ab0c6934babcc7c36a181cb66f0b7c
2021-10-12 22:26:22 -07:00
3cc40253d9 add gather to ShardedTensor (#65671)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65671

Tentative implementation to use dist.gather_object to collect shards from all ranks and then "merge" them. The merge is done on dst_rank though padding the sharded tensors into the size of full tensor based on their metadata (offsets, lengths) first, and then summing these padded tensors together.

Also considered concatenating sharded tensor without padding to minimize memory footprint (assuming padding will increase memory). But it may not be flexible enough for arbitrary sharing (e.g. shard on multiple directions)

Another way can be constructing the padded tensor on each rank and reduce to rank0. I feel this is the most easy implementation. But it will invoke higher memory usage and comm payload. Please let me know if this alternative is preferred.

cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang gcramer23

Test Plan:
Imported from OSS

  python test/distributed/_sharded_tensor/test_sharded_tensor.py -v -k test_gather

did not manage to test on oss, but tested in fbcode by reserving on demand gpu

  arc patch D31197611

modify the test with 2 gpus as on-demand gpu only has 2 cores (D31227986)

   buck test -c fbcode.enable_gpu_sections=true mode/dev-nosan caffe2/test/distributed/_sharded_tensor:sharded_tensor -- test_gather

   buck-out/gen/caffe2/test/distributed/_sharded_tensor/sharded_tensor#binary.par  test_sharded_tensor.TestShardedTensorChunked.test_gather

{F667213605}

Reviewed By: dagitses, pritamdamania87

Differential Revision: D31197611

Pulled By: dracifer

fbshipit-source-id: cf98b4a2d7838b11b9582eb23f826bb0fa38a7f4
2021-10-07 13:01:12 -07:00
768cfaa8f8 fix typo in _sharded_tensor (#65511)
Summary:
per title

cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang gcramer23

Pull Request resolved: https://github.com/pytorch/pytorch/pull/65511

Reviewed By: albanD

Differential Revision: D31239269

Pulled By: cbalioglu

fbshipit-source-id: 602c0bf7ef96a930606d68b15a5b3cadda9d9437
2021-09-29 18:00:47 -07:00
0dc98728bc Basic implementation of ShardedLinear using ShardedTensor. (#64128)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64128

This PR implements a sharded nn.Linear layer using ShardedTensors with
the following limitations:

1) Works only for ChunkShardingSpec.
2) Implementation is only aimed to demonstrate functionality and is most likely
not performant at all.

The PR also introduces a `shard_parameter` API to easily shard parameters of
`nn.Modules`. This also has the following limitations:

1) Works only for ChunkShardingSpec.
2) Is not performant since it uses broadcast instead of scatter since
ProcessGroupNCCL doesn't yet support scatter.

Overall user API for running a sharded linear would be something like this:

```
# SPMD programming paradigm running same code on all nodes.
fc = nn.Linear(10, 10)

# Setup sharding.
sharding_spec=ChunkShardingSpec(...)
shard_parameter(fc, 'weight', sharding_spec, src_rank=0)

# Run as a normal linear layer.
inp = torch.rand(10, 10)
output = fc(inp)
```
ghstack-source-id: 138500985

Test Plan:
1) unit tests.
2) waitforbuildbot

Reviewed By: wanchaol, bowangbj

Differential Revision: D30621215

fbshipit-source-id: 1aa7478568c18a4572f6c3462fdf24a4cbde01d6
2021-09-20 18:31:11 -07:00
473e55d5b2 Use classmethods for overrides (#64841)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64841

Test Plan: Imported from OSS

Reviewed By: heitorschueroff

Differential Revision: D30991424

Pulled By: albanD

fbshipit-source-id: 551e2119768f3a4292713f3bfa83930f5506adbd
2021-09-17 08:32:49 -07:00
a87808de93 Fix bug in ShardedTensorMetadata serde. (#63902)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63902

The 'memory_format' field was not being serialized correctly and used
the same encoding for different fields.
ghstack-source-id: 137142406

Test Plan: waitforbuildbot

Reviewed By: bowangbj

Differential Revision: D30527324

fbshipit-source-id: f0f223e2d660ef6e4abae9649d9992acc36e1278
2021-08-31 20:31:14 -07:00
49353e319c More sharded_tensor creation ops: harded_tensor.zeros, sharded_tensor.full, sharded_tensor.rand (#63732)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63732

Test Plan:
$ python test/distributed/_sharded_tensor/test_sharded_tensor.py  --v

$ python test/distributed/_sharded_tensor/test_sharded_tensor.py TestCreateTensorFromParams --v
$ python test/distributed/_sharded_tensor/test_sharded_tensor.py TestShardedTensorChunked --v

Imported from OSS

Differential Revision:
D30472621
D30472621

Reviewed By: pritamdamania87

Pulled By: bowangbj

fbshipit-source-id: fd8ebf9b815fdc292ad1aad521f9f4f454163d0e
2021-08-26 16:01:38 -07:00
835dac0869 Merge common fields from TensorInitParams and ShardedTensorMetadata into TensorProperties (#63731)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63731
1) Follow up [PR/63378 last comment](https://github.com/pytorch/pytorch/pull/63378#discussion_r693143053)
2) Also updated the caller side (usage of ShardedTensorMetadta) in fbcode

Ref: [landing workflow 3](https://www.internalfb.com/intern/wiki/PyTorch/PyTorchDev/Workflow/Landing/#landing-your-prs-from-gi-1)

Test Plan:
Imported from OSS

OSS: (pytorch).. $ python test/distributed/_sharded_tensor/test_sharded_tensor.py --v
FB:  fbcode $ buck test mode/dev //aiplatform/modelstore/checkpointing/pyper/tests:checkpoint_utils_test

Reviewed By: wanchaol, heitorschueroff

Differential Revision: D30472281

fbshipit-source-id: 727fb0e7f10eab4eb7a10476194e9008f2ac1fb5
2021-08-24 11:49:06 -07:00
8871ff29b7 [sharded_tensor] add readonly tensor properties (#63679)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63679

This PR add read only tensor properties to sharded tensor, to match the torch.Tensor behaviors.

Test Plan: test_sharded_tensor_metadata

Reviewed By: pritamdamania87

Differential Revision: D30459343

fbshipit-source-id: 9aec8ecfe76479eed25f3b843495e5719ed2956d
2021-08-20 22:17:11 -07:00
3ee1f81dce Extend _sharded_tensor constructor to support other ops like torch.ones (#63378)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63378

a) Introduce InitCommonParams to wrap tensor creation params
b) Factor local tensor initiation into common_params so that tensor value is not hard specified in ShardedTensor constructor
c) Add _sharded_tensor.ones(...) to exemplify - Note memory_format arg is not provided to be consistent as torch.ones
d) Follow up: more ops like torch.full, torch.zero, torch.rand,

Test:
$ python test/distributed/_sharded_tensor/test_sharded_tensor.py TestCreateTensorFromParams --v
$ python test/distributed/_sharded_tensor/test_sharded_tensor.py TestShardedTensorChunked.test_create_sharded_tensor_with_ones --v
$ python test/distributed/_sharded_tensor/test_sharded_tensor.py TestShardedTensorEnumerable.test_create_sharded_tensor_with_ones --v

Test Plan: Imported from OSS

Reviewed By: pritamdamania87, wanchaol

Differential Revision: D30359245

Pulled By: bowangbj

fbshipit-source-id: 85768fcb36e9d9d40213036884b1266930a91701
2021-08-20 17:11:34 -07:00
b8e6144e0a Add a _RemoteDevice structure for ShardedTensor/ShardingSpec. (#62927)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62927

As part of the ShardedTensor work, we realized we do need some sort of
_RemoteDevice structure that deals with our format of "workername/device" so
that users don't have to worry about parsing this string directly.

Right now this structure is just the bare minimum and is mostly a container for
describing a remote device. It is currently only used in ShardedTensor,
ShardingSpec and RemoteModule.

Once we actually have a consolidated remote device proposal, this class can be
extended appropriately if needed.
ghstack-source-id: 135534086

Test Plan:
1) unit tests
2) waitforbuildbot

Reviewed By: SciPioneer

Differential Revision: D30170689

fbshipit-source-id: 1ac2e81c7a597dc40bf3fbf2c1168c382c66649f
2021-08-11 11:27:32 -07:00