MAML
1.论文地址和代码
https://arxiv.org/abs/1703.03400
https://github.com/dragen1860/MAML-Pytorch
2.基本概念
Meta Learning是模仿人类的学习过程,先学习一个先验知识,并且利用这些知识,在新的问题上学的更快更好。通过学习到的模型,来学习新的模型。用任务task生成模型model。
2.1 举例
Meta learning是找一个
F
F
F,通过一堆任务训练,得到参数
θ
\theta
θ,然后对新的分类任务时,让
θ
\theta
θ自己调整适合新任务的、分类模型的、最优
θ
∗
\theta^*
θ∗,这个调整只需要几步。
所以,怎么找这个参数 θ \theta θ就成了我们的目标。
2.1.1 元学习的目标
元学习的目标是在接触到新任务或者迁移到新环境中时,可以根据之前的经验和少量的样本快速学习去应对。
2.1.2 元学习有三种常见的实现方法
- 1)学习有效的距离度量方式(基于度量的方法);
- 2)使用带有显式或隐式记忆储存的(循环)神经网络(基于模型的方法);
- 3)训练以快速学习为目标的模型(基于优化的方法)
三种方式分别举例:
- 在没有猫的训练集上训练出来一个图片分类器,这个分类器需要在看过少数几张猫的照片后分辨出测试集的照片中有没有猫。
- 训练一个玩游戏的AI,这个AI需要快速学会如何玩一个从来没玩过的游戏。
- 一个仅在平地上训练过的机器人,需要在山坡上完成给定的任务。
3.MAML
3.1 MAML就是用于找参数 θ \theta θ的方法
它是基于初始化的方法。
3.2 MAML的基本思路
1.假设有一个适用于所有task的初始化参数
θ
\theta
θ。
2.
θ
\theta
θ经过每一个
t
a
s
k
i
,
i
=
1
,
2
,
⋯
,
n
task_i,i=1,2,\cdots,n
taski,i=1,2,⋯,n都会得到对应的专属参数
θ
1
,
θ
2
,
⋯
,
θ
n
\theta_1,\theta_2,\cdots,\theta_n
θ1,θ2,⋯,θn。
3.如果能找到一个很好的初始化参数
θ
\theta
θ,对于新任务,它只需要经过很少的步数就能获得新任务的
θ
i
\theta_i
θi,并且在新任务上表现的很好。
评价初始 θ \theta θ的指标
主要看它被更新为 θ i \theta_i θi之后,在对应的任务 T i T_i Ti表现情况,对表现的评估用loss来评估。
4.每次输入一个batch的tasks,然后利用 θ \theta θ在这个batch中所有的tasks上更新,更新完毕后把这些Task上的测试损失值求和,得到一个大Loss(这个batch中所有任务的损失之和),然后用这个Loss来评价 θ \theta θ的好坏,然后通过优化方法找到最好的初始化参数 θ \theta θ,这样就找到一个很好的初始化参数 θ \theta θ,对于新任务,它只需要经过很少的步数就能获得新任务的 θ i \theta_i θi,并且在新任务上表现的很好。
3.3 基本术语
task任务:把已有的数据集切分转化为多个任务。
举例:以Mini-ImageNet数据集为例,如下图:
其中,每张图片剪裁成84*84的大小.
将训练集变成训练任务,这些训练任务就变成了N-way K-shot任务(或者是N-way K-shot Q-Query)。
N-way :从数据集上随机抽取N个类别
K-shot:每个类别下面抽取K个样本用于任务的训练。
N-way K-shot组成的数据集被称为Support set。
Q-Query:对N个类别每一个类别中再次抽取Q个样本,作为Query集,用于测试。
实际每个类别要抽取K+Q个样本,这样就得到了N-way K-shot任务。
下面这个例子,就是5-way 5-shot 1-Query任务:
3.4 从数据集中抽样的方法
举例,如对Mini-ImageNet做5-way 1-shot 15-Query
每个文件夹对应一个类别,每次抽取5个文件夹(相当于抽取了5个类)
Random.sample(folders, self.n_way)
然后每个文件夹抽取16张图片
labels_and_images = get_images(sampled_folders,
range(self.n_way),
nb_samples=self.n_img,
shuffle=False)
重复以上两步,200000次,就会得到200000个Tasks,作为训练任务集。
验证集和测试集也是同样的操作。
接下来:
假设batchsize为4,每一轮执行4个task,在meta-training时,按barch输入成tensor,4个任务,每个任务80张图片,图片大小为84x84x3,label为类别,5个类别,即得:
support_x : [4, 1x5, 84x84x3]
query_x : [4, 15x5, 84x84x3]
support_y : [4, 1x5, 5]
query_y : [4, 15x5, 5]
每次都拿1个batch的task来喂到模型里去训练,然后优化参数。
3.4 MAML的具体算法流程
1.首先,模型需要给定的task的分布distribution,就是之前的随机抽样的若干任务。
这里有一个问题要组合形成那么多的task,岂不是不同task之间会存在样本的重复?或者某些task的query set会成为其他task的support set?
MAML的目的,就在于fast adaptation,即通过对大量的task的学习,获得足够强的泛化能力,从而面对新的、从未见过的task时,通过fine-tuning就能快速适应,task之间,只要存在一定的差异就可以了。
每个task相当于普通深度学习模型训练的一条训练数据。
2.然后指定 α , β \alpha,\beta α,β两个超参学习率,分别用于Support-Query这是内循环和Meta这是外循环这两部分的梯度迭代的步长。
3.抽取出1个batch的任务,假设1个batch有5个任务,遍历5个任务,每个任务 T i T_i Ti都是由K个样本和标签label组成的。根据开源代码,作者直接用任务的Support set拿出来计算loss,也就是训练误差,即 L T i ( f θ ) \mathcal{L}_{T_i}(f_\theta) LTi(fθ):参数为 θ \theta θ时, T i T_i Ti中Support-set的loss. 即该任务上的训练误差。
4.然后梯度下降,第7步更新参数 θ \theta θ得 θ i ′ \theta_i' θi′,这里更新的次数不是1次,而是几次,代码中是5次,不能更新的多,要不然太慢了,如果迭代多了,容易过拟合(有实验证明,即使过拟合,也不影响结果),这里的参数更新,不影响 θ \theta θ的值,元模型的 θ \theta θ,在这里,会先复制一份 θ c o p y \theta_{copy} θcopy,然后在复制的参数上做反向传播更新参数,得到第一次参数更新的结果 θ i ′ \theta_i' θi′。
5.第8步元更新,这里的损失函数是用Query set的loss, L T i ( f θ i ′ ) \mathcal{L}_{T_i}(f_{\theta_i'}) LTi(fθi′):参数更新成 θ i ′ \theta_i' θi′后, T i T_i Ti中Query set的loss即该任务上的测试误差,通过测试损失对初始化参数 θ \theta θ做梯度下降优化,第二次参数更新时,更新的是元模型的 θ \theta θ,而不是更新的 θ i ′ \theta_i' θi′。
6.进入下一个batch,重复以上。
tips:
第四点中的copy的参数是每一个task都会copy一份,如10个task会copy10个临时参数
θ
c
o
p
y
\theta_{copy}
θcopy,在10个临时模型上,在各自的task上独立做梯度下降参数更新,然后整合起来,也就是更新元模型的
θ
\theta
θ。
为什么这样做?
因为每个task都会更新一次参数,如果用元模型的参数,会导致使用了上一个task更新过的参数
总结:第一次梯度更新,不作用于元模型,第二次梯度更新用于元模型。
其中:
∇
θ
L
(
θ
)
=
∇
θ
∑
T
i
∼
p
(
T
)
l
T
i
(
f
θ
i
′
)
=
∑
T
i
∼
p
(
T
)
∇
θ
l
T
i
(
f
θ
i
′
)
\nabla_\theta L(\theta)=\nabla_\theta \sum_{T_i\sim p(T)}l_{T_i}(f_{\theta_i'})=\sum_{T_i\sim p(T)} \nabla_\theta l_{T_i}(f_{\theta_i'})
∇θL(θ)=∇θ∑Ti∼p(T)lTi(fθi′)=∑Ti∼p(T)∇θlTi(fθi′)
整个式子是对
θ
\theta
θ求的梯度,而参数却是
θ
i
′
\theta_i'
θi′,下面就是推导过程:
把meta的
θ
\theta
θ记为
ϕ
\phi
ϕ,得:
∇
ϕ
L
(
ϕ
)
=
∑
T
i
∼
p
(
T
)
∇
ϕ
l
T
i
(
f
θ
i
′
)
\nabla_\phi L(\phi)=\sum_{T_i\sim p(T)} \nabla_\phi l_{T_i}(f_{\theta_i'})
∇ϕL(ϕ)=∑Ti∼p(T)∇ϕlTi(fθi′)
对于某一个 T i T_i Ti的 θ i ′ \theta_i' θi′,这里先去掉下标
可得: ∇ ϕ l ( f θ ′ ) = [ ∂ l ( f θ ′ ) ∂ ϕ ( 1 ) ∂ l ( f θ ′ ) ∂ ϕ ( 2 ) ⋯ ∂ l ( f θ ′ ) ∂ ϕ ( k ) ⋯ ∂ l ( f θ ′ ) ∂ ϕ ( 1 ) ] \nabla_\phi l(f_{\theta'})=\begin{bmatrix}\frac{\partial l(f_{\theta'})}{\partial \phi_{(1)}}\\ \frac{\partial l(f_{\theta'})}{\partial \phi_{(2)}}\\\cdots \\\frac{\partial l(f_{\theta'})}{\partial \phi_{(k)}}\\\cdots \\ \frac{\partial l(f_{\theta'})}{\partial \phi_{(1)}}\end{bmatrix} ∇ϕl(fθ′)= ∂ϕ(1)∂l(fθ′)∂ϕ(2)∂l(fθ′)⋯∂ϕ(k)∂l(fθ′)⋯∂ϕ(1)∂l(fθ′)
,其中 ϕ ( k ) \phi_{(k)} ϕ(k)是指 ϕ \phi ϕ中第k个参数。
回归算法流程,
θ
i
′
\theta_i'
θi′是由
ϕ
\phi
ϕ通过任务,更新,优化得到的。
所以对于第j维的参数,也是对应更新的。
作者认为在实际运算中,求这个二阶导非常麻烦,于是把二阶导去掉了,所以只剩下0和1了。
总结:所以,先使用1个batch中task的Support set来优化参数
θ
\theta
θ更新得到
θ
′
\theta'
θ′,然后基于这一次优化的
θ
‘
\theta‘
θ‘,用Query set计算task的loss,再对着干loss求梯度(对的
θ
′
\theta'
θ′),使用该梯度更新元网络的参数
θ
\theta
θ得到新的
θ
\theta
θ;然后再使用1个batch的task的Support set来优化新的参数
θ
\theta
θ,得到。。。
然后通过以上得到的meta网络的参数
θ
\theta
θ,在测试任务中使用测试任务的Support set对meta网络的
θ
\theta
θ进行fine tuning,最终再使用测试任务中的Query Set评估效果。
2.7 MAML二分类过程
MAML针对n张84选4的RGB图片进行二分类过程:
前四层都是卷积层,(Conv,bias)+relu+maxpool的搭配
返回第五层和第四层输出结果的內积+偏置,即分类预测标签(1*800)*(800*2) =(1*2)
2.8 MAML的效果
MAML在MiniImagenet上的表现: