Neural Network(神经网络)实例--手写数字识别
本实例整理自斯坦福机器学习课程课后练习ex3本例是对一个手写体的数据集(0-9)进行分类,其最终实现的效果同上一个实例相同。只是两者在实现方式上有所不同。In the previous part of this exercise, you implemented multi-class logistic regression to recognize handwritten digits. H
本例是对一个手写体的数据集(0-9)进行分类,其最终实现的效果同上一个实例相同。只是两者在实现方式上有所不同。
In the previous part of this exercise, you implemented multi-class logistic regression to recognize handwritten digits. However, logistic regression cannot form more complex hypotheses as it is only a linear classifier. You could add more features (such as polynomial features) to logistic regression, but that can be very expensive to train.
1.Model Representation
本例中的手写体图片采用的是 20 p i x e l × 20 p i x e l 20pixel \times 20pixel 20pixel×20pixel的格式,所有一张图片就有 20 × 20 = 400 20\times20=400 20×20=400个feature,即输入层有400个活化单元(activation unit),加上额外的偏置单元(bias unit)一共就是401个。整个神经网络的模型图如下所示:
2.Predict
本例中,第一层与第二层的权重(weights)即参数 Θ 1 ( 25 × 401 ) , Θ 2 ( 10 × 26 ) \Theta_1(25\times401),\Theta_2(10\times26) Θ1(25×401),Θ2(10×26)已经给出,我们只需要根据向前传播(Forward Propagation)的方法进行计算即可。
训练数据X是一个
5000
×
400
5000 \times 400
5000×400的矩阵,为了更加容易明白计算过程,我们先取其中的任意一行
x
(
1
×
400
)
x_{(1\times400)}
x(1×400)为例。根据神经网络的数学定义(第三点)知:
\begin{align*}
X & = [ones(m, 1) X];\
a_1 & =x’;\
z_2 & =\Theta_1a_1;\
a_2 & =sigmoid(z_2);\ \
a_2 & = [1;a_2];加上a_2对应的偏置项,\color{red}{注意不是a(0,0)=1,也不是a(1,1)=0}\
z_3 & =\Theta_2a_2;\
a_3 & =sigmoid(z_3);
\end{align*}
此时相当于一张图片,经过3层的神经网络模型的计算,就可以得出这张图片对应输出层的十个输出的概率了,然后选择概率值最大的输出,我们就可以知道该图片所对应的手写体数字了。
下面是为了更加形象化计算过程:
同之前One-vs-all中一样, g ( z i 3 ) g(z_i^3) g(zi3)代表的是该手写体对应为数字几的概率(其中0映射为10)。
%循环5000次,即可预测出所有图片所对应的手写体
for i = 1:m;
a1 = X(i,:)'; % 401 by 1
z2 = Theta1*a1; % 25 by 401 * 401 by 1
a2 = sigmoid(z2);% 25 by 1
a2 = [1;a2]; % column vector , 26 by 1
z3 = Theta2 * a2; % 10 by 26 * 26 by 1
a3 = sigmoid(z3); % 10 by 1
[temp p(i)] = max(a3);
end
其中,[temp p(i)] = max(a3),temp 用来保存最大的概率值,p(i)保存为其对应的数字。
如
a
3
a^3
a3的可能值为:
a
3
=
[
0.21
,
0.11
,
0.04
,
0.51
,
0.34
,
0.66
,
0.71
,
0.88
,
0.17
,
0.32
]
T
,则
t
e
m
p
=
0.88
,
p
(
i
)
=
8
a^3=[0.21,0.11,0.04,0.51,0.34,0.66,0.71,0.88,0.17,0.32]^T\text{,则 }temp = 0.88,p(i)=8
a3=[0.21,0.11,0.04,0.51,0.34,0.66,0.71,0.88,0.17,0.32]T,则 temp=0.88,p(i)=8
max的用法戳此处
test = X(3454,:);
[temp pp] = max(predict(Theta1, Theta2, test))
y(3454,1) %与已知标记进行对比验证
%%以下是输出结果
temp =
6
pp =
1%说明该图片为数字6的概率接近100%了
ans =
6
%这是矢量化的形式,即同时一起计算,不用循环
a1 = X';
z2 = Theta1*a1;
a2 = sigmoid(z2);
a2 = [ones(1,m);a2];
z3 = Theta2 * a2;
a3 = sigmoid(z3);
[temp p] = max(a3);
p = p(:);
更多推荐
所有评论(0)