import torch
import ELRec.Efficient_TT.efficient_tt as elrec
device_id = 2
device = torch.device('cuda:{}'.format(device_id))
if __name__ == '__main__':
tt_ranks = [8, 8]
embedding_dim = 12
num_embeddings = 1_000_000
indices = torch.LongTensor([0, 3, 23, 12, 3422, 75234, 2342, 12323, 342, 123]).to(device)
offsets = torch.LongTensor([0, 3, 4, 5, 7, 11]).to(device)
batch_size = 5
el_emb = elrec.Eff_TTEmbedding(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
tt_ranks=tt_ranks,
device=device_id,
batch_size=batch_size,
).to(device)
outputs = el_emb(indices, offsets)
print(outputs.shape)
The expected result should be "torch.Size[5, 12]", but I get "torch.Size[10, 12]".