keras对神经网络的抽象,都在Layer中,Model也是一种特殊的Layer。今天开始,我们就来看看Layer的源代码。

版本

keras 2.3.1

Layer类所在目录:keras.engine.base_layer.py

阅读策略

一行一行读,力争搞清楚每一行的含义

大略观察

一共有1474行代码,如果每天读上300行,几天就可以全部读完了。

第1-2行

"""Contains the base Layer class, from which all layers inherit. """

包含基本的Layer class。所有的layer都要继承它。

Note:1-2行,表示左右均包含,下同。

第3-20行

一些import,扫了一眼,主要有以下值得注意的地方:

K

这个代表后端,现在只支持tensorflow了。其所在目录是:

keras.backend

从该包的__init__.py文件中,可以看到K支持的操作。

e5dcdb11274dbc789a8258117f09e58e.png

大概有150多个支持的操作。以后写代码时,想要实现某个操作,又不知道用哪个函数时,应该先到这里来找找看,应该不会失望。

initializers

各种初始化器

其他的就是一些小方法了,比如

count_params

has_arg

to_list

这些就留到下面使用时再看,应该更容易理解

25-35

总体来说,这是一个disable_tracking注解。我理解,当不需要追踪某个函数的执行状况时,可以用这个注解。代码也比较好理解:

_DISABLE_TRACKING = threading.local()
_DISABLE_TRACKING.value = False


def disable_tracking(func):
    def wrapped_fn(*args, **kwargs):
        global _DISABLE_TRACKING
        prev_value = _DISABLE_TRACKING.value
        _DISABLE_TRACKING.value = True
        out = func(*args, **kwargs)
        _DISABLE_TRACKING.value = prev_value
        return out
    return wrapped_fn

36-107

主要是Layer类的注释。

input/output

为了理解这两个属性,我跳转到了input所在代码行:

abbac48c1f3e1e0a3f98ed6e56ddfaeb.png

原来主要是返回该层的输入tensor。但是有一个前提,就是这个Layer只有一个输入layer,如果有多个输入layer,也就是这个layer被重用了,就会报如下的错:

6b72319b117c708927499bb5f654f6af.png

当然,如果压根就没有输入Layer,也会报错:

57c14e261b4f8723f2f9f87322b4c214.png

贴出这两张图,主要是让自己能记住这些报错信息,以后遇到时,能快速判断出哪里出了问题。

尽快把自己训练成强大的人肉debug机器!

一切正常的话,就会调用这个方法:

self._get_node_attribute_at_index(0, 'input_tensors','input')

为了理解这个方法,我又跳转到了这个方法所在的代码:

13b2919b7e2f832d003ac526471df8aa.png

看说明,这个方法很强了,许多方法都是用它实现的。

简单讲,它是用来获得一个node的各个属性的。通过对代码实现的考察,发现这里的node,指的是layer的输入layers。对于所有该layer的输入layer,都会存到 _inbound_nodes这个数组中。

04d82ff20d4ae93bc697e06a3f8a892e.png

这里也指明了,是通过self._add_inbound_node()方法添加的。

这个方法的名字我觉得需要格外记一下,因为有过太多次,碰到layer属性相关的报错了。

这个方法有三个参数:

第一个参数:node_index,指明从哪个node中提取属性。这个node_index就是数组_inbound_nodes的索引。

第二个参数:属性的准确名称,这是给程序用的

第三个参数:给人读的属性名称,主要用来展示错误信息

572ae0e2be541ce900e06e1b03c8c6da.png

这里首先处理了两种错误,分别是该layer没有_inbound_nodes数组,和数组越界。

如果没有错误,就会取属性了。

59b7e2b1b0022be8ef48d77752a4fe87.png

getattr方法,很显然就是取出属性。

unpack_singleton方法,我查了下源码:

def unpack_singleton(x):
    """Gets the first element if the iterable has only one value.
    Otherwise return the iterable.
    # Argument
        x: A list or tuple.
    # Returns
        The same iterable or the first element.
    """
    if len(x) == 1:
        return x[0]
    return x

一目了然。

input还有一点,就是如果layer被重用了,那它就有多个input node。这是获取input会报错。正确的做法就是 使用

layer.get_input_at(node_index)

output和input类似

input_mask和output_mask与output和input类似。input_shape也是类似。

input_spec

看了说明,这个参数是对该layer的输入tensor进行限制的。

主要是维度和数据类型。

具体的是通过InputSpec来说明的。

于是,我又转到了InputSpec的源码。

class InputSpec(object):
    """Specifies the ndim, dtype and shape of every input to a layer.
    Every layer should expose (if appropriate) an `input_spec` attribute:
    a list of instances of InputSpec (one per input tensor).
    A None entry in a shape is compatible with any dimension,
    a None shape is compatible with any shape.
    # Arguments
        dtype: Expected datatype of the input.
        shape: Shape tuple, expected shape of the input
            (may include None for unchecked axes).
        ndim: Integer, expected rank of the input.
        max_ndim: Integer, maximum rank of the input.
        min_ndim: Integer, minimum rank of the input.
        axes: Dictionary mapping integer axes to
            a specific dimension value.
    """

    def __init__(self, dtype=None,
                 shape=None,
                 ndim=None,
                 max_ndim=None,
                 min_ndim=None,
                 axes=None):
        self.dtype = dtype
        self.shape = shape
        if shape is not None:
            self.ndim = len(shape)
        else:
            self.ndim = ndim
        self.max_ndim = max_ndim
        self.min_ndim = min_ndim
        self.axes = axes or {}

第一句话已经说的很清楚了。需要注意的是,注释中说每个Layer都应该暴露一个input_spec属性。

说明这个属性是给需要调用该Layer的人看的,告诉他,该Layer需要什么样的输入。

这种设计方式值得学习,比如我们自己的接口设计时,是不是也可以设计这么一个属性,帮助使用者了解,我们需要 什么样的输入。

还有一点,None在这里表示的意思是unchecked,就是需要在运行时确定,用户可以自己确定。

其他需要关注的属性

  • name,没啥说的,就是层的名字,建议起个自己好记的名字
  • stateful  是否有不是weight的状态,rnn就是典型的应用。TODO:读一下lstm的源码。
  • supports_masking 是否支持mask,keras中支持Layer不多,不是逼不得已,尽量不用
  • trainable  weight是否需要做梯度更新
  • uses_learning_phase  是否使用了K.in_training_phase()和K.in_test_phase()具体看源码咋用的
  • trainable_weights,non_trainable_weights,weights,就是名字的含义
  • dtype 数据类型

至此,将Layer的属性已经基本搞清楚了。其他行的代码是对方法的简单说明,等进一步读源码时再具体理解。此时,心中有个疑问,Node和Layer究竟是什么关系?下篇文章,就先解决这个问题。再继续。

Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