MiniMind-O 逐行代码解读 · 第四章:训练代码
📁 文件:
trainer/train_sft_omni.py(263行)+trainer/trainer_utils.py(200行) 🔖 训练是把数据"喂"给模型、让模型学习的过程——本章解读训练的每一个步骤
📚 本章导读
训练代码分为两部分:
- train_sft_omni.py — 训练主流程,包含命令行参数、训练循环、loss 计算
- trainer_utils.py — 工具函数,包含模型初始化、学习率调度、checkpoint 保存
Part A: trainer_utils.py 工具函数
1️⃣ is_main_process / Logger(第16-21行)
def is_main_process():
return not dist.is_initialized() or dist.get_rank() == 0
def Logger(content):
if is_main_process():
print(content)
分布式训练中:多张 GPU 各自运行一份代码,但日志只需要主 GPU(rank=0)打印,避免重复。
2️⃣ get_lr 学习率调度(第25-27行)
def get_lr(current_step, total_steps, lr):
return lr * (0.1 + 0.45 * (1 + math.cos(math.pi * current_step / total_steps)))
余弦退火调度:
lr
1.0 * base ┤████████╗
│ ╚════╗
│ ╚════╗
│ ╚════╗
0.1 * base ┤ ╚════════
└──────────────────────────────────── step
0 total_steps
- 起始 lr =
1.0 × base_lr - 结束 lr =
0.1 × base_lr - 平滑过渡,避免训练后期震荡
公式推导:
cos(0) = 1→0.1 + 0.45 * (1 + 1) = 0.1 + 0.9 = 1.0→ 起始倍率cos(π) = -1→0.1 + 0.45 * (1 - 1) = 0.1→ 结束倍率
3️⃣ init_distributed_mode 分布式初始化(第30-37行)
def init_distributed_mode():
if int(os.environ.get("RANK", -1)) == -1:
return 0 # 没有 RANK 环境变量 → 单卡模式,local_rank=0
dist.init_process_group(backend="nccl") # 初始化 NCCL 通信后端
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank) # 绑定当前进程使用的 GPU
return local_rank
torchrun 设置的环境变量:
| 变量 | 说明 |
|---|---|
RANK | 全局进程编号(0, 1, 2, ...) |
LOCAL_RANK | 当前节点内的 GPU 编号 |
WORLD_SIZE | 总进程数 |
4️⃣ setup_seed 随机种子(第40-47行)
def setup_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True # cuDNN 确定性模式
torch.backends.cudnn.benchmark = False # 关闭自动调优
为什么关闭 benchmark? benchmark=True 时 cuDNN 会尝试不同的卷积算法选最快的,但不同算法可能产生不同结果。关闭后保证每次运行结果一致(可复现性),但可能慢一点。
5️⃣ log_model_params 模型参数统计(第50-62行)
def log_model_params(model, ignore_patterns=['audio_encoder', 'vision_encoder']):
def should_count(n): return not any(p in n for p in ignore_patterns)
total = sum(p.numel() for n, p in model.named_parameters() if should_count(n)) / 1e6 # 总参数量(M)
cfg = model.config
n_routed = getattr(cfg, 'n_routed_experts', getattr(cfg, 'num_experts', 0)) # 路由专家数
n_active = getattr(cfg, 'num_experts_per_tok', 0) # 每次激活的专家数
n_shared = getattr(cfg, 'n_shared_experts', 0) # 共享专家数
expert = sum(p.numel() for n, p in model.named_parameters()
if 'mlp.experts.0.' in n and should_count(n)) / 1e6 # 单个专家参数量
shared_expert = sum(p.numel() for n, p in model.named_parameters()
if 'mlp.shared_experts.0.' in n and should_count(n)) / 1e6
base = total - (expert * n_routed) - (shared_expert * n_shared) # 非专家部分
active = base + (expert * n_active) + (shared_expert * n_shared) # 每次实际激活的参数量
if active < total:
Logger(f'Model Params: {total:.2f}M-A{active:.2f}M') # 如 113.13M-A56.8M
else:
Logger(f'Model Params: {total:.2f}M')
"-A" 的含义:MoE 模型虽然总参数量大,但每个 token 只激活一部分专家。113.13M-A56.8M 表示总参数113M,每次推理只用56.8M。
6️⃣ init_omni_model 模型初始化(第65-104行)⭐
def init_omni_model(omni_config, from_weight='full_sft', tokenizer_path='../model',
audio_encoder_path='../model/SenseVoiceSmall',
vision_model_path='../model/siglip2-base-p32-256-ve',
save_dir='../out', device='cuda', freeze_backbone='none', from_resume=0):
Step 1: 加载 tokenizer 和模型骨架
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
model = MiniMindOmni(omni_config, audio_encoder_path=audio_encoder_path,
vision_model_path=vision_model_path)
Step 2: 加载预训练权重
if from_weight != 'none':
moe_suffix = '_moe' if omni_config.use_moe else ''
weight_path = f'{save_dir}/{from_weight}_{omni_config.hidden_size}{moe_suffix}.pth'
# 例如: out/llm_768.pth 或 out/sft_omni_768.pth
if os.path.exists(weight_path):
weights = torch.load(weight_path, map_location=device)
# 检查 shape 不匹配的权重
param_shapes = {k: v.shape for k, v in model.named_parameters()}
incompatible = {k for k, v in weights.items()
if k in param_shapes and v.shape != param_shapes[k]}
if incompatible:
Logger(f'跳过shape不匹配的权重: {incompatible}')
weights = {k: v for k, v in weights.items() if k not in incompatible}
model.load_state_dict(weights, strict=False)
Logger(f'已加载权重: {weight_path}')
strict=False:允许权重文件中有模型不存在的 key,也允许模型有权重文件中没有的 key。这对增量训练非常重要——新的 Talker 层在旧权重中不存在,不会报错。
Step 3: Talker 层初始化(关键!)
if from_resume == 0 and omni_config.talker_hidden_size == omni_config.hidden_size:
n_talker = omni_config.num_talker_hidden_layers # 4
n_thinker = len(model.thinker.layers) # 8
has_talker = any(k.startswith('talker.layers.') for k in weights)
if not has_talker and n_talker > 0:
# 从 Thinker 的最后 4 层复制到 Talker 的 4 层
for i in range(n_talker):
src = n_thinker - n_talker + i # src = 4, 5, 6, 7
model.talker.layers[i].load_state_dict(
model.thinker.layers[src].state_dict()
)
Logger(f'Talker层初始化: 复制thinker layers[{n_thinker-n_talker}:{n_thinker}] → talker layers[0:{n_talker}]')
为什么这样做?
- Talker 和 Thinker 使用相同的
MiniMindBlock结构 - Talker 是新加的层,没有预训练权重
- 从 Thinker 的后几层复制权重,比随机初始化好得多
- Thinker 后几层已经学会了"理解语义",Talker 需要的就是理解语义来生成语音
Step 4: 冻结策略
if freeze_backbone == 'all':
# 冻结整个 Thinker 主干
for param in model.model.parameters():
param.requires_grad = False
elif freeze_backbone == 'last1':
# 冻结除最后一层之外的所有层
for param in model.model.parameters():
param.requires_grad = False
if hasattr(model.model, 'layers') and len(model.model.layers) > 0:
for param in model.model.layers[-1].parameters():
param.requires_grad = True
return model.to(device), tokenizer
| freeze_backbone | 效果 | 使用场景 |
|---|---|---|
none | 全部可训练 | Step 1/3 全量微调 |
all | 只训练 Talker + 投影层 | 极小显存训练 |
last1 | 只训练最后1层 + Talker + 投影层 | 折中方案 |
7️⃣ omni_checkpoint Checkpoint 管理(第107-162行)
def omni_checkpoint(omni_config, weight='pretrain_omni', model=None, optimizer=None,
epoch=0, step=0, wandb=None, save_dir='../checkpoints', **kwargs):
保存模式(model 不为 None)
if model is not None:
# 1. 取出原始模型(去掉 DDP 和 compile 包装)
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model) # torch.compile 包装
# 2. 清理不需要保存的权重
clean_state_dict = {
k: v for k, v in raw_model.state_dict().items()
if not k.startswith('audio_encoder.') and not k.startswith('vision_encoder.')
}
# 不保存冻结的编码器!它们可以从原始路径重新加载
# 这样 checkpoint 从 1.1GB 降到 226MB
state_dict = {k: v.half().cpu() for k, v in clean_state_dict.items()} # FP16 + 移到CPU
# 3. 原子写入(防崩溃损坏)
ckp_tmp = ckp_path + '.tmp'
torch.save(state_dict, ckp_tmp)
os.replace(ckp_tmp, ckp_path) # 原子操作:先写临时文件,再重命名
os.replace 是原子的:如果写入过程中断电,临时文件可能损坏,但原始文件完好。重命名是瞬间的原子操作。
# 4. 保存 resume checkpoint(含 optimizer 状态)
resume_data = {
'model': state_dict,
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'step': step,
'world_size': dist.get_world_size() if dist.is_initialized() else 1,
'wandb_id': wandb_id,
}
# 也用原子写入
加载模式(model 为 None)
else:
if os.path.exists(resume_path):
ckp_data = torch.load(resume_path, map_location='cpu')
saved_ws = ckp_data.get('world_size', 1)
current_ws = dist.get_world_size() if dist.is_initialized() else 1
if saved_ws != current_ws:
# GPU数量变化时,调整 step 数
ckp_data['step'] = ckp_data['step'] * saved_ws // current_ws
Logger(f'GPU数量变化({saved_ws}→{current_ws}),step已自动转换为{ckp_data["step"]}')
return ckp_data
return None
GPU 数量变化的处理:如果之前用4卡训练保存了 step=1000,现在改用1卡继续,step 需要调整。因为4卡时每个 step 处理 4×batch_size 的数据,1卡只处理 1×batch_size,所以 step 数要乘4才能处理等量的数据。
8️⃣ SkipBatchSampler(第176-199行)
class SkipBatchSampler(Sampler):
def __init__(self, sampler, batch_size, skip_batches=0):
self.sampler = sampler
self.batch_size = batch_size
self.skip_batches = skip_batches # 续训时跳过的 batch 数
def __iter__(self):
batch = []
skipped = 0
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
if skipped < self.skip_batches:
skipped += 1 # 跳过已训练的 batch
batch = []
continue
yield batch
batch = []
if len(batch) > 0 and skipped >= self.skip_batches:
yield batch # 最后不完整的 batch
用途:续训时,从上次中断的 step 继续,不重复训练已经处理过的数据。
Part B: train_sft_omni.py 训练主脚本
1️⃣ omni_collate_fn 自定义批处理函数(第24-48行)
def omni_collate_fn(batch):
input_ids, labels, audio_labels, audio_inputs, audio_lens, pixel_values, spk_emb = zip(*batch)
为什么需要自定义 collate? 因为不同样本的 audio_inputs 长度不同,pixel_values 可能是字典或 tensor,PyTorch 默认的 collate 无法处理。
# 文本标签和音频标签:直接 stack
input_ids = torch.stack(input_ids) # (B, 9, T)
labels = torch.stack(labels) # (B, T)
audio_labels = torch.stack(audio_labels) # (B, 8, T)
audio_lens = torch.tensor(audio_lens, dtype=torch.long) # (B,)
# 音频输入:不同长度,需要 pad 到最长
valid_audios = [a for a in audio_inputs if a is not None]
if valid_audios:
max_t = max(a.size(1) for a in valid_audios) # 找最长的时间帧数
padded = [a if a.size(1) == max_t
else torch.nn.functional.pad(a, (0, 0, 0, max_t - a.size(1)))
for a in valid_audios] # 短的填充零
audio_inputs = torch.cat(padded, dim=0) # (valid_B, max_T, 560)
else:
audio_inputs = None
# 图像输入:字典格式需要按键合并
valid_images = [p for p in pixel_values if p is not None]
if valid_images:
if hasattr(valid_images[0], 'keys'): # 字典格式
keys = set.intersection(*[set(d.keys()) for d in valid_images])
pixel_values = {k: torch.cat([d[k] for d in valid_images], dim=0) for k in keys}
else: # tensor 格式
pixel_values = torch.cat(valid_images, dim=0)
else:
pixel_values = None
spk_emb = torch.stack(spk_emb) # (B, 192)
return input_ids, labels, audio_labels, audio_inputs, audio_lens, pixel_values, spk_emb
2️⃣ train_epoch 训练一个 epoch(第51-136行)⭐ 核心训练循环
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
start_time = time.time()
last_step = start_step
训练循环主体
for step, (input_ids, labels, audio_labels, audio_inputs, audio_lens, pixel_values, spk_emb) in enumerate(loader, start=start_step + 1):
Step A: 数据移到 GPU
input_ids = input_ids.to(args.device)
labels = labels.to(args.device)
audio_labels = audio_labels.to(args.device)
audio_lens = audio_lens.to(args.device)
if audio_inputs is not None:
audio_inputs = audio_inputs.to(args.device)
if pixel_values is not None:
if hasattr(pixel_values, 'keys'): # 字典
pixel_values = {k: v.to(args.device) for k, v in pixel_values.items()}
else:
pixel_values = pixel_values.to(args.device)
spk_emb = spk_emb.to(args.device)
Step B: 更新学习率
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
Step C: 前向传播(在混合精度上下文中)
with autocast_ctx:
res = model(input_ids, audio_inputs=audio_inputs, audio_lens=audio_lens,
pixel_values=pixel_values, spk_emb=spk_emb)
Step D: 计算 Loss
loss_fct = nn.CrossEntropyLoss(reduction='none') # 不自动求均值
# === 文本损失 ===
text_loss_raw = loss_fct(res.logits.view(-1, res.logits.size(-1)), labels.view(-1))
text_mask = (labels.view(-1) != -100).float() # 只计算有效位置
text_loss = (text_loss_raw * text_mask).sum() / (text_mask.sum() + 1e-9)
# === 音频损失(8层独立计算)===
audio_loss = res.audio_logits[0].sum() * 0 # 初始化为 0(保持计算图)
for i, al in enumerate(res.audio_logits):
al_flat = al.view(-1, al.size(-1)) # (B*T, 2112)
target_flat = audio_labels[:, i, :].reshape(-1) # (B*T,)
layer_loss = loss_fct(al_flat, target_flat) # 逐位置损失
valid_mask = (target_flat != -100).float() # 有效位置
stop_mask = (target_flat == 2050).float() # stop token 位置
weighted_loss = layer_loss * valid_mask * (1 + stop_mask * 9) # stop 权重 10x!
msum = valid_mask.sum()
if msum > 0:
audio_loss = audio_loss + weighted_loss.sum() / (msum + 1e-9)
audio_loss = audio_loss / 8 # 8 层取平均
stop token 10 倍权重:
普通 token 权重 = 1 + 0 * 9 = 1
stop token 权重 = 1 + 1 * 9 = 10
- stop token 是音频生成的终止信号
- 如果模型错过 stop → 无限生成噪音
- 所以需要重点训练模型识别 stop
# === 总损失 ===
loss = (text_loss + audio_loss + res.aux_loss) / args.accumulation_steps
- 除以
accumulation_steps是为了梯度累积 - 等效 batch_size = batch_size × accumulation_steps
Step E: 反向传播 + 梯度累积
scaler.scale(loss).backward() # 放大梯度(float16 用)
if step % args.accumulation_steps == 0:
scaler.unscale_(optimizer) # 缩放回来
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) # 梯度裁剪
scaler.step(optimizer) # 更新参数
scaler.update() # 更新 scaler
optimizer.zero_grad(set_to_none=True) # 清零梯度(set_to_none=True 更快)
set_to_none=True:不把梯度设为零张量,而是设为 None。这样 PyTorch 可以跳过一些不必要的计算。
Step F: 日志记录
if step % args.log_interval == 0 or step == iters:
spend_time = time.time() - start_time
current_loss = loss.item() * args.accumulation_steps # 恢复真实 loss
text_loss_val = text_loss.item()
audio_loss_val = audio_loss.item()
current_lr = optimizer.param_groups[-1]['lr']
eta_min = spend_time / max(step - start_step, 1) * (iters - step) // 60 # 预估剩余时间
Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}), '
f'loss: {current_loss:.4f}, text: {text_loss_val:.4f}, '
f'audio: {audio_loss_val:.4f}, lr: {current_lr:.08f}, epoch_time: {eta_min:.1f}min')
Step G: 保存 Checkpoint
if (step % args.save_interval == 0 or step == iters) and is_main_process():
model.eval()
# ... 保存权重 ...
model.train()
Step H: 释放显存
del input_ids, labels, audio_labels, audio_inputs, audio_lens, pixel_values, spk_emb, res, loss
- 手动删除不再需要的张量
- 帮助 Python GC 和 PyTorch 缓存分配器回收显存
3️⃣ 命令行参数(第139-167行)
完整参数表:
| 参数 | 默认值 | 说明 |
|---|---|---|
--save_dir | ../out | 模型保存目录 |
--save_weight | sft_omni | 保存权重前缀 |
--epochs | 15 | 训练轮数 |
--batch_size | 32 | batch 大小 |
--learning_rate | 5e-4 | 学习率 |
--device | cuda:0 | 训练设备 |
--dtype | bfloat16 | 混合精度类型 |
--num_workers | 4 | DataLoader 线程数 |
--accumulation_steps | 1 | 梯度累积步数 |
--grad_clip | 1.0 | 梯度裁剪阈值 |
--log_interval | 100 | 日志间隔 |
--save_interval | 1000 | 保存间隔 |
--hidden_size | 768 | 隐藏维度 |
--num_hidden_layers | 8 | 层数 |
--max_seq_len | 512 | 最大序列长度 |
--use_moe | 0 | 是否 MoE |
--data_path | ../dataset/sft_t2a_mini.parquet | 数据路径 |
--from_weight | llm | 初始权重名 |
--from_resume | 0 | 是否续训 |
--freeze_backbone | none | 冻结策略 |
--mode | all | 训练模式 |
--use_compile | 0 | 是否 torch.compile |
--mode 的三个选项:
| 模式 | 冻结 | 只训练 | 用途 |
|---|---|---|---|
all | 无 | 全部 | Step 1/3 全量微调 |
audio_proj | 全部 | audio_proj | Step 2 音频对齐 |
vision_proj | 全部 | vision_proj | 视觉对齐(未用) |
4️⃣ 主流程(第169-262行)
# ========== 1. 初始化 ==========
local_rank = init_distributed_mode()
setup_seed(42 + (...))
# ========== 2. 配置模型 ==========
omni_config = OmniConfig(hidden_size=768, num_hidden_layers=8, use_moe=False)
ckp_data = omni_checkpoint(...) if args.from_resume else None # 续训时加载
# ========== 3. 混合精度 ==========
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if "cpu" in args.device else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. Wandb ==========
if args.use_wandb and is_main_process():
import swanlab as wandb # 用 swanlab 替代 wandb
wandb.init(...)
# ========== 5. 模型 ==========
model, tokenizer = init_omni_model(omni_config, from_weight=args.from_weight, ...)
if args.use_compile == 1:
model = torch.compile(model) # 编译加速
# 手动移动冻结的编码器到 GPU
if model.audio_encoder is not None: model.audio_encoder.to(args.device)
if model.vision_encoder is not None: model.vision_encoder.to(args.device)
# mode 特殊处理
if args.mode == 'audio_proj':
for p in model.parameters(): p.requires_grad = False
for p in model.audio_proj.parameters(): p.requires_grad = True
# ========== 6. 数据集 ==========
train_ds = OmniDataset(args.data_path, tokenizer, ...)
# ========== 7. 优化器 ==========
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
# ========== 8. 续训恢复 ==========
if ckp_data:
model.load_state_dict(ckp_data['model'], strict=False)
optimizer.load_state_dict(ckp_data['optimizer'])
scaler.load_state_dict(ckp_data['scaler'])
start_epoch, start_step = ckp_data['epoch'], ckp_data.get('step', 0)
# ========== 9. DDP 包装 ==========
if dist.is_initialized():
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 10. 训练循环 ==========
for epoch in range(start_epoch, args.epochs):
setup_seed(42 + epoch)
batch_sampler = SkipBatchSampler(...)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, collate_fn=omni_collate_fn,
num_workers=args.num_workers, pin_memory=True)
train_epoch(epoch, loader, len(loader), start_step, wandb)
📋 本章关键概念速查
三阶段训练流程
Step 1: sft_t2a (文本→音频对齐)
数据: sft_t2a_mini.parquet
模式: all (全量训练)
权重: llm_768.pth → sft_zero_768.pth
目标: 学会"说话"(文本→音频码的映射)
Step 2: audio_proj (音频理解对齐)
数据: sft_a2a_mini.parquet
模式: audio_proj (只训练投影层)
权重: sft_zero_768.pth → sft_zero_768.pth (覆盖)
目标: 学会"听"(音频特征→LLM空间的对齐)
Step 3: 全参数微调
数据: sft_a2a_mini.parquet
模式: all (全量训练)
学习率: 2e-5 (比Step1小25倍)
目标: 联合优化"听"和"说"
损失函数组成
total_loss = text_loss + audio_loss + aux_loss
↓ ↓ ↓
文本预测 8层音频预测 MoE路由均衡
(CrossEntropy) (stop token 10x权重) (负载均衡)
✅ 第四章完成 | 下一篇:推理 + WebUI 代码
转载自 CSDN-专业IT技术社区
原文链接:https://blog.csdn.net/liukanghao/article/details/162153106



