Giter Site home page Giter Site logo

Comments (2)

okhat avatar okhat commented on September 24, 2024

Good catch! Can you cast the set to a list? That sounds like it'll fix this. If you make a pull request, I'll merge it.

from baleen.

MFajcik avatar MFajcik commented on September 24, 2024

@okhat Wouldn't it be better to return N deduplicated-lists (where N is number of hops) from COLBERT engine. So the individual retrieval results would have preserved order?

I would submit the pull-request for COLBERT, but I am not sure if this won't cause problems with some scripts you have.

edit:/ code-wise something like

from baleen.utils.loaders import *
from baleen.condenser.condense import Condenser


class Baleen:
    def __init__(self, collectionX_path: str, searcher, condenser: Condenser):
        self.collectionX = load_collectionX(collectionX_path)
        self.searcher = searcher
        self.condenser = condenser

    def search(self, query, num_hops, depth=100, verbose=False):
        assert depth % num_hops == 0, f"depth={depth} must be divisible by num_hops={num_hops}."
        k = depth // num_hops

        searcher = self.searcher
        condenser = self.condenser
        collectionX = self.collectionX

        facts = []
        stage1_preds = None
        context = None

        pids_bag = [[] for _ in range(num_hops)]

        for hop_idx in range(0, num_hops):
            ranking = list(zip(*searcher.search(query, context=context, k=depth)))
            ranking_ = []

            facts_pids = set([pid for pid, _ in facts])

            for pid, rank, score in ranking:
                # print(f'[{score}] \t\t {searcher.collection[pid]}')
                if len(ranking_) < k and pid not in facts_pids:
                    ranking_.append(pid)

                if len(pids_bag[hop_idx]) < k:
                    if all(pid not in pids_bag[hi] for hi in range(num_hops)):
                        pids_bag[hop_idx].append(pid)

            stage1_preds, facts, stage2_L3x = condenser.condense(query, backs=facts, ranking=ranking_)
            context = ' [SEP] '.join([collectionX.get((pid, sid), '') for pid, sid in facts])

        assert sum(len(pids_per_hop) for pids_per_hop in pids_bag) == depth #//edit fixed assert

        return stage2_L3x, pids_bag, stage1_preds

from baleen.

Related Issues (5)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.