Giter Site home page Giter Site logo

Comments (3)

Ethan-yt avatar Ethan-yt commented on July 16, 2024

比赛结束后测试集的标签发放给我们了,然后我们写了一段代码计算。具体的代码可以参考:

import re


def get_entities(lines):
    result = []
    cur = 0
    for line in lines:
        entities = []
        last_end = 0
        for m in re.finditer(r"{{(.*?)::?(.*?)}}", line):
            label = m.group(1).upper()
            word = m.group(2)
            cur += m.start() - last_end
            last_end = m.end()
            if word:
                entities.append((label, cur, cur + len(word)))
                cur += len(word)
        cur += len(line) - last_end
        result.extend(entities)
    return result


def getlines(path):
    with open(path) as f:
        lines = f.read().split("\n")
        lines = list(filter(lambda line: line, lines))
        return lines


def main():
    ground_truth_path = 'zs100w_0921_wyq_up.txt'
    pred_path = 'result.txt'

    ground_truth_lines = getlines(ground_truth_path)
    pred_lines = getlines(pred_path)

    for i, (gtl, pl) in enumerate(zip(ground_truth_lines, pred_lines)):
        gtl_no_label = re.sub(r"{{(.*?)::?(.*?)}}", r'\2', gtl)
        pl_no_label = re.sub(r"{{(.*?)::?(.*?)}}", r'\2', pl)
        assert gtl_no_label == pl_no_label, f"Different data in row {i}: \n{gtl} \n{pl}"

    true_entities = set(get_entities(ground_truth_lines))
    pred_entities = set(get_entities(pred_lines))

    nb_correct = len(true_entities & pred_entities)
    nb_pred = len(pred_entities)
    nb_true = len(true_entities)
    p = nb_correct / nb_pred if nb_pred > 0 else 0
    r = nb_correct / nb_true if nb_true > 0 else 0
    score = 2 * p * r / (p + r) if p + r > 0 else 0
    print('P', p)
    print('R', r)
    print('F1', score)

    true_entities_dict = {}
    for t, start, end in true_entities:
        if t not in true_entities_dict:
            true_entities_dict[t] = set()
        true_entities_dict[t].add((start, end))

    pred_entities_dict = {}
    for t, start, end in pred_entities:
        if t not in pred_entities_dict:
            pred_entities_dict[t] = set()
        pred_entities_dict[t].add((start, end))

    nb_correct_dict = {k: len(true_entities_dict[k] & pred_entities_dict[k]) for k in true_entities_dict}
    nb_pred_dict = {k: len(pred_entities_dict[k]) for k in true_entities_dict}
    nb_true_dict = {k: len(true_entities_dict[k]) for k in true_entities_dict}

    p_dict = {k: nb_correct_dict[k] / nb_pred_dict[k] if nb_pred_dict[k] > 0 else 0 for k in true_entities_dict}
    r_dict = {k: nb_correct_dict[k] / nb_true_dict[k] if nb_true_dict[k] > 0 else 0 for k in true_entities_dict}
    score_dict = {k: 2 * p_dict[k] * r_dict[k] / (p_dict[k] + r_dict[k]) if p_dict[k] + r_dict[k] > 0 else 0 for k in
                  true_entities_dict}
    print('P', p_dict)
    print('R', r_dict)
    print('F1', score_dict)


if __name__ == '__main__':
    main()

from guwenbert.

hanyc0914 avatar hanyc0914 commented on July 16, 2024

好的,太感谢您了!还有一个问题,词表大小是 23292,但是网络最后一层输出的维度是 768,这样计算交叉熵损失函数会报错说label 范围超出实际维度,例如 logits.shape = [32,204,768], labels.shape = [32,204],那么 labels 中实际元素值肯定会比 768 大的,请问这个问题怎么解决呢?网络最后一层输出为什么没有设置成 23292 呢?
image

from guwenbert.

Ethan-yt avatar Ethan-yt commented on July 16, 2024

我上传的模型是transformers.RobertaForMaskedLM。如果使用transformers.RobertaModel将会抛弃lm_head层,所以直接输出hidden size。你可以使用transformers.RobertaForMaskedLM,最后会映射到词表的维度。

具体参考huggingface的文档

https://huggingface.co/transformers/model_doc/roberta.html#transformers.RobertaForMaskedLM

from guwenbert.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.