mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/134438 Approved by: https://github.com/zou3519
56 lines
1.8 KiB
Python
56 lines
1.8 KiB
Python
# mypy: allow-untyped-defs
|
|
|
|
import torch._C._lazy
|
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
|
|
from .closure import add_step_closure, run_step_closures
|
|
|
|
|
|
def mark_step(device: str = "", wait=False):
|
|
"""Triggers a mark step, which amounts to
|
|
- collecting a group of 'live' lazy tensors to index into the compilation cache
|
|
(lowering/compiling their IR graphs if not cached)
|
|
- kicking off execution of the compiled function
|
|
- (optionally, wait=True) waiting for cpu-side execution to complete (does not sync the accelerator)
|
|
"""
|
|
# TODO(whc) expand this to include backend hooks and align with XLA backend needs
|
|
torch._C._lazy._mark_step(device, [], wait=wait)
|
|
|
|
run_step_closures()
|
|
|
|
|
|
def wait_device_ops(devices=None):
|
|
"""Waits for all the async operations on the given devices to complete.
|
|
Args:
|
|
devices (string..., optional): The devices whose async ops need to be waited
|
|
for. If empty, all the local devices will be waited for.
|
|
"""
|
|
if devices is None:
|
|
devices = []
|
|
torch._C._lazy._wait_device_ops(devices=devices)
|
|
|
|
|
|
def sync_multi(tensors, devices):
|
|
"""
|
|
Sync the list of lazy tensors so there IR get lowered for the activate backend
|
|
and the compiled computation graph get cached.
|
|
"""
|
|
torch._C._lazy._sync_multi(tensors, devices)
|
|
|
|
|
|
def get_tensor_id(tensor):
|
|
"""Return a unique id of the lazy tensor maintained by LTC"""
|
|
return torch._C._lazy._get_tensor_id(tensor)
|
|
|
|
|
|
def to_cpu(tensors, devices=None):
|
|
devices = devices or ["lazy"]
|
|
|
|
flattened, spec = tree_flatten(tensors)
|
|
sync_multi(flattened, devices)
|
|
return tree_unflatten([t.to("cpu") for t in flattened], spec)
|
|
|
|
|
|
def save(tensors, *args, **kwargs):
|
|
torch.save(to_cpu(tensors), *args, **kwargs)
|