[源码笔记]keras源码分析之Model

本篇是keras源码笔记系列的第三篇。在前两篇中,我们分析了keras对Tensor和Layer等概念的处理,并说明了它们是如何作用别弄个构成有向无环图的。本篇着眼于多层网络模型层面的抽象,即与用户距离最近的接口,源代码文件是/keras/engine/training.py/keras/model.py,要观察的类是ModelSequential

本系列第一篇:【源码笔记】keras源码分析之Tensor, Node和Layer
第二篇:【源码笔记】keras源码分析之Container

Model:添加了训练信息的Container

Model.compile()主要完成了配置optimizer, loss, metrics等操作,而要执行的fit, evaluate等则不在compile过程中配置。

def compile(self, optimizer, loss, metrics=None, loss_weights=None,
            sample_weight_mode=None, **kwargs):
    loss = loss or {}
    self.optimizer = optimizers.get(optimizer)
    self.sample_weight_mode = sample_weight_mode
    self.loss = loss
    self.loss_weights = loss_weights

    loss_function = losses.get(loss)
    loss_functions = [loss_function for _ in range(len(self.outputs))]
    self.loss_functions = loss_functions

    # Prepare targets of model.
    self.targets = []
    self._feed_targets = []
    for i in range(len(self.outputs)):
        shape = self.internal_output_shapes[i]
        name = self.output_names[i]
        target = K.placeholder(ndim=len(shape),
                               name=name + '_target',
                               sparse=K.is_sparse(self.outputs[i]),
                               dtype=K.dtype(self.outputs[i]))
        self.targets.append(target)
        self._feed_targets.append(target)

    # Prepare metrics.
    self.metrics = metrics
    self.metrics_names = ['loss']
    self.metrics_tensors = []

    # Compute total loss.
    total_loss = None
    for i in range(len(self.outputs)):
        y_true = self.targets[i]
        y_pred = self.outputs[i]
        loss_weight = loss_weights_list[i]
        if total_loss is None:
            total_loss = loss_weight * output_loss
        else:
            total_loss += loss_weight * output_loss

    for loss_tensor in self.losses:
        total_loss += loss_tensor

    self.total_loss = total_loss
    self.sample_weights = sample_weights

Model对象的fit()方法封装了_fit_loop()内部方法,而_fit_loop()方法的关键步骤由_make_train_function()方法完成,返回history对象,用于回调函数的处理。

def fit(self, x=None, y=None, ...):
      self._make_train_function()
      f = self.train_function
      return self._fit_loop(f, ins, ...)

_fit_loop()方法中,回调函数完成了对训练过程的监控记录等任务,train_function也被应用于传入的数据:

def _fit_loop(self, f, ins, out_labels=None, batch_size=32,
              epochs=100, verbose=1, callbacks=None,
              val_f=None, val_ins=None, shuffle=True,
              callback_metrics=None, initial_epoch=0):
    self.history = cbks.History()
    callbacks = [cbks.BaseLogger()] + (callbacks or []) + [self.history]
    callbacks = cbks.CallbackList(callbacks)
    out_labels = out_labels or []
    callbacks.set_model(callback_model)
    callbacks.set_params({
        'batch_size': batch_size,
        'epochs': epochs,
        'samples': num_train_samples,
        'verbose': verbose,
        'do_validation': do_validation,
        'metrics': callback_metrics or [],
    })
    callbacks.on_train_begin()
    callback_model.stop_training = False

    for epoch in range(initial_epoch, epochs):
        callbacks.on_epoch_begin(epoch)
        batches = _make_batches(num_train_samples, batch_size)
        epoch_logs = {}
        for batch_index, (batch_start, batch_end) in enumerate(batches):
            batch_ids = index_array[batch_start:batch_end]
            batch_logs = {}
            batch_logs['batch'] = batch_index
            batch_logs['size'] = len(batch_ids)
            callbacks.on_batch_begin(batch_index, batch_logs)
            # 应用传入的train_function
            outs = f(ins_batch)
            callbacks.on_batch_end(batch_index, batch_logs)
        callbacks.on_epoch_end(epoch, epoch_logs)
    callbacks.on_train_end()
    return self.history

_make_train_function()方法从optimizer获取要更新的参数信息,并传入来自backendfunction对象:

def _make_train_function(self):
    if self.train_function is None:
        inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights
        training_updates = self.optimizer.get_updates(
            self._collected_trainable_weights,
            self.constraints,
            self.total_loss)
        updates = self.updates + training_updates
        # Gets loss and metrics. Updates weights at each call.
        self.train_function = K.function(inputs,
                                         [self.total_loss] + self.metrics_tensors,
                                         updates=updates,
                                         name='train_function',
                                         **self._function_kwargs)

Model的其他方法evaluate()等,与fit()的结构类似。

Sequential:构建模型的外层接口

Sequential对象是Model对象的进一步封装,也是用户直接面对的接口,其compile(), fit(), predict()等方法与Model几乎一致,所不同的是添加了add()方法,也是我们用于构建网络的最基本操作。

Sequential.add()方法的源码如下:

def add(self, layer):
    # 第一层必须是InputLayer对象
    if not self.outputs:
        if not layer.inbound_nodes:
            x = Input(batch_shape=layer.batch_input_shape,
                      dtype=layer.dtype, name=layer.name + '_input')
            layer(x)

        self.outputs = [layer.inbound_nodes[0].output_tensors[0]]
        self.inputs = topology.get_source_inputs(self.outputs[0])

        topology.Node(outbound_layer=self, ...)
    else:
        output_tensor = layer(self.outputs[0])
        self.outputs = [output_tensor]
        self.inbound_nodes[0].output_tensors = self.outputs

    self.layers.append(layer)

可以看到,add()方法总是确保网络的第一层为InputLayer对象,并将新加入的层应用于outputs,使之更新。因此,从本质上讲,在Model中添加新层还是在更新模型的outputs

@ddlee

Creative Commons License
本文章遵从Creative Commons Attribution-ShareAlike 4.0 International License
这意味着您可以署名转载本文章,并附上此协议。
如果您想定期获得关于我的博客文章的更新,欢迎邮件订阅东东月报

本文链接:https://blog.ddlee.cn/posts/ddc5b1bd/

相关文章

[论文笔记](R-CNN)Rich feature hierarchies for accurate object detection and semantic segmentation [源码笔记]keras源码分析之Container
Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×