Summary: this code is a bit intricate so i refactor it Pull Request resolved: https://github.com/pytorch/pytorch/pull/16995 Differential Revision: D14050667 Pulled By: ifedan fbshipit-source-id: 55452339c6518166f3d4bc9898b1fe2f28601dc4