Giter Site home page Giter Site logo

evolvegcn's Introduction

2024春社会计算实验

paper link: GAT

Dependency:

  • dgl
  • pandas
  • numpy
  • torch

Run

python train_model.py

Report

preprocess

根据观察, 我们将votes数据进行排序, 发现其和meetings id呈时间相关顺序, 故而根据每场meeting划分子图. 总共15个图, 前12个用于训练, 最后三个用于测试. 预测目标为投票倾向(N,V,NV)

Predictor

简单来说, 图上每个节点先经过embedding层创建embedding, 随后经过GATConv层汇聚邻居信息, 生成其表征representation. 对于每个预测对(member, bill), 将bill相关的sponsors和cosponsors取出, 与member的repr拼接后送入attention池化层, 得到池化后的隐层向量, 最后送入MLP进行分类.

Result

0.73 on testset.

预测算法

问题

给定一系列图 $G_i$, 这些图属于一个大图 $G$的子图, 有 $|G_i| = |G|, E_{G_i}\subseteq E_G$. 每个节点关联一个vote向量, 形如 $[B, ]$ ,表示对于每个bill的投票结果. $-inf$表示未投票. 每条边关联一个标量权重. 根据 $G_1, \cdots G_n$, 预测 $G_{n+1}...$的vote向量. 忽略 $-inf$项.

backbone算法

考虑到这是一个时序图 + 节点分类问题. 有两种解决方向: 一种是使用RNN等可动态更新权重的模块来演化网络权重, 另一种是使用一种归纳式(inductive)算法.

考虑到这些图都共享相同的节点(对应于现实中同一个人), 因此很自然的想法是对这些节点进行隐向量建模, 即, 每个节点使用一个长$\mathcal{E}$的向量进行表征. 在之前的图上训练的表征向量可以自然地迁移到下一个图上, 因为节点不变.

因此, 我们可以使用图注意力神经网络, GAT, 进行学习.

GAT使用标准的消息传递范式, 即下一层的节点表征由上一层的表征和邻居的表征变换而来:

$$ \begin{split}\begin{align} z_i^{(l)}&=W^{(l)}h_i^{(l)},& \\ e_{ij}^{(l)}&=\text{LeakyReLU}(\vec a^{(l)^T}(z_i^{(l)}||z_j^{(l)})),&\\ \alpha_{ij}^{(l)}&=\frac{\exp(e_{ij}^{(l)})}{\sum_{k\in \mathcal{N}(i)}^{}\exp(e_{ik}^{(l)})},&\\ h_i^{(l+1)}&=\sigma\left(\sum_{j\in \mathcal{N}(i)} {\alpha^{(l)}_{ij} z^{(l)}_j }\right),& \end{align}\end{split} $$

可以看到, 注意力机制被应用于最后的加权上.

特别地, 在边加权的情况下, 有公式

$$ h_{i}^{l+1} = \sigma\left(\sum_{j\in\mathcal{N}(i)} w_{ij}\alpha_{ij}^lz_j^{l}\right) $$

这里的权重是标准化之后的权重(我们使用了出度标准化).

在实现中, 我们使用了一个三层GATConv层网络, 输出一个恒等激活的隐藏层变量, 作为节点的最终表征.

池化 & 分类

具体到这个问题上来说, 一个bill的主要属性是其发起者和协助发起者. 因此, 一个自然的想法就是用发起者+协助发起者的表征作为bill的表征, 即

$$\mathcal{R}B = h{s}\oplus h_{c_1}\oplus h_{c_2}\cdots$$

为了综合节点属性,降低参数复杂度, 我们引入一个注意力池化块

$$ h_{iB}= \text{AttentionPooling}(h_i, \mathcal{R}_B) $$

该池化块有一组参数$W_q, W_k, W_v$, 计算方法为

$$ AP(Q, K) = Attention(W_qQ, W_kK, W_vK) $$

因为我们认为, 衡量一个节点是否会对特定提案投票, 取决于该节点的属性和该提案的契合程度, 所以Attention计算中, 节点属于Query, 而提案属性属于Key; Value则必须和Key一一对其, 所以只能从提案变换而来.

为了降低计算量, 我们观察到4位协助者就已经可以覆盖95%的提案, 所以至多计算4位协助.

分类器只是一个简单的单层MLP, 输入$h_{iB}$, 输出一个长$3$的logits.

效果

测试集报告: 准确率0.73

evolvegcn's People

Contributors

tomorrowdawn avatar maqy1995 avatar yyf233333 avatar ruibaixu avatar mottledpanpipe avatar wotaicaili avatar

Watchers

 avatar

evolvegcn's Issues

Graph Predictor Module

Component

这是一个专门为本次实验设计的预测器

以下内容已经过期. 需要重构.

预测任务

给定一组社交网络和提案信息, 预测每个成员对于提案的态度: YES, NO 或者 Not Vote

Excepted Input and output

Input: glist, a list of DGLGraph. Each graph contains two attributes: ndata['feat'] and edata['weight'].

A batched proposal dict

proposal: 一个字典,包含'id','sponsors'和'cosponsors'。
        分别代表提案id(一个[B, 1] tensor),发起者(一个[B, 1] tensor),协助发起者(一个[B, 4] tensor. 因为我们发现95%分位点就是4个协助发起者)
        如果不足4个协力, 则用-1填充。
        返回一个[B, N, 3]张量,表示每个节点yes/no/none的概率。

Output: A prediction tensor, shape of [B, N, 3]. Each representing yes/no/not vote possibility.

