Giter Site home page Giter Site logo

Comments (7)

msgi avatar msgi commented on August 26, 2024

Of course it can.

from tez.

hemanthh17 avatar hemanthh17 commented on August 26, 2024

Then there shouldn't be an assertion error I suppose. If you have found anything wrong then please let me know

from tez.

msgi avatar msgi commented on August 26, 2024

Could you paste some of the codes? Maybe there is a config problem , like the config of the device may be wrong.

from tez.

abhishekkrthakur avatar abhishekkrthakur commented on August 26, 2024

as @msgi mentioned, it must work on both cpu and gpu. some code can be useful :)

from tez.

hemanthh17 avatar hemanthh17 commented on August 26, 2024

OK this is the code and I was watching the recommender system video by @abhishekkrthakur

import pandas as pd
import tez
import torch
from sklearn.model_selection import train_test_split
import torch.nn as nn
from sklearn import metrics,preprocessing
import numpy as np

class MovieDataset:
    def __init__(self,users,movies,ratings):
        self.users=users
        self.movies=movies
        self.ratings=ratings

    def __len__(self):
        return len(self.users)
    
    def __getitem__(self,item):
        user=self.users[item]
        movie=self.movies[item]
        rating=self.ratings[item]

        return {
            "users":torch.tensor(user,dtype=torch.long),
            "movies":torch.tensor(movie,dtype=torch.long),
            "ratings":torch.tensor(rating,dtype=torch.float)
              
        }



class RecSysModel(tez.Model):
    def __init__(self,num_users,num_movies):
        super().__init__()
        self.user_embed=nn.Embedding(num_users,32)
        self.movie_embed=nn.Embedding(num_movies,32)
        self.out=nn.Linear(64,1)
        self.step_scheduler_after='epoch'

    def fetch_optimizer(self):
        opt=torch.optim.Adam(self.parameters(),lr=1e-4)
        return opt
    
    def fetch_scheduler(self):
        sch= torch.optim.lr_scheduler.StepLR(self.optimizer,step_size=3,gamma=0.7)
        return sch

    def monitor_metrics(self,output,rating):
        output=output.detach().cpu().numpy()
        rating=rating.detach().cpu().numpy()
        return {
            'rmse':np.sqrt(metrics.mean_squared_error(rating,output))
        }
        
    def forward(self,users,movies,ratings=None):
        user_embeds=self.user_embed(users)
        movie_embeds=self.movie_embed(movies)
        output= torch.cat([user_embeds,movie_embeds],dim=1)
        output=self.out(output)
        
        
        loss=nn.MSELoss()(output,ratings.view(-1,1))
        calc_metrics =self.monitor_metrics(output,ratings.view(-1,1))
        return output,loss,calc_metrics






def train():
    df= pd.read_csv('train_v2.csv')
    lbl_user=preprocessing.LabelEncoder()
    lbl_movie=preprocessing.LabelEncoder()
    df.user=lbl_user.fit_transform(df.user.values)
    df.movie=lbl_movie.fit_transform(df.movie.values)

    df_train,df_valid=train_test_split(df,test_size=0.2,random_state=42,stratify=df.rating.values)
    train_dataset=MovieDataset(users=df_train.user.values,movies=df_train.movie.values,ratings=df_train.rating.values)
    valid_dataset=MovieDataset(users=df_valid.user.values,movies=df_valid.movie.values,ratings=df_valid.rating.values)
    model=RecSysModel(num_users=len(lbl_user.classes_), num_movies=len(lbl_movie.classes_))
    model.fit(
        train_dataset,valid_dataset,train_bs=1024,
        valid_bs=1024, fp16=True
    )

if __name__=="__main__":
    train()

from tez.

abhishekkrthakur avatar abhishekkrthakur commented on August 26, 2024

use:

    model.fit(
        train_dataset,valid_dataset,train_bs=1024,
        valid_bs=1024, fp16=True, device="cpu"
    )

from tez.

hemanthh17 avatar hemanthh17 commented on August 26, 2024

Oh wow I tried the above snippet before it did not work... Now it is functioning well 👍🏼

from tez.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.