mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 14:15:01 +08:00
Summary: Closes https://github.com/caffe2/caffe2/pull/226 Differential Revision: D4793550 Pulled By: JoelMarcey fbshipit-source-id: cc33e58186304fa8dcac2ee9115dcc271d785b1e
440 lines
15 KiB
Python
440 lines
15 KiB
Python
## @package net_builder
|
|
# Module caffe2.python.net_builder
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from caffe2.python import core, context
|
|
from caffe2.python.task import Task, TaskGroup
|
|
|
|
|
|
@context.define_context()
|
|
class NetBuilder(object):
|
|
"""
|
|
Scope-driven mechanism for building nets, loops and conditional blocks.
|
|
Example:
|
|
from caffe2.python.net_builder import NetBuilder, ops
|
|
with NetBuilder() as nb:
|
|
c = ops.Const(5)
|
|
d = ops.Const(0)
|
|
with ops.loop():
|
|
ops.stop_if(ops.LE([c, ops.Const(0)]))
|
|
ops.Add([c, ops.Const(-1)], [c])
|
|
with ops.If(ops.GE([c, ops.Const(3)])):
|
|
ops.Add([d, ops.Const(10)])
|
|
ops.Print(c, [])
|
|
ops.Print(d, [])
|
|
step = core.to_execution_step(nb)
|
|
"""
|
|
def __init__(self, name=None, _stop_blob_required=False,
|
|
_stop_blob=None, _fullname=None):
|
|
nb = NetBuilder.current(required=False)
|
|
assert not _fullname or not name, 'Cannot set both _fullname and name'
|
|
self.name = _fullname or '/'.join(filter(lambda x: x, (
|
|
nb.name if nb else None, name)))
|
|
self._frozen = False
|
|
self._current_net = None
|
|
self._children = []
|
|
self._stop_blob = _stop_blob
|
|
self._stop_blob_required = _stop_blob_required
|
|
|
|
def stop_blob(self):
|
|
"""
|
|
Returns the BlobReference to the stop_blob of this NetBuilder.
|
|
If one is not yet available, creates one.
|
|
This function assumes that the stop_blob() will be used immediatelly
|
|
in the current net, so it doesn't initialize it if the current net is
|
|
the first of the builder.
|
|
"""
|
|
if self._stop_blob is None:
|
|
net = self.current_net()
|
|
self._stop_blob = core.BlobReference(
|
|
net.NextName('stop_blob'), net=net)
|
|
if self._current_net != self._children[0]:
|
|
self._children.insert(0, core.Net('stop_blob_init'))
|
|
self._children[0].Const(False, blob_out=self._stop_blob)
|
|
return self._stop_blob
|
|
|
|
def stop_if(self, blob):
|
|
ops.Copy(blob, self.stop_blob())
|
|
self._current_net = None
|
|
|
|
def _assert_mutable(self):
|
|
assert not self._frozen, (
|
|
'This NetBuilder (%s) has been built already.' % self.name)
|
|
|
|
def add(self, child):
|
|
self._assert_mutable()
|
|
self._current_net = None
|
|
self._children.append(child)
|
|
# to-do : check it's not a dag net
|
|
if isinstance(child, core.Net):
|
|
self._current_net = child
|
|
return child
|
|
|
|
def current_net(self, name=None):
|
|
self._assert_mutable()
|
|
if self._current_net is None or name is not None:
|
|
self.add(core.Net(name))
|
|
return self._current_net
|
|
|
|
def freeze(self):
|
|
for child in self._children:
|
|
if hasattr(child, 'freeze'):
|
|
child.freeze()
|
|
self._current_net = None
|
|
self._frozen = True
|
|
|
|
def get(self):
|
|
self.freeze()
|
|
return self._children
|
|
|
|
def __exit__(self, etype, *args):
|
|
self.freeze()
|
|
if etype is not None:
|
|
return
|
|
assert (not self._stop_blob_required) or self._stop_blob is not None, (
|
|
'This NetBuilder (%s) requires a stop condition ' % self.name +
|
|
'to be set with `stop` or `stop_if`')
|
|
|
|
def __str__(self):
|
|
return self.name or 'Un-named NetBuilder'
|
|
|
|
|
|
class Operations(object):
|
|
"""
|
|
Operations to be used in the context of a NetBuilder.
|
|
"""
|
|
def net(self, net=None, name=None):
|
|
"""
|
|
Retrieves the current net, or add a new net to the builder.
|
|
Args:
|
|
net: If provided, add the given net to the active builder.
|
|
Else, returns the current Net or creates a new one as needed.
|
|
name: if provided, creates a new Net with given name and makes
|
|
it the new current net of the active builder. Cannot
|
|
be provided if net is provided.
|
|
"""
|
|
assert name is None or net is None, (
|
|
'Cannot provide both `net` and `name`.')
|
|
if net is not None:
|
|
NetBuilder.current().add(net)
|
|
return net
|
|
return NetBuilder.current().current_net(name=name)
|
|
|
|
def __getattr__(self, op_type):
|
|
"""
|
|
Adds an operator call to the currently active Net.
|
|
"""
|
|
if op_type.startswith('__'):
|
|
raise AttributeError()
|
|
# We want hasattr to work properly even if no context is active.
|
|
if NetBuilder.current(required=False) is None:
|
|
raise AttributeError('No active NetBuilder.')
|
|
return getattr(self.net(), op_type)
|
|
|
|
def task_group(self):
|
|
"""
|
|
Creates a local task group which will execute as the next step of
|
|
the current NetBuilder.
|
|
"""
|
|
from caffe2.python import task
|
|
group = NetBuilder.current()
|
|
with task.Cluster():
|
|
with task.Node('local'):
|
|
tg = task.TaskGroup()
|
|
group.add(tg)
|
|
return tg
|
|
|
|
def stop(self):
|
|
"""
|
|
Stop execution of the current execution step.
|
|
Example:
|
|
ops.Print(a, 0)
|
|
ops.stop()
|
|
ops.Print(b, 0)
|
|
In the example, 'b' will never be printed.
|
|
"""
|
|
return self.stop_if(ops.Const(True))
|
|
|
|
def stop_if(self, blob):
|
|
"""
|
|
Stop execution of the current execution step if the
|
|
condition `blob` is met.
|
|
Example:
|
|
ops.Print(a, 0)
|
|
ops.stop_if(ops.LE([x, ops.Const(0)]))
|
|
ops.Print(b, 0)
|
|
In the example, 'b' will only be printed if the value of scalar
|
|
tensor 'x' lower or equal to 0.
|
|
"""
|
|
return NetBuilder.current().stop_if(blob)
|
|
|
|
def loop(self, iters=None, name=None):
|
|
"""
|
|
Creates a NetBuilder that will execute in a loop as the next step of
|
|
the current NetBuilder. If `iters` is provided, the loop will execute
|
|
for `iters` iterations and then stop. `iters` can be a constant or a
|
|
BlobReference. If `iters` is not provided, the loop will execute
|
|
until `ops.stop` or `ops.stop_if` is called.
|
|
Examples:
|
|
a = ops.Const(5)
|
|
with ops.loop():
|
|
ops.stop_if(ops.LE([a, ops.Const(0)]))
|
|
ops.Print(a, 0)
|
|
ops.Add([a, ops.Const(-1)], [a])
|
|
Above, 'a' will be printed 5 times, with values 5 to 1.
|
|
|
|
with ops.loop(10) as loop:
|
|
ops.LogInfo(loop.iter())
|
|
This will print the numbers from 0 to 9.
|
|
|
|
x = ops.Add([ops.Const(10), ops.Const(10)])
|
|
with ops.loop(x) as loop:
|
|
ops.LogInfo(loop.iter())
|
|
This will print the numbers from 0 to 19.
|
|
"""
|
|
return NetBuilder.current().add(_Loop(iters, name=name))
|
|
|
|
def stop_guard(self, has_stopped_blob=None, name=None):
|
|
"""
|
|
Creates a NetBuilder that will execute once as the next step of the
|
|
current NetBuilder. After execution, a bool tensor will indicate
|
|
whether the inner execution was halted with `stop` or `stop_if`.
|
|
Example:
|
|
a = ops.Const(True)
|
|
with ops.stop_guard() as sg1:
|
|
ops.stop_if(a)
|
|
ops.Print(ops.Const('did not stop'))
|
|
b = ops.Const(False)
|
|
with ops.stop_guard() as sg2:
|
|
ops.stop_if(b)
|
|
ops.Print(ops.Const('did not stop'))
|
|
ops.Print(sg1.has_stopped(), [])
|
|
ops.Print(sg2.has_stopped(), [])
|
|
In the example, 'did not stop' will be printed once,
|
|
followed by True and False.
|
|
"""
|
|
return NetBuilder.current().add(
|
|
_StopGuard(has_stopped_blob=has_stopped_blob, name=name))
|
|
|
|
def If(self, cond, name=None):
|
|
"""
|
|
Creates a NetBuilder that will execute once as the next step of the
|
|
current NetBuilder if the blob `cond` is True.
|
|
Example:
|
|
with ops.If(ops.Const(True)):
|
|
ops.Print(ops.Const('Will print'))
|
|
with ops.If(ops.Const(False)):
|
|
ops.Print(ops.Const('Wont print'))
|
|
The example will print 'Will print' once.
|
|
"""
|
|
return NetBuilder.current().add(_RunIf(cond, name=name))
|
|
|
|
def task_init(self):
|
|
"""
|
|
Defines operations that will be executed once at task startup.
|
|
Useful when implementing processors, that don't have access to the Task
|
|
top-level structure.
|
|
Example:
|
|
def my_processor(rec):
|
|
with ops.task_init():
|
|
one = ops.Const(1)
|
|
two = ops.Const(1)
|
|
return Tuple(
|
|
ops.Add(rec[0](), zero), ops.Add(rec[1](), two))
|
|
"""
|
|
setup = _SetupBuilder(_SetupBuilder.INIT)
|
|
self.net().add_attribute(Task.TASK_SETUP, setup)
|
|
return setup
|
|
|
|
def task_exit(self):
|
|
"""
|
|
Define operations to be executed at task shutdown.
|
|
Useful when implementing processors, that don't have access to the Task
|
|
top-level structure.
|
|
Example:
|
|
def read_queue(queue):
|
|
with ops.task_exit():
|
|
queue.close(ops.net())
|
|
return queue.read(ops.net())
|
|
"""
|
|
setup = _SetupBuilder(_SetupBuilder.EXIT)
|
|
self.net().add_attribute(Task.TASK_SETUP, setup)
|
|
return setup
|
|
|
|
def local_init(self):
|
|
"""
|
|
Similar to `task_init`, but executes at TaskGroup's startup instead,
|
|
before any task of the group starts executing.
|
|
"""
|
|
setup = _SetupBuilder(_SetupBuilder.INIT)
|
|
self.net().add_attribute(TaskGroup.LOCAL_SETUP, setup)
|
|
return setup
|
|
|
|
def local_exit(self):
|
|
"""
|
|
Similar to `task_init`, but executes at TaskGroup's exit instead,
|
|
after all tasks of the group finished execution.
|
|
"""
|
|
setup = _SetupBuilder(_SetupBuilder.EXIT)
|
|
self.net().add_attribute(TaskGroup.LOCAL_SETUP, setup)
|
|
return setup
|
|
|
|
def task_reporter(self, interval_ms=1000, name=None):
|
|
"""
|
|
Define operations to be executed at every time interval from
|
|
task start-up to finish. These operations are guaranteed to
|
|
execute at least once after all other operations of the task are
|
|
finished.
|
|
|
|
Example:
|
|
with ops.task_reporter(interval_ms=10000):
|
|
ops.LogInfo('10s elapsed')
|
|
"""
|
|
return _ReporterBuilder(interval_ms, net=self.net(), name=name)
|
|
|
|
def local_reporter(self, interval_ms=1000, name=None):
|
|
"""
|
|
Similar to task_report, but operations defined within this block
|
|
will run repeatedly for as long as any of the tasks in the current
|
|
TaskGroup have not finished.
|
|
"""
|
|
return _ReporterBuilder(interval_ms, name=name)
|
|
|
|
|
|
ops = Operations()
|
|
|
|
|
|
class _ReporterBuilder(NetBuilder):
|
|
def __init__(self, interval_ms, net=None, name=None):
|
|
NetBuilder.__init__(self, name)
|
|
self._net = net
|
|
self.interval_ms = interval_ms
|
|
|
|
def __exit__(self, etype, *args):
|
|
if etype is None:
|
|
step = core.to_execution_step(self)
|
|
step.RunEveryMillis(self.interval_ms)
|
|
if self._net:
|
|
self._net.add_attribute(Task.REPORT_STEP, step)
|
|
else:
|
|
TaskGroup.current().report_step(
|
|
step, interval_ms=self.interval_ms)
|
|
NetBuilder.__exit__(self, etype, *args)
|
|
|
|
|
|
class _SetupBuilder(NetBuilder):
|
|
INIT = 'init'
|
|
EXIT = 'exit'
|
|
|
|
def __init__(self, type, name=None):
|
|
NetBuilder.__init__(self, name)
|
|
self.type = type
|
|
|
|
def setup(self, net):
|
|
if self.type == _SetupBuilder.INIT:
|
|
return core.to_execution_step(self)
|
|
|
|
def exit(self, net):
|
|
if self.type == _SetupBuilder.EXIT:
|
|
return core.to_execution_step(self)
|
|
|
|
|
|
class _RunOnce(NetBuilder):
|
|
def __init__(self, name=None):
|
|
NetBuilder.__init__(self, name)
|
|
|
|
def __exit__(self, etype, *args):
|
|
if etype is None and self._stop_blob is not None:
|
|
ops.stop()
|
|
NetBuilder.__exit__(self, etype, *args)
|
|
|
|
|
|
class _StopGuard(_RunOnce):
|
|
def __init__(self, has_stopped_blob=None, name=None):
|
|
_RunOnce.__init__(self, name)
|
|
self._stopped = has_stopped_blob
|
|
self._ran = False
|
|
|
|
def __enter__(self):
|
|
r = _RunOnce.__enter__(self)
|
|
self._stopped = ops.Const(True, blob_out=self._stopped)
|
|
return r
|
|
|
|
def __exit__(self, etype, *args):
|
|
if etype is None:
|
|
self._ran = True
|
|
ops.Const(False, blob_out=self._stopped)
|
|
_RunOnce.__exit__(self, etype, *args)
|
|
|
|
def has_stopped(self):
|
|
"""
|
|
Return a blob that will be set to scalar bool `True` after
|
|
this net builder ran, iff it was halted early.
|
|
"""
|
|
assert self._ran, 'Context not used yet.'
|
|
return self._stopped
|
|
|
|
|
|
class _Loop(NetBuilder):
|
|
def __init__(self, iters=None, name=None):
|
|
NetBuilder.__init__(self, name, _stop_blob_required=True)
|
|
if iters is not None:
|
|
self._inc = ops.Const(1)
|
|
self._iter = ops.Const(0)
|
|
self._num_iters = (
|
|
iters if isinstance(iters, core.BlobReference)
|
|
else ops.Const(iters))
|
|
else:
|
|
self._num_iters = None
|
|
|
|
def iter(self):
|
|
assert self._num_iters is not None, (
|
|
'This loop does not have a number of iterations.')
|
|
assert self._iter is not None, (
|
|
'iter() must be called from inside the loop context')
|
|
return self._iter
|
|
|
|
def __enter__(self):
|
|
builder = NetBuilder.__enter__(self)
|
|
if self._num_iters is not None:
|
|
ops.stop_if(ops.GE([self._iter, self._num_iters]))
|
|
return builder
|
|
|
|
def __exit__(self, type, *args):
|
|
if type is None and self._num_iters is not None:
|
|
self.current_net().Add([self._iter, self._inc], [self._iter])
|
|
NetBuilder.__exit__(self, type, *args)
|
|
|
|
|
|
class _RunIf(_RunOnce):
|
|
def __init__(self, cond_blob=None, name=None, _already_ran=None):
|
|
_RunOnce.__init__(self, name)
|
|
assert cond_blob or _already_ran
|
|
self._is_else = cond_blob is None
|
|
if _already_ran is None:
|
|
self._else_blob = ops.Not(cond_blob)
|
|
self._already_ran = ops.Const(False)
|
|
else:
|
|
self._already_ran = _already_ran
|
|
self._else_blob = _already_ran if cond_blob is None else (
|
|
ops.Or([_already_ran, ops.Not(cond_blob)]))
|
|
|
|
def __enter__(self):
|
|
r = _RunOnce.__enter__(self)
|
|
ops.stop_if(self._else_blob)
|
|
ops.Const(True, blob_out=self._already_ran)
|
|
return r
|
|
|
|
def Elif(self, cond, name=None):
|
|
assert not self._is_else, 'Else not allowed for an Else.'
|
|
return NetBuilder.current().add(_RunIf(
|
|
cond, name=name or self.name, _already_ran=self._already_ran))
|
|
|
|
def Else(self, name=None):
|
|
assert not self._is_else, 'Elif not allowed for an Else.'
|
|
return NetBuilder.current().add(
|
|
_RunIf(name=name or self.name, _already_ran=self._already_ran))
|