1. https://github.com/pytorch/pytorch/pull/164111/ adds the support of splitting BlockMask. But BlockMask actually has B=1 case that the BlockMask will be broadcast. This PR adds the support of B=1 case.
2. The original split_args_kwargs_into_chunks doesn't initialize the default specs correctly. Since we now use tree_flatten and tree_unflatten to do split, we should also use tree_map to initialize the default spec. This will actually support the case when the values are not torch.Tensor, which were only supported if users explicitly provide the shard spec.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165306
Approved by: https://github.com/H-Huang
BlockMask has batch dimension information. So PP has to split it as well just like all other tensors. All the tensors in BlockMask have the batch dimension, so we can just split it without too many issues. However, `mask_mod` requires the batch index as the input, which the value is going to be changed after the split. So we have to wrap it inside a closure to modify the batch index.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164111
Approved by: https://github.com/H-Huang
* Automatically applies ruff rule 401. Turns loops into equivalent list comprehensions which are faster and do not leak the scope of the loop variables.
* list comprehensions not only often have better typing, but are 50+% faster than for loops on overhead. They also preserve length information etc and are better for the interpreter to optimize.
* Manually went back and made mypy happy after the change.
* Also fixed style lints in files covered by flake8 but not by pyfmt
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140980
Approved by: https://github.com/justinchuby, https://github.com/malfet
Part of #134054.
This corresponds to the pytorch mypy changes from D61493706. Updating takes so
long and touches so many files that it's impossible to land as a whole without conflicting with some other intermediate change.
So landing these 'type: ignore' for pytorch in advance of them actually being needed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134202
Approved by: https://github.com/Skylion007
Changed the API of `pipeline()` to take microbatch instead of full batch as example args.
Main purpose is to:
- make this API more atomic;
- decouple tracing frontend from runtime info like `num_chunks`.
Side effects:
- Creates opportunity for varying `num_chunks` of schedules with the same `pipe` object.
- User has to create example microbatch input.
- Chunk spec stuff are now all moved to runtime side.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128163
Approved by: https://github.com/H-Huang