Comments (3)
比赛结束后测试集的标签发放给我们了,然后我们写了一段代码计算。具体的代码可以参考:
import re
def get_entities(lines):
result = []
cur = 0
for line in lines:
entities = []
last_end = 0
for m in re.finditer(r"{{(.*?)::?(.*?)}}", line):
label = m.group(1).upper()
word = m.group(2)
cur += m.start() - last_end
last_end = m.end()
if word:
entities.append((label, cur, cur + len(word)))
cur += len(word)
cur += len(line) - last_end
result.extend(entities)
return result
def getlines(path):
with open(path) as f:
lines = f.read().split("\n")
lines = list(filter(lambda line: line, lines))
return lines
def main():
ground_truth_path = 'zs100w_0921_wyq_up.txt'
pred_path = 'result.txt'
ground_truth_lines = getlines(ground_truth_path)
pred_lines = getlines(pred_path)
for i, (gtl, pl) in enumerate(zip(ground_truth_lines, pred_lines)):
gtl_no_label = re.sub(r"{{(.*?)::?(.*?)}}", r'\2', gtl)
pl_no_label = re.sub(r"{{(.*?)::?(.*?)}}", r'\2', pl)
assert gtl_no_label == pl_no_label, f"Different data in row {i}: \n{gtl} \n{pl}"
true_entities = set(get_entities(ground_truth_lines))
pred_entities = set(get_entities(pred_lines))
nb_correct = len(true_entities & pred_entities)
nb_pred = len(pred_entities)
nb_true = len(true_entities)
p = nb_correct / nb_pred if nb_pred > 0 else 0
r = nb_correct / nb_true if nb_true > 0 else 0
score = 2 * p * r / (p + r) if p + r > 0 else 0
print('P', p)
print('R', r)
print('F1', score)
true_entities_dict = {}
for t, start, end in true_entities:
if t not in true_entities_dict:
true_entities_dict[t] = set()
true_entities_dict[t].add((start, end))
pred_entities_dict = {}
for t, start, end in pred_entities:
if t not in pred_entities_dict:
pred_entities_dict[t] = set()
pred_entities_dict[t].add((start, end))
nb_correct_dict = {k: len(true_entities_dict[k] & pred_entities_dict[k]) for k in true_entities_dict}
nb_pred_dict = {k: len(pred_entities_dict[k]) for k in true_entities_dict}
nb_true_dict = {k: len(true_entities_dict[k]) for k in true_entities_dict}
p_dict = {k: nb_correct_dict[k] / nb_pred_dict[k] if nb_pred_dict[k] > 0 else 0 for k in true_entities_dict}
r_dict = {k: nb_correct_dict[k] / nb_true_dict[k] if nb_true_dict[k] > 0 else 0 for k in true_entities_dict}
score_dict = {k: 2 * p_dict[k] * r_dict[k] / (p_dict[k] + r_dict[k]) if p_dict[k] + r_dict[k] > 0 else 0 for k in
true_entities_dict}
print('P', p_dict)
print('R', r_dict)
print('F1', score_dict)
if __name__ == '__main__':
main()
from guwenbert.
好的,太感谢您了!还有一个问题,词表大小是 23292,但是网络最后一层输出的维度是 768,这样计算交叉熵损失函数会报错说label 范围超出实际维度,例如 logits.shape = [32,204,768], labels.shape = [32,204],那么 labels 中实际元素值肯定会比 768 大的,请问这个问题怎么解决呢?网络最后一层输出为什么没有设置成 23292 呢?
from guwenbert.
我上传的模型是transformers.RobertaForMaskedLM
。如果使用transformers.RobertaModel
将会抛弃lm_head层,所以直接输出hidden size。你可以使用transformers.RobertaForMaskedLM
,最后会映射到词表的维度。
具体参考huggingface的文档
https://huggingface.co/transformers/model_doc/roberta.html#transformers.RobertaForMaskedLM
from guwenbert.
Related Issues (20)
- 关于新建词表 HOT 2
- 输入两个句子导致索引越界 HOT 26
- 请问能否公开一下古联杯的训练数据呢 HOT 1
- 运行时报错:输入两个句子导致索引越界 HOT 5
- 预训练载入方式咨询 HOT 5
- 预训练代码 HOT 3
- Mask值 HOT 5
- 请问从huggingface下载的模型怎么使用 HOT 6
- 预训练模型的问题 HOT 5
- 添加新的标签 HOT 1
- sequence大小问题 HOT 1
- 请问可以分享一下残缺古诗语句预测的下游代码吗?为什么我这里显示的是随机的文字呢 HOT 1
- 古联杯数据集发邮件询问官方人员,没回复怎么办呢 HOT 2
- 去停用词的问题 HOT 2
- Tokenizer 工作异常
- 关于未来的工作计划 HOT 6
- 预训练语料 HOT 1
- 想请教一下模型结构的问题,为什么模型权重从pooler层之后就没有了 HOT 4
- 请问古文的训练数据方便公开吗? HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from guwenbert.