Giter Site home page Giter Site logo

zeroized / deeprec-torch Goto Github PK

View Code? Open in Web Editor NEW
38.0 2.0 8.0 1.07 MB

Easy-to-use pytorch-based framework for RecSys models

License: MIT License

Jupyter Notebook 6.34% Python 93.66%
fm deepfm pnn fnn pytorch ctr deep-learning afm wide-and-deep dcn

deeprec-torch's Introduction

DeepRec-torch

DeepRec-torch is a framework based on pytorch. This project is more like a tutorial for learning recommender system models than a tool for direct using. The analysis of the implemented models is available in author`s github pages, zeroized.github.io or the corresponding blog URL zeroized.xyz, which are provided in Simplified Chinese.

Dependency

  • torch 1.2.0
  • numpy 1.17.3
  • pandas 0.25.3
  • scikit-learn 0.21.3
  • tensorboard 2.2.1 (For loss and metrics visualization)
  • lightgbm 2.3.0 (For building high-order feature interaction with GBDT)

Quick Start

1.Load and preprocess data

from example.loader.criteo_loader import load_data,missing_values_process
# load 10,000 pieces from criteo-1m dataset
data = load_data('/path/to/data',n_samples=10000)
data = missing_values_process(data)

2.Describe the columns with FeatureMeta

from feature.feature_meta import FeatureMeta
from example.loader.criteo_loader import continuous_columns,category_columns

feature_meta = FeatureMeta()
for column in continuous_columns:
    # By default, the continuous feature will not be discretized.
    feature_meta.add_continuous_feat(column)
for column in category_columns:
    feature_meta.add_categorical_feat(column)

3.Transform data into wanted format (usually feat_index and feat_value)

from preprocess.feat_engineering import preprocess_features

x_idx, x_value = preprocess_features(feature_meta, data)

label = data.y

4.Prepare for training

import torch

# Assign the device for training
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load data into assigned device
X_idx_tensor_gpu = torch.LongTensor(x_idx).to(device)
X_value_tensor_gpu = torch.Tensor(x_value).to(device)
y_tensor_gpu = torch.Tensor(label).to(device)

# Note that a binary classifier requires label with shape (n_samples,1)
y_tensor_gpu = y_tensor_gpu.reshape(-1, 1)

# Form a dataset for torch`s DataLoader
X_cuda = TensorDataset(X_idx_tensor_gpu, X_value_tensor_gpu, y_tensor_gpu)

5.Load a model and set parameters (pre-defined models for ctr prediction task are in model.ctr package)

from model.ctr.fm import FM

# Create an FM model with embedding size of 5 and binary output, and load it into the assigned device
fm_model = FM(emb_dim=5, num_feats=feat_meta.get_num_feats(), out_type='binary').to(device)

# Assign an optimizer for the model
optimizer = torch.optim.Adam(fm_model.parameters(), lr=1e-4)

6.Train the model with a trainer

from util.train import train_model_hold_out

# Train the model with hold-out model selection
train_model_hold_out(job_name='fm-binary-cls', device=device,
                     model=fm_model, dataset=X_cuda,
                     loss_func=nn.BCELoss(), optimizer=optimizer,
                     epochs=20, batch_size=256)
# Checkpoint saving is by default true in trainers. 
# For more custom settings, create a dict like follow:
ckpt_settings = {'save_ckpt':True, 'ckpt_dir':'path/to/ckpt_dir', 'ckpt_interval':3}
# Then send the kwargs parameter
train_model_hold_out(...,**ckpt_settings)
# Settings for log file path, model saving path and tensorboard file path is similar, see util.train.py

The role of the trainer is more a log writer than a simple model training method.

For more examples:

  • Model usage examples are available in example.model package.

  • Data loader examples are available in example.loader package.

  • Dataset EDA examples are available in example.eda package with jupyter notebook format.

Change Log

See changelog.md

Model list

Click Through Rate Prediction

model paper
LR: Logistic Regression Simple and Scalable Response Prediction for Display Advertising
FM: Factorization Machine [ICDM 2010]Factorization Machines
GBDT+LR: Gradient Boosting Tree with Logistic Regression Practical Lessons from Predicting Clicks on Ads at Facebook
FNN: Factorization-supported Neural Network [ECIR 2016]Deep Learning over Multi-field Categorical Data: A Case Study on User Response Prediction
PNN: Product-based Neural Network [ICDM 2016]Product-based neural networks for user response prediction
Wide and Deep [DLRS 2016]Wide & Deep Learning for Recommender Systems
DeepFM [IJCAI 2017]DeepFM: A Factorization-Machine based Neural Network for CTR Prediction
AFM: Attentional Factorization Machine [IJCAI 2017]Attentional Factorization Machines: Learning the Weight of Feature Interactions via Attention Networks
NFM: Neural Factorization Machine [SIGIR 2017]Neural Factorization Machines for Sparse Predictive Analytics
DCN: Deep & Cross Network [ADKDD 2017]Deep & Cross Network for Ad Click Predictions
AutoInt [CIKM 2019]AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks
FLEN [AAAI 2020]FLEN: Leveraging Field for Scalable CTR Prediction

Sequential Recommendation

model/keywords paper
DIN: Deep Interest Network [KDD 2018]Deep Interest Network for Click-Through Rate Prediction
DIEN: Deep Interest Evolution Network [AAAI 2019]Deep Interest Evolution Network for Click-Through Rate Prediction

deeprec-torch's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.