Giter Site home page Giter Site logo

KeyError: 'question' about concrete HOT 1 OPEN

khuangaf avatar khuangaf commented on May 30, 2024
KeyError: 'question'

from concrete.

Comments (1)

khuangaf avatar khuangaf commented on May 30, 2024

Hi, I realized we have done small modification to the original biencoder.py script. Here is the updated function:

@classmethod
def create_biencoder_input(cls,
                           samples: List,
                           tensorizer: Tensorizer,
                           insert_title: bool,
                           num_hard_negatives: int = 0,
                           num_other_negatives: int = 0,
                           shuffle: bool = True,
                           shuffle_positives: bool = False,
                           ) -> BiEncoderBatch:
    """
    Creates a batch of the biencoder training tuple.
    :param samples: list of data items (from json) to create the batch for
    :param tensorizer: components to create model input tensors from a text sequence
    :param insert_title: enables title insertion at the beginning of the context sequences
    :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools)
    :param num_other_negatives: amount of other negatives per question (taken from samples' pools)
    :param shuffle: shuffles negative passages pools
    :param shuffle_positives: shuffles positive passages pools
    :return: BiEncoderBatch tuple
    """
    question_tensors = []
    ctx_tensors = []
    positive_ctx_indices = []
    hard_neg_ctx_indices = []

    for sample in samples:
        # ctx+ & [ctx-] composition
        # as of now, take the first(gold) ctx+ only
        if shuffle and shuffle_positives:
            positive_ctxs = sample['positive_ctxs']
            positive_ctx = positive_ctxs[np.random.choice(len(positive_ctxs))]
        else:
            positive_ctx = sample['positive_ctxs'][0]

        positive_ctx = positive_ctx['passage']
        neg_ctxs = sample['negative_ctxs'] if 'hard_negative_ctxs' in sample else []
        hard_neg_ctxs = sample['hard_negative_ctxs'] if 'hard_negative_ctxs' in sample else [ctx['passage'] for ctx in sample['negative_ctxs']]
        question = normalize_question(sample['claim'])

        if shuffle:
            random.shuffle(neg_ctxs)
            random.shuffle(hard_neg_ctxs)

        neg_ctxs = neg_ctxs[0:num_other_negatives]
        hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives]

        all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs
        # print(all_ctxs) 
        hard_negatives_start_idx = 1
        hard_negatives_end_idx = 1 + len(hard_neg_ctxs)

        current_ctxs_len = len(ctx_tensors)

        sample_ctxs_tensors = [tensorizer.text_to_tensor(ctx, title=None)
                               for
                               ctx in all_ctxs]

        ctx_tensors.extend(sample_ctxs_tensors)
        positive_ctx_indices.append(current_ctxs_len)
        hard_neg_ctx_indices.append(
            [i for i in
             range(current_ctxs_len + hard_negatives_start_idx, current_ctxs_len + hard_negatives_end_idx)])

        question_tensors.append(tensorizer.text_to_tensor(question))

    ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors], dim=0)
    questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors], dim=0)

    ctx_segments = torch.zeros_like(ctxs_tensor)
    question_segments = torch.zeros_like(questions_tensor)

    return BiEncoderBatch(questions_tensor, question_segments, ctxs_tensor, ctx_segments, positive_ctx_indices,
                          hard_neg_ctx_indices)

from concrete.

Related Issues (6)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.