数据集蒸馏 by Matching Training Trajectories
date
Nov 9, 2022
Last edited time
Mar 27, 2023 08:42 AM
status
Published
slug
数据集蒸馏_by_Matching_Training_Trajectories
tags
DL
DataCenric
summary
type
Post
origin
Field
Plat
IntroductionContributionsApproachExpert TrajectoriesLong-Range Parameter Matching ExperimentMemory ConstraintsExperimentsReferences
![notion image](https://www.notion.so/image/https%3A%2F%2Fs3-us-west-2.amazonaws.com%2Fsecure.notion-static.com%2F991c001e-18de-47e9-81d0-b859e74d9d1b%2Fd4b64e81-0725-4002-83df-e946b5884b07.jpeg?table=block&id=52f1e002-6c90-4935-947c-67d002503406&cache=v2)
Introduction
数据集蒸馏旨在构造一个合成数据集,其数据规模远小于原始数据集,但却能使在其上面训练的模型达到和原始数据集相似的精度。数据集蒸馏的核心思想如下所示:
![notion image](https://www.notion.so/image/https%3A%2F%2Fs3-us-west-2.amazonaws.com%2Fsecure.notion-static.com%2Ff5346c70-0cef-40b0-869b-43fc70c986ab%2FUntitled.png?table=block&id=49137e03-f46d-4e54-8983-4d596eb99648&cache=v2)
合成数据集可视化:
![notion image](https://www.notion.so/image/https%3A%2F%2Fs3-us-west-2.amazonaws.com%2Fsecure.notion-static.com%2Fd090dcd8-1a6d-4d50-a575-ca2cda7c14cd%2FUntitled.png?table=block&id=e9c534f9-5f20-4b7a-b4e6-8f74697c5efd&cache=v2)
现有的数据集蒸馏方法一些考虑使用端到端训练,但这通常需要大量计算和内存,并且会受到不精确的松弛或执行多次迭代导致训练不稳定的影响。 为了降低优化难度,另一些方法侧重于短程行为,聚焦于使在蒸馏数据上的单步训练匹配在真实数据上的。 但是,由于蒸馏数据会被多次迭代,导致在验证过程中错误可能会被累积。
Contributions
基于此,作者直接模仿在真实数据集上训练模型的长程训练动态。大量实验表明,所提方法优于现有的数据集蒸馏方法以及在标准数据集上进行核心子集选择的方法。
Approach
首先定义文章所用符号: 合成数据集: 真实训练集:。
![notion image](https://www.notion.so/image/https%3A%2F%2Fs3-us-west-2.amazonaws.com%2Fsecure.notion-static.com%2F41025139-fddc-4177-ac01-92ab1ff87fa7%2FUntitled.png?table=block&id=f6b63129-c0d7-4514-b556-7217847f5965&cache=v2)
上图阐述了本文数据集蒸馏的核心思想。
Expert Trajectories
本文核心在于引入了 expert trajectories 来指导合成数据集的蒸馏。本文通过训练大量的模型,并将每个模型每个 epoch 的模型参数保存下来,每个模型不同 epoch 组成一条 expert trajectory。作者称这些参数序列为 “expert trajectory”,因为它们代表了数据集蒸馏任务的理论上限。从相同的初始化模型参数开始,作者的目的是蒸馏数据集使其有与真实数据集上相似的轨迹,从而最终得到一个相似的模型。由于这些 expert trajectories 是预先计算好的,因此可以快速的进行蒸馏操作。
Long-Range Parameter Matching Experiment
本文所提数据集蒸馏方式从 expert trajectories 中学习学习参数,对于每一步, 先从 expert trajectories 中采样一条作为初始化学生参数 , 并且约束 使得 expert trajectory 的模型参数不会变太多。接着用合成数据集对学生参数进行 N 次梯度下降更新:
其中 𝒜 是可微分增强操作,α 是个可学习的学习率。然后计算更新后的学生参数和 expert trajectory 的模型参数的匹配损失:
其中 为初始化学生参数更新 M 次的参数。最后根据匹配损失 ℒ 更新 𝒟𝓈𝓎𝓃 和 α。详细算法如下表所示:
![notion image](https://www.notion.so/image/https%3A%2F%2Fs3-us-west-2.amazonaws.com%2Fsecure.notion-static.com%2Fc7bd80ca-a6b7-45da-a168-e05eb9c27ec4%2FUntitled.png?table=block&id=d4ffd546-0798-4aee-9ef3-61aa2f3f6016&cache=v2)
Memory Constraints
回顾
可以发现由于 𝒟𝓈𝓎𝓃 每个类图片数量过多,一次性输入会存在内存占用过高问题,为了解决这个问题,本文将 𝒟𝓈𝓎𝓃 划分为多个 batch。此时上式变为:
Experiments
本文的实验在 CIFAR-10,CIFAR-100(32 × 32),Tiny ImageNet(64 × 64)和 ImageNet(128× 128)上进行。
![notion image](https://www.notion.so/image/https%3A%2F%2Fs3-us-west-2.amazonaws.com%2Fsecure.notion-static.com%2Fbf17b3e3-8566-4c63-8075-b950bf31266b%2FUntitled.png?table=block&id=980c0839-dcc4-4245-a056-b42dd6c774ee&cache=v2)
上图展示了本文所提方法与核心子集选择方法和之前的数据集蒸馏的 baseline 比较。可以看出在数据集压缩率相同的条件下,本文所提方法性能明显优于其他方法。下图是在 CIFAR-10 上蒸馏得到的图像,上边是一类一张图像,下边是一类各十张:
![notion image](https://www.notion.so/image/https%3A%2F%2Fs3-us-west-2.amazonaws.com%2Fsecure.notion-static.com%2F8c17c286-3f7e-4b9d-9473-ce3d7bf8282c%2FUntitled.png?table=block&id=ea6f1577-d009-41bb-b4c3-c5d32c1d30eb&cache=v2)
接着作者又与一种最近的数据蒸馏方式 KIP[1] 比较,可以发现在相同模型宽度的情况下所提方法明显优于 KIP,甚至部分优于 KIP 使用更宽的模型。
![notion image](https://www.notion.so/image/https%3A%2F%2Fs3-us-west-2.amazonaws.com%2Fsecure.notion-static.com%2F3da9358f-7d50-4188-9adb-7ee0e9400dfd%2FUntitled.png?table=block&id=6cb29b90-e927-488d-8396-2fbd04f826bc&cache=v2)
由于所提方法是在一个特定模型上训练的,因此作者在不同模型结构上进行验证,可以发现也都优于 baseline,这说明了合成的数据集不是对训练模型 overfitting 的。
![notion image](https://www.notion.so/image/https%3A%2F%2Fs3-us-west-2.amazonaws.com%2Fsecure.notion-static.com%2F8032748d-6ad2-45d4-a66c-0e98fa51139e%2FUntitled.png?table=block&id=34f876ec-7f89-437b-9303-1c412d3244b5&cache=v2)
接下来作者探索了 long-range 匹配和 short-range 匹配的效果。从下图的左边可以看出 long-range 的性能明显优于 short-range(较小的 M 和 N 表示 short-range 行为)。右边则展示了 long-range 行为更好的估逼近了真实数据的训练(距离目标参数空间越近)。
![notion image](https://www.notion.so/image/https%3A%2F%2Fs3-us-west-2.amazonaws.com%2Fsecure.notion-static.com%2F8d273c48-1ed9-4dc5-8cff-cb2088d19cc5%2FUntitled.png?table=block&id=7cf87575-662a-4f20-8192-f47f430e6151&cache=v2)
在 64 × 64 的 Tiny ImageNet 上可视化效果(每类一张),可以看出尽管分辨率更高,所提方法仍然能够产生高保真图像,这十个类分别是: 第一行:African Elephant, Jellyfish, Kimono, Lamp-shade, Monarch. 第二行: Organ, Pizza, Pretzel, Teapot, Teddy.
![notion image](https://www.notion.so/image/https%3A%2F%2Fs3-us-west-2.amazonaws.com%2Fsecure.notion-static.com%2Faf984a6a-3d9b-41c5-8361-bfbc9f3e1f8c%2FUntitled.png?table=block&id=1afa7eac-1289-470e-b9f0-f9484054b968&cache=v2)
接着作者又在 128 × 128 分辨率的 ImageNet 子集上进行了实验,下表展示了合成数据集所达到的精度。
![notion image](https://www.notion.so/image/https%3A%2F%2Fs3-us-west-2.amazonaws.com%2Fsecure.notion-static.com%2F557f13f7-bf08-4a8b-a1e3-5849dfc66484%2FUntitled.png?table=block&id=d50581f3-4d84-4527-b4ae-1d0d18ea8fbd&cache=v2)
合成的效果如下图所示,对于所有类都有的任务类似的结构但独特的纹理(ImageSquawk)和颜色(ImageYellow)。
![notion image](https://www.notion.so/image/https%3A%2F%2Fs3-us-west-2.amazonaws.com%2Fsecure.notion-static.com%2Fcf53c9ce-d214-4537-8ec0-d04e030c4c1b%2FUntitled.png?table=block&id=b339cf27-0d2e-4e8c-a795-336d7c996f66&cache=v2)