deepcompile: Create a full list of no-copy ops (#7562)

The list of torch no-copy ops is hard coded and does not include all
operations that may aliasing inputs in their outputs.

Instead of using a fixed list, iterate over all ops under torch.ops.aten
and identify those with aliasing behavior by inspecting their schema.

With PyTorch 2.7.1, the default overload of ops identified by the
updated logic include:

  - _nested_view_from_buffer
  - _reshape_alias
  - alias
  - as_strided
  - conj
  - detach
  - diagonal
  - expand
  - imag
  - lift_fresh
  - narrow
  - permute
  - pin_memory
  - positive
  - real
  - reshape
  - squeeze
  - t
  - unfold
  - unsqueeze
  - view
  - view_as_complex
  - view_as_real
  - most operations whose name ends with an underscore

Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
This commit is contained in:
Junjie Mao
2025-09-17 00:05:11 +08:00
committed by GitHub
parent e9d5d416cc
commit 2d84be8159

View File

@ -76,13 +76,46 @@ def log_rank0(msg: str, enable: bool = False):
print(msg) print(msg)
@functools.lru_cache
def get_no_copy_ops(): def get_no_copy_ops():
# Need to compile custom ops # Need to compile custom ops
get_deepcompile_handle() get_deepcompile_handle()
return {
torch.ops.aten.t.default, torch.ops.aten.view.default, torch.ops.aten.detach.default, no_copy_ops = {torch.ops.dc.wait_allgather.default}
torch.ops.aten.permute.default, torch.ops.dc.wait_allgather.default
} # All operations whose return value aliases any of their inputs are included
# in the returned list to ensure that the last user of a node is computed
# correctly.
#
# This can be overly conservative if not all input tensors are aliased in
# the output. While we can determine exactly which tensors are aliased, a
# finer-grained algorithm is required in get_last_uses() and get_real_uses()
# to utilize that information. This is left as future work when real needs
# arise.
warned = False
for op_name in torch.ops.aten:
packet = getattr(torch.ops.aten, op_name)
for overload_name in packet:
op = getattr(packet, overload_name)
try:
for return_info in op._schema.returns:
if isinstance(return_info.type, torch.TensorType) and return_info.alias_info is not None:
no_copy_ops.add(op)
break
except AttributeError:
# In case no schema is available, conservatively assume the op
# may reuse tensor storage and print a one-time warning on its
# potential performance impact.
if not warned:
log_rank0(
f"WARNING: Schema is missing for some torch.ops.aten ops (e.g. {op_name}.{overload_name})."
"We assume those ops may reuse tensor storage. This may impact performance of compiled models.",
enable=True,
)
warned = True
no_copy_ops.add(op)
return no_copy_ops
def get_input_nodes(graph: Graph) -> List[Node]: def get_input_nodes(graph: Graph) -> List[Node]: