OpenGait(步态识别框架)的配置项说明

📅 发布时间:2026/7/5 11:09:08 👁️ 浏览次数:
OpenGait(步态识别框架)的配置项说明
一、核心配置模块解读1. data_cfg数据配置核心作用定义数据集来源、加载方式、测试集等基础数据参数。表格参数说明示例dataset_name训练数据集名称仅支持 CASIA-B/OUMVLPCASIA-Bdataset_root数据集存储路径/data/CASIA-Bdataset_partition数据集划分文件划分训练 / 测试集./datasets/CASIA-B/CASIA-B.jsonnum_workers数据加载的线程数8根据 CPU 核心数调整cache是否将数据全加载到内存加速训练需大内存True/Falsetest_dataset_name测试数据集名称CASIA-B2. loss_cfg损失函数配置核心作用定义训练使用的损失函数支持 TripletLoss/CrossEntropyLoss支持多损失加权。表格参数说明示例type损失函数类型TripletLoss/CrossEntropyLossloss_term_weight损失权重多损失时调整各损失占比1.0TripletLoss、0.1CrossEntropyLosslog_prefix损失日志前缀便于区分不同损失triplet/softmaxmarginTripletLoss 专属三元组损失的边际值0.2scaleCrossEntropyLoss 专属交叉熵损失的缩放系数163. optimizer_cfg优化器配置核心作用定义优化器类型及参数对齐 PyTorch 原生优化器。表格参数说明示例solver优化器类型SGD/Adamlr学习率0.1SGD、1e-4Adammomentum动量SGD 专属0.9weight_decay权重衰减防止过拟合0.00054. scheduler_cfg学习率调度器核心作用定义学习率衰减策略对齐 PyTorch 原生调度器。表格参数说明示例scheduler调度器类型MultiStepLR/CosineAnnealingLRmilestonesMultiStepLR 专属学习率衰减的迭代节点[20000, 40000]gamma衰减系数0.1每次衰减为原学习率的 10%5. model_cfg模型配置核心作用定义模型结构需参考框架的 Model Library。表格核心参数说明示例model模型名称如 Baseline、GaitSetBaselinebackbone_cfg骨干网络配置通道数、层结构in_channels:1, layers_cfg: [BC-64, M, ...]bin_num特征分箱数步态特征编码[16,8,4,2,1]6. evaluator_cfg评估器配置核心作用定义模型评估的规则推理方式、指标、 checkpoint 加载等。表格关键参数说明示例restore_hint加载的 checkpoint 迭代数 / 路径60000加载第 6 万迭代的权重save_name实验名称用于输出目录Baseline_CASIA-Beval_func评估函数CASIA-B 用 identificationevaluate_indoor_datasetsampler推理采样器配置type: InferenceSampler, sample_type: all_orderedmetric距离计算方式euc 欧氏距离 /cos 余弦距离euctransform数据预处理切黑边 / 不切BaseSilCuttingTransform切黑边7. trainer_cfg训练器配置核心作用定义训练流程迭代数、采样器、 checkpoint 保存、BN 同步等。表格关键参数说明示例total_iter总训练迭代数60000log_iter日志打印间隔100每 100 迭代打印一次save_iter权重保存间隔10000每 1 万迭代保存一次sampler训练采样器TripletSamplerbatch_size: [8,16]8 个身份每个身份 16 个序列sample_type训练帧采样方式fixed_unordered固定帧数随机选帧sync_BN多卡同步 BNTrue多卡训练建议开启with_test训练中是否穿插测试False默认关闭避免拖慢训练二、关键参数重点说明1. 采样器sampler核心参数训练 / 评估的采样器是步态识别的核心需重点理解表格场景sample_type 取值含义训练fixed_unordered固定帧数如 30 帧随机选取无序训练unfixed_ordered帧数在 [min,max] 间随机按自然顺序选帧评估all_ordered用完整序列按自然顺序输入保证测试一致性训练 batch_size 格式为[P,K]P一个 batch 中的身份数如 8K每个身份的序列数如 16需结合硬件显存调整P×K 越大显存占用越高。2. 输出目录规则输出目录 output/${dataset_name}/${model}/${save_name}例如output/CASIA-B/Baseline/Baseline_CASIA-B包含log训练日志checkpoint模型权重summary可视化 / 评估结果。3. 优先级规则自定义配置会覆盖default.yaml中的默认配置需注意若未定义某参数自动使用default.yaml的默认值自定义参数与默认参数冲突时以自定义为准。三、配置文件编写规范1. 基础结构所有配置需按模块分层编写data_cfg/loss_cfg/...示例框架yamldata_cfg: dataset_name: CASIA-B dataset_root: /path/to/CASIA-B num_workers: 8 dataset_partition: ./datasets/CASIA-B/CASIA-B.json loss_cfg: - type: TripletLoss loss_term_weight: 1.0 margin: 0.2 log_prefix: triplet - type: CrossEntropyLoss loss_term_weight: 0.1 scale: 16 log_prefix: softmax # 其他模块optimizer/scheduler/model/evaluator/trainer按上述规则补充2. 适配不同数据集的注意点表格数据集关键调整项CASIA-Beval_func: identificationmetric: eucOUMVLP增大 batch_size如 [16,16]调整 total_iter如 1200003. 多卡训练配置需开启trainer_cfg.sync_BN: True并调整num_workers建议为卡数 ×4示例yamltrainer_cfg: sync_BN: True enable_float16: True # 混合精度训练节省显存 data_cfg: num_workers: 16 # 2卡×8四、常见问题与调优建议显存不足降低trainer_cfg.sampler.batch_size如从 [8,16] 改为 [4,8]开启enable_float16: True混合精度减小frames_num_fixed训练帧数如从 30 改为 20。训练不收敛调整 TripletLoss 的 margin如 0.2→0.1增大学习率SGD 从 0.1→0.2或调整权重衰减检查数据集划分文件是否正确dataset_partition。评估精度低评估时用sample_type: all_ordered完整序列切换 metriceuc/cosCASIA-B 优先 euc确保模型权重加载正确restore_hint 路径 / 迭代数无误。