Source code for barrista.solver

# -*- coding: utf-8 -*-
"""Exposes the caffe solvers."""
# pylint: disable=E1101, F0401, C0103, R0913, R0914, W0212, E1121, E0611, W0406
# pylint: disable=duplicate-code, too-many-lines
from __future__ import print_function

from . import monitoring as _monitoring
from . import parallel as _parallel
# CAREFUL! This must be imported pre any caffe-related import!
from .tools import pbufToPyEnum as _pbufToPyEnum

import time as _time
import logging as _logging
import hashlib
import copy
from tempfile import NamedTemporaryFile as _NamedTemporaryFile
import numpy as _np
import google.protobuf.text_format as _gprototext

import caffe as _caffe
import caffe.proto.caffe_pb2 as _caffe_pb2

#: Describes the type of the solver used. All solver types supported by caffe
#: are available.
SolverType = _pbufToPyEnum(_caffe_pb2.SolverParameter.SolverType)

#: Describes the Phase used. All solver types supported by caffe
#: are available.
_Phase = _pbufToPyEnum(_caffe_pb2.Phase)

_HAS_ITER_SIZE = hasattr(_caffe_pb2.SolverParameter, 'iter_size')
try:
    _ADAM_SOLVER_CLASS = _caffe.AdamSolver
    _ADAM_SOLVER_ENUM = SolverType.ADAM
except AttributeError:  # pragma: no cover
    _ADAM_SOLVER_CLASS = None
    _ADAM_SOLVER_ENUM = None
try:
    _ADADELTA_SOLVER_CLASS = _caffe.AdaDeltaSolver
    _ADADELTA_SOLVER_ENUM = SolverType.ADADELTA
except AttributeError:  # pragma: no cover
    _ADADELTA_SOLVER_CLASS = None
    _ADADELTA_SOLVER_ENUM = None
try:
    _ADAGRAD_SOLVER_CLASS = _caffe.AdaGradSolver
    _ADAGRAD_SOLVER_ENUM = SolverType.ADAGRAD
except AttributeError:  # pragma: no cover
    _ADAGRAD_SOLVER_CLASS = None
    _ADAGRAD_SOLVER_ENUM = None
try:
    _RMSPROP_SOLVER_CLASS = _caffe.RMSPropSolver
    _RMSPROP_SOLVER_ENUM = SolverType.RMSPROP
except AttributeError:  # pragma: no cover
    _RMSPROP_SOLVER_CLASS = None
    _RMSPROP_SOLVER_ENUM = None

_LOGGER = _logging.getLogger(__name__)


