mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: uncomment lines in the pyrefly.toml file step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/bb31574ac8a59893c9cf52189e67bb2d after: 0 errors (1,970 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164588 Approved by: https://github.com/oulgen
98 lines
3.3 KiB
Python
98 lines
3.3 KiB
Python
from typing import Union
|
|
|
|
import torch
|
|
import torch.utils._pytree as pytree
|
|
from torch.export.exported_program import ExportedProgram
|
|
|
|
|
|
__all__ = ["move_to_device_pass"]
|
|
|
|
|
|
def move_to_device_pass(
|
|
ep: ExportedProgram, location: Union[torch.device, str, dict[str, str]]
|
|
) -> ExportedProgram:
|
|
"""
|
|
Move the exported program to the given device.
|
|
|
|
Args:
|
|
ep (ExportedProgram): The exported program to move.
|
|
location (Union[torch.device, str, Dict[str, str]]): The device to move the exported program to.
|
|
If a string, it is interpreted as a device name.
|
|
If a dict, it is interpreted as a mapping from
|
|
the existing device to the intended one
|
|
|
|
Returns:
|
|
ExportedProgram: The moved exported program.
|
|
"""
|
|
|
|
def _get_new_device(
|
|
curr_device: torch.device,
|
|
location: Union[torch.device, str, dict[str, str]],
|
|
) -> str:
|
|
if isinstance(location, dict):
|
|
if str(curr_device) in location.keys():
|
|
return location[str(curr_device)]
|
|
else:
|
|
return str(curr_device)
|
|
else:
|
|
return str(location)
|
|
|
|
# move all the state_dict
|
|
for k, v in ep.state_dict.items():
|
|
if isinstance(v, torch.nn.Parameter):
|
|
ep._state_dict[k] = torch.nn.Parameter(
|
|
v.to(_get_new_device(v.device, location)),
|
|
v.requires_grad,
|
|
)
|
|
else:
|
|
ep._state_dict[k] = v.to(_get_new_device(v.device, location))
|
|
|
|
# move all the constants
|
|
for k, v in ep.constants.items():
|
|
if isinstance(v, torch.Tensor):
|
|
ep._constants[k] = v.to(_get_new_device(v.device, location))
|
|
|
|
# move example_inputs if they exist
|
|
if ep.example_inputs is not None:
|
|
args, kwargs = ep.example_inputs
|
|
moved_args = pytree.tree_map_only(
|
|
torch.Tensor,
|
|
lambda tensor: tensor.to(_get_new_device(tensor.device, location)),
|
|
args,
|
|
)
|
|
moved_kwargs = pytree.tree_map_only(
|
|
torch.Tensor,
|
|
lambda tensor: tensor.to(_get_new_device(tensor.device, location)),
|
|
kwargs,
|
|
)
|
|
ep._example_inputs = (moved_args, moved_kwargs)
|
|
|
|
for m in ep.graph_module.modules():
|
|
if isinstance(m, torch.fx.GraphModule):
|
|
for node in m.graph.nodes:
|
|
# move all the nodes kwargs with burnt-in device
|
|
if "device" in node.kwargs:
|
|
kwargs = node.kwargs.copy()
|
|
kwargs["device"] = _get_new_device(kwargs["device"], location)
|
|
node.kwargs = kwargs
|
|
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.aten.to.device
|
|
):
|
|
args = list(node.args)
|
|
# pyrefly: ignore # unsupported-operation
|
|
args[1] = _get_new_device(args[1], location)
|
|
node.args = tuple(args)
|
|
|
|
# move all the tensor metadata
|
|
node.meta["val"] = pytree.tree_map(
|
|
lambda v: v.to(_get_new_device(v.device, location))
|
|
if isinstance(v, torch.Tensor)
|
|
else v,
|
|
node.meta.get("val"),
|
|
)
|
|
|
|
ep.validate()
|
|
return ep
|