RotatE实战:用Python复现知识图谱中的旋转嵌入模型(附代码)

📅 发布时间:2026/7/5 1:50:26 👁️ 浏览次数:
RotatE实战:用Python复现知识图谱中的旋转嵌入模型(附代码)
RotatE实战用Python复现知识图谱中的旋转嵌入模型附代码知识图谱作为结构化的语义知识库在智能搜索、推荐系统等领域扮演着核心角色。然而现实中的知识图谱往往存在大量缺失的链接如何精准预测这些未知关系一直是工业界和学术界共同关注的难题。传统的基于规则或统计的方法往往力不从心而知识表示学习特别是嵌入模型为我们提供了一条优雅的解决路径。它将实体和关系映射到连续的向量空间通过向量运算来建模和推断潜在的事实。在众多嵌入模型中RotatE以其简洁而强大的数学形式脱颖而出。它将关系建模为复向量空间中的旋转这一灵感来源于欧拉公式不仅数学上优美更在理论上证明能够同时建模对称、反对称、逆反和组合等多种复杂的关系模式。对于希望将前沿算法落地应用的开发者而言理解其原理并亲手实现是掌握其精髓的最佳方式。本文正是为这样的实践者准备的。我们将完全从实战角度出发跳过冗长的理论推导直接进入代码层面。假设你已具备基本的Python编程能力和对深度学习框架如PyTorch的初步了解我们将一步步搭建完整的RotatE模型涵盖从环境搭建、数据预处理、模型定义、训练技巧到结果可视化的全流程。你会发现这个看似高深的模型其核心代码可能比你想象的要简洁得多。1. 环境准备与数据理解在开始编写模型之前确保有一个干净、可复现的开发环境是成功的第一步。我们推荐使用conda或venv创建独立的Python环境以避免包版本冲突。1.1 创建环境与安装依赖首先创建一个新的conda环境并激活它。这里我们使用Python 3.8这是一个在稳定性和兼容性之间取得良好平衡的版本。conda create -n rotate_env python3.8 -y conda activate rotate_env接下来安装核心的深度学习框架和科学计算库。PyTorch因其动态图特性在研究和原型开发中非常友好。根据你的CUDA版本如果有GPU去PyTorch官网获取对应的安装命令。以下以CPU版本为例pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu pip install numpy pandas tqdm scikit-learn matplotlib提示如果拥有NVIDIA GPU并已安装CUDA强烈建议安装对应的PyTorch GPU版本以大幅加速训练过程。可以访问 PyTorch 官网获取准确的安装命令。1.2 认识知识图谱数据集知识图谱数据通常以三元组(头实体, 关系, 尾实体)的形式存储。我们将使用一个经典的基准数据集FB15k-237的子集作为示例。这个数据集来自Freebase包含了实体间的各种复杂关系。数据文件通常包含三个部分train.txt: 用于模型训练的三元组。valid.txt: 用于在训练过程中调整超参数、进行早停的验证集。test.txt: 用于最终评估模型泛化能力的测试集。每一行都是一个三元组格式为头实体\t关系\t尾实体。我们的首要任务是将这些文本形式的实体和关系映射为模型可以处理的整数索引ID。import codecs def load_data(file_path): 加载三元组数据并构建实体和关系的词汇表。 entities, relations set(), set() triples [] with codecs.open(file_path, r, encodingutf-8) as f: for line in f: h, r, t line.strip().split(\t) triples.append((h, r, t)) entities.update([h, t]) relations.add(r) # 创建从字符串到ID的映射字典 entity2id {e: i for i, e in enumerate(sorted(entities))} relation2id {r: i for i, r in enumerate(sorted(relations))} # 将三元组转换为ID形式 triples_id [(entity2id[h], relation2id[r], entity2id[t]) for h, r, t in triples] return triples_id, entity2id, relation2id # 示例加载训练集 train_triples, e2id, r2id load_data(data/FB15k-237/train.txt) print(f实体数量: {len(e2id)}) print(f关系数量: {len(r2id)}) print(f训练三元组数量: {len(train_triples)})运行上述代码后你将得到实体和关系的总数以及被转换为整数ID的三元组列表。这是所有后续操作的基础。2. RotatE模型的核心实现RotatE的核心思想非常直观在复向量空间中将头实体向量乘以关系向量其模长为1得到的结果应该接近尾实体向量。这个乘法操作在复数域中即表现为旋转。2.1 复数运算与距离函数在PyTorch中我们可以利用其内置的复数张量支持torch.complex64或torch.complex128来优雅地实现。但为了更清晰地展示原理我们也可以将复数实部和虚部分开存储手动实现运算。这里我们采用PyTorch原生复数张量它更高效且代码更简洁。首先我们需要定义RotatE的距离函数。对于一个正样本三元组(h, r, t)RotatE期望t ≈ h ◦ r其中◦是逐元素乘法Hadamard积h, r, t均为复数向量且关系向量r的每个分量的模|r_i| 1。距离函数通常采用模长之差d(h, r, t) || h ◦ r - t ||在实际的损失函数中我们使用L1或L2范数。以下是使用L2范数欧氏距离的实现import torch import torch.nn as nn import torch.nn.functional as F class RotatE(nn.Module): def __init__(self, num_entities, num_relations, embedding_dim, margin6.0): 初始化RotatE模型。 Args: num_entities: 实体总数 num_relations: 关系总数 embedding_dim: 嵌入向量的维度复数维度实际存储大小为2*embedding_dim margin: 间隔损失函数中的margin参数 super(RotatE, self).__init__() self.num_entities num_entities self.num_relations num_relations self.embedding_dim embedding_dim self.margin margin # 实体嵌入每个实体对应一个复数向量 # 使用实部虚部分开初始化然后组合成复数 self.entity_emb_real nn.Embedding(num_entities, embedding_dim) self.entity_emb_imag nn.Embedding(num_entities, embedding_dim) # 关系嵌入每个关系对应一个复数向量其模长被约束为1 # 我们存储关系旋转的角度θ关系向量 r e^(iθ) cosθ i sinθ self.relation_angle nn.Embedding(num_relations, embedding_dim) # 初始化参数 self._init_weights() def _init_weights(self): # Xavier初始化实体嵌入 nn.init.xavier_uniform_(self.entity_emb_real.weight) nn.init.xavier_uniform_(self.entity_emb_imag.weight) # 将关系角度初始化为[-pi, pi]之间的均匀分布 nn.init.uniform_(self.relation_angle.weight, -3.14, 3.14) def get_entity_embedding(self, idx): 获取复数形式的实体嵌入 real self.entity_emb_real(idx) imag self.entity_emb_imag(idx) return torch.complex(real, imag) def get_relation_embedding(self, idx): 获取复数形式的关系嵌入模长强制为1 angle self.relation_angle(idx) # r cosθ i sinθ return torch.complex(torch.cos(angle), torch.sin(angle)) def forward(self, head, relation, tail, modesingle): 计算三元组的得分距离。得分越低表示三元组成立的可能性越大。 Args: head: 头实体ID [batch_size] relation: 关系ID [batch_size] tail: 尾实体ID [batch_size] mode: single 计算给定三元组的距离head_batch 或 tail_batch 用于链接预测评估 Returns: score: 三元组的距离得分 [batch_size] h self.get_entity_embedding(head) # [batch_size, embed_dim] r self.get_relation_embedding(relation) # [batch_size, embed_dim] t self.get_entity_embedding(tail) # [batch_size, embed_dim] # RotatE的核心操作h ◦ r rotated_h h * r # 复数逐元素乘法 # 计算旋转后的头实体与尾实体之间的L2距离 distance torch.abs(rotated_h - t) # [batch_size, embed_dim] score torch.norm(distance, p2, dim-1) # [batch_size] return score在上面的代码中我们通过关系角度θ来参数化关系嵌入并利用cosθ i sinθ确保其模长恒为1这比直接优化一个复数向量并添加模长约束要简单稳定。2.2 自对抗负采样技术原始论文中一个关键的创新点是自对抗负采样。传统的负采样是均匀随机的但训练后期很多负样本过于“简单”得分很高对模型提升没有帮助。自对抗负采样根据当前模型为每个候选负样本生成的概率来采样更关注那些“难以区分”的负样本即模型当前认为可能成立但实际上不成立的三元组从而提供更有信息量的梯度。其采样概率公式为p((h, r, t) | (h, r, t)) ∝ exp(α * f(h, r, t))其中f是模型得分函数距离的负数α是温度参数。在实现时完全按照分布采样开销较大。论文采用了一种巧妙的加权损失方法仍然均匀采样一批负样本但在计算损失时用上述概率对每个负样本的损失进行加权。这样难以区分的负样本会获得更大的权重。class RotatEWithLoss(nn.Module): def __init__(self, model, adv_temperature1.0, neg_per_pos256): super().__init__() self.model model self.adv_temp adv_temperature self.neg_per_pos neg_per_pos # 每个正样本对应的负样本数 def get_negative_score(self, head, relation, tail, mode): 为给定的正样本生成负样本并计算得分 batch_size head.size(0) neg_score [] # 这里简化实现为每个正样本随机替换头实体或尾实体来生成负样本 # 更完整的实现会考虑“head_batch”和“tail_batch”两种腐蚀方式 for i in range(self.neg_per_pos): if mode head-batch: # 腐蚀头实体 neg_head torch.randint(0, self.model.num_entities, (batch_size,)).to(head.device) neg_score.append(self.model(neg_head, relation, tail, modehead_batch)) else: # tail-batch # 腐蚀尾实体 neg_tail torch.randint(0, self.model.num_entities, (batch_size,)).to(head.device) neg_score.append(self.model(head, relation, neg_tail, modetail_batch)) neg_score torch.stack(neg_score, dim1) # [batch_size, neg_per_pos] return neg_score def forward(self, pos_head, pos_relation, pos_tail): 计算自对抗负采样损失。 batch_size pos_head.size(0) device pos_head.device # 1. 计算正样本得分 pos_score self.model(pos_head, pos_relation, pos_tail) # [batch_size] # 2. 生成负样本并计算其得分这里以腐蚀尾实体为例 neg_score self.get_negative_score(pos_head, pos_relation, pos_tail, modetail-batch) # [batch_size, neg_per_pos] # 3. 计算自对抗权重 # 将得分取负号因为距离越小得分低表示越可能成立我们想要其权重高 # 使用softmax计算权重温度参数adv_temp控制分布的尖锐程度 weight F.softmax(neg_score * self.adv_temp, dim-1).detach() # [batch_size, neg_per_pos] # 4. 计算加权间隔损失 (Margin Ranking Loss) # 目标正样本得分 margin 负样本得分 loss 0 for i in range(self.neg_per_pos): loss (weight[:, i] * F.relu(pos_score self.model.margin - neg_score[:, i])).mean() loss loss / self.neg_per_pos return loss注意上述负采样实现是一个简化版本主要用于说明自对抗加权的思想。在实际完整的训练循环中需要更高效地批量生成和处理负样本并同时考虑腐蚀头实体和尾实体两种方式。3. 模型训练与调优实战有了模型和损失函数我们就可以开始训练了。训练过程涉及数据加载、优化器选择、学习率调度以及训练循环的构建。3.1 构建数据管道我们需要一个DataLoader来高效地提供训练数据。这里我们使用PyTorch的Dataset和DataLoader。from torch.utils.data import Dataset, DataLoader class KGDataset(Dataset): def __init__(self, triples): self.triples triples # list of (h_id, r_id, t_id) def __len__(self): return len(self.triples) def __getitem__(self, idx): h, r, t self.triples[idx] # 返回张量 return torch.LongTensor([h]), torch.LongTensor([r]), torch.LongTensor([t]) # 创建数据集和数据加载器 train_dataset KGDataset(train_triples) train_loader DataLoader(train_dataset, batch_size1024, shuffleTrue, num_workers4)3.2 训练循环与关键超参数训练循环是模型学习的核心。我们将设置优化器、学习率并监控损失的变化。def train_model(model, loss_model, train_loader, epochs500, lr0.0005): device torch.device(cuda if torch.cuda.is_available() else cpu) print(fUsing device: {device}) model.to(device) loss_model.to(device) optimizer torch.optim.Adam(model.parameters(), lrlr) # 使用学习率预热和余弦退火调度 scheduler torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_050, T_mult2) model.train() for epoch in range(epochs): total_loss 0.0 for batch_idx, (head, relation, tail) in enumerate(train_loader): head, relation, tail head.squeeze().to(device), relation.squeeze().to(device), tail.squeeze().to(device) optimizer.zero_grad() loss loss_model(head, relation, tail) loss.backward() # 梯度裁剪防止梯度爆炸 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm1.0) optimizer.step() total_loss loss.item() if batch_idx % 100 0: print(fEpoch [{epoch1}/{epochs}], Step [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}) avg_loss total_loss / len(train_loader) current_lr scheduler.get_last_lr()[0] scheduler.step() print(fEpoch [{epoch1}/{epochs}] finished. Average Loss: {avg_loss:.4f}, LR: {current_lr:.6f}) print(Training finished.) return model超参数对模型性能有决定性影响。以下是一些关键参数及其典型取值范围你可以将其作为调优的起点超参数说明典型取值范围/建议embedding_dim实体和关系嵌入的维度复数维度500 - 2000 (维度越高模型容量越大但也更容易过拟合)margin间隔损失中的间隔值6.0 - 24.0 (需与距离函数的尺度匹配)batch_size训练批大小256 - 2048 (取决于GPU内存)learning_rate初始学习率0.0001 - 0.001adv_temperature自对抗负采样的温度参数α0.5 - 2.0 (值越大越关注困难负样本)neg_per_pos每个正样本对应的负样本数128 - 1024调优是一个迭代过程。一个实用的策略是先在一个较小的数据集或子集上用较少的epoch快速跑通流程然后固定其他参数每次只调整1-2个参数观察验证集指标如MRR Hits10的变化。4. 评估、可视化与进阶技巧模型训练完成后我们需要评估其链接预测的性能并可视化学习到的嵌入以更直观地理解模型学到了什么。4.1 链接预测评估标准知识图谱链接预测的评估通常采用排名指标。对于测试集中的每个三元组(h, r, t)我们进行以下操作损坏尾实体固定(h, r)用知识库中所有实体替换t计算得分然后将真实尾实体t的得分在所有候选实体中进行升序排序距离越小排名越靠前。损坏头实体固定(r, t)用所有实体替换h进行类似排序。记录真实实体在两次排名中的位置。常用的评估指标有Mean Rank (MR): 真实实体排名的平均值。越小越好。Mean Reciprocal Rank (MRR): 真实实体排名倒数的平均值。越大越好最大为1。Hitsk: 真实实体排名在前k名内的比例通常k1, 3, 10。越大越好。注意在计算排名时需要采用“过滤式”设置即从候选列表中移除那些在训练集、验证集或测试集中存在的其他正确三元组避免因它们的存在而错误地降低真实三元组的排名。4.2 嵌入空间可视化将高维嵌入降维到2D或3D进行可视化可以帮助我们定性地评估模型。例如我们可以检查逆关系关系r和其逆关系r_inv的嵌入是否大致满足r ≈ conjugate(r_inv)共轭即旋转角度相反对称关系对称关系的嵌入角度是否集中在0或π附近即cosθ ≈ ±1我们可以使用PCA或t-SNE对实体嵌入进行降维并用不同颜色或形状标记不同类型的实体例如根据其所属的类别。import matplotlib.pyplot as plt from sklearn.manifold import TSNE def visualize_embeddings(entity_emb, relation_angle, entity_labelsNone, top_n_relations5): 可视化实体嵌入和关系角度。 entity_emb: 实体嵌入张量 [num_entities, embed_dim*2] (实部虚部拼接) relation_angle: 关系角度张量 [num_relations, embed_dim] # 1. 实体嵌入可视化 (使用t-SNE降维) tsne TSNE(n_components2, perplexity30, random_state42) entity_2d tsne.fit_transform(entity_emb.detach().cpu().numpy()) plt.figure(figsize(15, 5)) plt.subplot(1, 2, 1) plt.scatter(entity_2d[:, 0], entity_2d[:, 1], alpha0.6, s5) plt.title(Entity Embeddings (t-SNE)) plt.xlabel(Dimension 1) plt.ylabel(Dimension 2) # 2. 关系角度可视化 (取前top_n个关系展示其平均角度或角度分布) plt.subplot(1, 2, 2) # 计算每个关系角度的平均弧度或主要成分 mean_angle relation_angle.mean(dim1).detach().cpu().numpy()[:top_n_relations] relations [fRel{i} for i in range(top_n_relations)] plt.bar(relations, mean_angle) plt.title(fMean Rotation Angle of Top-{top_n_relations} Relations) plt.ylabel(Angle (radians)) plt.xticks(rotation45) plt.tight_layout() plt.show() # 假设model是训练好的RotatE模型 # entity_emb torch.cat([model.entity_emb_real.weight.data, model.entity_emb_imag.weight.data], dim1) # relation_angle model.relation_angle.weight.data # visualize_embeddings(entity_emb, relation_angle)4.3 实战中的技巧与避坑指南在复现和调优RotatE模型时有几个细节至关重要初始化策略关系角度θ的初始化范围[-π, π]是合理的。实体嵌入的初始化则可以采用Xavier或Kaiming初始化确保初始尺度合适。梯度问题与归一化虽然关系嵌入通过cosθ i sinθ天然归一化但实体嵌入的模长可能在训练过程中变得非常大或非常小导致距离函数失去意义。一个常见的技巧是定期对实体嵌入进行归一化例如每训练几个epoch后将实体嵌入除以其L2范数或者在对距离函数计算前对h和t进行归一化。损失函数的选择除了间隔损失Margin Ranking Loss也可以尝试使用负对数似然损失NLL或交叉熵损失将链接预测视为一个分类问题。不同的损失函数可能适用于不同的数据集和场景。复杂度的考量RotatE的复杂度是O(d)其中d是嵌入维度这比许多模型如ConvE要低。但在处理超大规模知识图谱如Wikidata时即使O(d)的复杂度也可能带来挑战。此时可以考虑采用混合精度训练AMP来减少内存占用并加速或者使用分批负采样策略而不是为每个正样本生成数百个负样本。我在自己的实验中发现学习率调度和自对抗采样的温度参数α对最终结果影响非常显著。一个过高的α会使模型过于关注极少数最困难的负样本而忽略了其他有信息的样本导致训练不稳定。通常从α1.0开始根据验证集MRR进行调整是一个不错的起点。