Mechanics

该Predictor由两个类构成, 分别是RepresentationModule和ClassifierModule.

RepresentationModule是一个EvolveGCN网络, 拆掉了分类头, 输出一个[N, H]的节点embedding.

ClassifierModule是一个注意力汇聚网络, 接收任意多个sponsor embeddings和待预测的成员的embedding, 生成一个[1, 3] tensor.

测试项目

从训练集中留出10%(相对于初始数据集)数据, 在训练后进行检测,

检测指标: Accuracy, Loss

测试项目

model相关

测试model/文件夹下面的类, AttentionPooling, GATPredictor, WeightedGAT.
检查给定预期输入后是否产生预期的输出, 维度是否正确.

方法: def get_embeddings(embeddings, *indices): 检查是否按照indices取出embedding, indices=-1时是否取出的是0 embedding.

create subgraph fix

def create_subgraph(self):
        subgraphs = []  # 存储每次会议的子图

        m_copy = self.meeting_bill_member_tensor_voteResult
        #mask = m_copy == -np.inf
        #m_copy[mask] = 0

        # 对每次会议创建子图
        # for meeting_index in range(self.meeting_num):
        for meeting_index in tqdm(range(self.meeting_num), desc="Generating subgraphs"):
            start_time = time.time()

            # 提取每次会议的投票tensor
            votes_slice = deepcopy(m_copy[meeting_index, :, :])
            votes_slice[votes_slice==-np.inf] = 0
            
                        
            # 计算相似度矩阵(其尺寸应该是 num_members x num_members)
            # similarity_matrix = squareform(pdist(votes_slice.T, lambda u, v: self.similarity_measure(u, v, meeting_index)))
            
            max_vote_count = self.meeting_bill_tensor_maxVoteCount[meeting_index, :] # 每个议案的最大投票次数
            normalized_max_vote_count = np.where(max_vote_count == 0, 1, max_vote_count) # 防止分母为零
            # normalized_max_vote_count 从(8849,) 重塑为 (8849,1)
            normalized_max_vote_count = normalized_max_vote_count[:, np.newaxis]

            # 这样就可以保持行对行的除法操作,每个议题对每个成员进行归一化,  标准化议案贡献,使得每个议案对于相似度贡献相同 (members x bills)
            normalized_votes = votes_slice / normalized_max_vote_count

            # 利用矩阵乘法计算成员间的相似度 (members x members)
            # np.dot 对二维数组执行矩阵乘法,对于一维数组执行内积
            similarity_matrix = np.dot(normalized_votes, normalized_votes.T)
            
            # 根据相似度矩阵创建图
            g = dgl.DGLGraph()
            
            # 使用triu_indices函数获取上三角矩阵中的索引
            src_list, dst_list = np.triu_indices(self.member_num, k=1)  # k=1表示不包括对角线
            src_list = src_list.astype(np.int64)
            dst_list = dst_list.astype(np.int64)
            # print("src_list.shape before: ", src_list.shape)
            # print("dst_list.shape before: ", dst_list.shape)

            # 从这些索引中得到所有的边的权重,并过滤掉无穷大的权重
            edge_weights = similarity_matrix[src_list, dst_list]
            # print("edge_weights.shape: ", edge_weights.shape)
            finite_edges = ~np.isinf(edge_weights)
            # print("finite_edges.shape: ", finite_edges.shape)
            
            # 只保留有限权重的边
            src_list = src_list[finite_edges]
            dst_list = dst_list[finite_edges]
            edge_weights = edge_weights[finite_edges]
            # print("src_list.shape: ", src_list.shape)
            # print("dst_list.shape: ", dst_list.shape)
            # print("edge_weights.shape: ", edge_weights.shape)

            # 如果有边可以添加,那么转换权重到适当的类型并添加这些边
            if len(src_list) > 0:
                # 将权重从NumPy数组转换为PyTorch张量
                edge_weights_tensor = torch.from_numpy(edge_weights).float()

                # 一次性添加所有的边和它们的权重
                g.add_edges(src_list, dst_list, {'weight': edge_weights_tensor})
            
            # # 将NumPy数组转换为PyTorch张量
            votes_slice = m_copy[meeting_index, :, :]
            vote_data = torch.from_numpy(votes_slice.T).float()

            # # 输出有多少个节点
            # print("g.num_nodes(): ", g.num_nodes())
            # # 输出有多少个边
            # print("g.num_edges(): ", g.num_edges())
            # # 输出vote_data 的形状
            # print("vote_data.shape: ", vote_data.shape)
            
            # # 把每个成员的投票数据设置为节点的'data'特征
            g.ndata['vote'] = vote_data

            # 为了把每个成员的投票数据设置为节点的'data'特征, 遍历已有节点并设置'data'特征
            
            subgraphs.append(g)
        
        self.subgraph = subgraphs
        pass

Futhermore the generator class should be able to switch to different similarity matrix computing methods.

RoadMap

目前需要完成的代码部分:

  1. load_dataset(可以直接使用pd.read_json)
  2. generator. 需要完成similarity的定义, __init__内部的准备和实时生成. 应当将数据集改造成按照时间排序的(u, v, weight)格式, 然后根据时间窗口获得edge mask, 通过subgraph构造图.
  3. model. Model需要将EvolveGCN的分类头拆掉, 暴露出内部的hidden states, 然后加上AttentionCls(因为提案者不固定)
  4. Inspecter: 给定输入图, 检测该图的特定结构. 用于测试 & 报告. 需要统计尽可能多的指标.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.