## What's the problem?
The popular `fx.node.map_arg()` and `fx.node.map_aggregate()` apply operations recursively on `dict`s, `tuples`, `list`s, etc, and return a new collection of the same type.
Unfortunately, their base input type is `Argument`, which is [very unspecific indeed](5d55a6585d/torch/fx/node.py (L48-L58)): most type information is just thrown away at the call site of either of these functions, as far as the type checker goes.
As `torch` moves to a more typed code base, this would force innocent, unsuspecting developers to add logically unnecessary casts or `# type: ignore` statements.
## What's the solution?
Making these two `node.map_*` functions generic on the first argument and return type means that type information is preserved for the type checker. (The signature of the other parameter, the function that visits the nodes and subnodes, has not changed, nor should it.)
## Won't it break everything?
It doesn't break the type checker - one place needed an extra hint.
There have been code breakages, resolved one, at least one new one... we'll see!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146248
Approved by: https://github.com/XuehaiPan, https://github.com/Skylion007