Files
pytorch/torch/_dynamo/graph_utils.py
xinan.lin e93706c2c8 [Intel GPU][pre_compile] Add XPU toolkit version and hardware info in compiled model check. (#162951)
Following #162438, this PR generalized the origin CUDA only check, and add XPU check.

Fixes #162939, Fixes #162938, Fixes #163032,Fixes #163045

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162951
Approved by: https://github.com/EikanWang, https://github.com/jansel
2025-09-18 00:04:22 +00:00

117 lines
3.5 KiB
Python

from collections import deque
from typing import Any, Optional
import torch
from torch.fx import Graph, map_arg, Node
from torch.utils._ordered_set import OrderedSet
from torch.utils._pytree import tree_flatten
# flattens with support for slices
# Note: a better way to do this would
# be register/unregister slices as pytree nodes
# but there is no unregister API in the pytorch
# pytree impl
def _get_flat_args(
node: Node, node_to_additional_deps: dict[Node, OrderedSet[Node]]
) -> list[Node]:
args = list[Any]()
map_arg((node.args, node.kwargs), args.append)
if node in node_to_additional_deps:
args.extend(node_to_additional_deps[node])
return args
def _get_flat_args_unique(
node: Node, node_to_additional_deps: dict[Node, OrderedSet[Node]]
) -> OrderedSet[Node]:
args = OrderedSet[Node]()
map_arg((node.args, node.kwargs), args.add)
if node in node_to_additional_deps:
args.update(node_to_additional_deps[node])
return args
def _detect_cycles(
graph: Graph, node_to_additional_deps: dict[Node, OrderedSet[Node]]
) -> str:
current_path: deque[Node] = deque()
current_path_set: set[Node] = set()
pending: deque[tuple[Node, Node]] = deque()
def add_to_current_path(node: Node) -> None:
current_path.append(node)
current_path_set.add(node)
def pop_current_path() -> None:
node = current_path.pop()
current_path_set.remove(node)
def current_path_head() -> Node:
return current_path[-1]
for origin in graph.find_nodes(op="output"):
current_path.clear()
current_path_set.clear()
add_to_current_path(origin)
for child in _get_flat_args_unique(origin, node_to_additional_deps):
pending.append((child, origin))
while pending:
cur_node, parent = pending.pop()
# handle backtracking
while current_path and current_path_head() != parent:
pop_current_path()
if not isinstance(cur_node, Node):
continue
if cur_node in current_path_set:
current_path.append(cur_node)
return f"cycle detected in path: {current_path}"
add_to_current_path(cur_node)
for child in _get_flat_args_unique(cur_node, node_to_additional_deps):
pending.append((child, cur_node))
return "no cycle detected"
def _graph_device_type(graph: Optional[Graph]) -> str:
if graph is None:
return "cpu"
def _device_type(x: Any) -> str:
if isinstance(x, torch.device):
return x.type
if isinstance(x, torch.Tensor):
return x.device.type
return "cpu"
def _flatten_meta(node: Node, key: str) -> list[Any]:
if key not in node.meta:
return []
flat, _ = tree_flatten(node.meta[key])
return flat
for node in graph.nodes:
for key in ("val", "example_value"):
for obj in _flatten_meta(node, key):
return _device_type(obj)
# Check for device conversions
if node.op == "call_method":
for gpu in ["cuda", "xpu"]:
if node.target == gpu:
return gpu
if node.target == "to" and gpu in node.args:
return gpu
# Check args/kwargs for non-CPU device specs
flat_args, _ = tree_flatten((node.args, node.kwargs))
for obj in flat_args:
return _device_type(obj)
return "cpu"