[源码笔记]keras源码分析之Layer、Tensor和Node

Keras架构的主要逻辑实现在/keras/engine/topology.py中,主要有两个基类Node()Layer(),一个重要函数Input()。具体地,

  • Layer()是一个计算层的抽象,完成网络中对Tensor的计算过程;
  • Node()描述两个层之间连接关系的抽象,配合Layer()构建DAG;
  • Input()实例化一个特殊的Layer(InputLayer),将backend(TensorFlow或Theano)建立的Tensor对象转化为Keras Tensor对象。

Keras Tensor: 增强版Tensor

相比原始的TensorFlow或者Theano的张量对象,Keras Tensor加入了如下两个属性,以使Tensor中包含了自己的来源和规模信息:

  • _Keras_history: 保存了最近一个应用于这个Tensor的Layer
  • _keras_shape: 标准化的Keras shape接口

当使用Keras建立深度网络时,传入的数据首先要经过Input()函数。在Input()函数中,实例化一个InputLayer()对象,并将此Layer()对象作为第一个应用于传入张量的Layer,置于_keras_history属性中。此外,InputLayer()Input()还会对传入的数据进行规模检查和变换等,使之符合后续操作的要求。

代码上实现如下:

def Input():
  input_layer = InputLayer()
  outputs = InputLayer.inbound_nodes[0].output_tensor
  return outputs

class InputLayer():
  def __init__():
    input_tensor._keras_history = (self, 0, 0)
    Node(self, ...)

在下面我们将看到,加入的_keras_history属性在计算图的构建上所起的作用是关键的。仅通过输入和输出的Tensor,我们可以构建出整张计算图。但这样的代价是Tensor对象太重了,包含了Layer的信息。

Node对象:层与层之间链接的抽象

若考虑Layer对象抽象的是完成计算的神经元胞体,则Node对象是对神经元树突结构的抽象。其内聚的主要信息是:

class Node():
  def __init__(self, outbound_layer,
              inbound_layers, node_indices, tensor_indices,
              input_tensors, output_tensors, ...)

其中outbound_layer是施加计算(使input_tensors变为output_tensors)的层,inbound_layers对应了input_tensors来源的层,而node_indicestensor_indices则记录了NodeLayer之间的标定信息。

Node对象总在outbound_layer被执行时创建,并加入outbound_layerinbound_nodes属性中。在Node对象的表述下,A和B两个层产生连接关系时,Node对象被建立,并被加入A.outbound_nodesB.inbound_nodes

Layer对象:计算层的抽象

Layer对象是对网络中神经元计算层的抽象,实例化需要如下参数:

allowed_kwargs = {'input_shape',
                  'batch_input_shape',
                  'batch_size',
                  'dtype',
                  'name',
                  'trainable',
                  'weights',
                  'input_dtype',  # legacy
                  }

大部分与传入数据的类型和规模相关,trainable表征该层是否需要更新权重。此外,还有inbound_nodesoutbound_nodes属性来标定与Node对象的链接。

Layer对象最重要的方法是__call__(),主要完成如下三件事情:

  1. 验证传入数据的合法性,通过调用内部方法实现:self.assert_input_compatibility(inputs)

  2. 进行计算outputs = self.call(inputs, ...),被其子类具体实现,如Linear, Dropout

  3. 更新Tensor中的_keras_history属性,记录该次计算操作,通过内部方法_add_inbound_nodes()实现

方法_add_inbound_nodes()对Tensor的更新是构建Layer之间关系的关键操作,其主要代码如下:

for x in input_tensors:
    if hasattr(x, '_keras_history'):
        inbound_layer, node_index, tensor_index = x._keras_history
        inbound_layers.append(inbound_layer)
        node_indices.append(node_index)
        tensor_indices.append(tensor_index)

# Node对象的建立过程中将更新self的inbound_nodes属性
Node(self,
    inbound_layers=inbound_layers,
    node_indices=node_indices,
    tensor_indices=tensor_indices,
    ...)

for i in range(len(output_tensors)):
     output_tensors[i]._keras_history = (self, len(self.inbound_nodes) - 1, i)

上段代码取出input_tensor_keras_history属性,建立新的Node,并将当前Layer的信息更新到计算得到的output_tensor中。

实例:Node,TensorLayer间连接关系的表征

下面通过代码来说明三者之间的关系,来自于测试代码:

# 建立新的keras Tensor
a = Input(shape=(32,), name='input_a')
b = Input(shape=(32,), name='input_b')

a_layer, a_node_index, a_tensor_index = a._keras_history
assert len(a_layer.inbound_nodes) == 1
assert a_tensor_index is 0

# node和layer之间的关系
node = a_layer.inbound_nodes[a_node_index]
assert node.outbound_layer == a_layer

# 建立连接层,将Tensor传入
dense = Dense(16, name='dense_1')
a_2 = dense(a)
b_2 = dense(b)

assert len(dense.inbound_nodes) == 2
assert len(dense.outbound_nodes) == 0

# 与张量a关联的Node
assert dense.inbound_nodes[0].inbound_layers == [a_layer]
assert dense.inbound_nodes[0].outbound_layer == dense
assert dense.inbound_nodes[0].input_tensors == [a]

# 与张量b关联的Node
assert dense.inbound_nodes[1].inbound_layers == [b_layer]
assert dense.inbound_nodes[1].outbound_layer == dense
assert dense.inbound_nodes[1].input_tensors == [b]

总结

keras利用Node对象描述Layer之间的连接关系,并在Tensor中记录其来源信息。在下篇中,我们将看到keras如何利用这些抽象和增强属性构建DAG,并实现前向传播和反向训练的。

@ddlee

Creative Commons License
本文章遵从署名-相同方式共享4.0国际协议(CC BY-SA 4.0)
这意味着您可以署名转载本文章,并附上此协议。
我每周会分享一些有趣实用的英文文章,欢迎关注ddlee每周分享
这里可以找到我推荐的服务、应用程序、书籍和电影。

本文链接:https://blog.ddlee.cn/posts/4943e1b8/
分享文章:

相关文章

深度学习中的权重衰减 Tensorflow最佳实践:试验管理

评论

Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×