[Dynamo] Add XPU API to trace_rules (#155788)

# Motivation
- Add binding API and non-bindling API to trace rules for XPU;
- Add some XPU API to the const fold function for Dynamo capture.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155788
Approved by: https://github.com/jansel, https://github.com/EikanWang
ghstack dependencies: #155787
This commit is contained in:
Yu, Guangye
2025-06-13 14:15:23 +00:00
committed by PyTorch MergeBot
parent 69acba2b19
commit b51d803785
2 changed files with 59 additions and 0 deletions

View File

@ -350,6 +350,8 @@ manual_torch_name_rule_map: dict[str, Any] = {
"torch.sparse_csr_tensor": SkipFunctionVariable,
"torch.sparse_compressed_tensor": SkipFunctionVariable,
"torch._C._autograd._unsafe_set_version_counter": TorchInGraphFunctionVariable,
"torch.xpu.get_rng_state": SkipFunctionVariable,
"torch.xpu.set_rng_state": SkipFunctionVariable,
# avoid skipping user defined modules in distributed unit tests
"torch/testing/_internal/common_fsdp.py#forward": UserFunctionVariable,
f"torch/testing/_internal/common_fsdp.py#{TORCH_DYNAMO_RESUME_IN_PREFIX}": UserFunctionVariable,
@ -1343,6 +1345,21 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._C._warn",
"torch._C._will_engine_execute_node",
"torch._C._wrap_tensor_impl",
"torch._C._xpu_emptyCache",
"torch._C._xpu_getArchFlags",
"torch._C._xpu_getCurrentStream",
"torch._C._xpu_getCurrentRawStream",
"torch._C._xpu_getDeviceCount",
"torch._C._xpu_getDevice",
"torch._C._xpu_getMemoryInfo",
"torch._C._xpu_getStreamFromExternal",
"torch._C._xpu_isInBadFork",
"torch._C._xpu_init",
"torch._C._xpu_memoryStats",
"torch._C._xpu_resetAccumulatedMemoryStats",
"torch._C._xpu_resetPeakMemoryStats",
"torch._C._xpu_setStream",
"torch._C._xpu_synchronize",
"torch._C.fork",
"torch._C.get_autocast_cpu_dtype",
"torch._C.get_autocast_dtype",
@ -2265,6 +2282,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch.slice_inverse",
"torch._assert_scalar",
"torch._functional_assert_scalar",
"torch.xpu._get_device_properties",
],
TorchInGraphFunctionVariable,
)
@ -2880,6 +2898,43 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
"torch.tensordot",
"torch.unique_consecutive",
"torch.use_deterministic_algorithms",
"torch.xpu._get_device",
"torch.xpu._get_generator",
"torch.xpu._get_rng_state_offset",
"torch.xpu._is_compiled",
"torch.xpu._lazy_call",
"torch.xpu._lazy_init",
"torch.xpu._set_rng_state_offset",
"torch.xpu._set_stream_by_id",
"torch.xpu._utils._get_device_index",
"torch.xpu.current_device",
"torch.xpu.current_stream",
"torch.xpu.device_count",
"torch.xpu.get_arch_list",
"torch.xpu.get_device_capability",
"torch.xpu.get_device_name",
"torch.xpu.get_device_properties",
"torch.xpu.get_gencode_flags",
"torch.xpu.get_stream_from_external",
"torch.xpu.init",
"torch.xpu.is_available",
"torch.xpu.is_bf16_supported",
"torch.xpu.is_initialized",
"torch.xpu.memory.empty_cache",
"torch.xpu.memory.max_memory_allocated",
"torch.xpu.memory.max_memory_reserved",
"torch.xpu.memory.mem_get_info",
"torch.xpu.memory.memory_allocated",
"torch.xpu.memory.memory_reserved",
"torch.xpu.memory.memory_stats_as_nested_dict",
"torch.xpu.memory.memory_stats",
"torch.xpu.memory.reset_accumulated_memory_stats",
"torch.xpu.memory.reset_peak_memory_stats",
"torch.xpu.random.initial_seed",
"torch.xpu.random.seed_all",
"torch.xpu.random.seed",
"torch.xpu.set_stream",
"torch.xpu.synchronize",
],
TorchInGraphFunctionVariable,
)

View File

@ -134,6 +134,8 @@ REWRITE_OPS_TO_TENSOR_SIZE_METHOD = dict.fromkeys(
constant_fold_functions_need_guards = [
torch.cuda.current_device,
torch.cuda.is_initialized,
torch.xpu.current_device,
torch.xpu.is_initialized,
]
constant_fold_functions = [
@ -156,6 +158,8 @@ constant_fold_functions = [
torch.promote_types,
torch._C._get_privateuse1_backend_name,
torch.autograd._is_checkpoint_valid,
torch.xpu.get_device_properties,
torch.xpu.is_available,
] + constant_fold_functions_need_guards
if torch.distributed.is_available():
constant_fold_functions.extend(