Giter Site home page Giter Site logo

dygrec / asrep Goto Github PK

View Code? Open in Web Editor NEW
50.0 2.0 12.0 29.09 MB

Released code of SIGIR2021 Augmenting Sequential Recommendation with Pseudo-Prior Items via Reversely Pre-training Transformer.

Python 100.00%
augmented-sequences item-prediction transformer recommender-system sequential-recommendation

asrep's People

Contributors

jimliu96 avatar zfan20 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

asrep's Issues

a doubt in DataProcessing.py

user_train_reverse = {}
user_valid_reverse = {}
user_test_reverse = {}
for user in User_forreversed:
    nfeedback = len(User_forreversed[user])
    if nfeedback < 3:
        user_train_reverse[user] = User_forreversed[user]
        user_valid_reverse[user] = []
        user_test_reverse[user] = []
    else:
        user_train_reverse[user] = User[user][:-2]
        user_valid_reverse[user] = []
        user_valid_reverse[user].append(User[user][-2])
        user_test_reverse[user] = []
        user_test_reverse[user].append(User[user][-1])

I download and run your code,found the train_reverse.txt same as train.txt .
In this circle ,after else ,the list should use User_forreversed?
Looking forward to your reply.

A doubt about data augmentation

Thanks for your nice work, but the detail of data augmentation may have a leakage problem. More precisely, the pseudo-prior items may see the test information ahead of the inference.

def data_augment(model, dataset, args, sess, gen_num):

    [train, valid, test, original_train, usernum, itemnum] = copy.deepcopy(dataset)
    all_users = list(train.keys())

    cumulative_preds = defaultdict(list)
    for num_ind in range(gen_num):
        batch_seq = []
        batch_u = []
        batch_item_idx = []

        for u_ind, u in enumerate(all_users):
            u_data = train.get(u, []) + valid.get(u, []) + test.get(u, []) + cumulative_preds.get(u, [])

            if len(u_data) == 0 or len(u_data) >= args.M: continue

            seq = np.zeros([args.maxlen], dtype=np.int32)
            idx = args.maxlen - 1
            for i in reversed(u_data):
                if idx == -1: break
                seq[idx] = i
                idx -= 1
            rated = set(u_data)
            item_idx = list(set([i for i in range(itemnum)]) - rated) 

            batch_seq.append(seq)
            batch_item_idx.append(item_idx)
            batch_u.append(u)

The user data (i.e. โ€˜u_data = train.get(u, []) + valid.get(u, []) + test.get(u, []) + cumulative_preds.get(u, [])โ€™) consist of the test data and used for generate the prior data. And the augmented data (i.e. prior data + train data + valid data) training the left-to-right model in the fine-tuning stage and the model to infer the rec result. So both augmented data and the left-to-right model see the test data(leakage of the test data) ahead of the inference.

problems with item order in a sequence

When I try to pretrain your model to yield pseudo items, it seems that in your code, the actual item order in a sequence is inconsistent with the proposed model in your paper.
For example, assume you have a sequence with items [1,2,3,4,5,6]. When you try to reverse-train the model, you should input [6,5,4,3,2] as training user sequence and [5,4,3,2,1] as positive user sequence. But when I try to examine this, I get very confusing sequence like [5,4,2,1,3].

def sample_function(user_train, usernum, itemnum, batch_size, maxlen, result_queue, SEED):
    def sample():
        user = np.random.randint(1, usernum + 1)
        while len(user_train[user]) <= 1: user = np.random.randint(1, usernum + 1)

        seq = np.zeros([maxlen], dtype=np.int32)
        pos = np.zeros([maxlen], dtype=np.int32)
        neg = np.zeros([maxlen], dtype=np.int32)
        nxt = user_train[user][-1]
        idx = maxlen - 1

        ts = set(user_train[user])
        for i in reversed(user_train[user][:-1]):
            seq[idx] = i
            pos[idx] = nxt
            if nxt != 0: neg[idx] = random_neq(1, itemnum + 1, ts)
            nxt = i
            idx -= 1
            if idx == -1: break

        return (user, seq, pos, neg)

    np.random.seed(SEED)
    while True:
        one_batch = []
        for i in range(batch_size):
            one_batch.append(sample())
            print(one_batch[-1]) #something went wrong here
        result_queue.put(zip(*one_batch))

What's more, the item order is also suspicious when I try to reverse-predict pseudo items based on trained model.
According to this code, train, test, valid, and cumulatively-predicted items are concatenated orderly and then fed into transformer model to yield next pseudo item.
u_data = train.get(u, []) + valid.get(u, []) + test.get(u, []) + cumulative_preds.get(u, [])
But if you want to predict pseudo items, shouldn't you put cumulative predictions at the beginning of the original sequence so the order should be u_data = cumulative_preds.get(u, [])+train.get(u, []) + valid.get(u, []) + test.get(u, []) ?
Looking forward to your reply
Thanks

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.