欢迎访问个人网络日志🌹🌹知行空间🌹🌹

论文: https://arxiv.org/abs/1805.06725

代码: https://github.com/samet-akcay/ganomaly

1.介绍

GANomaly是英国杜伦大学(Durham University,QS前100)Samet Akcay等发表在ACCV2018上的会议论文。

这篇文章是期望提出一种可以只在正常数据上训练,但却能识别异常图像的方法。其提出了由对抗训练框架组成的通用异常检测模型。本文作者在编码-解码-编码式的架构中使用了对抗自编码器,获取训练数据在图像和隐式向量中的分布。

这篇文章主要贡献:

  • 半监督异常检测,提出了基于编码-解码再编码架构的对抗式自编码器,获取训练数据在图像和隐式向量空间的分布,取得比其他基于GAN网络和自编码器异常检测方法更好的效果

  • 代码开源

2.GANomaly网络组成

2.1 GAN简介

生成对抗网络Generative Adversarial Networks (GAN)是蒙特利尔大学Université de MontréalIan Goodfellow 2014年发表的论文(作者Ian Goodfellow最近因不让居家办公了从Apple公司离职媒体正宣传的热闹),GAN属于无监督机器学习算法,原来GAN模型的目标是为了生成原始数据, 其结构包括在训练过程中对抗的两部分,生成器和对抗器,生成器负责生成与source data尽可能相似的数据,判别器负责尽可能的找出生成器生成的fake data。关于GAN的更多介绍可参考:(一)深度卷积对抗网络DCGAN

2.2 问题定义

异常检测问题的正式定义:

  • 数据集,训练数据集 D \mathcal{D} D为只能包含 M M M个正常类别的训练数据 D = { X 1 , X 2 , . . . , X M } \mathcal{D}=\{X_1,X_2,...,X_M\} D={X1,X2,...,XM},测试数据集 D ^ \hat{\mathcal{D}} D^为包含 N N N个正常和异常数据的集合, N N N通常比 M M M小很多。 D ^ = { ( X 1 ^ , y 1 ) , . . . , ( X 2 ^ , y 2 ) ) \hat{\mathcal{D}}=\{(\hat{X_1},y_1),...,(\hat{X_2}, y_2)) D^={(X1^,y1),...,(X2^,y2))
  • 目标模型学习数据 D \mathcal{D} D中的大多数公共性质,训练后在推理阶段检测测试数据集 D ^ \hat{\mathcal{D}} D^中的异常数据,模型 f f f学习正常数据的分布并最小化正常数据输入时模型的异常数据评分输出 A ( x ) \mathcal{A}(x) A(x),对于一个测试数据 x ^ \hat{x} x^,模型输出的异常评分 A ( x ^ ) \mathcal{A}(\hat{x}) A(x^)越高,表示输入时异常数据的可能性越大。设置阈值 ( ϕ ) (\phi) (ϕ),当 A ( x ^ ) > ϕ \mathcal{A}(\hat{x})\gt\phi A(x^)>ϕ时即认为是异常数据输入。

2.3网络结构

在这里插入图片描述

网络结构如上图,GANormaly主要包括3部分,以个自编码器,一个编码器和一个生成器。第一部分是一个蝴蝶结形的自编码网络作为模型的生成器,生成器学习输入数据表征,通过编码和解码网络重建输入数据。生成器中编码网络的输出 z z z也被成为生成器的瓶颈特征,并被认为其代表了包含输入数据最好表征的最小维度。第二部分是编码网络E将生成器的输出 x ^ \hat{x} x^压缩成低维的 z ^ \hat{z} z^, E E E和生成器 G G G中的编码网络 G E G_E GE有着相同的结构,但参数不同,因此 z z z z ^ 维 度 大 小 相 同 \hat{z}维度大小相同 z^。以往的方法都是通过瓶颈特征来最小化隐式向量,GANomaly通过增加一个编码网络,显式的学习最小化特征距离。第3部分是判别器网路D,其目标是判别输入 x x x和生成器的输出 x ^ \hat{x} x^real还是fake

3.模型训练

因训练时只使用了正常类别的数据,可以假设即使生成器的编码器可以将输入数据 X X X映射到隐向量 z z z,判别器却不能够判别异常。因此生成器的输出 X ^ \hat{X} X^将会是去掉异常特征后的图像数据,再通过编码器 E E E,将 X ^ \hat{X} X^映射到特征向量 z ^ \hat{z} z^上,此时生成器中的编码器 G E G_E GE输出的隐向量 z z z z ^ \hat{z} z^因一个包含图像异常,一个不包含,因此对于异常数据两者将有较大的差异,故可以识别出输入的异常数据。

GANomaly模型包含 3 3 3部分,其损失函数也包含3部分,每一部分损失分别对应网络的相应结构。

3.1对抗损失

Adversarial Loss
其使用的是特征对齐的损失函数,而非基于判别器输出, f f f是一个函数, 可以根据输入 x x x选择判别器的中间层来计算生成器对应层输出的 L 2 L_2 L2距离。

L a d v = E x ∼ p x ∣ ∣ f ( x ) − E x ∼ p x f ( G ( x ) ) ∣ ∣ 2 L_{adv} =\mathop{E}\limits_{x\sim px}||f(x) - \mathop{E}\limits_{x\sim px}f(G(x))||_2 Ladv=xpxEf(x)xpxEf(G(x))2

