input 右边开始输入_keras进阶,我从Layer开始
❝keras对神经网络的抽象,都在Layer中,Model也是一种特殊的Layer。今天开始,我们就来看看Layer的源代码。❞版本keras 2.3.1Layer类所在目录:keras.engine.base_layer.py阅读策略一行一行读,力争搞清楚每一行的含义大略观察一共有1474行代码,如果每天读上300行,几天就可以全部读完了。第1-2行"""Contains the ba...
❝keras对神经网络的抽象,都在Layer中,Model也是一种特殊的Layer。今天开始,我们就来看看Layer的源代码。
❞
版本
keras 2.3.1
Layer类所在目录:keras.engine.base_layer.py
阅读策略
一行一行读,力争搞清楚每一行的含义
大略观察
一共有1474行代码,如果每天读上300行,几天就可以全部读完了。
第1-2行
"""Contains the base Layer class, from which all layers inherit. """
包含基本的Layer class。所有的layer都要继承它。
Note:1-2行,表示左右均包含,下同。
第3-20行
一些import,扫了一眼,主要有以下值得注意的地方:
K
这个代表后端,现在只支持tensorflow了。其所在目录是:
keras.backend
从该包的__init__.py文件中,可以看到K支持的操作。
大概有150多个支持的操作。以后写代码时,想要实现某个操作,又不知道用哪个函数时,应该先到这里来找找看,应该不会失望。
initializers
各种初始化器
其他的就是一些小方法了,比如
count_params
has_arg
to_list
这些就留到下面使用时再看,应该更容易理解
25-35
总体来说,这是一个disable_tracking注解。我理解,当不需要追踪某个函数的执行状况时,可以用这个注解。代码也比较好理解:
_DISABLE_TRACKING = threading.local()
_DISABLE_TRACKING.value = False
def disable_tracking(func):
def wrapped_fn(*args, **kwargs):
global _DISABLE_TRACKING
prev_value = _DISABLE_TRACKING.value
_DISABLE_TRACKING.value = True
out = func(*args, **kwargs)
_DISABLE_TRACKING.value = prev_value
return out
return wrapped_fn
36-107
主要是Layer类的注释。
input/output
为了理解这两个属性,我跳转到了input所在代码行:
原来主要是返回该层的输入tensor。但是有一个前提,就是这个Layer只有一个输入layer,如果有多个输入layer,也就是这个layer被重用了,就会报如下的错:
当然,如果压根就没有输入Layer,也会报错:
贴出这两张图,主要是让自己能记住这些报错信息,以后遇到时,能快速判断出哪里出了问题。
尽快把自己训练成强大的人肉debug机器!
一切正常的话,就会调用这个方法:
self._get_node_attribute_at_index(0, 'input_tensors','input')
为了理解这个方法,我又跳转到了这个方法所在的代码:
看说明,这个方法很强了,许多方法都是用它实现的。
简单讲,它是用来获得一个node的各个属性的。通过对代码实现的考察,发现这里的node,指的是layer的输入layers。对于所有该layer的输入layer,都会存到 _inbound_nodes这个数组中。
这里也指明了,是通过self._add_inbound_node()方法添加的。
这个方法的名字我觉得需要格外记一下,因为有过太多次,碰到layer属性相关的报错了。
这个方法有三个参数:
第一个参数:node_index,指明从哪个node中提取属性。这个node_index就是数组_inbound_nodes的索引。
第二个参数:属性的准确名称,这是给程序用的
第三个参数:给人读的属性名称,主要用来展示错误信息
这里首先处理了两种错误,分别是该layer没有_inbound_nodes数组,和数组越界。
如果没有错误,就会取属性了。
getattr方法,很显然就是取出属性。
unpack_singleton方法,我查了下源码:
def unpack_singleton(x):
"""Gets the first element if the iterable has only one value.
Otherwise return the iterable.
# Argument
x: A list or tuple.
# Returns
The same iterable or the first element.
"""
if len(x) == 1:
return x[0]
return x
一目了然。
input还有一点,就是如果layer被重用了,那它就有多个input node。这是获取input会报错。正确的做法就是 使用
layer.get_input_at(node_index)
output和input类似
input_mask和output_mask与output和input类似。input_shape也是类似。
input_spec
看了说明,这个参数是对该layer的输入tensor进行限制的。
主要是维度和数据类型。
具体的是通过InputSpec来说明的。
于是,我又转到了InputSpec的源码。
class InputSpec(object):
"""Specifies the ndim, dtype and shape of every input to a layer.
Every layer should expose (if appropriate) an `input_spec` attribute:
a list of instances of InputSpec (one per input tensor).
A None entry in a shape is compatible with any dimension,
a None shape is compatible with any shape.
# Arguments
dtype: Expected datatype of the input.
shape: Shape tuple, expected shape of the input
(may include None for unchecked axes).
ndim: Integer, expected rank of the input.
max_ndim: Integer, maximum rank of the input.
min_ndim: Integer, minimum rank of the input.
axes: Dictionary mapping integer axes to
a specific dimension value.
"""
def __init__(self, dtype=None,
shape=None,
ndim=None,
max_ndim=None,
min_ndim=None,
axes=None):
self.dtype = dtype
self.shape = shape
if shape is not None:
self.ndim = len(shape)
else:
self.ndim = ndim
self.max_ndim = max_ndim
self.min_ndim = min_ndim
self.axes = axes or {}
第一句话已经说的很清楚了。需要注意的是,注释中说每个Layer都应该暴露一个input_spec属性。
说明这个属性是给需要调用该Layer的人看的,告诉他,该Layer需要什么样的输入。
这种设计方式值得学习,比如我们自己的接口设计时,是不是也可以设计这么一个属性,帮助使用者了解,我们需要 什么样的输入。
还有一点,None在这里表示的意思是unchecked,就是需要在运行时确定,用户可以自己确定。
其他需要关注的属性
- name,没啥说的,就是层的名字,建议起个自己好记的名字
- stateful 是否有不是weight的状态,rnn就是典型的应用。TODO:读一下lstm的源码。
- supports_masking 是否支持mask,keras中支持Layer不多,不是逼不得已,尽量不用
- trainable weight是否需要做梯度更新
- uses_learning_phase 是否使用了K.in_training_phase()和K.in_test_phase()具体看源码咋用的
- trainable_weights,non_trainable_weights,weights,就是名字的含义
- dtype 数据类型
至此,将Layer的属性已经基本搞清楚了。其他行的代码是对方法的简单说明,等进一步读源码时再具体理解。此时,心中有个疑问,Node和Layer究竟是什么关系?下篇文章,就先解决这个问题。再继续。
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)