Optimize Net._get_next_net_name (#107479)

Summary: This is surprisingly expensive and can be easily optimized.

Differential Revision: D48440000

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107479
Approved by: https://github.com/kit1980
This commit is contained in:
Jeffrey Dunn
2023-08-22 19:15:11 +00:00
committed by PyTorch MergeBot
parent 24147a8e1c
commit 1e9b590df9

View File

@ -8,6 +8,7 @@
from collections import namedtuple, OrderedDict, defaultdict
from past.builtins import basestring
from itertools import chain
from typing import Dict
from caffe2.proto import caffe2_pb2
from caffe2.python import scope, utils, workspace
@ -1445,7 +1446,7 @@ def _recover_record_by_prefix(names, prefix=''):
class Net:
_net_names_used = set()
_net_names_used_counters: Dict[str, int] = {}
operator_registry_ = {}
@staticmethod
@ -1454,17 +1455,16 @@ class Net:
builder = NetBuilder.current(required=False)
return builder.name if builder else ''
@staticmethod
def _reset_used_names() -> None:
Net._net_names_used_counters = {}
@staticmethod
def _get_next_net_name(basename):
name = basename = '/'.join(
x for x in [Net.current_prefix(), basename] if x
)
next_idx = 1
while name in Net._net_names_used:
name = basename + '_' + str(next_idx)
next_idx += 1
Net._net_names_used |= set([name])
return name
basename = "/".join(x for x in [Net.current_prefix(), basename] if x)
next_idx = Net._net_names_used_counters.get(basename, 0)
Net._net_names_used_counters[basename] = next_idx + 1
return basename if next_idx == 0 else f"{basename}_{next_idx}"
def __init__(self, name_or_proto, inplace=False):
"""