PyTorch实战:从混淆矩阵到F1,手把手实现多分类评估

📅 发布时间:2026/7/5 17:20:55 👁️ 浏览次数:
PyTorch实战:从混淆矩阵到F1,手把手实现多分类评估
1. 为什么我们需要这些“花里胡哨”的评估指标如果你刚开始接触机器学习分类任务可能会觉得模型预测完了算个准确率Accuracy不就完事了吗比如100张图片模型对了90张准确率90%听起来很不错。我以前也是这么想的直到在一个实际项目里踩了坑。那个项目是识别工厂流水线上的产品缺陷总共有10种缺陷类型。我训练了一个模型整体准确率高达95%兴冲冲地拿去给工程师看。结果人家现场一测发现模型把最严重、但出现频率很低的一种裂纹缺陷A类几乎全部漏检了因为A类缺陷在数据集中只占2%模型为了追求高准确率干脆倾向于不预测它反正猜错了也只影响2%的准确率。这时候95%的准确率就是个“虚荣指标”完全掩盖了模型在关键问题上的无能。这就是只依赖准确率的陷阱。在类别不平衡某些类别的样本特别少或者不同类别的错误代价不同比如医疗诊断中把有病判成没病比把没病判成有病严重得多的场景下准确率会严重失真。我们需要更精细的“手术刀”来解剖模型的性能。这套工具就是混淆矩阵Confusion Matrix以及由它衍生出的精确率Precision、召回率Recall和F1分数F1-Score。你可以这样理解它们的分工混淆矩阵是“体检报告”的原始数据它清清楚楚地告诉你模型在每个类别上预测对了多少预测错了多少以及错成了什么样子。而精确率、召回率和F1分数则是医生根据这份报告给你计算的几个关键健康指数。精确率Precision“宁缺毋滥”的挑剔鬼。对于模型预测为“A类”的所有样本精确率关心的是其中有多少是真正的“A类”。它衡量的是模型预测的精准度。精确率高意味着模型一旦说某个东西是A那它大概率真的是A很少“冤枉好人”。在我们缺陷检测的例子中对于A类缺陷我们希望模型预测的精确率越高越好否则会浪费大量人力去复查那些被误报的正常产品。召回率Recall“宁可错杀不可放过”的捕快。对于所有真正的“A类”样本召回率关心的是模型成功找出了多少。它衡量的是模型发现的覆盖率。召回率高意味着真实的A类缺陷绝大部分都被模型抓出来了很少“漏网之鱼”。在缺陷检测中召回率低是致命的意味着有缺陷的产品可能流向下游。F1分数F1-Score“端水大师”。精确率和召回率经常像跷跷板一个高了另一个就容易低。F1分数是它们的调和平均数目的是找到一个平衡点。当精确率和召回率都重要且你需要一个单一指标来综合评判时F1分数就非常有用。所以别再只盯着准确率了。接下来我就手把手带你在PyTorch里从最基础的张量操作开始一步步构建出这套完整的评估体系让你彻底搞懂每个数字是怎么算出来的。2. 基石亲手用PyTorch构建混淆矩阵一切评估的开始都源于混淆矩阵。很多教程会直接让你调用sklearn.metrics.confusion_matrix这当然方便但就像开车只懂踩油门和刹车不懂发动机原理遇到复杂路况就容易懵。我们今天要做的是用PyTorch最核心的张量操作之一——scatter函数来“徒手”打造一个动态生成混淆矩阵的函数。理解了这一步后面所有的计算你都会觉得顺理成章。2.1 理解混淆矩阵一张“对错账单”假设我们在做一个3分类任务猫、狗、鸟。模型对5个样本的预测结果和真实标签如下真实标签targets:[猫, 狗, 鸟, 狗, 猫]模型预测predictions:[猫, 猫, 鸟, 狗, 鸟]对应的混淆矩阵是一个3x3的表格真实 \ 预测预测为猫预测为狗预测为鸟真实为猫1(猫-猫对)01 (猫-鸟错)真实为狗1 (狗-猫错)1(狗-狗对)0真实为鸟001(鸟-鸟对)看懂了吗矩阵的行代表真实的类别列代表模型预测的类别。对角线上的数字加粗部分就是预测正确的样本数。其他位置则是各种错误的情况。比如第一行第三列的“1”就表示有1个真实是“猫”的样本被模型错误地预测成了“鸟”。我们的目标就是写一个函数输入predictions和targets这两个一维张量输出这样一个n_classes x n_classes的矩阵。2.2 核心武器深入理解scatter_函数这是实现的关键也是很多初学者觉得抽象的地方。官方文档的说法比较学术我用自己的经验给你翻译一下。scatter_(dim, index, src)这个操作可以想象成一次“定点填充”self: 待填充的目标张量通常是全零张量。dim: 指定在哪个维度上进行“索引定位”。index: 一个和src形状相同的张量里面的每个值都指明了在dim这个维度上要把src里对应的值放到self的哪个位置。src: 源数据可以是张量也可以是一个标量值。最常用的模式是dim1。我们来看一个我调试过无数次的例子保证你过目不忘import torch # 假设我们有4个样本3个类别 batch_size 4 n_classes 3 # 模型预测的类别索引0, 1, 2 predicted torch.tensor([0, 2, 1, 0]) # 形状: (4,) # 我们要生成一个4x3的one-hot编码矩阵 # 第一步创建一个全零的“画布” pre_mask torch.zeros(batch_size, n_classes) # 形状: (4, 3) print(初始画布 pre_mask:\n, pre_mask) # 第二步准备索引。scatter要求index和src形状匹配。 # predicted.view(-1, 1) 把 [0, 2, 1, 0] 变成 [[0], [2], [1], [0]]形状(4,1) index predicted.view(-1, 1) print(索引 index:\n, index) # 第三步执行填充。dim1表示在第二个维度列方向上根据index的值确定位置。 # src1 表示在所有指定位置填充数字1。 pre_mask.scatter_(dim1, indexindex, src1) print(填充后的 pre_mask (one-hot):\n, pre_mask)运行这段代码你会看到初始画布 pre_mask: tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]) 索引 index: tensor([[0], [2], [1], [0]]) 填充后的 pre_mask (one-hot): tensor([[1., 0., 0.], # 第0个样本在第0列填1 [0., 0., 1.], # 第1个样本在第2列填1 [0., 1., 0.], # 第2个样本在第1列填1 [1., 0., 0.]]) # 第3个样本在第0列填1看明白了吗dim1时规则是对于pre_mask中的第i行去index[i]找到列坐标然后在该位置填上src的值。index[i]的值是0就在第i行第0列填1值是2就在第i行第2列填1。这样我们就把类别索引高效地转换成了one-hot编码矩阵。这个操作是向量化的速度极快比用for循环一个个赋值优雅多了。2.3 实现混淆矩阵函数理解了scatter_构建混淆矩阵就水到渠成了。混淆矩阵的每个位置(i, j)统计的是真实类别为i且预测类别为j的样本数。我们可以利用scatter_进行两次“升维”统计。def confusion_matrix(predictions, targets, num_classes): 计算混淆矩阵 Args: predictions (Tensor): 模型预测的类别索引形状 [batch_size] targets (Tensor): 真实的类别索引形状 [batch_size] num_classes (int): 类别总数 Returns: conf_matrix (Tensor): 混淆矩阵形状 [num_classes, num_classes] # 确保输入是一维的 predictions predictions.flatten() targets targets.flatten() # 初始化一个全零的混淆矩阵 conf_matrix torch.zeros(num_classes, num_classes, dtypetorch.long) # 核心思想将每个样本的 (真实类别, 预测类别) 对视为混淆矩阵中的一个坐标点。 # 我们需要把这个坐标点“投射”到二维矩阵上并给这个位置的值1。 # 我们可以把整个batch的样本一次性“画”到矩阵上。 # 方法使用 scatter_add_或两次 scatter_ # 思路一利用线性索引更高效但稍复杂 # 思路二利用 one-hot 编码相乘直观易于理解 # 这里采用更直观的第二种方法它和我们后续计算其他指标的逻辑一脉相承。 # 1. 生成预测的one-hot矩阵 [batch_size, num_classes] pred_one_hot torch.zeros(predictions.size(0), num_classes, dtypetorch.long) pred_one_hot.scatter_(1, predictions.view(-1, 1), 1) # dim1, 按列索引填充1 # 2. 生成真实的one-hot矩阵 [batch_size, num_classes] targ_one_hot torch.zeros(targets.size(0), num_classes, dtypetorch.long) targ_one_hot.scatter_(1, targets.view(-1, 1), 1) # dim1, 按列索引填充1 # 3. 关键步骤计算外积的“批量”版本 # 其实更简单混淆矩阵的 (i, j) 元素 sum( (真实类别i) 且 (预测类别j) ) # 这等价于conf_matrix[i, j] (targ_one_hot[:, i] * pred_one_hot[:, j]).sum() # 我们可以用矩阵乘法优雅地实现 # (targ_one_hot.T) pred_one_hot # 因为 targ_one_hot[:, i] 是第i列转置后是第i行。 # 矩阵乘法的第i行第j列元素就是 targ_one_hot[:, i] 与 pred_one_hot[:, j] 的内积即我们要的求和。 conf_matrix targ_one_hot.T pred_one_hot return conf_matrix让我们用之前的猫狗鸟例子测试一下targets torch.tensor([0, 1, 2, 1, 0]) # 猫(0),狗(1),鸟(2),狗(1),猫(0) preds torch.tensor([0, 0, 2, 1, 2]) # 猫(0),猫(0),鸟(2),狗(1),鸟(2) cm confusion_matrix(preds, targets, num_classes3) print(混淆矩阵:\n, cm)输出应该就是我们在2.1节画出来的那个矩阵。有了这个基石我们就可以在上面轻松地计算出所有评估指标了。3. 从混淆矩阵到精确率、召回率与F1现在我们手里有了一张清晰的“对错账单”混淆矩阵就可以开始算账了。记住对于多分类问题我们通常是为每个类别单独计算精确率、召回率和F1这样就能一眼看出模型在哪个类别上表现好哪个类别上拉胯。3.1 拆解计算公式用生活例子理解我们继续用3分类的混淆矩阵CM为例假设类别索引是0, 1, 2。对于类别0真正例TP真实是0预测也是0。就是CM[0, 0]。假正例FP真实不是0但预测是0。就是CM[1, 0] CM[2, 0]第0列除了对角线以外的和。假反例FN真实是0但预测不是0。就是CM[0, 1] CM[0, 2]第0行除了对角线以外的和。所以类别0的精确率 P0 TP0 / (TP0 FP0) CM[0,0] / (CM[:,0].sum())类别0的召回率 R0 TP0 / (TP0 FN0) CM[0,0] / (CM[0,:].sum())看到了吗计算某个类别的精确率是看矩阵的“列”计算召回率是看矩阵的“行”。这个对应关系非常直观建议你背下来。3.2 批量计算所有类别的指标在代码里我们当然不会用一个for循环去算每个类别那样太低效了。PyTorch的广播和向量化操作可以让我们一次性算完。def calculate_metrics(conf_matrix): 根据混淆矩阵计算各类别的精确率、召回率、F1及整体准确率 Args: conf_matrix (Tensor): 混淆矩阵形状 [n, n] Returns: dict: 包含各类指标字典 n_classes conf_matrix.size(0) # 确保是浮点数方便除法 conf_matrix conf_matrix.float() # 1. 计算每个类别的TP, FP, FN # TP: 对角线元素 tp_per_class conf_matrix.diag() # 形状 [n_classes] # FP: 各列求和减去TP fp_per_class conf_matrix.sum(dim0) - tp_per_class # 形状 [n_classes] # FN: 各行求和减去TP fn_per_class conf_matrix.sum(dim1) - tp_per_class # 形状 [n_classes] # 2. 计算精确率和召回率 (处理除零情况) # 添加一个极小值epsilon防止除以0 epsilon 1e-7 precision_per_class tp_per_class / (tp_per_class fp_per_class epsilon) recall_per_class tp_per_class / (tp_per_class fn_per_class epsilon) # 3. 计算F1分数 f1_per_class 2 * (precision_per_class * recall_per_class) / (precision_per_class recall_per_class epsilon) # 4. 计算整体准确率 total_correct tp_per_class.sum() total_samples conf_matrix.sum() accuracy total_correct / total_samples # 5. 计算宏平均Macro-average和微平均Micro-average # 宏平均先对每个类别计算指标再求平均。平等看待每个类别。 macro_precision precision_per_class.mean() macro_recall recall_per_class.mean() macro_f1 f1_per_class.mean() # 微平均先汇总所有类别的TP, FP, FN再计算一个总的指标。受大类别影响大。 total_tp tp_per_class.sum() total_fp fp_per_class.sum() total_fn fn_per_class.sum() micro_precision total_tp / (total_tp total_fp epsilon) micro_recall total_tp / (total_tp total_fn epsilon) micro_f1 2 * (micro_precision * micro_recall) / (micro_precision micro_recall epsilon) metrics { confusion_matrix: conf_matrix, class_precision: precision_per_class, class_recall: recall_per_class, class_f1: f1_per_class, accuracy: accuracy, macro_precision: macro_precision, macro_recall: macro_recall, macro_f1: macro_f1, micro_precision: micro_precision, micro_recall: micro_recall, micro_f1: micro_f1, } return metrics这段代码有几个我踩过坑才学到的细节除零保护epsilon是必须的。当一个类别在验证集中没有出现真实样本数为0或者模型从未预测过这个类别预测样本数为0时分母会为0导致NaN。加一个极小值可以稳定计算。宏平均 vs 微平均这是多分类评估中非常重要的概念。在类别不平衡时它们给出的信号可能截然不同。宏平均把所有类别一视同仁。即使某个小类别只有10个样本它在平均值中的权重也和有1000个样本的大类别一样。这能反映模型在所有类别上的整体表现适合你关心每个类别性能的场景。微平均汇总所有类别的统计量再计算。由于大类别的样本数多它们的统计量会主导最终结果。这个指标更接近整体准确率的视角。 选择哪个取决于你的任务。比如在缺陷检测中我们极度关心那个只占2%的A类缺陷那么宏平均F1就更重要。4. 实战整合一个完整的PyTorch验证循环示例理论说再多不如一行代码。下面我把上面的所有片段整合到一个完整的模型验证函数里并加上详细的注释和我在实际项目中总结的最佳实践。import torch from torch.utils.data import DataLoader def evaluate_model(model, dataloader, device, num_classes): 在给定数据集上评估模型返回详细的分类指标。 Args: model: 待评估的PyTorch模型 dataloader: 数据加载器 (验证集或测试集) device: 计算设备 (cuda 或 cpu) num_classes: 分类类别数 Returns: dict: 包含所有评估指标的字典 model.eval() # 将模型设置为评估模式关闭Dropout等 # 初始化一个全零的混淆矩阵用于累积整个数据集的结果 conf_matrix torch.zeros(num_classes, num_classes, dtypetorch.long).to(device) with torch.no_grad(): # 关闭梯度计算节省内存和计算 for batch_idx, (inputs, targets) in enumerate(dataloader): inputs, targets inputs.to(device), targets.to(device) # 前向传播 outputs model(inputs) # 获取预测类别最大概率的索引 _, predictions torch.max(outputs, dim1) # --- 核心批量更新混淆矩阵 (高效版本) --- # 我们不再为每个batch生成one-hot而是直接累加。 # 利用线性索引将二维坐标 (target, prediction) 映射到一维。 # 公式index target * num_classes prediction # 然后使用 torch.bincount 统计每个索引出现的次数再重塑为矩阵。 # 但更简单的方法是直接使用我们之前写的函数但需要处理device。 # 这里展示另一种更PyTorch风格的方式直接操作张量 # 确保 predictions 和 targets 在同一设备上 # 构建一个 (batch_size, 2) 的索引对 indices torch.stack([targets, predictions], dim1) # shape: [batch_size, 2] # 将这些索引对加到混淆矩阵中 # 注意由于可能有重复的(target, pred)对我们需要用index_add_或者用bincount # 这里使用 for 循环简单明了对于不是特别大的batch和类别数可以接受 for t, p in indices: conf_matrix[t.long(), p.long()] 1 # 将累积的混淆矩阵移回CPU进行计算如果之前用了GPU conf_matrix conf_matrix.cpu() # 调用我们之前写好的函数计算所有指标 metrics calculate_metrics(conf_matrix) # 打印一份美观的报表 print(\n *60) print(模型评估报告) print(*60) print(f整体准确率: {metrics[accuracy]:.4f}) print(f宏平均精确率: {metrics[macro_precision]:.4f}) print(f宏平均召回率: {metrics[macro_recall]:.4f}) print(f宏平均F1分数: {metrics[macro_f1]:.4f}) print(f微平均F1分数: {metrics[micro_f1]:.4f}) print(-*60) print(各类别详细指标:) for i in range(num_classes): print(f 类别 {i}: fPrecision{metrics[class_precision][i]:.4f}, fRecall{metrics[class_recall][i]:.4f}, fF1{metrics[class_f1][i]:.4f}) print(*60) # 可以选择打印混淆矩阵当类别数不多时 if num_classes 10: print(\n混淆矩阵 (行: 真实标签, 列: 预测标签):) print(metrics[confusion_matrix].int().numpy()) return metrics # 假设你已经有 model, val_loader, device, num_classes # metrics evaluate_model(model, val_loader, device, num_classes)这个evaluate_model函数可以直接嵌入你的项目。它会在每个batch中动态更新混淆矩阵最后一次性计算出所有指标。打印出来的报告非常清晰能帮你快速定位模型的问题。比如如果发现某个类别的召回率特别低你就知道下一步该去增加这个类别的训练数据或者尝试一些针对类别不平衡的损失函数如Focal Loss。5. 避坑指南与高级技巧自己实现评估流程最大的好处是灵活和深度可控。这里分享几个我趟过的雷和对应的解决方案。坑1指标波动与不可比性在训练过程中你可能会在验证集上计算这些指标。但要注意如果你的验证集采用随机采样且样本不多每次验证的指标可能会有较大波动。这不一定代表模型性能不稳定可能是数据分布的小幅随机性导致的。解决方案固定验证集的随机种子或者使用交叉验证来获得更稳定的评估。坑2多标签与多分类的混淆我们讨论的是单标签多分类即一个样本只属于一个类别。如果你的任务是多标签分类一个样本可以属于多个类别比如一张图片同时有“猫”和“沙发”两个标签那么混淆矩阵和这里的计算方式完全不适用。多标签任务通常使用基于阈值的指标如平均精度Average Precision。坑3scatter_函数索引越界这是新手常犯的错误。scatter_(dim, index, src)中的index张量其值必须在目标张量self在dim维度的大小范围内。比如self形状是(4, 10)dim1那么index里的每个值必须在0到9之间。如果预测的类别索引是10就会报错。解决方案在数据预处理和模型输出层确保类别索引从0开始且连续。高级技巧使用torchmetrics库虽然我们从头实现了一遍但在生产环境或快速原型中我强烈推荐使用torchmetrics库。它封装了各种指标的计算支持分布式训练并且自动处理设备移动和数值稳定性。使用起来非常简单from torchmetrics import Accuracy, Precision, Recall, F1Score, ConfusionMatrix # 为多分类任务初始化指标 num_classes 10 accuracy Accuracy(taskmulticlass, num_classesnum_classes) precision Precision(taskmulticlass, num_classesnum_classes, averagemacro) recall Recall(taskmulticlass, num_classesnum_classes, averagemacro) f1 F1Score(taskmulticlass, num_classesnum_classes, averagemacro) confmat ConfusionMatrix(taskmulticlass, num_classesnum_classes) # 在验证循环中更新 for inputs, targets in val_loader: outputs model(inputs) preds outputs.argmax(dim1) accuracy.update(preds, targets) precision.update(preds, targets) recall.update(preds, targets) f1.update(preds, targets) confmat.update(preds, targets) # 计算最终结果 final_acc accuracy.compute() final_precision precision.compute() # ... 其他指标torchmetrics的好处是代码简洁且内部优化得很好。但了解其背后的原理能让你在它出问题比如对某个特殊任务支持不好时有能力自己动手写一个或者更好地理解它输出的结果。这就是为什么我坚持要带你先从底层实现走一遍的原因。