This is relanding the troubling part of #95009 that caused a regression.
BC: This changes the signature and semantics of DeviceMesh::all_reduce.
DeviceMesh::all_reduce now uses a functional collective under the hood which makes it more easily traceable.
You no longer need to use CommTensor to get a trace.
all_reduce now is async only and uses AsyncCollectiveTensor to ensure proper stream synchronization.
Signature changed: removed async_op param and changes return type from Optional[Work] to torch.Tensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95804
Approved by: https://github.com/fegin
BC: This changes the signature and semantics of DeviceMesh::all_reduce.
DeviceMesh::all_reduce now uses a functional collective under the hood which makes it more easily traceable.
You no longer need to use CommTensor to get a trace.
all_reduce now is async only and uses AsyncCollectiveTensor to ensure proper stream synchronization.
Signature changed: removed `async_op` param and changes return type from `Optional[Work]` to `torch.Tensor`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95009
Approved by: https://github.com/wanchaol
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
This PR changes the op registration to a better mechanism, now
we require the directly overload registration instead of the op
key str, this have several benefits:
1. We ensure that the op registration registers the correct op, which
means it would be faild if the op registration become wrong (this PR
already fixing several op registration errors as we use direct
OpOverload registration
2. If the overload name get changed/deleted, we immediately know it at
the source code compilation level, which is safer
3. This also keep it consistents with the op registration mechanism with
other tensor subclasses within PyTorch
Differential Revision: [D42876250](https://our.internmc.facebook.com/intern/diff/D42876250)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90735
Approved by: https://github.com/XilunWu, https://github.com/fduwjj
This PR refactors the threaded PG logic to enable multiple sub pg
creation under the world threaded pg, and allow the case where
we can call collectives together on different subpgs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91649
Approved by: https://github.com/XilunWu
Continuation after https://github.com/pytorch/pytorch/pull/90163.
Here is a script I used to find all the non-existing arguments in the docstrings (the script can give false positives in presence of *args/**kwargs or decorators):
_Edit:_
I've realized that the indentation is wrong for the last `break` in the script, so the script only gives output for a function if the first docstring argument is wrong. I'll create a separate PR if I find more issues with corrected script.
``` python
import ast
import os
import docstring_parser
for root, dirs, files in os.walk('.'):
for name in files:
if root.startswith("./.git/") or root.startswith("./third_party/"):
continue
if name.endswith(".py"):
full_name = os.path.join(root, name)
with open(full_name, "r") as source:
tree = ast.parse(source.read())
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
all_node_args = node.args.args
if node.args.vararg is not None:
all_node_args.append(node.args.vararg)
if node.args.kwarg is not None:
all_node_args.append(node.args.kwarg)
if node.args.posonlyargs is not None:
all_node_args.extend(node.args.posonlyargs)
if node.args.kwonlyargs is not None:
all_node_args.extend(node.args.kwonlyargs)
args = [a.arg for a in all_node_args]
docstring = docstring_parser.parse(ast.get_docstring(node))
doc_args = [a.arg_name for a in docstring.params]
clean_doc_args = []
for a in doc_args:
clean_a = ""
for c in a.split()[0]:
if c.isalnum() or c == '_':
clean_a += c
if clean_a:
clean_doc_args.append(clean_a)
doc_args = clean_doc_args
for a in doc_args:
if a not in args:
print(full_name, node.lineno, args, doc_args)
break
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90505
Approved by: https://github.com/malfet, https://github.com/ZainRizvi
Observed by @aazzolini, some op might have Optional[Tensor] returns
where it return None (i.e. native_layer_norm_backward), it's a mismatch
between C++ aten op signature and python None, but we need to handle it
in the python side
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90241
Approved by: https://github.com/aazzolini
This PR get rids of torchgen FunctionSchema parsing and parse
it manually, it should resolve torchgen package issue and also
provide some perf wins when running DTensor eagerly
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90106
Approved by: https://github.com/awgu