mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
deepcompile: Create a full list of no-copy ops (#7562)
The list of torch no-copy ops is hard coded and does not include all operations that may aliasing inputs in their outputs. Instead of using a fixed list, iterate over all ops under torch.ops.aten and identify those with aliasing behavior by inspecting their schema. With PyTorch 2.7.1, the default overload of ops identified by the updated logic include: - _nested_view_from_buffer - _reshape_alias - alias - as_strided - conj - detach - diagonal - expand - imag - lift_fresh - narrow - permute - pin_memory - positive - real - reshape - squeeze - t - unfold - unsqueeze - view - view_as_complex - view_as_real - most operations whose name ends with an underscore Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
This commit is contained in:
@ -76,13 +76,46 @@ def log_rank0(msg: str, enable: bool = False):
|
|||||||
print(msg)
|
print(msg)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache
|
||||||
def get_no_copy_ops():
|
def get_no_copy_ops():
|
||||||
# Need to compile custom ops
|
# Need to compile custom ops
|
||||||
get_deepcompile_handle()
|
get_deepcompile_handle()
|
||||||
return {
|
|
||||||
torch.ops.aten.t.default, torch.ops.aten.view.default, torch.ops.aten.detach.default,
|
no_copy_ops = {torch.ops.dc.wait_allgather.default}
|
||||||
torch.ops.aten.permute.default, torch.ops.dc.wait_allgather.default
|
|
||||||
}
|
# All operations whose return value aliases any of their inputs are included
|
||||||
|
# in the returned list to ensure that the last user of a node is computed
|
||||||
|
# correctly.
|
||||||
|
#
|
||||||
|
# This can be overly conservative if not all input tensors are aliased in
|
||||||
|
# the output. While we can determine exactly which tensors are aliased, a
|
||||||
|
# finer-grained algorithm is required in get_last_uses() and get_real_uses()
|
||||||
|
# to utilize that information. This is left as future work when real needs
|
||||||
|
# arise.
|
||||||
|
warned = False
|
||||||
|
for op_name in torch.ops.aten:
|
||||||
|
packet = getattr(torch.ops.aten, op_name)
|
||||||
|
for overload_name in packet:
|
||||||
|
op = getattr(packet, overload_name)
|
||||||
|
try:
|
||||||
|
for return_info in op._schema.returns:
|
||||||
|
if isinstance(return_info.type, torch.TensorType) and return_info.alias_info is not None:
|
||||||
|
no_copy_ops.add(op)
|
||||||
|
break
|
||||||
|
except AttributeError:
|
||||||
|
# In case no schema is available, conservatively assume the op
|
||||||
|
# may reuse tensor storage and print a one-time warning on its
|
||||||
|
# potential performance impact.
|
||||||
|
if not warned:
|
||||||
|
log_rank0(
|
||||||
|
f"WARNING: Schema is missing for some torch.ops.aten ops (e.g. {op_name}.{overload_name})."
|
||||||
|
"We assume those ops may reuse tensor storage. This may impact performance of compiled models.",
|
||||||
|
enable=True,
|
||||||
|
)
|
||||||
|
warned = True
|
||||||
|
no_copy_ops.add(op)
|
||||||
|
|
||||||
|
return no_copy_ops
|
||||||
|
|
||||||
|
|
||||||
def get_input_nodes(graph: Graph) -> List[Node]:
|
def get_input_nodes(graph: Graph) -> List[Node]:
|
||||||
|
Reference in New Issue
Block a user