Just image Transformer (JiT) 结果展示

Back to Basics: Let Denoising Generative Models Denoise 的 PyTorch/GPU 复现结果展示页
本页侧重:可视化样例、FID/IS 指标,以及与传统扩散模型的差异说明。

1. 方法概览

本节结合源码与原论文 《Back to Basics: Let Denoising Generative Models Denoise》(Li & He, 2025)梳理 JiT 的总体设计思路。

2. 训练过程数据流(源码 + 公式对照)

本节对照代码与论文,逐步梳理 JiT 的训练数据流。每个阶段都给出“代码片段 → 数学公式 → 数据流动”的映射,方便交叉验证。

2.1 数据载入与增强(main_jit.py

# main_jit.py
transform_train = transforms.Compose([
    transforms.Lambda(lambda img: center_crop_arr(img, args.img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.PILToTensor()
])
dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'),
                                     transform=transform_train)

数据流:ImageNet 原图 → 中心裁剪到 img_size → 随机水平翻转 → 转为整型 Tensor(范围 [0,255])。此时尚未归一化。

2.2 归一化并送入 Denoiser(engine_jit.py::train_one_epoch

# engine_jit.py
x = x.to(device, non_blocking=True).to(torch.float32).div_(255)
x = x * 2.0 - 1.0
labels = labels.to(device, non_blocking=True)

对应公式:

$$ x_{\text{norm}} = 2 \cdot \frac{x_{\text{uint8}}}{255} - 1 \in [-1,1] $$

即把像素映射到 [-1,1],并与标签一起传入 Denoiser(分布式 DDP 包装后的模型)。

2.3 噪声混合与速度目标(denoiser.py::forward

t = \texttt{sample\_t}(B)                     # sigmoid(𝒩)
e = torch.randn_like(x) * self.noise_scale   # 噪声
z = t * x + (1 - t) * e                      # 线性插值
v = (x - z) / (1 - t).clamp_min(self.t_eps)  # 目标速度
x_pred = self.net(z, t.flatten(), labels_dropped)
v_pred = (x_pred - z) / (1 - t).clamp_min(self.t_eps)

公式与数据流:

2.4 损失函数与公式

# denoiser.py
loss = (v - v_pred) ** 2
loss = loss.mean(dim=(1, 2, 3)).mean()

数学描述:

$$ \mathcal{L}_{\text{JiT}} = \mathbb{E}_{x, e, t}\left[\left\|v - v_\theta(z,t,y)\right\|_2^2\right], \quad v = \frac{x - (t x + (1-t)e)}{1 - t} $$

该损失与论文中提出的“仅回归速度”的思想对应,训练与采样共用 $v$。

2.5 反向传播、优化与 EMA

# engine_jit.py
optimizer.zero_grad()
loss.backward()
optimizer.step()
model_without_ddp.update_ema()

优化器使用 AdamW(main_jit.py 中设置),EMA 更新在 denoiser.py::update_ema 中按 $m \leftarrow \alpha m + (1-\alpha)\theta$ 形式维护两个时间常数,供采样时切换。

2.6 训练数据流总览

  1. 数据加载:ImageNet → 中心裁剪 + 随机翻转 + Tensor(main_jit.py)。
  2. 归一化:$x_{\text{uint8}} \to x_{\text{norm}} \in [-1,1]$(engine_jit.py)。
  3. 噪声注入:采样 $t$、生成噪声 $e$,构造 $z = t x + (1-t)e$(denoiser.py)。
  4. 网络前向:JiT Transformer 接收 $(z,t,y)$,输出 $x_\theta$,进而得到 $v_\theta$。
  5. 损失与优化:最小化 $\|v - v_\theta\|^2$,反向传播,更新参数 + EMA。

3. 推理生成数据流(源码 + 公式对照)

下面以 denoiser.py::generate 及其辅助函数为线索,讲解“如何从纯噪声生成高质量图像”,并与公式对应。

3.1 初始化噪声与时间网格

# denoiser.py
z = self.noise_scale * torch.randn(bsz, 3, H, W, device=device)
timesteps = torch.linspace(0.0, 1.0, self.steps + 1, device=device) \
                .view(-1, *([1] * z.ndim)).expand(-1, bsz, -1, -1, -1)

3.2 单步前向 + CFG(_forward_sample

# denoiser.py
x_cond   = self.net(z, t.flatten(), labels)
v_cond   = (x_cond - z) / (1.0 - t)
x_uncond = self.net(z, t.flatten(), torch.full_like(labels, self.num_classes))
v_uncond = (x_uncond - z) / (1.0 - t)
interval_mask = (t < high) & ((low == 0) | (t > low))
cfg_scale = torch.where(interval_mask, self.cfg_scale, 1.0)
v_guided = v_uncond + cfg_scale * (v_cond - v_uncond)

对应关系:

3.3 ODE 步进(Euler / Heun)

# Euler
z_{k+1} = z_k + (t_{k+1} - t_k) * v_guided(z_k, t_k)

# Heun
z^{(euler)}_{k+1} = z_k + (t_{k+1} - t_k) * v_guided(z_k, t_k)
z_{k+1} = z_k + (t_{k+1} - t_k) * 0.5 * (v_guided(z_k, t_k) + v_guided(z^{(euler)}_{k+1}, t_{k+1}))

Heun 相当于显式两次评估的二阶 Runge-Kutta,能在保持速度场一致的同时,提高收敛稳定性。

3.4 最终图像重建

# engine_jit.py::evaluate
sampled_images = model_without_ddp.generate(labels)   # ∈ [-1,1]
sampled_images = (sampled_images + 1) / 2             # 归一化到 [0,1]

在评估/采样脚本中,会把结果映射回 [0,1] 再乘 255 保存成 PNG,同时确保各类别数量一致(详见 engine_jit.py::evaluate)。

3.5 推理数据流总览

  1. 准备:加载 EMA 权重,设定采样步长、CFG 区间、标签列表。
  2. 初始化:生成噪声图 $z_0$,创建线性时间网格 $t_0 \dots t_K$。
  3. 循环:对每个 $k$ 执行 CFG 速度预测、Euler/Heun 步进,得到 $z_{k+1}$。
  4. 输出:最后一步 $z_K$ 即为 [-1,1] 范围的生成图;映射回 [0,1] 后写入磁盘或传给后续评估。

4. 生成样例

4.1 ImageNet 256×256

JiT ImageNet 256x256 Samples
图 1:JiT-B/16 在 ImageNet 256×256 上的生成样例。

4.2 架构示意

JiT Architecture
图 2:Just image Transformer (JiT) 模型结构概览。

5. 定量评估

填写你实际测得的 FID / IS:

5.1 ImageNet 256×256 – JiT-B/16

指标数值备注
FID ↓xx.x与 jit_in256_stats.npz 对比
Inception Score ↑xx.x使用 torch-fidelity
采样步数50Heun ODE
CFG2.9 (区间 0.1–1.0)仅中后期启用

5.2 ImageNet 512×512 – JiT-B/32

指标数值备注
FID ↓xx.x与 jit_in512_stats.npz 对比
Inception Score ↑xx.x
采样步数50Heun ODE
CFG2.9 (区间 0.1–1.0)

6. 与传统扩散模型的差异(事无巨细)

下列对比基于 denoiser.pymodel_jit.pyengine_jit.py 以及论文 Li & He (2025) 中的描述整理。

6.1 噪声建模与时间参数化

传统扩散:DDPM/EDM/DiT 通常需要显式的 $\beta_t$ 或 $\sigma_t$ 调度表,定义 $x_t = \alpha_t x_0 + \sigma_t \epsilon$,并在训练/推理中都依赖该 schedule。

JiT:直接采样 $t \sim \mathrm{sigmoid}(\mathcal{N}(P_{\text{mean}}, P_{\text{std}}^2))$(见 denoiser.py::sample_t),用插值 $$ z = t x + (1-t) e $$ 构造“噪声图像”。这种折线式的时间参数化摒弃了复杂的噪声表,呼应论文中“回归到最简单的去噪建模”的观点。

6.2 训练目标:速度 v 的统一监督

传统扩散:大多预测噪声 $\epsilon$、干净图 $x_0$ 或 EDM 风格的 $v$,但仍以预定义的噪声 schedule 为前提。

JiT:直接以 $$ v = \frac{x - z}{1 - t} $$ 作为唯一监督量(denoiser.py::forward),并要求网络回归 $v_\theta(z,t,y)$。原论文强调该“速度视角”可以让推理过程完全复用训练时的物理量,从而更稳定。

6.3 采样策略:确定性 ODE vs 随机逆扩散

传统扩散:逆扩散链往往需要在每一步添加噪声(DDPM)或使用复杂的 $\sigma$ 网格进行 SDE/ODE 求解(EDM、DDIM)。

JiT:在 $t \in [0,1]$ 上等步长积分 $$ z_{k+1} = z_k + (t_{k+1} - t_k)\, v_\theta(z_k, t_k) $$ 可选 Euler 或 Heun(_euler_step / _heun_step),完全确定性,不再向链中重新注入噪声。如此一来,生成轨迹与论文中的“ODE 形式的纯去噪过程”保持一致。

6.4 网络主干:Just image Transformer vs U-Net

传统扩散:多采用带跳连的 U-Net 主干,或在 DiT 中只在高分辨率中段使用 Transformer。

JiT:通篇使用 ViT 风格的 Transformer(见 model_jit.py::JiT),特点包括:

6.5 条件控制:in-context tokens + 区间化 CFG

传统扩散:常见做法是把 class embedding 加到所有 token,或引入单个 class token,CFG 对所有时间步使用统一倍率。

JiT:

6.6 训练与评估流水线

传统扩散:往往需要配套的噪声 schedule 配置器、EMA 切换脚本、独立的采样器模块等。

JiT:训练和评估逻辑被整合在极少数文件中:

7. 训练 / 评估配置

# 训练示例
torchrun --nproc_per_node=8 \
  main_jit.py \
  --model JiT-B/16 \
  --P_mean -0.8 --P_std 0.8 \
  --img_size 256 --noise_scale 1.0 \
  --batch_size 128 --blr 5e-5 \
  --epochs 600 --warmup_epochs 5 \
  --gen_bsz 128 --num_images 50000 \
  --cfg 2.9 --interval_min 0.1 --interval_max 1.0 \
  --data_path ${IMAGENET_PATH} --online_eval
# 评估示例
torchrun --nproc_per_node=8 \
  main_jit.py \
  --model JiT-B/16 \
  --img_size 256 --noise_scale 1.0 \
  --gen_bsz 128 --num_images 50000 \
  --cfg 2.9 --interval_min 0.1 --interval_max 1.0 \
  --resume ${CKPT_DIR} --evaluate_gen

8. 结语

本页面可作为 JiT 复现实验的展示页:提供可视化、定量指标以及与传统扩散的差异说明。
建议:将当前 HTML 放在仓库里(例如 docs/results.html),并在 README 中加链接,方便快速查看。