扩散模型
1. 前向扩散训练过程(Forward Diffusion / Noise Addition)
1.1 公式
前向扩散过程:
\[q(x_t | x_0) = \mathcal{N}\Big(x_t; \sqrt{\bar{\alpha}_t} x_0, (1-\bar{\alpha}_t) I \Big) \]
其中:
\[\bar{\alpha}_t = \prod_{s=1}^{t} \alpha_s, \quad \alpha_s = 1 - \beta_s\]
然后训练目标是预测噪声 (\(\epsilon\)):
\[\text{loss} = | \epsilon - \epsilon_\theta(x_t, t) |^2\]
1.2 伪代码
| Text Only |
|---|
| 输入: 原图 x0, 总步数 T, beta schedule (β1 ~ βT)
输出: MSE loss
1. 随机采样时间步 t ∈ [0, T-1],每个样本独立
2. 生成噪声 noise ~ N(0, I)
3. 根据公式生成 x_t:
x_t = sqrt(alpha_bar[t]) * x0 + sqrt(1 - alpha_bar[t]) * noise
4. 预测噪声 epsilon_pred = model(x_t, t)
5. 计算 loss = MSE(noise, epsilon_pred)
6. 返回 loss
|
1.3 Python 代码
| Python |
|---|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52 | def extract(v, t, x_shape):
"""
从 v 中根据时间步 t 提取对应的系数,并 reshape 为可以广播的形状。
输入:
v: [T] 或 [T, ...],每个时间步对应的系数,比如 sqrt_alphas_bar
t: [batch_size],每个样本的时间步索引
x_shape: 输入数据的形状,用于广播
返回:
[batch_size, 1, 1, ..., 1],可以直接乘到 x_t 上
"""
device = t.device
# # 按索引 t 从 v 中取值,得到每个样本对应时间步的系数,out = [v[2], v[5], v[7]] = [2, 5, 7]
out = torch.gather(v, index=t, dim=0).float().to(device)
# reshape 为 [batch_size, 1, 1, ...],方便广播乘到 [B, C, H, W] 的 x_t
return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))
class GaussianDiffusionTrainer(nn.Module):
def __init__(self, model, beta_1, beta_T, T):
super().__init__()
# UNet
self.model = model
# 总时间步长
self.T = T
# betas: 存储一个长度为T, 元素从beta_1线性变化到beta_T
self.register_buffer(
'betas', torch.linspace(beta_1, beta_T, T).double())
alphas = 1. - self.betas
alphas_bar = torch.cumprod(alphas, dim=0) # 累积 ᾱ_t = ∏_{s=1}^{t} α_s
# calculations for diffusion q(x_t | x_{t-1}) and others
# 计算前向扩散所需系数
self.register_buffer(
'sqrt_alphas_bar', torch.sqrt(alphas_bar)) # sqrt(ᾱ_t)
self.register_buffer(
'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar)) # sqrt(1-ᾱ_t)
def forward(self, x_0):
"""
Algorithm 1. 扩散步骤, 输入原始样本 x_0,返回训练损失
"""
# 随机选择每个样本的时间步 t
t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)
noise = torch.randn_like(x_0) # 生成噪声 ε ~ N(0,1)
# 根据公式 q(x_t | x_0) = sqrt(ᾱ_t) * x_0 + sqrt(1-ᾱ_t) * ε
x_t = (
extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)
# 训练目标: 预测噪声 ε,损失为 MSE
loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')
return loss
|
2️. 反向采样过程(Reverse Sampling / Denoising)
2.1 公式
DDPM 反向采样:
\[p_\theta(x_{t-1} | x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \sigma_t^2 I)\]
其中:
\[\mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}}\Big(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(x_t, t)\Big) \]
\[\sigma_t^2 = \beta_t \frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t}\]
采样公式:
\[x_{t-1} = \mu_\theta(x_t, t) + \sqrt{\sigma_t^2} \cdot \text{noise}, \quad \text{noise} \sim N(0, I)\]
最后一步 (t=0) 不加噪声。
2.2 伪代码
| Text Only |
|---|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 | 输入: 噪声图 x_T, 总步数 T, beta schedule
输出: 生成图 x0
1. x_t = x_T
2. 对 time_step 从 T-1 到 0:
t = [time_step] * batch_size
eps = model(x_t, t) # 预测噪声
mean = (1/sqrt(alpha_t)) * (x_t - (beta_t / sqrt(1-alpha_bar_t)) * eps)
var = beta_t * (1 - alpha_bar_{t-1}) / (1 - alpha_bar_t)
if time_step > 0:
noise = N(0, I)
else:
noise = 0
x_t = mean + sqrt(var) * noise
3. return clip(x_t, -1, 1)
|
2.3 Python 代码
| Python |
|---|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62 | class GaussianDiffusionSampler(nn.Module):
def __init__(self, model, beta_1, beta_T, T):
super().__init__()
self.model = model
self.T = T
# 线性betas
self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
alphas = 1. - self.betas
alphas_bar = torch.cumprod(alphas, dim=0)
# ᾱ_{t-1},方便计算后向采样
alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]
# coeff1, coeff2 用于从 x_t 预测 x_{t-1} 的均值 μ
self.register_buffer('coeff1', torch.sqrt(1. / alphas))
self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))
# 去噪过程方差 σ²_t
self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))
def predict_xt_prev_mean_from_eps(self, x_t, t, eps):
"""
根据公式 μ = 1/√α_t * (x_t - (1-α_t)/√(1-ᾱ_t) * ε) 预测 x_{t-1} 均值
"""
assert x_t.shape == eps.shape
return (
extract(self.coeff1, t, x_t.shape) * x_t -
extract(self.coeff2, t, x_t.shape) * eps
)
def p_mean_variance(self, x_t, t):
"""
返回每一步的 μ 和 σ²
"""
# 选择每个时间步对应的方差
# below: only log_variance is used in the KL computations
var = torch.cat([self.posterior_var[1:2], self.betas[1:]])
var = extract(var, t, x_t.shape)
eps = self.model(x_t, t)
xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)
return xt_prev_mean, var
def forward(self, x_T):
"""
Algorithm 2. 采样过程: 从 x_T ~ N(0,1) 逐步采样到 x_0
"""
x_t = x_T
for time_step in reversed(range(self.T)):
print(time_step)
# 构造时间步张量 t
t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
mean, var= self.p_mean_variance(x_t=x_t, t=t)
# t>0 时加噪声,否则 t=0 不加噪声
if time_step > 0:
noise = torch.randn_like(x_t)
else:
noise = 0
# 采样公式: x_{t-1} = μ + sqrt(σ²) * ε
x_t = mean + torch.sqrt(var) * noise
assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
x_0 = x_t
return torch.clip(x_0, -1, 1)
|