上式中 p x px px x x x的分布

3.2 上下文损失

Contextual Loss
为了学习输入数据中的上下文信息,增加衡量输出数据 x x x和生成器重建数据 x ^ \hat{x} x^误差的上下文损失 L c o n L_{con} Lcon

L c o n = E x ∼ p x ∣ ∣ x − G ( x ) ∣ ∣ 1 L_{con} = \mathop{E}\limits_{x\sim px}||x-G(x)||_1 Lcon=xpxExG(x)1

3.3 编码器损失

Encoder Loss,前面两个损失函数不仅可以让生成的数据尽量真实,还能保存数据的上下文信息。引入Encoder Loss是为了使 G E G_E GE输出的隐向量 z z z E E E生成的特征向量 z ^ \hat{z} z^的距离最小。

L e n c = E x ∼ p x ∣ ∣ G E ( x ) − E ( G ( X ) ) ∣ ∣ 2 L_{enc} = \mathop{E}\limits_{x\sim px}||G_E(x) - E(G(X))||_2 Lenc=xpxEGE(x)E(G(X))2

最终,GANomaly的损失函数为:

L = ω a d v L a d v + ω c o n L c o n + ω e n c L e n c L = \omega_{adv}L_{adv}+\omega_{con}L_{con} + \omega_{enc}L_{enc} L=ωadvLadv+ωconLcon+ωencLenc

4.模型测试

测试阶段模型使用 L e n c L_enc Lenc作为一个输入图像异常程度的评分。因为通过训练阶段最小化 L e n c L_{enc} Lenc,则对于异常图像 z z z z ^ \hat{z} z^差异会比较大。

A ( x ) = ∣ ∣ G E ( x ^ ) − E ( G ( x ^ ) ) ∣ ∣ 1 \mathcal{A(x)} = ||G_E(\hat{x})-E(G(\hat{x}))||_1 A(x)=GE(x^)E(G(x^))1

为了评估整体的异常性能,对测试数据集 D ^ \hat{\mathcal{D}} D^中的每个 x ^ \hat{x} x^计算其异常评分 A ( x ) \mathcal{A(x)} A(x),得测试数据集上每个数据对应的评分集合 S = { s i : A ( x i ^ ) , x i ^ ∈ D ^ } \mathcal{S}=\{s_i:\mathcal{A(\hat{x_i})}, \hat{x_i}\in\hat{\mathcal{D}}\} S={si:A(xi^),xi^D^},对 S \mathcal{S} S中的元素缩放到 [ 0 , 1 ] [0,1] [0,1]

s i ′ = s i − m i n ( S ) m a x ( S ) − m i n ( S ) s'_i = \frac{s_i-min(S)}{max(S)-min(S)} si=max(S)min(S)simin(S)

5.代码分析

GANomaly的代码基于pytorch实现,代码使用方法说明的很清晰。

5.1数据加载

GANomaly数据加载使用的torchvision提供的ImageFolder类,只需按

Custom Dataset
├── test
│   ├── 0.normal
│   │   └── normal_tst_img_0.png
│   │   └── normal_tst_img_1.png
│   │   ...
│   │   └── normal_tst_img_n.png
│   ├── 1.abnormal
│   │   └── abnormal_tst_img_0.png
│   │   └── abnormal_tst_img_1.png
│   │   ...
│   │   └── abnormal_tst_img_m.png
├── train
│   ├── 0.normal
│   │   └── normal_tst_img_0.png
│   │   └── normal_tst_img_1.png
│   │   ...
│   │   └── normal_tst_img_t.png

这样的格式将数据存放好即可。

"""
加载自定义数据的代码
"""
splits = ['train', 'test']
drop_last_batch = {'train': True, 'test': False}
shuffle = {'train': True, 'test': True}
transform = transforms.Compose([transforms.Resize(opt.isize),
                              transforms.CenterCrop(opt.isize),
                              transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])

dataset = {x: ImageFolder(os.path.join(opt.dataroot, x), transform) for x in splits}
dataloader = {x: torch.utils.data.DataLoader(dataset=dataset[x],
                                             batch_size=opt.batchsize,
                                             shuffle=shuffle[x],
                                             num_workers=int(opt.workers),
                                             drop_last=drop_last_batch[x],
                                             worker_init_fn=(None if opt.manualseed == -1
                                             else lambda x: np.random.seed(opt.manualseed)))
               for x in splits}
return dataloader

5.2 损失定义

""" Backpropagate through netG
"""
self.err_g_adv = self.l_adv(self.netd(self.input)[1], self.netd(self.fake)[1])
self.err_g_con = self.l_con(self.fake, self.input)
self.err_g_enc = self.l_enc(self.latent_o, self.latent_i)
self.err_g = self.err_g_adv * self.opt.w_adv + \
               self.err_g_con * self.opt.w_con + \
               self.err_g_enc * self.opt.w_enc
self.err_g.backward(retain_graph=True)

损失函数的使用如上述代码

6.测试效果

  • 数据量
NomalyAbnomaly
TES290747
TRAIN291
  • 测试结果

在这里插入图片描述

  • 可以看到准确率只有91%,效果在自定义的数据集上还不太好,不容易应用

注,上图分类评估指标可参考(二)sklearn.metrics.classification_report中的Micro/Macro/Weighted Average指标求得。

参考资料


欢迎访问个人网络日志🌹🌹知行空间🌹🌹


Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