# pylint: disable=too-many-instance-attributes
[docs]class Solver(object): """Describes the Solver concept.""" _solver_types = {} _caffe_solver_type = None _solver_type = None def __init__(self, **kwargs): r""" Constructor. :param iter_size: int>0. The number of batches the gradient is accumulated over (not available in older caffe versions). :param lr_policy: string in ['fixed', 'step', ...] The policy to use to adjust the learning rate during fitting. Taken from ``solver.cpp``: * fixed: always return base_lr. * step: return base_lr \* gamma ^ (floor(iter / step)) * exp: return base_lr \* gamma ^ iter * inv: return base_lr \* (1 + gamma \* iter) ^ (- power) * multistep: similar to step but it allows non uniform steps defined by stepvalue * poly: the effective learning rate follows a polynomial decay, to be zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) * sigmoid: the effective learning rate follows a sigmod decay return base_lr ( 1/(1 + exp(-gamma \* (iter - stepsize)))) :param base_lr: float or None. The base learning rate to use. :param gamma: float or None. :param power: float or None. :param weight_decay: float or None. Use weight decay to reduce the weights at each step. :param regularization_type: string in ['L1', 'L2']. Specifies how the ``weight_decay`` is applied. :param step_stepsize: float or None. The stepsize for the step policy. :param stepvalue: list(int) or None. The stepvalue parameter for the multistep policy. :param clip_gradients: float or None. Clips the gradients to the specified value. :param random_seed: int>0 or None. If specified, seeds the solver for reproducible results. Otherwise, it uses a time dependent seed. :param snapshot_prefix: string or None. If the ``Checkpointer`` monitor is used, this prefix is used to create the snapshots. :param debug_info: bool. If set to ``True``, gives additional output in the logs. """ self._net = None self._parameter_hash = None self._parameter_dict = dict() self.update_parameters(**kwargs) # some default internal parameters self._parameter_dict['snapshot_after_train'] = False self._parameter_dict['solver_type'] = self._caffe_solver_type # every solver can append its on assertions or overwrite the given ones self._asserts = [] if _HAS_ITER_SIZE: self._asserts.append(self.Assert_iter_size) self._asserts.append(self.Assert_regularization_types) self._asserts.append(self.Assert_policy) self._solver = None self._print_warning = False self._train_net_dummy = None self._test_net_dummy = None self._parallel_train_filler = None self._parallel_test_filler = None self._parallel_batch_res_train = None self._parallel_batch_res_test = None
[docs] def restore(self, filename, net=None): """Restore the solverstate from a file.""" if self._net is None: assert net is not None, ('you must specify a net on which the ' 'restored solver will be used!') if net is not None: # The method self._Update_net must not be used here, since it # is allowed to use a new net. self._net = net self._Update_solver() self._solver.restore(filename)
@classmethod
[docs] def Get_required_arguments(cls): """The minimum number of required parameters.""" return ['base_lr']
@classmethod
[docs] def Get_optional_arguments(cls): """ Get the optional parameters. Optional parameters and some of which are None not all combinations are possible, this is enforced by various asserts when calling Get_parameter_dict(). """ ret_dict = {'debug_info': False, 'weight_decay': None, 'lr_policy': 'fixed', 'regularization_type': 'L2', 'power': None, 'gamma': None, 'snapshot_prefix': None, 'stepsize': None, 'stepvalue': None, 'clip_gradients': None, 'random_seed': None, 'net': None} if _HAS_ITER_SIZE: ret_dict['iter_size'] = 1 return ret_dict
[docs] def fit(self, # pylint: disable=too-many-statements, too-many-branches iterations, X=None, X_val=None, input_processing_flags=None, test_iterations=0, test_interval=0, test_initialization=False, train_callbacks=None, test_callbacks=None, net=None, read_input_batch_size_from_blob_name=None, use_fit_phase_for_validation=False, allow_test_phase_for_train=False, shuffle=False): r""" fit the network to specific data. Use monitors from the module :py:mod:`barrista.monitoring` as callbacks to monitor the state of the net and create checkpoints. This method offers the following kwargs to monitors (* indicates, that the values are only available at test time, - indicates, that the value is not necessarily available): * max_iter, * iter, * batch_size, * net, * testnet\[only if there is a test phase, i.e., X_val is set] * solver, * callback_signal\[is automatically set by the fit function], * X\-[only if provided by the user], * X_val\-[only if provided by the user], * [the following fields are only set if the corresponding loss/accuracy layer exists for the train and/or test phase. It can also be set by providing a custom ResultExtractor] * loss\-, * test_loss\*, * accuracy\-, * test_accuracy\*-, :param iterations: int. The number of training iterations to do. This is the plain number of iterations, completely disregarding the batch size, i.e., for ``iterations`` being 10 and ``batch_size`` being 10, just one batch is forward propagated. :param X: dict of numpy.ndarray or None. If specified, is used as input data. It is used sequentially, so shuffle it pre, if required. The keys of the dict have to have a corresponding layer name in the net. :param X_val: dict of numpy.ndarray or None. If specified and ``test_interval>0``, it is used as input data. It is used sequentially, so shuffle it pre, if required. The keys of the dict have to have a corresponding layer name in the net. :param input_processing_flags: dict(string, string) or None. See ``CyclingDataMonitor.__init__`` for the ``input_processing_flags`` parameter. In short, if you specify your sample via list, you may specify for each blob, whether they should be padded 'p', or resized 'r' to match the network input size. If they fit perfectly, you may specify 'n' or omit the parameter and use ``None``. :param test_iterations: int. The number of test iterations to determine the validation score, if ``test_interval>0``. :param test_interval: int. The number of iterations between runs on the validation set. Is specified in plain iterations, disregarding batch size. Hence, it must be a multiple of the batch size. :param test_initialization: bool. Whether to do a run on the validation set pre the training is started to get an initial score. :param train_callbacks: list(barrista.monitoring.Monitor). List of callback callables. Will be called pre and post training batch is processed. This list will be processed sequentially, meaning that monitors in the sequence can provide information for later monitors as done with ResultExtractor. :param test_callbacks: list(callable). List of callback callables. Will be called for pre and post testing and pre and post each batch of testing processed. This list will be processed sequentially, meaning that monitors in the sequence can provide information for later monitors as done with ResultExtractor. :param read_input_batch_size_from_blob_name: string. The name of the layer to take the input batch size from (as the first dimension of its first blob). Must be specified if the network does not have explicit inputs (e.g., when trained from an LMDB). :param use_fit_phase_for_validation: bool. If set to True, use do not change the phase of the net for running a validation step during training. This can be helpful to reduce memory consumption. This ignores the TEST phase of the net completely, but it's not necessary to use it if the data is provided by the Python layers. :param allow_test_phase_for_train: bool. If set to True, allow using a network in its TEST phase to be trained. May make sense in exotic settings, but should prevent bugs. If not set to True, an AssertionError is raised in this scenario. Why is this so important? The ``DropoutLayer`` and ``PoolLayer`` (in the case of stochastic pooling) are sensitive to this parameter and results are very different for the two settings. :param shuffle: bool. If set to True, shuffle the training data every epoch. The test data is not shuffled. Default: False. """ if net is not None: from barrista import net as _net assert isinstance(net, _net.Net), ( 'net must be an instance of barrista.net.Net') self._Update_net(net) assert self._net is not None, ( 'neither the solver was initialized with a net nor', 'the fit function was called with one') assert self._net._mode == _Phase.TRAIN or allow_test_phase_for_train, ( 'The network must be in TRAIN phase for fitting! If you really ' 'want to, you can override this requirement by setting ' 'the optional parameter `allow_test_phase_for_train` to True.' ) train_callbacks = self._Assert_callbacks(self._net, train_callbacks, 'train') test_callbacks = self._Assert_callbacks(self._net, test_callbacks, 'test') testnet = self._Init_testnet(test_interval, use_fit_phase_for_validation) batch_size, test_iterations = self._Get_batch_size( self._net, testnet, test_interval, test_iterations, X_val, read_input_batch_size_from_blob_name) self._Assert_iterations( batch_size, iterations, test_interval, test_iterations, self._parameter_dict.get('stepvalue') ) if self._parameter_dict.get('stepvalue') is not None: self._parameter_dict['stepvalue'] = [ val / batch_size for val in self._parameter_dict['stepvalue']] self._Init_cycling_monitor(X, X_val, input_processing_flags, batch_size, test_interval, train_callbacks, test_callbacks, shuffle) run_pre = True iteration = 0 cbparams = dict() cbparams['max_iter'] = iterations cbparams['batch_size'] = batch_size cbparams['iter'] = 0 cbparams['net'] = self._net cbparams['testnet'] = testnet cbparams['solver'] = self cbparams['X'] = X cbparams['X_val'] = X_val cbparams['test_iterations'] = test_iterations cbparams['test_interval'] = test_interval cbparams['train_callbacks'] = train_callbacks cbparams['test_callbacks'] = test_callbacks cbparams['callback_signal'] = 'initialize_train' for cb in train_callbacks: cb(cbparams) if test_interval > 0: cbparams['callback_signal'] = 'initialize_test' for cb in test_callbacks: cb(cbparams) try: _parallel.init_prebatch( self, self._net, train_callbacks, True) if test_interval > 0: _parallel.init_prebatch( self, testnet, test_callbacks, False) while iteration <= iterations: cbparams['iter'] = iteration # Check whether to test the net. if (( test_interval > 0 and iteration % test_interval == 0 and iteration > 0 ) or ( iteration == 0 and test_initialization ) or ( test_interval > 0 and iteration + batch_size > iterations ) ): ############################################################### # testing loop ############################################################### test_iter = 0 run_pre = True # Pretest gets called if necessary in `run_prebatch`. while test_iter < test_iterations: cbparams['callback_signal'] = 'pre_test_batch' _parallel.run_prebatch( self, test_callbacks, cbparams, False, cbparams['iter'], run_pre) # pylint: disable=W0212 testnet._forward(0, len(testnet.layers) - 1) cbparams['callback_signal'] = 'post_test_batch' for cb in test_callbacks: cb(cbparams) test_iter += batch_size run_pre = False cbparams['callback_signal'] = 'post_test' for cb in test_callbacks: cb(cbparams) run_pre = True if iteration == iterations: break ################################################################### # training loop ################################################################### # `pre_fit` gets called if necessary in `run_prebatch`. PRETRBATCH_BEGINPOINT = _time.time() cbparams['callback_signal'] = 'pre_train_batch' _parallel.run_prebatch( self, train_callbacks, cbparams, True, cbparams['iter'] + batch_size, run_pre) run_pre = False PRETRBATCH_DURATION = _time.time() - PRETRBATCH_BEGINPOINT _LOGGER.debug("Pre-batch preparation time: %03.3fs.", PRETRBATCH_DURATION) TRBATCH_BEGINPOINT = _time.time() self.step(1) TRBATCH_DURATION = _time.time() - TRBATCH_BEGINPOINT _LOGGER.debug("Batch processing time: %03.3fs.", TRBATCH_DURATION) POSTTRBATCH_BEGINPOINT = _time.time() cbparams['callback_signal'] = 'post_train_batch' for cb in train_callbacks: cb(cbparams) POSTTRBATCH_DURATION = _time.time() - POSTTRBATCH_BEGINPOINT _LOGGER.debug("Post-batch processing time: %03.3fs.", POSTTRBATCH_DURATION) iteration += batch_size finally: for cb in set(train_callbacks + test_callbacks): if not isinstance(cb, _monitoring.ParallelMonitor): cb.finalize(cbparams) _parallel.finalize_prebatch(self, cbparams) if self._parameter_dict.get('stepvalue') is not None: self._parameter_dict['stepvalue'] = [ val * batch_size for val in self._parameter_dict['stepvalue']]
[docs] def step(self, number_of_batches): """Run ``number_of_batches`` solver steps.""" tmp_hash = self.Get_parameter_hash(self.Get_parameter_dict()) if self._parameter_hash != tmp_hash: if self._print_warning: # pragma: no cover _LOGGER.warn('WARNING: ---------------------------------------------') _LOGGER.warn('you are re-initializing a new solver which will delete') _LOGGER.warn('the weight history of the solver.') _LOGGER.warn('Only use this option if you know what you are doing!') self._print_warning = False self._Update_solver() return self._solver.step(number_of_batches)
[docs] def Get_parameter_dict(self): """Get the solver describing parameters in a dictionary.""" # work our stack of assertions followed by a weak copy of the dict for Tmp_assert in self._asserts: assert Tmp_assert() return copy.copy(self._parameter_dict)
[docs] def Assert_iter_size(self): """Enforce the parameter constraints.""" return self._parameter_dict['iter_size'] > 0
[docs] def Assert_regularization_types(self): """Enforce the parameter constraints.""" return self._parameter_dict['regularization_type'] in ['L1', 'L2']
[docs] def Assert_policy(self): # pylint: disable=R0911 """Enforce the parameter constraints.""" # although redundant this allows to have a quick check # of what is really required without loading the actuall net which # might take a bit of time if self._parameter_dict['lr_policy'] == 'fixed': return 'base_lr' in self._parameter_dict if self._parameter_dict['lr_policy'] == 'step': return 'gamma' in self._parameter_dict if self._parameter_dict['lr_policy'] == 'exp': return 'gamma' in self._parameter_dict if self._parameter_dict['lr_policy'] == 'inv': return ('gamma' in self._parameter_dict and 'power' in self._parameter_dict) if self._parameter_dict['lr_policy'] == 'multistep': return ('stepvalue' in self._parameter_dict and 'base_lr' in self._parameter_dict and 'gamma' in self._parameter_dict) if self._parameter_dict['lr_policy'] == 'poly': return 'power' in self._parameter_dict if self._parameter_dict['lr_policy'] == 'sigmoid': return 'stepsize' in self._parameter_dict return False
@classmethod
[docs] def Get_parameter_hash(cls, solver_parameter_dict): """Get a has of the parameter dict.""" hash_obj = hashlib.md5() for key in sorted(solver_parameter_dict.keys()): hash_obj.update(str(key).encode('utf-8')) hash_obj.update(str(solver_parameter_dict[key]).encode('utf-8')) return str(hash_obj.hexdigest())
@classmethod
[docs] def Get_caffe_solver_instance(cls, solver_parameter_dict, net): """Get a caffe solver object.""" # now we actually create a instance of the solver solver_message = _caffe_pb2.SolverParameter(**solver_parameter_dict) messagestr = _gprototext.MessageToString(solver_message) with _NamedTemporaryFile(mode='w+b', suffix='.prototxt') as tmpfile: tmpfile.write(bytes(messagestr.encode('utf-8'))) tmpfile.flush() try: # Newer version of caffe with full solver init support. return cls.Get_caffe_solver_class( solver_parameter_dict['solver_type'])._caffe_solver_class( tmpfile.name, net, _caffe._caffe.NetVec(), True) except TypeError: # Fallback for older, patched versions. return cls.Get_caffe_solver_class( solver_parameter_dict['solver_type'])._caffe_solver_class( tmpfile.name, net) raise Exception('could not initialize solver class')
@classmethod
[docs] def Get_solver_class(cls, solver_type): """Get the solver class as string.""" return cls._solver_types[solver_type]
@classmethod
[docs] def Get_caffe_solver_class(cls, caffe_solver_type): """Get the solver class as ``caffe_solver_type``.""" return cls._solver_types[caffe_solver_type]
@classmethod
[docs] def Register_solver(cls, solver_class): """Register a solver class.""" assert issubclass(solver_class, Solver) if solver_class._solver_type in cls._solver_types: raise Exception( ' '.join('solver', solver_class._solver_type, 'already defined')) if solver_class._caffe_solver_type in cls._solver_types: raise Exception( ' '.join('solver', solver_class._solver_type, 'already defined')) # we register with both access types cls._solver_types[solver_class._caffe_solver_type] = solver_class cls._solver_types[solver_class._solver_type] = solver_class
def _Update_solver(self): """Re-initialize the solver.""" # we (re-)initialize the solver self._solver = self.Get_caffe_solver_instance( self.Get_parameter_dict(), self._net) self._parameter_hash = self.Get_parameter_hash( self.Get_parameter_dict()) # we only want to see the warning once self._print_warning = True
[docs] def update_parameters(self, **kwargs): """Update the solver parameters.""" # adding the default keys if they are not yet set for argument, default in list(self.Get_optional_arguments().items()): if argument not in self._parameter_dict and default is not None: self._parameter_dict[argument] = default # first add all parameters which are actually required for arg_key, arg_value in list(kwargs.items()): if arg_key in self.Get_required_arguments(): self._parameter_dict[arg_key] = arg_value # make sure that all required arguments are set tmp_required_arguments = set(self.Get_required_arguments()) intersection = tmp_required_arguments.intersection(set(kwargs.keys())) if intersection != tmp_required_arguments: raise Exception(' '.join( ['we are missing required arguments', str(list(kwargs.keys())), 'vs', str(self.Get_required_arguments())])) for arg_key, arg_value in list(kwargs.items()): # the very special case of passing the net # this will not be passed as a parameter to the parameter dict # but we will ensure that the net is always the same # as the one used for initialization if arg_key == 'net': self._Update_net(arg_value) continue if arg_key in list(self.Get_optional_arguments().keys()): self._parameter_dict[arg_key] = arg_value # we make sure that there is no spelling mistake in the kwargs total_arguments = set(self.Get_required_arguments()) total_arguments = total_arguments.union( list(self.Get_optional_arguments().keys())) for argument in list(kwargs.keys()): if argument not in total_arguments: raise Exception(' '.join( ['argument', argument, 'is not supported']))
def _Update_net(self, net): """Check that the net remains the same.""" # since the user could potentially provide two different nets to # the solver, which is not supported, thus we check that the net # has not changed if net is None: return if self._net is not None: if id(self._net) != id(net): raise Exception(' '.join( ['a solver works only with one network', 'the network has to remain the same'])) self._net = net def _Get_batch_size(self, # pylint: disable=R0201 net, testnet, test_interval, test_iterations, X_val, read_input_batch_size_from_blob_name): """Get the batch size and the test iterations.""" if len(net.inputs) > 0: # Otherwise, a DB backend is used. batch_size = net.blobs[net.inputs[0]].data.shape[0] if testnet is not None: assert (testnet.blobs[net.inputs[0]].data.shape[0] == batch_size), ("Validation and fit network batch size " "must agree!") if (test_interval != 0 and test_iterations == 0 and X_val is not None): if isinstance(X_val, dict): if len(X_val.values()[0]) % batch_size != 0: _LOGGER.warn( "The number of test samples is not a multiple " "of the batch size. Test performance estimates " "will be slightly off.") test_iterations = _np.ceil(float(len(X_val.values()[0])) / float(batch_size)) * batch_size else: if len(X_val) % batch_size != 0: _LOGGER.warn( "The number of test samples is not a multiple " "of the batch size. Test performance estimates " "will be slightly off.") test_iterations = _np.ceil(float(len(X_val)) / float(batch_size)) * batch_size if read_input_batch_size_from_blob_name is not None: tmp_batch_size = net.blobs[ read_input_batch_size_from_blob_name].data.shape[0] assert (tmp_batch_size == batch_size), ( "The input size and the first dimension of " "the blob to read the batch size from don't " "match: {}, {}.".format(tmp_batch_size, batch_size)) return batch_size, test_iterations # some kind of backend is used assert read_input_batch_size_from_blob_name is not None, ( 'no inputs thus the batch_size must be determined from a blob') batch_size = net.blobs[ read_input_batch_size_from_blob_name].data.shape[0] return batch_size, test_iterations @classmethod def _Assert_iterations(cls, batch_size, iterations, test_interval, test_iterations, multistep_stepvalue): """Make sure iterations follow all of our rules.""" # namely being a multiple of the batch_size assert iterations % batch_size == 0, ( 'Error: iterations do not match {} {}'.format(iterations, batch_size)) if test_interval > 0: assert test_iterations > 0, ( 'Test iterations must be > 0 but is {}'.format( test_iterations)) # Set the configurable arguments. assert test_iterations >= 0, ( 'Test iterations must be >= 0 but is {}'.format( test_iterations)) assert test_interval >= 0, ( 'Test interval must be >= 0 but is {}'.format( test_iterations)) assert test_interval % batch_size == 0, ( 'The test interval must be a multiple of the batch size: {}, {}', test_iterations, batch_size) if multistep_stepvalue is not None: for val in multistep_stepvalue: assert val % batch_size == 0, ( "The step values must be multiples of the batch size " "(is given in sample iterations)! Is %d, batch size %d." % ( val, batch_size)) @classmethod def _Assert_callbacks(cls, net, callbacks, phase): """Assert the callbacks work properly.""" if callbacks is None: callbacks = [] assert isinstance(callbacks, list), ( 'callbacks have to be in a list {} {}'.format( str(callbacks), type(callbacks))) for callback in callbacks: assert isinstance(callback, _monitoring.Monitor), ( 'a callback has to derive from montoring.Monitor') if 'loss' in list(net.blobs.keys()): callbacks.insert(0, _monitoring.ResultExtractor( phase + '_loss', 'loss')) if 'accuracy' in list(net.blobs.keys()): callbacks.insert(0, _monitoring.ResultExtractor( phase + '_accuracy', 'accuracy')) return callbacks @classmethod def _Init_cycling_monitor(cls, X, X_val, input_processing_flags, batch_size, test_interval, train_callbacks, test_callbacks, shuffle): """ Convencience initialization function. ...such that the user can simply provide X, X_val dicts and we internally create the CyclingDataMonitors. """ if X is not None: assert len(list(X.values())[0]) >= batch_size # safety measure, we do not want to have two different data # monitors in the same callback list for callback in train_callbacks: assert not isinstance(callback, _monitoring.DataMonitor), ( 'if we use X we cannot use a data monitor') tmp_data_monitor = _monitoring.CyclingDataMonitor( X=X, input_processing_flags=input_processing_flags, shuffle=shuffle) train_callbacks.insert(0, tmp_data_monitor) if test_interval > 0 and X_val is not None: assert X_val is not None assert len(list(X_val.values())) == len(list(X.values())) # safety measure, we do not want to have two different data # monitors in the same callback list for callback in test_callbacks: assert not isinstance(callback, _monitoring.DataMonitor), ( 'if we use X_val we cannot use a data monitor') tmp_data_monitor = _monitoring.CyclingDataMonitor( X=X_val, input_processing_flags=input_processing_flags) test_callbacks.insert(0, tmp_data_monitor) def _Init_testnet(self, test_interval, use_fit_phase_for_validation): """Initialize the test phase network.""" testnet = None if test_interval > 0: if use_fit_phase_for_validation: testnet = self._net else: # Setup the test net. test_netspec = self._net._specification.copy() test_netspec.phase = _Phase.TEST test_netspec.predict_inputs = None test_netspec.predict_input_shapes = None testnet = test_netspec.instantiate() testnet.share_with(self._net) return testnet
[docs]class SGDSolver(Solver): r""" Thin wrapper for the vanilla SGD solver provided by the caffe framework. :param momentum: float or None. The momentum to use. Multiplies the former gradient with this factor and adds it to the gradient in the following step. """ _solver_type = 'sgd' _caffe_solver_type = SolverType.SGD _caffe_solver_class = _caffe.SGDSolver def __init__(self, **kwargs): """Constructor.""" Solver.__init__(self, **kwargs) @classmethod
[docs] def Get_required_arguments(cls): """See :py:class:`barrista.solver.Solver`.""" return Solver.Get_required_arguments()
@classmethod
[docs] def Get_optional_arguments(cls): """See :py:class:`barrista.solver.Solver`.""" optional_arguments = Solver.Get_optional_arguments() optional_arguments['momentum'] = 0.0 return optional_arguments
[docs]class AdagradSolver(Solver): r""" Thin wrapper for the Adagrad solver provided by the caffe framework. To understand how this solver works please inspect the cplusplus implementation in solver.cpp. The corresponding publication is called 'Adaptive Subgradient Methods for Online Learning and Stochastic Optimization' by John Duchi, Elad Hazan, Yoram Singer :param momentum: float or None. The momentum to use. Multiplies the former gradient with this factor and adds it to the gradient in the following step. """ _solver_type = 'adagrad' _caffe_solver_type = _ADAGRAD_SOLVER_ENUM _caffe_solver_class = _ADAGRAD_SOLVER_CLASS def __init__(self, **kwargs): """See :py:class:`barrista.solver.Solver`.""" Solver.__init__(self, **kwargs) @classmethod
[docs] def Get_required_arguments(cls): """See :py:class:`barrista.solver.Solver`.""" required_arguments = Solver.Get_required_arguments() required_arguments.append('delta') return required_arguments
@classmethod
[docs] def Get_optional_arguments(cls): """See :py:class:`barrista.solver.Solver`.""" return Solver.Get_optional_arguments()
[docs]class NesterovSolver(Solver): r""" Thin wrapper for the Nesterov solver provided by the caffe framework. To understand how this solver works please inspect the cplusplus implementation in solver.cpp. :param momentum: float or None. The momentum to use. Multiplies the former gradient with this factor and adds it to the gradient in the following step. """ _solver_type = 'nesterov' _caffe_solver_type = SolverType.NESTEROV _caffe_solver_class = _caffe.NesterovSolver def __init__(self, **kwargs): """See :py:class:`barrista.solver.Solver`.""" Solver.__init__(self, **kwargs) @classmethod
[docs] def Get_required_arguments(cls): """See :py:class:`barrista.solver.Solver`.""" return Solver.Get_required_arguments()
@classmethod
[docs] def Get_optional_arguments(cls): """See :py:class:`barrista.solver.Solver`.""" optional_arguments = Solver.Get_optional_arguments() optional_arguments['momentum'] = 0.0 return optional_arguments
[docs]class RMSPropSolver(Solver): r""" Thin wrapper for the RMSProp solver provided by the caffe framework. To understand how this solver works please inspect the cplusplus implementation in solver.cpp. This solver has been discussed in a lecture given by Hinton. www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf :param rms_decay: float MeanSquare(t) = rms_decay*MeanSquare(t-1)+(1-rms_decay)*SquareGradient(t) :param delta: float numerical stability [useful choice 1E-8] """ _solver_type = 'rmsprop' _caffe_solver_type = _RMSPROP_SOLVER_ENUM _caffe_solver_class = _RMSPROP_SOLVER_CLASS def __init__(self, **kwargs): """See :py:class:`barrista.solver.Solver`.""" Solver.__init__(self, **kwargs) @classmethod
[docs] def Get_required_arguments(cls): """See :py:class:`barrista.solver.Solver`.""" required_arguments = Solver.Get_required_arguments() required_arguments.append('rms_decay') required_arguments.append('delta') return required_arguments
@classmethod
[docs] def Get_optional_arguments(cls): """See :py:class:`barrista.solver.Solver`.""" return Solver.Get_optional_arguments()
[docs]class AdaDeltaSolver(Solver): r""" Thin wrapper for the AdaDelta solver provided by the caffe framework. To understand how this solver works please inspect the cplusplus implementation in solver.cpp. The corresponding arxiv paper is called 'ADADELTA: An Adaptive Learning Rate Method' by Matthew D. Zeiler. :param delta: float numerical stability [useful choice 1E-8] :param momentum: float or None. The momentum to use. Multiplies the former gradient with this factor and adds it to the gradient in the following step. """ _solver_type = 'adadelta' _caffe_solver_type = _ADADELTA_SOLVER_ENUM _caffe_solver_class = _ADADELTA_SOLVER_CLASS def __init__(self, **kwargs): """See :py:class:`barrista.solver.Solver`.""" Solver.__init__(self, **kwargs) @classmethod
[docs] def Get_required_arguments(cls): """See :py:class:`barrista.solver.Solver`.""" required_arguments = Solver.Get_required_arguments() required_arguments.append('momentum') return required_arguments
@classmethod
[docs] def Get_optional_arguments(cls): """See :py:class:`barrista.solver.Solver`.""" optional_arguments = Solver.Get_optional_arguments() # epsilon optional_arguments['delta'] = 1E-8 return optional_arguments
[docs]class AdamSolver(Solver): r""" Thin wrapper for the Adam solver provided by the caffe framework. To understand how this solver works please inspect the cplusplus implementation in solver.cpp. The corresponding arxiv paper is called ' Adam: A Method for Stochastic Optimization ' by Diederik Kingma, Jimmy Ba :param base_lr: float [useful choice 0.001] :param momentum: float. beta 1 useful default 0.9 :param momentum2: float. beta 2 useful default 0.999 :param delta: float numerical stability [useful choice 1E-8] """ _solver_type = 'adam' _caffe_solver_type = _ADAM_SOLVER_ENUM _caffe_solver_class = _ADAM_SOLVER_CLASS def __init__(self, **kwargs): """See :py:class:`barrista.solver.Solver`.""" Solver.__init__(self, **kwargs) @classmethod
[docs] def Get_required_arguments(cls): """See :py:class:`barrista.solver.Solver`.""" return Solver.Get_required_arguments()
@classmethod
[docs] def Get_optional_arguments(cls): """See :py:class:`barrista.solver.Solver`.""" optional_arguments = Solver.Get_optional_arguments() # beta 1 optional_arguments['momentum'] = 0.9 # beta 2 optional_arguments['momentum2'] = 0.999 # epsilon optional_arguments['delta'] = 1E-8 return optional_arguments # register the locally specified solver
Solver.Register_solver(SGDSolver) Solver.Register_solver(AdagradSolver) Solver.Register_solver(NesterovSolver) if _RMSPROP_SOLVER_CLASS is not None: Solver.Register_solver(RMSPropSolver) if _ADADELTA_SOLVER_CLASS is not None: Solver.Register_solver(AdaDeltaSolver) if _ADAM_SOLVER_CLASS is not None: Solver.Register_solver(AdamSolver) Get_solver_class = Solver.Get_solver_class Get_caffe_solver_class = Solver.Get_caffe_solver_class