diff --git a/deepspeed/compile/util.py b/deepspeed/compile/util.py index 8559926bf..9ac6ba904 100644 --- a/deepspeed/compile/util.py +++ b/deepspeed/compile/util.py @@ -76,13 +76,46 @@ def log_rank0(msg: str, enable: bool = False): print(msg) +@functools.lru_cache def get_no_copy_ops(): # Need to compile custom ops get_deepcompile_handle() - return { - torch.ops.aten.t.default, torch.ops.aten.view.default, torch.ops.aten.detach.default, - torch.ops.aten.permute.default, torch.ops.dc.wait_allgather.default - } + + no_copy_ops = {torch.ops.dc.wait_allgather.default} + + # All operations whose return value aliases any of their inputs are included + # in the returned list to ensure that the last user of a node is computed + # correctly. + # + # This can be overly conservative if not all input tensors are aliased in + # the output. While we can determine exactly which tensors are aliased, a + # finer-grained algorithm is required in get_last_uses() and get_real_uses() + # to utilize that information. This is left as future work when real needs + # arise. + warned = False + for op_name in torch.ops.aten: + packet = getattr(torch.ops.aten, op_name) + for overload_name in packet: + op = getattr(packet, overload_name) + try: + for return_info in op._schema.returns: + if isinstance(return_info.type, torch.TensorType) and return_info.alias_info is not None: + no_copy_ops.add(op) + break + except AttributeError: + # In case no schema is available, conservatively assume the op + # may reuse tensor storage and print a one-time warning on its + # potential performance impact. + if not warned: + log_rank0( + f"WARNING: Schema is missing for some torch.ops.aten ops (e.g. {op_name}.{overload_name})." + "We assume those ops may reuse tensor storage. This may impact performance of compiled models.", + enable=True, + ) + warned = True + no_copy_ops.add(op) + + return no_copy_ops def get_input_nodes(graph: Graph) -> List[Node]: