mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-27 17:54:55 +08:00
Update base for Update on "[inductor] do comm compute overlap at aten fx level"
This is first part of the stack that does comm/compute reordering, and then uses the exposure analysis to do bucketing. Subsequent prs will handle: - use of exposure analysis to do bucketing - make sure inductor respects comm/compute overlapping done at fx level - non-profiling mm estimation/rank broadcasting of profile results Other mis: - Validate accuracy of nccl estimations ( use ruisi's profiling instead ?) For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives. fwd example: https://gist.github.com/eellison/76209c49d8829c5f1e323d34a3f040c3 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
This commit is contained in: