mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Optimize Net._get_next_net_name (#107479)
Summary: This is surprisingly expensive and can be easily optimized. Differential Revision: D48440000 Pull Request resolved: https://github.com/pytorch/pytorch/pull/107479 Approved by: https://github.com/kit1980
This commit is contained in:
committed by
PyTorch MergeBot
parent
24147a8e1c
commit
1e9b590df9
@ -8,6 +8,7 @@
|
||||
from collections import namedtuple, OrderedDict, defaultdict
|
||||
from past.builtins import basestring
|
||||
from itertools import chain
|
||||
from typing import Dict
|
||||
|
||||
from caffe2.proto import caffe2_pb2
|
||||
from caffe2.python import scope, utils, workspace
|
||||
@ -1445,7 +1446,7 @@ def _recover_record_by_prefix(names, prefix=''):
|
||||
|
||||
|
||||
class Net:
|
||||
_net_names_used = set()
|
||||
_net_names_used_counters: Dict[str, int] = {}
|
||||
operator_registry_ = {}
|
||||
|
||||
@staticmethod
|
||||
@ -1454,17 +1455,16 @@ class Net:
|
||||
builder = NetBuilder.current(required=False)
|
||||
return builder.name if builder else ''
|
||||
|
||||
@staticmethod
|
||||
def _reset_used_names() -> None:
|
||||
Net._net_names_used_counters = {}
|
||||
|
||||
@staticmethod
|
||||
def _get_next_net_name(basename):
|
||||
name = basename = '/'.join(
|
||||
x for x in [Net.current_prefix(), basename] if x
|
||||
)
|
||||
next_idx = 1
|
||||
while name in Net._net_names_used:
|
||||
name = basename + '_' + str(next_idx)
|
||||
next_idx += 1
|
||||
Net._net_names_used |= set([name])
|
||||
return name
|
||||
basename = "/".join(x for x in [Net.current_prefix(), basename] if x)
|
||||
next_idx = Net._net_names_used_counters.get(basename, 0)
|
||||
Net._net_names_used_counters[basename] = next_idx + 1
|
||||
return basename if next_idx == 0 else f"{basename}_{next_idx}"
|
||||
|
||||
def __init__(self, name_or_proto, inplace=False):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user