变分自编码器
变分自编码器(vae)这个东西知道很久了,不过一直理解不是很深刻,现在总结一下查阅到的文档,同时记录一下自己的一些问题。
1. pytorch实现
class Encoder(torch.nn.Module):
def __init__(self, D_in, H, D_out):
super(Encoder, self).__init__()
self.linear1 = torch.nn.Linear(D_in, H)
self.linear2 = torch.nn.Linear(H, D_out)
def forward(self, x):
x = F.relu(self.linear1(x))
return F.relu(self.linear2(x))
class Decoder(torch.nn.Module):
def __init__(self, D_in, H, D_out):
super(Decoder, self).__init__()
self.linear1 = torch.nn.Linear(D_in, H)
self.linear2 = torch.nn.Linear(H, D_out)
def forward(self, x):
x = F.relu(self.linear1(x))
return F.relu(self.linear2(x))
class VAE(torch.nn.Module):
latent_dim = 8
def __init__(self, encoder, decoder):
super(VAE, self).__init__()
self.encoder = encoder
self.decoder = decoder
self._enc_mu = torch.nn.Linear(100, 8)
self._enc_log_sigma = torch.nn.Linear(100, 8)
def _sample_latent(self, h_enc):
"""
Return the latent normal sample z ~ N(mu, sigma^2)
"""
mu = self._enc_mu(h_enc)
log_sigma = self._enc_log_sigma(h_enc)
sigma = torch.exp(log_sigma)
std_z = torch.from_numpy(np.random.normal(0, 1, size=sigma.size())).float()
self.z_mean = mu
self.z_sigma = sigma
return mu + sigma * Variable(std_z, requires_grad=False) # Reparameterization trick
def forward(self, state):
h_enc = self.encoder(state)
z = self._sample_latent(h_enc)
return self.decoder(z)
def latent_loss(z_mean, z_stddev):
mean_sq = z_mean * z_mean
stddev_sq = z_stddev * z_stddev
return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq) - 1)
criterion = nn.MSELoss()
dec = vae(inputs)
loss = criterion(dec, inputs) + latent_loss(vae.z_mean, vae.z_sigma) # 重建损失与kl距离
2. 记录点
X(样本空间),Z(latent space)
2.1 保证Z满足标准高斯分布(独立,多元),如何保证呢?
只要保证$p(Z|X)$满足$\mathcal{N}(0,I)$即可,这就是latent_loss为何要这么设计。 \(p(Z)=\sum_X p(Z|X)p(X)=\sum_X \mathcal{N}(0,I)p(X)=\mathcal{N}(0,I) \sum_X p(X) = \mathcal{N}(0,I)\)
2.2 latent loss的推导
由于我们考虑的是各分量独立的多元正态分布,因此只需要推导一元正态分布的情形即可, \(\begin{aligned}&KL\Big(N(\mu,\sigma^2)\Big\Vert N(0,1)\Big)\\ =&\int \frac{1}{\sqrt{2\pi\sigma^2}}e^{-(x-\mu)^2/2\sigma^2} \left(\log \frac{e^{-(x-\mu)^2/2\sigma^2}/\sqrt{2\pi\sigma^2}}{e^{-x^2/2}/\sqrt{2\pi}}\right)dx\\ =&\int \frac{1}{\sqrt{2\pi\sigma^2}}e^{-(x-\mu)^2/2\sigma^2} \log \left\{\frac{1}{\sqrt{\sigma^2}}\exp\left\{\frac{1}{2}\big[x^2-(x-\mu)^2/\sigma^2\big]\right\} \right\}dx\\ =&\frac{1}{2}\int \frac{1}{\sqrt{2\pi\sigma^2}}e^{-(x-\mu)^2/2\sigma^2} \Big[-\log \sigma^2+x^2-(x-\mu)^2/\sigma^2 \Big] dx\end{aligned}\) 整个结果分为三项积分,第一项实际上就是$-\log \sigma^2$乘以概率密度的积分(也就是1),所以结果是$-\log \sigma^2$;第二项实际是正态分布的二阶矩,熟悉正态分布的朋友应该都清楚正态分布的二阶矩为$\mu^2+\sigma^2$;而根据定义,第三项实际上就是“-方差除以方差=-1”。所以总结果就是: \(KL\Big(N(\mu,\sigma^2)\Big\Vert N(0,1)\Big)=\frac{1}{2}\Big(-\log \sigma^2+\mu^2+\sigma^2-1\Big)\)
2.3 Evidence Lower Bound(ELBO)推导
2.2部分是kl散度的推导,不过VAE的整个损失并不是只有这个,其损失函数是被称为ELBO的一个东西。因为我们想将Z变成一个$\mathcal{N}(0,I)$的分布,而我们又只有X,那么我们要做的就是使得$KL\Big(Q(Z)\Big\Vert P(Z|X)\Big)$最小化。 \(\begin{aligned}&KL\Big(Q(Z)\Big\Vert P(Z|X)\Big)\\=&E_{Z\sim Q}\Big(logQ(Z)-logP(Z|X)\Big)\\=&E_{Z\sim Q}\Big(logQ(Z)-logP(X|Z)-logP(Z)\Big)+logP(X)\end{aligned}\)
移项整理得到: \(logP(X)-KL\Big(Q(Z)\big \Vert P(Z|X)\Big)=E_{Z\sim Q}\Big(logP(X|Z)\Big)-KL\Big(Q(Z)\Big \Vert P(Z) \Big)\) 将$Q(Z)$替换为$Q(Z|X)$得到: \(logP(X)-KL\Big(Q(Z|X)\big \Vert P(Z|X)\Big)=E_{Z\sim Q}\Big(logP(X|Z)\Big)-KL\Big(Q(Z|X)\Big \Vert P(Z) \Big)\) 显然,左边两项就是我们要优化的项,左边两项越大越好。而右边两项则是可以计算的,右边第一项相当于一个decoder,而第二项相当于一个encoder,它也对应于2.2部分的推导,右边部分即被称为Evidence Lower Bound,一般在讨论VAE的时候我们用ELBO来指代它的cost function。
2.4 KL vanish
KL vanish的出现可以从公式4得知,如果Z和X相互独立,即X完全不依赖于Z,那么右边第二项KL损失就可以被优化为0,而仅仅第一项起作用,这时候KL就发生了vanish。
参考资料: