Giter Site home page Giter Site logo

hysonlab / ligand_generation Goto Github PK

View Code? Open in Web Editor NEW
24.0 24.0 2.0 263.61 MB

Target-aware Variational Auto-encoders for Ligand Generation with Multimodal Protein Representation Learning

Home Page: https://doi.org/10.1088/2632-2153/ad3ee4

Python 100.00%
generative-ai geometric-deep-learning graph-neural-networks multimodal-deep-learning protein-ligand variational-autoencoder

ligand_generation's People

Contributors

hytruongson avatar nnkhang19 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

ligand_generation's Issues

process_protein_3d.py

Hello, I download the source code and extract the data set PDB file, run the problem occurred, how to solve it
0%| | 0/229 [00:02<?, ?it/s]
Traceback (most recent call last):
File "/home/lq/Ligand_Generation-main/process_protein_3d.py", line 30, in
torch.save(protein_graph, os.path.join(root, "res_graph", f"{folder}.pt"))
File "/root/anaconda3/envs/cold/lib/python3.7/site-packages/torch/serialization.py", line 422, in save
with _open_zipfile_writer(f) as opened_zipfile:
File "/root/anaconda3/envs/cold/lib/python3.7/site-packages/torch/serialization.py", line 309, in _open_zipfile_writer
return container(name_or_buffer)
File "/root/anaconda3/envs/cold/lib/python3.7/site-packages/torch/serialization.py", line 287, in init
super(_open_zipfile_writer_file, self).init(torch._C.PyTorchFileWriter(str(name)))
RuntimeError: Parent directory data/kiba/prot_3d_for_KIBA/res_graph does not exist.

Process finished with exit code 1

davis

