以下是关于 PyTorch 中生成模型(Generative Models)的详细介绍,包括核心概念、常见模型类型、实现方法及代码示例。
一、生成模型简介
生成模型(Generative Models) 是一类用于学习数据分布并生成新样本的机器学习模型。其目标是捕捉输入数据的潜在分布 P(X),从而能够生成与训练数据相似的新样本。与判别模型(Discriminative Models)不同,生成模型更关注数据本身的分布特性。
核心任务
-
生成新样本:从学习到的分布中采样生成新数据。
-
密度估计:计算给定样本的概率。
-
隐变量推断:学习数据的潜在表示(Latent Variables)。
二、常见生成模型类型
PyTorch 中常用的生成模型包括:
- 生成对抗网络(GAN, Generative Adversarial Networks)
原理:通过生成器(Generator)和判别器(Discriminator)的对抗训练,生成器尝试生成逼真样本,判别器尝试区分真实样本和生成样本。
-
优点:生成样本质量高。
-
缺点:训练不稳定,模式坍塌(Mode Collapse)。
PyTorch 实现示例
import torch
import torch.nn as nn
# 定义生成器
class Generator(nn.Module):
def __init__(self, latent_dim, img_shape):
super().__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 784),
nn.Tanh()
)
self.img_shape = img_shape
def forward(self, z):
img = self.model(z)
return img.view(img.size(0), *self.img_shape)
# 定义判别器
class Discriminator(nn.Module):
def __init__(self, img_shape):
super().__init__()
self.model = nn.Sequential(
nn.Linear(784, 256),
nn.LeakyReLU(