HiSD阅读
date
May 9, 2022
Last edited time
May 10, 2022 02:45 PM
status
Published
slug
HiSD阅读
tags
DL
GAN
summary
type
Post
Field
Plat
IntroduceMethod主要流程损失函数Feature-based Local TranslatorTag-irrelevant Conditional DiscriminatorExperiments代码分析L_g lossL_d loss
![notion image](https://www.notion.so/image/https%3A%2F%2Fs3-us-west-2.amazonaws.com%2Fsecure.notion-static.com%2F52f274d8-5e17-4c1b-83cf-c3f5326bbf4f%2FUntitled.png?table=block&id=7fff682f-5de0-4116-bd86-623dfecaaf9d&cache=v2)
Introduce
由于标签的独立性和排他性尚未得到研究,导致 GAN 生成不可控的结果。例如,对于人脸属性篡改任务,我们想要给人脸加上刘海,可是却改变了发色或是背景,再例如,我们想要给人脸加上眼睛,结果竟然性别和年龄也改变了。
![notion image](https://www.notion.so/image/https%3A%2F%2Fs3-us-west-2.amazonaws.com%2Fsecure.notion-static.com%2F0656ac18-5b1f-4729-9096-f2c7588dab4c%2FUntitled.png?table=block&id=1072a9a4-3cec-4a1c-a845-252e030d4b60&cache=v2)
本文提出了 HiSD 来解决这个问题。具体来说,我们将标签组织成一个分层的树状结构,在这个结构中,独立的标签、独占的属性和分解的样式从上到下被分配。
- 目前图片生成的方式
- 直接将目标标签(域) + 样式代码 注入生成器来学习风格。
- 将与目标标签(域)映射的样式代码 注入生成器来学习混合样式。
这里提到的两个样式代码不是一个意思,第一个样式代码 其实是生成的 latent code (或者随机噪声 ),而第二个样式代码是真正提供给生成器的 。
- 限制
由于目标标签(域)中的所有图片,共享一个标签,因此在图像转换的时候经常会涉及到不必要的操作,例如更改面部身份和影响背景。此外,这两种方式无法独立的学习刘海、眼镜和头发颜色等正交属性的变换。
- 解决方式 - 层次风格解耦
本文提出了一个新的框架,称为层次风格解耦。由于大多数标签之间的独立性和排他性。例如,在 CelebA 中,原始二元标签“With Bangs”和“With Glasses”是独立的,而“Blond Hair”和“Black Hair”是专有的。
![notion image](https://www.notion.so/image/https%3A%2F%2Fs3-us-west-2.amazonaws.com%2Fsecure.notion-static.com%2Fbb75027a-c4e3-4b03-88b4-c65c25a5469e%2FUntitled.png?table=block&id=e55cc1de-c86c-4e64-9a27-ccb7ce3292df&cache=v2)
因此,multi-label (多个域)问题可以分为两个子任务:multi-attribute任务,multi-tag任务,如上图。HiSD网络实际上训练了很多个网络模型,其中有些地方权重共享,来控制分别多个Tag的Attribute变化。
![notion image](https://www.notion.so/image/https%3A%2F%2Fs3-us-west-2.amazonaws.com%2Fsecure.notion-static.com%2F6e7fbbd8-a15b-47f2-96f9-126fdd9d24c2%2FUntitled.png?table=block&id=5eb676f5-3969-4c84-9504-d377ce681bbe&cache=v2)
Method
![notion image](https://www.notion.so/image/https%3A%2F%2Fs3-us-west-2.amazonaws.com%2Fsecure.notion-static.com%2F73c80222-a3c2-46e9-8a7d-e6b4bbae0ea3%2FUntitled.png?table=block&id=dc624fac-24b5-46e6-b3ae-4484e5cf41ee&cache=v2)
指具有 Tag 为 , attribute 为 的图片,对于 tag 为 ,attribute 为 的标签相关风格为 。注意,一张图片可能具有多个标签相关风格。
主要流程
- 对于给定图片 ,使用编码器 提取出对应特征 。
- 给定潜码 ,使用映射模块 得到指定 Tag 与 Attribute 的标签相关风格代码 。
- 标签相关风格代码同样也能通过风格提取器 获得,对于指定的Tag ,,提取出输入图片 在Tag 的标签相关风格代码 。
- 使用 translator 模块 ,对提取出的特征 进行修改,即 。通过生成器 来获得最终的图像 。
这里实际上是训练了多个网络 ,然后通过Tag来选择需要使用的网络 。
- 使用判别器 来确定给定标签和属性的图像是否是真实。
![notion image](https://www.notion.so/image/https%3A%2F%2Fs3-us-west-2.amazonaws.com%2Fsecure.notion-static.com%2F2d5b7724-0ab7-4841-8663-e6349c29d1d3%2FUntitled.png?table=block&id=dfb1725e-655c-4c2e-92ee-b21c3e2bc5ac&cache=v2)
![notion image](https://www.notion.so/image/https%3A%2F%2Fs3-us-west-2.amazonaws.com%2Fsecure.notion-static.com%2Fe01060c9-459c-42b6-895b-f6dda955c973%2FUntitled.png?table=block&id=21473724-ae52-4844-9dd3-86804c121724&cache=v2)
损失函数
- 重建损失
- 首先是最简单的重建:
- 使用本图像风格的重建
- 循环一致的重建
因此,重建损失为以上的加权和:
- GAN损失
- 风格损失
变换图像的提取出的样式代码应该等于生成的样式代码,因此引入风格损失。
总体损失如下:
Feature-based Local Translator
为了抑制区域上的过度篡改,引入了已经被广泛使用的无监督掩膜思想,唯一不同的是我们的掩膜是作用在特征图上的,而不是图像本身上,也因此加入了Channel-wise的注意力。
translator 模块 的公式如下,变换后的特征与原来特征的维度相同。
为提取的图像特征。 为 sigmoid 函数, 为掩码,则 为注意力掩码,
Tag-irrelevant Conditional Discriminator
由于数据集中本身对于各个属性就是不解耦的(戴眼镜的有83%的男性,而不戴的只有36%),性别和年龄在极度不平衡的数据集的对抗过程中,仍然被不可避免的篡改了。因此,这里使用了Tag无关条件鉴别器,来缓解很多对抗过程中数据集本身不平衡的问题。
这里将GAN损失修改为如下形式:
其中 为图片 与 Tag 无关的属性,如年龄、性别。
因此,鉴别器不仅仅能够看到图片,还能看到类别不平衡的无关属性,从而来让翻译前后保持这些不平衡标签不变。意味着鉴别器可以进一步区分什么样的图像才符合我们的目标(例如不是男的就更像戴眼镜,而是眼镜这个特征本身让图像更像戴眼镜),也就可以促使生成器的解耦了。
Experiments
![notion image](https://www.notion.so/image/https%3A%2F%2Fs3-us-west-2.amazonaws.com%2Fsecure.notion-static.com%2Fe2ef18cb-efbe-4511-a2f8-944a920e1b10%2FUntitled.png?table=block&id=637f577d-f058-413a-93d8-86108d1e2f68&cache=v2)
![notion image](https://www.notion.so/image/https%3A%2F%2Fs3-us-west-2.amazonaws.com%2Fsecure.notion-static.com%2F2349b94e-26ba-4aee-be8e-6f415a32f97e%2FUntitled.png?table=block&id=f376a575-91e5-4b31-b347-6c933a4827c2&cache=v2)
![notion image](https://www.notion.so/image/https%3A%2F%2Fs3-us-west-2.amazonaws.com%2Fsecure.notion-static.com%2Fcbdd917e-6a65-4c26-88b0-468d13f95a1e%2FUntitled.png?table=block&id=91e898a1-dabf-4921-8819-6393ced20cb9&cache=v2)
代码分析
L_g loss
loss 与论文中使用的不同,应用了论文 的方法,即把 引入判别器,帮助判别器学习到图片真假与风格无关。
上面的
calc_gen_loss_real
部分就是用来训练 ,来最小化判别器 的输出,让判别器接收到真样本 的时候,输出为假样本。结合 部分:其中也有
calc_dis_loss_real
训练 在 的干扰下进行正确的分类,让判别器 的输出与风格无关。这样能够缓解样本不平衡带来的问题。比如样本中带胡子的很少,那么只输入照片的判别器可能直接将带胡子的判断为生成的假样本。而使用了 ALI 则让判别器学习到是否为真样本,与带不带胡子无关。
L_d loss
此外,判别器的loss改为 Hinge loss,让训练更稳定。
Hinge loss 在我的GAN入门内有讲过,当然,ALI也有。