Hi, I'm sorry to bother you, but while running the program, I found that the Kiba data set worked fine, but using the Davis data set, I found the following error, and there was no ADCK in the original prot. PDB files, and Davis has 442 PDB files, should be a complete data set, please ask this how to deal with it?
Epoch 0: 75%|▊| 11344/15101 [31:02<10:16, 6.09it/s, loss=0.801, v_num=59, lossTraceback (most recent call last):
File "/home/lq/Ligand_Generation-main/train_binding_affinity.py", line 109, in
trainer.fit(model, train_loader, val_loader)
File "/root/anaconda3/envs/cold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 609, in fit
self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
File "/root/anaconda3/envs/cold/lib/python3.7/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/root/anaconda3/envs/cold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _fit_impl
self._run(model, ckpt_path=self.ckpt_path)
File "/root/anaconda3/envs/cold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1112, in _run
results = self._run_stage()
File "/root/anaconda3/envs/cold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1191, in _run_stage
self._run_train()
File "/root/anaconda3/envs/cold/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1214, in _run_train
self.fit_loop.run()
File "/root/anaconda3/envs/cold/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
self.advance(*args, **kwargs)
File "/root/anaconda3/envs/cold/lib/python3.7/site-packages/pytorch_lightning/loops/fit_loop.py", line 267, in advance
self._outputs = self.epoch_loop.run(self._data_fetcher)
File "/root/anaconda3/envs/cold/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
self.advance(*args, **kwargs)
File "/root/anaconda3/envs/cold/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 187, in advance
batch = next(data_fetcher)
File "/root/anaconda3/envs/cold/lib/python3.7/site-packages/pytorch_lightning/utilities/fetching.py", line 184, in next
return self.fetching_function()
File "/root/anaconda3/envs/cold/lib/python3.7/site-packages/pytorch_lightning/utilities/fetching.py", line 265, in fetching_function
self._fetch_next_batch(self.dataloader_iter)
File "/root/anaconda3/envs/cold/lib/python3.7/site-packages/pytorch_lightning/utilities/fetching.py", line 280, in _fetch_next_batch
batch = next(iterator)
File "/root/anaconda3/envs/cold/lib/python3.7/site-packages/pytorch_lightning/trainer/supporters.py", line 571, in next
return self.request_next_batch(self.loader_iters)
File "/root/anaconda3/envs/cold/lib/python3.7/site-packages/pytorch_lightning/trainer/supporters.py", line 583, in request_next_batch
return apply_to_collection(loader_iters, Iterator, next)
File "/root/anaconda3/envs/cold/lib/python3.7/site-packages/lightning_utilities/core/apply_func.py", line 51, in apply_to_collection
return function(data, *args, **kwargs)
File "/root/anaconda3/envs/cold/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 628, in next
data = self._next_data()
File "/root/anaconda3/envs/cold/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 671, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/root/anaconda3/envs/cold/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 58, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/root/anaconda3/envs/cold/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 58, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/lq/Ligand_Generation-main/binding_data.py", line 33, in getitem
protein_graph = torch.load(protein_file, map_location="cpu")
File "/root/anaconda3/envs/cold/lib/python3.7/site-packages/torch/serialization.py", line 771, in load
with _open_file_like(f, 'rb') as opened_file:
File "/root/anaconda3/envs/cold/lib/python3.7/site-packages/torch/serialization.py", line 270, in _open_file_like
return _open_file(name_or_buffer, mode)
File "/root/anaconda3/envs/cold/lib/python3.7/site-packages/torch/serialization.py", line 251, in init
super(_open_file, self).init(open(name, mode))
FileNotFoundError: [Errno 2] No such file or directory: 'data/davis/prot_3d_for_Davis/res_graph/ADCK.pdb.pt'

affinity(model)

Hello, I would like to ask in the drug target affinity prediction model, you provide TransformerGVP (nn.Module) and ThreeD (nn.Module) : What is the difference? I see the GVP model used in the code, but it transforms first and then does three message passing aggregations, which doesn't match the description in the paper?

import numpy as np
import torch
import torch.nn as nn
from . import GVP, GVPConvLayer, LayerNorm, tuple_index
from torch.distributions import Categorical
from torch_scatter import scatter_mean
from performer_pytorch import Performer, PerformerLM
import torch_geometric
from torch_geometric.utils import to_dense_batch
from linear_attention_transformer import LinearAttentionTransformerLM, LinformerSettings
from performer_pytorch import PerformerLM

class TransformerGVP(nn.Module):
def init(self, node_in_dim, node_h_dim,
edge_in_dim, edge_h_dim,
seq_in=False, num_layers=3, drop_rate=0.1, attention_type = "performer"):

    super().__init__()
    
    if seq_in:
        self.W_s = nn.Embedding(20, 64)
        node_in_dim = (node_in_dim[0], node_in_dim[1])
    
    self.W_v = nn.Sequential(
        LayerNorm(node_in_dim),  #对标量数据进行 Layer Normalization 操作,同时将矢量通道除以计算的 L2 范数。这有助于保持标量通道的分布稳定性,同时在矢量通道上进行归一化操作。
        GVP(node_in_dim, node_h_dim, activations=(None, None))
    )
    self.W_e = nn.Sequential(
        LayerNorm(edge_in_dim),
        GVP(edge_in_dim, edge_h_dim, activations=(None, None))
    )

    self.W_in = nn.Sequential(
        LayerNorm(node_h_dim),
        GVP(node_h_dim, (node_h_dim[0], 0), vector_gate=True)
    )
    
    self.layers = nn.ModuleList(
            GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate) 
        for _ in range(num_layers))
    
    ns, _ = node_h_dim
    self.W_out = nn.Sequential(
        LayerNorm(node_h_dim),
        GVP(node_h_dim, (ns, 0)))
        
    self.attention_type = attention_type
   
    if attention_type == "performer":
        self.transformer = Performer(
                        dim = ns,
                        depth = 2,
                        heads = 4,
                        dim_head = ns // 4, 
                        causal = False
                    )
    else:
        layer = nn.TransformerEncoderLayer(ns, 4, ns * 2, batch_first=True)
        self.transformer = nn.TransformerEncoder(layer, 2)

    self.final_readout = nn.Sequential(
        nn.Linear(ns + ns, 128), nn.ReLU(), nn.Linear(128, 128)
    )
    self.seq_transformer = LinearAttentionTransformerLM(
                    num_tokens = 20,
                    dim = 128,
                    heads = 8,
                    depth = 2,
                    max_seq_len = 640,
                    return_embeddings=True,
                    linformer_settings = LinformerSettings(256))
    
