使用的数据集为DuIE2.0,它是业界规模最大的中文关系抽取数据集,其schema在传统简单关系类型基础上添加了多元复杂关系类型,此外其构建语料来自百度百科、百度信息流及百度贴吧文本,全面覆盖书面化表达及口语化表达语料,能充分考察真实业务场景下的关系抽取能力。
1. 添加配置项
# config.py
TRAIN_JSON_PATH = './data/input/duie/duie_train.json'
TEST_JSON_PATH = './data/input/duie/duie_test.json'
DEV_JSON_PATH = './data/input/duie/duie_dev.json'
BERT_MODEL_NAME = 'bert-base-chinese'
2. 新建文件
# utils.py
import torch.utils.data as data
import pandas as pd
import random
from config import *
import json
from transformers import BertTokenizerFast
3. 加载关系表
def get_rel():
df = pd.read_csv(REL_PATH, names=['rel', 'id'])
return df['rel'].tolist(), dict(df.values)
id2rel, rel2id = get_rel()
print(id2rel) # 因为list本身的位置就有id的特性
print(rel2id)
exit()
['毕业院校', '嘉宾', '配音', '主题曲', '代言人', '所属专辑'.....
{'毕业院校': 0, '嘉宾': 1, '配音': 2, '主题曲': 3, '代言人': 4, .....
4. Dataset初始化
class Dataset(data.Dataset):
def __init__(self, type='train'): # type类型为加载的哪个文件
super().__init__()
_, self.rel2id = get_rel()
# 加载文件
if type == 'train':
file_path = TRAIN_JSON_PATH
elif type == 'test':
file_path = TEST_JSON_PATH
elif type == 'dev':
file_path = DEV_JSON_PATH
with open(file_path) as f:
self.lines = f.readlines() # 按行去读取文件,拿到训练集的长度==》lines
# 加载bert
self.tokenizer = BertTokenizerFast.from_pretrained(BERT_MODEL_NAME)
def __len__(self):
return len(self.lines)
def __getitem__(self, index):
line = self.lines[index]
info = json.loads(line)
tokenized = self.tokenizer(info['text'], return_offsets_mapping=True) # 第一个参数为要转换的文本,第二个参数是为了中英文混搭,使用偏移量记录词
info['input_ids'] = tokenized['input_ids'] # 追加给info info['offset_mapping'] = tokenized['offset_mapping']
print(info)
exit()
5. 尝试加载数据集
if __name__ == '__main__':
dataset = Dataset()
loader = data.DataLoader(dataset)
print(iter(loader).next()) # 指针往后挪动一位,取第一条他的数据
def parse_json(self, info):
text = info['text']
input_ids = info['input_ids']
dct = {
'text': text,
'input_ids': input_ids,
'offset_mapping': info['offset_mapping'],
'sub_head_ids': [],
'sub_tail_ids': [],
'triple_list': [],
'triple_id_list': []
}
for spo in info['spo_list']:
subject = spo['subject']
object = spo['object']['@value']
predicate = spo['predicate']
dct['triple_list'].append((subject, predicate, object))
# @todo
exit(dct)
return dct
# 计算 subject 实体位置
tokenized = self.tokenizer(subject, add_special_tokens=False)
sub_token = tokenized['input_ids']
sub_pos_id = self.get_pos_id(input_ids, sub_token)
if not sub_pos_id:
continue
sub_head_id, sub_tail_id = sub_pos_id
# 计算 object 实体位置
tokenized = self.tokenizer(object, add_special_tokens=False)
obj_token = tokenized['input_ids']
obj_pos_id = self.get_pos_id(input_ids, obj_token)
if not obj_pos_id:
continue
obj_head_id, obj_tail_id = obj_pos_id
# 数据组装dct['sub_head_ids'].append(sub_head_id)
dct['sub_tail_ids'].append(sub_tail_id)
dct['triple_id_list'].append((
[sub_head_id, sub_tail_id],
self.rel2id[predicate],
[obj_head_id, obj_tail_id],
))
source为原始的文本,elem为当前的subject的token。遍历原始的id,滑动窗口,每次找一段儿去校对,每次找elem对应长度的值,就算超过了length那也匹配不上。
def get_pos_id(self, source, elem):
for head_id in range(len(source)):
tail_id = head_id + len(elem)
if source[head_id:tail_id] == elem:
return head_id, tail_id - 1
def parse_json(self, info):
text = info['text']
input_ids = info['input_ids']
dct = {
'text': text,
'input_ids': input_ids,
'offset_mapping': info['offset_mapping'],
'sub_head_ids': [],
'sub_tail_ids': [],
'triple_list': [],
'triple_id_list': []
}
for spo in info['spo_list']:
subject = spo['subject']
object = spo['object']['@value']
predicate = spo['predicate']
dct['triple_list'].append((subject, predicate, object))
# 计算 subject 实体位置
tokenized = self.tokenizer(subject, add_special_tokens=False)
sub_token = tokenized['input_ids']
sub_pos_id = self.get_pos_id(input_ids, sub_token)
if not sub_pos_id:
continue
sub_head_id, sub_tail_id = sub_pos_id
# 计算 object 实体位置
tokenized = self.tokenizer(object, add_special_tokens=False)
obj_token = tokenized['input_ids']
obj_pos_id = self.get_pos_id(input_ids, obj_token)
if not obj_pos_id:
continue
obj_head_id, obj_tail_id = obj_pos_id
# 数据组装
dct['sub_head_ids'].append(sub_head_id)
dct['sub_tail_ids'].append(sub_tail_id)
dct['triple_id_list'].append((
[sub_head_id, sub_tail_id],
self.rel2id[predicate],
[obj_head_id, obj_tail_id],
))
exit(dct)
return dct
def get_pos_id(self, source, elem):
for head_id in range(len(source)):
tail_id = head_id + len(elem)
if source[head_id:tail_id] == elem:
return head_id, tail_id - 1
阅读量:2017
点赞量:0
收藏量:0