Giter Site home page Giter Site logo

Implement NDCG about tensorrec HOT 6 CLOSED

jfkirk avatar jfkirk commented on June 4, 2024
Implement NDCG

from tensorrec.

Comments (6)

jcauteru avatar jcauteru commented on June 4, 2024

WIP

from tensorrec.

jcauteru avatar jcauteru commented on June 4, 2024

@jfkirk finished the dense version. Going to take some tweaking to work with sparse iteration matrix as input . Have it on a branch for now

from tensorrec.

jcauteru avatar jcauteru commented on June 4, 2024

@jfkirk

def _idcg(hits, k=10, ctype="binary"):

    if ctype == "binary":
        arg = min(hits, k+1)
        idgc = np.sum(1/np.log2(np.arange(arg) + 2))  # arange index from 0
    elif ctype == "scalar":
        sorted = hits[np.argsort(-hits)][:k]
        idgc = np.sum(sorted/np.log2(np.arange(len(sorted)) + 2))
    else:
        raise ValueError("Invalid IDCG calculation type")
    
    return idgc


def normalized_discounted_cumulative_gain(model, test_interactions, k=10,
                                          user_features=None,
                                          item_features=None,
                                          preserve_rows=False):

    predicted_ranks = model.predict_rank(user_features=user_features,
                                         item_features=item_features)

    positive_test_interactions = test_interactions > 0
    ranks_of_relevant = sp.csr_matrix(predicted_ranks *
                                      positive_test_interactions.A)

    k_mask = np.less(ranks_of_relevant.data, k + 1)
    ror_at_k = np.maximum(np.multiply(ranks_of_relevant.data, k_mask), 1)

    relevance = sp.csr_matrix(
        test_interactions.A *
        positive_test_interactions.A
    )

    relevance_at_k = (2**np.multiply(relevance.data, k_mask)) - 1
    ranks_of_relevant.data = relevance_at_k/np.log2(ror_at_k + 1)  # ranks at 1

    dcg = ranks_of_relevant.sum(axis=1).flatten()

    if np.max(test_interactions.A) == 1:
        # If the data is binary, we can save the ideal ranking sort
        idcg = np.apply_along_axis(_idcg,
                                   0,
                                   positive_test_interactions.sum(axis=1).flatten()[0],
                                   ctype="binary"
                                   )
    else:
        idcg = np.apply_along_axis(_idcg,
                                   0,
                                   relevance,
                                   ctype="scalar"
                                   )

    print relevance
    ndcg = dcg/idcg

    if preserve_rows:
        return ndcg.flatten()
    else:
        return np.nanmean(ndcg.flatten())

from tensorrec.

jcauteru avatar jcauteru commented on June 4, 2024

some trickery required for the "no relevant example case". I don't see a way around the sort for the scalar case. Note np.argsort gives the sort slice for the original array NOT the resultant sorted position. I break out the binary and scalar cases to save the sort in the binary case

from tensorrec.

jcauteru avatar jcauteru commented on June 4, 2024

need to add tests for scalar case somehow

from tensorrec.

jcauteru avatar jcauteru commented on June 4, 2024

#40

from tensorrec.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.