谷歌的论文,基于seq2seq+VAE编码并生成手绘序列
https://arxiv.org/pdf/1704.03477.pdf
本文主要是论文的概述翻译,记录
1.Introduction
- 生成模型的发展:GANs(Generative Adversarial Networks)、VI(Variational Inference)、AR(Autoregressive)
- 当前多用于处理图像像素数据(pixel images),而人的理解是矢量序列的,本文即提出用于矢量图像的生成模型(手绘草图)
- 文本贡献:适用于线序列的条件/非条件生成模型框架;提出的sketch-RNN可以生成矢量格式的有意义的图像;开发了一种可以使训练更鲁棒的方法;将矢量图映射到了潜在空间;最后讨论了本文可能的应用领域
2.Related Work
- 对于模仿绘画,有通过既定的文件执行绘画的机器人、和基于强化学习的方法,并非生成
- 神经网络用于生成的多是栅格图像;早期对于线的有HMM模型方法;最近有基于RNN的Mixture Densify Network及其改进的方法用于生成连续数据点、和汉字
- 最近有用Sequence-to-Sequence 模型结合VAE(变分自编码器)的用于英语语言编码到潜在空间的研究
- 最后提了一些现有的公开数据集
3.方法
3.1 数据集
- 取自QuickDraw应用,有20s以内绘制出的草图,上百类,每类有70k的训练样本,及2.5k验证,2.5k测试
- 数据序列组织为[dx,dy,p1,p2,p3],分别表示x、y方向的变化,p1、p2、p3表示继续绘制、结束子序列、结束绘图三个状态
3.2 Sketch-RNN
- 基本结构为Sequence-to-Sequence 变分自编码器
- 其中编码器用双向RNN,以草图作为输入,潜在空间向量作为输出(用常规的VAE,先出 μ \mu μ和 σ \sigma σ,再按正态分布算 z z z)。
h → = e n c o d e → ( S ) , h ← = e n c o d e ← ( S r e v e r s e ) , h = [ h → ; h ← ] h_\rightarrow=encode_\rightarrow(S), h_\leftarrow=encode_\leftarrow(S_{reverse}), h=[h_\rightarrow;h_\leftarrow] h→=encode→(S),h←=encode←(Sreverse),h=[h→;h←]
μ = W μ h + b μ , σ ^ = W σ h + b σ , σ = e x p ( σ ^ 2 ) , z = μ + σ ⊙ N ( 0 , 1 ) \mu=W_{\mu}h+b_{\mu}, \hat{\sigma}=W_{\sigma}h+b_{\sigma}, \sigma=exp(\frac{\hat{\sigma}}{2}), z=\mu+\sigma \odot \mathcal{N}(0,1) μ=Wμh+bμ,σ^=Wσh+bσ,σ=exp(2σ^),z=μ+σ⊙N(0,1)
-
解码器用自回归RNN,以序列后一点作为当前点的输出;由于之前是双向RNN编码,所以z先过一个tanh得到解码器的初始状态
[ h 0 ; c 0 ] = t a n h ( W z z + b z ) [h_0;c_0]=tanh(W_zz+b_z) [h0;c0]=tanh(Wzz+bz) -
S 0 S_0 S0定义为(0,0,1,0,0)
-
(dx,dy)通过M元正态分布的高斯混合模型(GMM)计算概率,(q1,q2,q3)作为类别来计算,M也是一个类别分布,是GMM的混合权重
p ( △ x , △ y ) = ∑ j = 1 M N ( △ x , △ y ∣ μ x , j , μ y , j , σ x , j , σ y , j , ρ x y , j ) , w h e r e ∑ j = 1 M Π j = 1 p(\triangle x,\triangle y)=\sum_{j=1}^{M} \mathcal{N}(\triangle x,\triangle y | \mu_{x,j},\mu_{y,j},\sigma_{x,j},\sigma_{y,j},\rho_{xy,j}),where \ \sum_{j=1}^{M}\Pi_j=1 p(△x,△y)=