[mathscr{Lorain~y~w~la~Lora~blea.} newcommand{DS}[0]{displaystyle} % operators alias newcommand{opn}[1]{operatorname{#1}} newcommand{card}[0]{opn{card}} newcommand{lcm}[0]{opn{lcm}} newcommand{char}[0]{opn{char}} newcommand{Char}[0]{opn{Char}} newcommand{Min}[0]{opn{Min}} newcommand{rank}[0]{opn{rank}} newcommand{Hom}[0]{opn{Hom}} newcommand{End}[0]{opn{End}} newcommand{im}[0]{opn{im}} newcommand{tr}[0]{opn{tr}} newcommand{diag}[0]{opn{diag}} newcommand{coker}[0]{opn{coker}} newcommand{id}[0]{opn{id}} newcommand{sgn}[0]{opn{sgn}} newcommand{Res}[0]{opn{Res}} newcommand{Ad}[0]{opn{Ad}} newcommand{ord}[0]{opn{ord}} newcommand{Stab}[0]{opn{Stab}} newcommand{conjeq}[0]{sim_{u{conj}}} newcommand{cent}[0]{u{degree C}} newcommand{Sym}[0]{opn{Sym}} newcommand{wg}[0]{wedge} newcommand{Wg}[0]{bigwedge} % symbols alias newcommand{E}[0]{exist} newcommand{A}[0]{forall} newcommand{l}[0]{left} newcommand{r}[0]{right} newcommand{ox}[0]{otimes} newcommand{lra}[0]{leftrightarrow} newcommand{llra}[0]{longleftrightarrow} newcommand{iso}[1]{overset{sim}{#1}} newcommand{eps}[0]{varepsilon} newcommand{Ra}[0]{Rightarrow} newcommand{Eq}[0]{Leftrightarrow} newcommand{d}[0]{mathrm{d}} newcommand{e}[0]{mathrm{e}} newcommand{i}[0]{mathrm{i}} newcommand{j}[0]{mathrm{j}} newcommand{k}[0]{mathrm{k}} newcommand{Ex}[0]{mathbb{E}} newcommand{D}[0]{mathbb{D}} newcommand{oo}[0]{infty} newcommand{tto}[0]{rightrightarrows} newcommand{mmap}[0]{hookrightarrow} newcommand{emap}[0]{twoheadrightarrow} newcommand{actl}[0]{curvearrowright} newcommand{actr}[0]{curvearrowleft} newcommand{nsubg}[0]{triangleleft} newcommand{nsupg}[0]{triangleright} newcommand{lin}[0]{lim_{ntooo}} newcommand{linf}[0]{liminf_{ntooo}} newcommand{lsup}[0]{limsup_{ntooo}} newcommand{ser}[0]{sum_{n=1}^oo} newcommand{serz}[0]{sum_{n=0}^oo} newcommand{isoto}[0]{oversetsimto} newcommand{F}[0]{mathbb F} newcommand{x}[0]{times} newcommand{M}[0]{mathbf{M}} newcommand{T}[0]{intercal} % symbols with parameters newcommand{der}[1]{frac{d}{d #1}} newcommand{ul}[1]{underline{#1}} newcommand{ol}[1]{overline{#1}} newcommand{wt}[1]{widetilde{#1}} newcommand{br}[1]{l(#1r)} newcommand{bk}[1]{l[#1r]} newcommand{ev}[1]{l.#1r|} newcommand{abs}[1]{l|#1r|} newcommand{bs}[1]{boldsymbol{#1}} newcommand{env}[2]{begin{#1}#2end{#1}} % why not? newcommand{ALI}[1]{env{aligned}{#1}} newcommand{CAS}[1]{env{cases}{#1}} newcommand{pmat}[1]{env{pmatrix}{#1}} newcommand{dary}[2]{l|begin{array}{#1}#2end{array}r|} newcommand{pary}[2]{l(begin{array}{#1}#2end{array}r)} newcommand{pblk}[4]{l(begin{array}{c|c}{#1}&{#2}\hline{#3}&{#4}end{array}r)} newcommand{u}[1]{mathrm{#1}} newcommand{lix}[1]{lim_{xto #1}} newcommand{ops}[1]{#1cdots #1} newcommand{seq}[3]{{#1}_{#2}ops,{#1}_{#3}} newcommand{dedu}[2]{u{(#1)}Rau{(#2)}} % SPECIAL newcommand{dat}[1]{bs{mathrm{#1}}} % font for data point / data set ]
限于笔者水平, 本文或仅适合 AEVB 及 VAE 的基础学习. 如果希望更深入地了解 VAE, 推荐阅读参考资料 ([1]) 及相关文献.
对于数学水平要求, 本文仅假设读者掌握朴素概率论和入门的分析学.
(1) 数学基础
(1.1) KL 散度
The Kullback–Leibler divergence (also called relative entropy and I-divergence ), denoted (D_{u{KL}}(Pparallel Q)) , is a type of statistical distance : a measure of how much a model probability distribution (Q) is different from a true probability distribution (P) .
定量地, 离散条件下的 KL 散度定义为
[ALI{ D_{u{KL}}(Pparallel Q) &:= sum_{dat x}P(dat x)logfrac{P(dat x)}{Q(dat x)}\ &= -sum_{dat x}P(dat x)log Q(dat x)+sum_{dat x}P(dat x)log P(dat x). } ]
从信息熵 (也即 "相对熵" 这个名字) 的角度容易理解. 我们尝试用 (Q) 的最优编码方式 (即事件 (dat x) 使用 (-log Q(dat x)) 个 bit 的编码) 来编码 (P) , (D_{u{KL}}(Pparallel Q)) 给出的就是这种编码所用 bit 数与直接最优编码 (P) 本身的 bit 数 (即 (P) 本身的熵) 的差值, 这一差值反应了把编码从 (Q) 直接迁移到 (P) 的 "某种代价". 在这样的直观理解下, 如果二者是同分布的, 这一差值显然是 (0) ; 而对于一般的 (P) 和 (Q) , 也不难看出 (D_{u{KL}}(Pparallel Q)ge 0) .
(1.2) Evidence Lower BOund (ELBO)
这里我们着重研究形如 (D_{u{KL}}(Q(dat z)parallel P(dat zmiddat x))) 的 KL 散度, 其中 (dat x) 是某一特定事件, (P(dat zmiddat x)) 给出此时 (dat z) 的条件分布. 推导:
[ALI{ D_{u{KL}}(Qparallel P) &= sum_{dat z}Q(dat z)logfrac{Q(dat z)P(dat x)}{P(dat xdat z)} \ &= sum_{dat z}Q(dat z)br{logfrac{Q(dat z)}{P(dat xdat z)}+log P(dat x)}\ &= sum_{dat z}Q(dat z)(log Q(dat z)-log P(dat xdat z))+underbrace{sum_{dat z}Q(dat z)}_{=1}log P(dat x) \ &= sum_{dat z}Q(dat z)(log Q(dat z)-log P(dat xdat z))+log P(dat x). } ]
对分布 (Q(dat z)) , 记 (Ex_Qf(dat z):=sum_{dat z}Q(dat z)f(dat z)) , 则
[D_{u{KL}}(Q(dat z)parallel P(dat zmiddat x))=Ex_Q(log Q(dat z)-log P(dat xdat z))+log P(dat x). ]
[ALI{ implies log P(dat x) &= D_{u{KL}}(Q(dat z)parallel P(dat zmiddat x))-Ex_Q(log Q(dat z)-log P(dat xdat z))\ &=: D_{u{KL}}(Q(dat z)parallel P(dat zmiddat x))+mathcal L(Q). }tag 1 ]
由于 (D_{u{KL}}(Qparallel P)ge 0) , 有
[log P(dat x)gemathcal L(Q).tag 2 ]
即 (mathcal L(Q)) 可以作为 (log P(dat x)) 的下界估计.
(2) 模型结构
(2.1) 基本假设
设数据集 (dat X={dat x^{(i)}}_{i=1}^N) 由 (N) 个独立同分布的数据点构成. 我们假设它由以下过程采样而来:
从某个先验分布 (p_{dattheta^*}(dat z)) 采样 (dat z^{(i)}) ;
从某个条件分布 (p_{dattheta^*}(dat xmiddat z=dat z^{(i)})) 采样 (dat x^{(i)}) .
其中 (p_{dattheta^*}(dat z)) 和 (p_{dattheta^*}(dat xmid dat z)) 来自一族参数化分布 (p_{dattheta}(dat z)) 和 (p_{dattheta}(dat xmiddat z)) , 且它们的概率密度函数对 (dattheta) 和 (dat z) 几乎处处可微.
现在, 数据集 (dat X) 是已知的, 但我们不知道隐变量 (dat z^{(i)}) 和具体的分布参数 (dattheta^*) . 因此, 我们尝试引入一个识别模型 (q_{datphi}(dat zmid dat x)) 用来估计真实的后验分布 (p_{dattheta}(dat zmiddat x)) , 并尝试一起学习 (datphi) 和 (dattheta) .
我们将在后验分布 (p_{dattheta}(dat zmiddat x)) ((q_{datphi}(dat zmiddat x)) ) 上采样 (dat z) 的行为视作对数据 (dat x) 的编码, 在条件分布 (p_{dattheta}(dat xmiddat z)) 上采样 (dat x) 的行为视作对编码 (dat z) 的解码, 这就是所谓的 encode 和 decode 过程.
(2.2) Marginal Likelyhood
为了学到最优的 (dattheta^*) , 我们势必需要引入一个评估分布参数优劣的值. 模仿最大似然的手法, 我们仍然研究数据集 (dat X) 被模型生成的概率. 则对某个数据点 (dat x) 和待评估的参数 (dattheta) , 有
[p_{dattheta}(dat x)=int p_{dattheta}(dat xmiddat z)p_{dattheta}(z)d z. ]
(这里忽略了超参数 (alpha) . 为了让式子更完整, 可以在所有概率中 condition on (alpha) .) 而
[log p_{dattheta}(dat X)=sum_{i=1}^Nlog p_{dattheta}(dat x^{(i)}). ]
利用识别模型 (q_{datphi}) 估计后验分布, 套用 ((1)) , 我们知道
[log p_{dattheta}(dat x^{(i)})=D_{u{KL}}(q_{datphi}(dat zmiddat x^{(i)})parallel p_{dattheta}(dat zmiddat x^{(i)}))+mathcal L(dattheta,datphi;dat x^{(i)}). ]
同时由 ((2)) ,
[ALI{ log_{dattheta}(dat x^{(i)}) &ge mathcal L(dattheta,datphi;dat x^{(i)})\ &= Ex_{q_{datphi}(dat zmid dat x^{(i)})}(-log q_{datphi}(dat zmiddat x^{(i)})+log p_{dattheta}(dat x^{(i)}dat z))&(3)\ &= Ex_{q_{datphi}(dat zmid dat x^{(i)})}(-log q_{datphi}(dat zmiddat x^{(i)})+log p_{dattheta}(dat z)+log p_{dattheta}(dat x^{(i)}middat z))\ &= -D_{u{KL}}(q_{datphi}(dat zmid dat x^{(i)})parallel p_{dat theta}(dat z))+Ex_{q_{datphi}(dat zmid dat x^{(i)})}(log p_{dattheta}(dat x^{(i)}mid dat z)).&(4) } \ ]
我们希望通过对 (mathcal L(dattheta,datphi;dat x^{(i)})) 梯度下降来学出优秀的 (dattheta) 和 (datphi) .
(2.3) 重参数化 (reparameterization) 与 AEVB 算法
然而 ([1]) 中指出, (mathcal L(dattheta,datphi;dat x^{(i)})) 对 (datphi) 的梯度的方差很大, 不适用于数值计算. (不过对此论断, ([2]) 的评论区中有不同的分析, 可自行了解.) 这里, 我们采用重参数化技巧: 对 (dat zsim q_{datphi}(dat zmiddat x)) , 假定 (dat z=g_{datphi}(datepsilon,dat x)) 可微, (datphi) 是参数, (datepsilonsim p(datepsilon)) 是噪声. 以此为条件, 根据概率密度的定义:
[q_{datphi}(dat zmiddat x)ddat z=p(datepsilon)ddatepsilon. ]
进而
[ALI{ Ex_{q_{datphi}(dat zmiddat x^{(i)})}f(dat z) &= int q_{datphi}(dat zmid dat x^{(i)})f(dat z)ddat z\ &= int p(datepsilon)f(g_{datphi}(datepsilon,dat x^{(i)}))ddatepsilon\ &approx frac{1}{L}sum_{ell=1}^L f(underbrace{g_{datphi}(datepsilon^{(ell)},dat x^{(i)})}_{=:dat z^{(i,ell)}}),quad datepsilon^{(ell)}sim p(datepsilon). } ]
以此估计 ((3)) , 给出
[ALI{ mathcal L(dattheta,datphi;dat x^{(i)}) &approx wt{mathcal L}^A(dattheta,datphi;dat x^{(i)})\ &:= frac{1}{L}sum_{ell=1}^L(-log q_{datphi}(dat z^{(i,ell)}middat x^{(i)})+log p_{dattheta}(dat x^{(i)}dat z^{(i,ell)})). } ]
或者, 以此估计 ((4)) , 给出
[ALI{ mathcal L(dattheta,datphi;dat x^{(i)}) &approx wt{mathcal L}^B(dattheta,datphi;dat x^{(i)})\ &:= -D_{u{KL}}(q_{datphi}(dat zmiddat x^{(i)})parallel p_{dattheta}(dat z))+frac{1}{L}sum_{ell=1}^Llog p_{dattheta}(dat x^{(i)}middat z^{(i,ell)}). } ]
前一项散度据 ([1]) 称通常可以解析地求出.
接着, 在数据集 (dat X) 上采样一个大小为 (M) 的 minibatch 来估计给定参数的 marginal likelyhood, 有
[ALI{ mathcal L(dattheta,datphi;dat X) &approx wt{mathcal L}^M(dattheta,datphi;dat X)\ &:= frac{N}{M}sum_{i=1}^Mwt{L}(dattheta,datphi;dat x^{(i)}). } ]
(这里的 (M) 和单个数据点的采样数量 (L) 间可以 trade-off. ([1]) 指出当 (M=100) 时 (L=1) 的表现已经出色.)
最终, 嵌套地使用 (wt{mathcal L}^M) 和 [(wt{mathcal L}^A) 或 (wt{mathcal L}^B) ] 两次估计, 我们就能对 marginal likelyhood 的下界 ELBO 进行调优了. 这朴素地推导出 Auto-Encoding VB (AEVB) 算法:
[begin{array}{r|l} & text{Minibatch version of the Auto-Encoding VB algorithm}\ hline 0 & M,Lgets 100,1\ 1 & p(datepsilon),p_{dattheta}(dat xmiddat z),q_{datphi}(dat zmid dat x),p_{dattheta}(dat z) gets text{chosen distri. forms}\ 2 & dattheta,datvarphi gets text{initial parameters}\ 3 & textbf{repeat}\ 4 & qquad dat X^M gets text{minibatch sampled from }dat X\ 5 & qquad dat epsilon gets text{noise sampled from }p(datepsilon)\ 6 & qquad dat g gets nabla_{dattheta,datphi}wt{mathcal L}^M(dattheta,datphi;dat X^M,datepsilon)\ 7 & qquad dattheta,datphi gets text{parameters optimized by }dat g\ 8 & textbf{until}~text{convergence of }(dattheta,datphi)\ 9 & textbf{return}~dattheta,datphi end{array} ]
(2.4) 实例: VAE 算法
在 AEVB 的框架下, 不平凡的工作是指定分布 (p(datepsilon),p_{dattheta}(dat xmiddat z),q_{datphi}(dat zmid dat x),p_{dattheta}(dat z)) 的形式. 在 Variational Auto-Encoder (VAE) 中, 我们取
[ALI{ p(datepsilon) &= mathcal N(datepsilon;bs 0,bs 1),\ q_{datphi}(dat zmiddat x^{(i)}) &= mathcal N(dat z;datmu^{(i)},(datsigma^2)^{(i)}bs 1),\ p_{dattheta}(dat z) &= mathcal N(dat z;bs 0,bs 1),\ g_{datphi}(datepsilon^{(ell)},dat x^{(i)}) &= datmu^{(i)}+datsigma^{(i)}odotdatepsilon^{(ell)}. } ]
其中 (bs 1) 是适合尺寸的单位矩阵. ((datsigma^2)^{(i)}bs 1) 给出的是对角协方差阵, 即每个 (z_jsimmathcal N(mu^{(i)}_j,(sigma_j^{(i)})^2)) , 互相独立. (但个人感觉这个记号本身有些奇怪.)
而对于 (p_{dattheta}(dat xmiddat z)) , 可以根据数据类型选择:
对于二元数据, (p_{dattheta}(x_imiddat z)=mathcal B(x_i;1,y_i)) , 其中 (dat y) 由模型给出;
对于实值数据, (p_{dattheta}(x_imiddat z)=mathcal N(x_i;mu'_i,sigma_i'^2)) , 其中 (datmu') 和 (datsigma') 由模型给出.
这里给出实值数据下 VAE 一次 encode-decode 的示意. 其中 (dat xinR^5) , (dat zinR^3) , 蓝色点云表示概率密度:
接下来还需要验证 (wt{mathcal L}) 的形式. 这里采用 (wt{mathcal L}^B) 的估计, 需要计算 (-D_{u{KL}}(q_{datphi}(dat zmiddat x^{(i)})parallel p_{dattheta}(dat z))+frac{1}{L}sum_{ell=1}^Llog p_{dattheta}(dat x^{(i)}middat z^{(i,ell)})) . 对于前一项, (q_{datphi}(dat zmiddat x^{(i)})) 简记作 (q_{datphi}(dat z)) , 设向量维度为 (J) , 根据定义 (这里就是把离散情况的求和对应地变为分布函数上的 Lebesgue 积分, 我们在上文已经假设了这些分布良好的分析性质):
[ALI{ -D_{u{KL}}(q_{datphi}(dat z)parallel p_{dattheta}(dat z)) &= int q_{datphi}(dat z)log p_{dattheta}(dat z)ddat z-int q_{datphi}(dat z)log q_{datphi}(dat z)ddat z\ &= intmathcal N(dat z;datmu,datsigma^2)log mathcal N(dat z;bs 0,bs 1)ddat z-intmathcal N(dat z;datmu,datsigma^2)logmathcal N(dat z;datmu,datsigma^2)ddat z\ &=: I_1-I_2. } ]
容易计算:
[ALI{ I_1 &= intbr{prod_{j=1}^Jmathcal N(z_j;mu_j,sigma_j^2)}sum_{j=1}^Jlog mathcal N(z_j;0,1)ddat z\ &= sum_{i=1}^Jintmathcal N(z_i;mu_i,sigma_i^2)logmathcal N(z_i;0,1)cdotprod_{jneq i}mathcal N(z_j;mu_j,sigma_j^2)ddat z\ &= sum_{i=1}^Jintmathcal N(z_i;mu_i,sigma_i^2)logmathcal N(z_i;0,1)d z_icdotunderbrace{prod_{jneq i}intmathcal N(z_j;mu_j,sigma_j^2)d z_j}_{=1}\ &= -frac{1}{2}sum_{i=1}^Jintfrac{1}{sqrt{2pi}sigma_i}e^{-frac{(z_i-mu_i)^2}{2sigma_i^2}}br{log(2pi)+z_i^2}d z_i\ &= -frac{1}{2}sum_{i=1}^Jbr{log(2pi)+frac{1}{sqrt{2pi}sigma_i}int_{-oo}^{+oo}e^{-frac{x^2}{2sigma_i^2}}(x^2+2mu_ix+mu_i^2)d x}\ &= -frac{1}{2}sum_{i=1}^Jbr{log(2pi)+mu_i^2+frac{1}{sqrt{2pi}sigma_i}int_{-oo}^{+oo}x^2e^{-frac{x^2}{2sigma_i^2}}d x} } ]
回忆 Gauss 积分 (DSint_{-oo}^{+oo}x^2e^{-ax^2}d x=frac{1}{2}sqrt{frac{pi}{a^3}}) , 代入化简得
[ALI{ I_1 &= -frac{J}{2}log(2pi)-frac{1}{2}sum_{i=1}^Jbr{mu_i^2+frac{1}{sqrt{2pi}sigma_i}cdotfrac{1}{2}sqrt{8sigma_i^6pi}}\ &= -frac{J}{2}log(2pi)-frac{1}{2}sum_{i=1}^J(mu_i^2+sigma_i^2). } ]
同理
[I_2=-frac{J}{2}log(2pi)-frac{1}{2}sum_{i=1}^J(1+logsigma_j^2). ]
所以
[-D_{u{KL}}(q_{datphi}(dat z)parallel p_{dattheta}(dat z))=frac{1}{2}sum_{i=1}^J(1+logsigma_i^2-mu_i^2-sigma_i^2). ]
最终
[ALI{ mathcal L(dattheta,datphi;dat x^{(i)}) &approx wt{mathcal L}^B(dattheta,datphi;dat x^{(i)})\ &= frac{1}{2}sum_{i=1}^J(1+logsigma_i^2-mu_i^2-sigma_i^2)+frac{1}{L}sum_{ell=1}^Llog p_{dattheta}(dat x^{(i)}mid dat z^{(i,ell)}). } ]
这样的良好形式已然可以启动训练了. 在这一表达式中, 前一项即 (负) KL 散度, 后一项一般称为重构损失 (reconstruction loss).
(3) MNIST 实战
由于 VAE 和最常见的 "将 batch 输入模型 - 比对模型输出与 ground truth 计算 loss - 反向传播" 的训练方式有些差异, 实现起来可能有些难度. 所以这里以 MNIST 为例实现完整的 VAE, 并通过一些数据实验加深对 VAE 的理解.
(注: 文末提供了本节的完整代码.)
(3.1) 数据准备
无需多言. (Tips: MNIST 单图的初始形态为 ((1,28,28)) ; ToTensor() 后灰度值在 ([0,1]) 中.)
import torch import torch.nn as nn import torch.nn.functional as F import torchvision import matplotlib matplotlib.use("Agg") # 笔者使用的 WSL import matplotlib.pyplot as plt device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True) train_dataset.transform = torchvision.transforms.ToTensor() # 注意这里的 100 对应了训练量时 M 的值 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=True)
(3.2) 分布选取与框架代码
实践上, 在 decode 时直接采用独立 Bernoulli 分布是一个高质且高效的选择. 这时
[log p_{dattheta}(dat x^{(i)}middat z^{(i,ell)})=sum_{j}br{x^{(i)}_jlogmu'_j+(1-x^{(i)}_j)log(1-mu'_j)}, ]
其中 (dat mu'=dat mu'(dat z^{(i,ell)})) 即 decode 样本点 (dat z^{(i,ell)}) 的模型输出 (不必再如上文图中输出一个 (datsigma') ).
Q1: 灰度值是一个实值量, 为什么不如上文所说地使用正态分布来 decode?
A1: 用正态分布的最大问题是范围不匹配. 正态分布会给出 (R) 上的采样, 如果不在训练过程中强制截断, 会导致重构损失非常巨大 (实测 (10^9) 倍于 KL 散度) 而难以训练; 而强制截断则会导致边界概率密度的不合理分配.
Q2: 图像灰度值分布的 ground truth 总该是 ([0,1]) 上的连续分布, 我们用离散的 Bernoulli 分布去拟合合理吗?
A2: 的确, Bernoulli 分布无法建模中间灰度, 理论上有偏差. 如果希望更精确地拟合, 可以采用独立 Beta 分布等分布模型. Bernoulli 分布的优势在于其模型简单, 训练高效且稳定.
给出框架代码:
class Encoder(nn.Module): def __init__(self, LATENT_DIM): super(Encoder, self).__init__() self.W_h = nn.Linear(784, 256) self.b_h = nn.Parameter(torch.zeros(256)) self.W_mu = nn.Linear(256, LATENT_DIM) self.b_mu = nn.Parameter(torch.zeros(LATENT_DIM)) self.W_sgm = nn.Linear(256, LATENT_DIM) self.b_sgm = nn.Parameter(torch.zeros(LATENT_DIM)) def forward(self, x): x = x.view((-1, 784)) h = F.relu(self.W_h(x) + self.b_h) # 也可以用 tanh 等激活 mu = self.W_mu(h) + self.b_mu sgm = self.W_sgm(h) + self.b_sgm # sigma 可能 <0, 其行为和 >0 一致 return mu, sgm class Decoder(nn.Module): def __init__(self, LATENT_DIM): super(Decoder, self).__init__() self.W_h = nn.Linear(LATENT_DIM, 256) self.b_h = nn.Parameter(torch.zeros(256)) self.W_mu = nn.Linear(256, 784) self.b_mu = nn.Parameter(torch.zeros(784)) def forward(self, z): h = F.relu(self.W_h(z) + self.b_h) mu_re = F.sigmoid(self.W_mu(h) + self.b_mu) return mu_re # 使用 Bernoulli 分布, 只输出 mu' class VAE(nn.Module): def __init__(self, LATENT_DIM): super(VAE, self).__init__() self.LATENT_DIM = LATENT_DIM self.encoder = Encoder(LATENT_DIM) self.decoder = Decoder(LATENT_DIM) def generate(self, num=1): # [用于测试] 在隐空间随机采样重构 imgs = None with torch.no_grad(): z = torch.randn((num, self.LATENT_DIM)).to(device) mu_re = self.decoder(z) imgs = mu_re.view(-1, 1, 28, 28) return imgs.cpu() def reconstruct(self, X): # [用于测试] 模拟 encode-decode (如上文图过程) mu, sgm = self.encoder(X) eps = torch.randn_like(sgm).to(device) z = mu + sgm * eps mu_re = self.decoder(z) return mu_re.view(-1, 1, 28, 28).cpu() # 没有必要实现 forward 方法 class ELBO_Estimator(nn.Module): def __init__(self): super(ELBO_Estimator, self).__init__() self.L = 1 # 估算积分时的采样次数 self.FIX_EPS = 1e-8 # /0, log0 修正 def forward(self, X_M): mu, sgm = model.encoder(X_M) kl_div = -0.5 * torch.sum(1 + torch.log(sgm**2 + self.FIX_EPS) - mu**2 - sgm**2) re_loss = 0 for _ in range(self.L): e_l = torch.randn_like(sgm).to(device) # 批量采样 epsilon z_l = mu + sgm * e_l mu_re = model.decoder(z_l) re_loss += torch.sum(X_M * torch.log(mu_re + self.FIX_EPS)) re_loss += torch.sum((1 - X_M) * torch.log(1 - mu_re + self.FIX_EPS)) re_loss /= self.L elbo = -(re_loss - kl_div) # 负的 ELBO (调优时最小化之), 忽略了常数因子 return elbo, kl_div, re_loss # 后两项用于输出时观察
(3.3) 训练
无需多言.
model = VAE(2).to(device) # 这里 2 是隐空间维度, 可以自由调节 criterion = ELBO_Estimator().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # 随手写的学习率 def train_vae(model, train_loader, optimizer, epochs=10): model.train() for epoch in range(epochs): total_loss = 0 for batch_idx, (data, _) in enumerate(train_loader): data = data.view((-1, 784)).float().to(device) optimizer.zero_grad() loss, kl_div, re_loss = criterion(data) # 直接算 criterion, 不必 model.forward loss.backward() optimizer.step() total_loss += loss.item() if batch_idx % 100 == 0: print(f'batch {batch_idx + 1}/{len(train_loader)} | loss: {loss.item():.2f}', f'| kl_div: {kl_div.item():.2f} | re_loss: {re_loss.item():.2f}') print(f'---epoch {epoch + 1}/{epochs} | loss: {total_loss / len(train_loader):.2f}---n') # 启动训练 train_vae(model, train_loader, optimizer, epochs=10) torch.save(model.state_dict(), f'vae.pth')
(3.4) 实验
先来观察直接在整个隐空间采样 (dat z) 并重构的效果.
def generate_grid(model): model.eval() with torch.no_grad(): imgs = model.generate(16) grid = torchvision.utils.make_grid(imgs, nrow=4, padding=2) plt.figure(figsize=(10, 10)) plt.imshow(grid.permute(1, 2, 0).squeeze(), cmap='gray') plt.axis('off') plt.savefig('grid.png', format='png') plt.close() generate_grid(model)
结果 (LATENT_DIM=10):
每个 "数字" 看上去是若干个标准数字的模糊叠加. 直接这样生成数字虽然勉强能看, 但的确不够理想.