def forward(self, h_V, edge_index, h_E, seq=None, batch=None):      
    '''
    :param h_V: tuple (s, V) of node embeddings
    :param edge_index: `torch.Tensor` of shape [2, num_edges]
    :param h_E: tuple (s, V) of edge embeddings
    :param seq: if not `None`, int `torch.Tensor` of shape [num_nodes]
                to be embedded and appended to `h_V`
    '''
    if seq is not None:
        #seq = self.W_s(seq)
        seq, mask = to_dense_batch(seq, batch, max_num_nodes=640)
        seq_emb = self.seq_transformer(seq)
        seq_rep = torch.sum(seq_emb, dim = 1)
    
    h_V = self.W_v(h_V)  #h_V里面包含node-s和node_v
    h_E = self.W_e(h_E)  #里面包含edge-s和edge-v

    h_t = self.W_in(h_V)
    h_t, mask = to_dense_batch(h_t, batch)
    h_t = self.transformer(h_t)
    h_t = h_t[mask]

    for layer in self.layers:
        h_V = layer(h_V, edge_index, h_E)
    out = self.W_out(h_V)

    node_rep = torch.cat([h_t, out], dim = -1)
    node_rep = self.final_readout(node_rep)
    
    geo_rep =  scatter_mean(node_rep, batch, dim = 0)
    return torch.cat([geo_rep, seq_rep], dim = -1)

class ThreeD_Protein_Model(nn.Module):
def init(self, node_in_dim, node_h_dim,
edge_in_dim, edge_h_dim,
seq_in=False, num_layers=3, drop_rate=0.5, attention_type = "performer"):

    super().__init__()
    
    if seq_in:
        self.W_s = nn.Embedding(20, 20)
        node_in_dim = (node_in_dim[0], node_in_dim[1])
    
    self.W_v = nn.Sequential(
        LayerNorm(node_in_dim),
        GVP(node_in_dim, node_h_dim, activations=(None, None))
    )
    self.W_e = nn.Sequential(
        LayerNorm(edge_in_dim),
        GVP(edge_in_dim, edge_h_dim, activations=(None, None))
    )
    
    self.layers = nn.ModuleList(
            GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate) 
        for _ in range(num_layers))
    
    ns, _ = node_h_dim
    self.W_out = nn.Sequential(
        LayerNorm(node_h_dim),
        GVP(node_h_dim, (ns, 0), vector_gate=True))
        
    self.attention_type = attention_type
    if attention_type == "performer":
        self.transformer = Performer(
                        dim = ns,
                        depth = 2,
                        heads = 4,
                        dim_head = ns // 4, 
                        causal = False
                    )
    else:
        layer = nn.TransformerEncoderLayer(ns, 4, ns * 2, batch_first=True)
        self.transformer = nn.TransformerEncoder(layer, 2)

    self.seq_transformer = LinearAttentionTransformerLM(
                    num_tokens = 20,
                    dim = 128,
                    heads = 4,
                    depth = 4,
                    dim_head = 128 // 4,
                    max_seq_len = 640,
                    return_embeddings=True,
                    linformer_settings = LinformerSettings(256), 
                    ff_dropout=drop_rate, 
                    attn_dropout=drop_rate,
                    attn_layer_dropout=drop_rate)
    
    self.skip_connection = nn.Sequential(nn.Linear(ns * 2, ns), nn.ReLU(), nn.Linear(ns, ns))

def forward(self, h_V, edge_index, h_E, seq=None, batch=None):      
    '''
    :param h_V: tuple (s, V) of node embeddings
    :param edge_index: `torch.Tensor` of shape [2, num_edges]
    :param h_E: tuple (s, V) of edge embeddings
    :param seq: if not `None`, int `torch.Tensor` of shape [num_nodes]
                to be embedded and appended to `h_V`
    '''
    if seq is not None:
        seq, mask = to_dense_batch(seq, batch, max_num_nodes=640)
        seq_emb = self.seq_transformer(seq)
        seq_rep = torch.mean(seq_emb, dim = 1)
    
    h_V = self.W_v(h_V)
    h_E = self.W_e(h_E)
    for layer in self.layers:
        h_V = layer(h_V, edge_index, h_E)
    out = self.W_out(h_V)
   
    x, mask = to_dense_batch(out, batch)
    x_o = self.transformer(x)
    x = torch.cat([x, x_o], dim = -1)
    x = self.skip_connection(x)
    geo_rep = x.mean(dim = 1)
    if seq is not None:
        z = torch.cat([geo_rep, seq_rep], dim = -1)
        return z
    return geo_rep

pre-trained model

I'm very interested in your job.
I downloaded the source code, but I can not download the pre-trained unconditional VAEs model.
The file does not exist. Could you please upload it again?

Or, I would like to sample ligand for a specific PDB file, can you provide a learned model?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.