论文精读-CMPNN-2020


Communicative Representation Learning on Attributed Molecular Graphs

  • 杂志: IJCAI-20
  • IF: None
  • 分区: None
  • github

Introduction

  1. 准确预测分子性质是药物研究中一个重要的课题,这可以为下游药物开发节省大量的资源和时间【Cherkasov et al.2014】。

  2. 性质预测的主要思路:先将药物分子embed为一个dense feature vector,然后再对这个feature进行预测。

  3. 早期定量进行embed,主要通过特征工程的手段(QSPR),如:

    • expert-crafted physicochemical descriptors 【Nettles et al.2007】
    • molecular fingerprints 【Rogers and Hahn2010】

    然而这些方法认为我们进行预测所需的信息都包括在这些特征中,这可能无法满足。

  4. 最近随着分子实验数据的增加,基于机器学习和深度学习的方法显示出明显的优势,其可以直接输入原始的、完整的分子表示(SMILES string或topological graph),从而学习到更加全面的信息。

    • 本质上,任何一个分子都可以描述为一个hydrogen-depleted topological graphs,其中节点是原子(atoms)、边是化学键(bonds)。
    • 【Duvenaud et al.2015】最早尝试使用GCN学习fingerprint。
    • 【Gilmer et al.2017】总结了众多可用的架构,称为message passing neural networks(MPNNs),其在化学性质预测上都有较好的表现,即:
      • message passing module:将每个节点的信息进行转换,并传送到它的邻接点上
      • updating module:对于每个节点,基于其接受的信息,更新该节点的特征

    但MPNNs类的方法没有考虑到edge上的信息。

  5. 为了解决上述edges的问题,【Yang et al.2019】提出了D-MPNN,其使用的是有向图

    主要贡献是避免了不必要的信息循环,从而得到没有冗余的信息。

    但其没有考虑到从bonds到atoms的消息传递,从而无法有效的捕获到更加全面的特征。

  6. 本研究提出了directed graph-based Communicative Message Passing Neural Networks(CMPNN):

    • 其可以同时更新node和edge的特征;
    • 为了避免冗余信息,精心设计了node interaction procedure;
    • 提出了messager booster来丰富message生成;
  1. Descriptor-based Representation

    最常用的descriptor就是分子指纹(chemical fingerprint),比如ECFP【Rogers and Hahn2010】。

    我的理解,类似one-hot向量,每个变量表示是否存在某种结构。

    一些基于fingerprint的DL方法,如【Dahl et al.2014】,表现出比传统ML方法更好的结果。

    但其带来了以下两个问题:

    • 数据过于稀疏,变量太多。
    • 解释性不好。
  2. Linear Notation-based Representation

    此类中最常见的是SMILES notation,其基于共同的chemical bonding rules将topological graph进行编码。

    一般处理这一类数据的NN都较为复杂:【[Zheng et al.2019b] [Jastrzebski et al.2016] [Zheng et al.2020] [Zheng et al.2019a]】

    但序列表示的可扩展性差、空间信息的丢失是无法解决的问题。

  3. Graph Structure-based Representation

    • 【Duvenaud et al.2015】最早,将molecular映射为neural fingerprint,并逐渐有一批改进,这些共同构成了MPNN类的方法,但这类方法只考虑到了atoms information。
    • 【Kearnes et al.2016】、【Gilmer et al.2017】和【Coley et al.2017】逐渐发展了将atoms和bonds都考虑其中的方法,这些方法因为是基于MPNN直接发展而来,所以在迭代过程中存在information redundancy。
    • DMPNN【Yang et al.2019】将graph看做是一个edge-oriented directed structure,从而避免了unnecessary loops,减轻了information redundancy。
    • 本研究基于DMPNN,将node-edge interaction module融入,从而充分利用atoms和bonds的信息。

Methods


Communicative Message Passing


以上是MPNN、DMPNN、CMPNN在消息传递中的不同,可以看到:

  • MPNN仅仅更新node特征;
  • DMPNN只更新edge特征;
  • CMPNN将node和edge特征都进行更新,并使用一个特殊的module进行操作。

算法如下所示:


可以看到:

  • 先更新node特征:node message \(\mathbf{m}^{k}(v)\)来自指向它的edge features \(\mathbf{h}^{k-1}(e_{u,v})\)的聚合(这是最主要的和MPNN的不同),然后使用communicate func聚合message和原始特征更新node特征。

  • edge特征:原始edge \(\mathbf{h}^0(e_{v,w})\)和传递过来的edge message \(\mathbf{m}^k(e_{v,w})\)的计算

    • 考虑原始edge特征\(\mathbf{h}^0(e_{v,w})\),相当于提供了一个skip connect【Yang et al.2019】
    • 非线性函数\(\sigma\)是relu。

    如果是DMPNN,这里的edge message是基于所有邻接边的特征\(\{\mathbf{h}^{k-1}(e_{u,v}),\forall u\in N(v)/w \}\)计算的。其没有考虑到其反向边特征\(\mathbf{h}^{k-1}(e_{w,v})\)

    也就是只考虑和目标edge有相同终点的edges。

    在CMPNN中,因为node features中已经编码上述邻接边的特征,所以我们可以直接利用它,然后再减去目标edge的反向特征\(\mathbf{h}^{k-1}(e_{w,v})\)来作为edge message。

经过\(K\)层的特征特征更新后,再利用edge feature得到node message,然后使用communicate func将node message、node feature和原始node feature进行整合,得到最终的node features。

最后,使用一个readout函数(这里是GRU【Cho et al. (2014)】),将所有node features变换为一个向量\(\mathbf{z}\)。之后再接fc进行分类即可。

Message Booster

这里介绍上面的aggregate func。

  • 【Hamilton et al.2017】提到了两种常用的aggregate func,即LSTM和max pooling。
  • 【Xu et al.2018】则认为sum aggregattion要优于max/mean pooling。
  • 【Yang et al.2019】也使用到了加和的方式来聚合邻接点特征形成message。

但以上的方式都没有考虑到edges间的关系,所以本研究提出了message booster的方式来进行aggregate得到message,如下图所示:


公式为:


这个公式可能写错了,实际上就是将所有node features,做了一次sum和max,然后乘起来即可。下面的公式可能更加正确: \[[\sum_{u\in N(v)}{\mathbf{h}^{k-1}(e_{u,v})}]\odot [\max_{u\in N(v)}{\mathbf{h}^{k-1}(e_{u,v})}]\]

实际上就是sum和max的结合。mean pooling是不可取的,max可能是最好的,但会丢失一些信息,sum可以进行一定的弥补。

Node-Edge Message Communication

这里介绍communicate func,在MPNN和DMPNN中称为updating step。

这里有3种备选方案:

  1. Inner product kernel:

    \[\mathbf{h}^k(v)=\mathbf{m}^k(v)\odot\mathbf{h}^{k-1}(v)\]

  2. Gated graph kernel【Li et al.2015】:

    \[\mathbf{h}^k(v)=GRU(\mathbf{h}^{k-1}(v), \mathbf{m}^k(v))\]

    GRU有更强的拟合能力,但其不是symmetric的,这和graph的性质略有不符。

  3. Multilayer Perception:

    \[\mathbf{h}^k(v)=\sigma(W \cdot Concat(\mathbf{h}^{k-1}(v), \mathbf{m}^k(v)))\]

Results

实验设置


一共有6个benchmark datasets:

  • BBBP,血脑屏障穿透数据集,记录的是化合物的穿透性。
  • Tox21,预测12个和药物毒性有关的靶点。
  • Sider,已上市药物的27个器官的毒性反应。
  • ClinTox,包括了FDA审批通过的药物和因为药物毒性没有通过的药物。
  • ESOL,化合物的水溶性。
  • FreeSolv,水化自由能。

其中Tox21、Sider、ClinTox是多任务学习。

比较的方法有9种(+CMPNN):

  • binary Morgan fingerprints + RF
  • binary Morgan fingerprints + FNN(MLP)
  • GCN【Kipf and Welling2016】
  • Weave【Kearnes et al.2016】
  • N-Gram【Liu et al.2019】(unsupervised)
  • RGAT【Ryu et al.2018】
  • MPNN
  • DMPNN

使用5-CV,分割策略有random和scaffold-based两种。评价指标是AUC和RMSE。使用的节点特征使用开源库RDKit计算得到。超参数使用Bayesian Optimization搜索得到。

scaffold-based split是一种更加接近现实、更加有挑战性的数据集分割方式,可以通过python module RDKit实现。

实验结果


可以看到,CMPNN取得了非常好的结果。

因为Tox21的不平衡性,这里使用scaffold-based split进行进一步的验证:


消融实验


证明了message booster和communicate func的重要性,也和attention-based booster进行了比较。

特征可视化


分子特性通常和其结构密切相关,为了证明CMPNN能够学习到更好的特征,这里对学习到的特征进行可视化。

  • 匹配PAINS database(包含有400个毒性子结构),从Tox21中选择出100个毒性分子,然后对应的选择出100个非毒性分子。
  • 将每个分子(graph)中的atoms(node)对应的标记成toxic和non-toxic。
  • 利用上述方法学习这些atoms的特征,然后使用t-SNE进行可视化。

从上图可以看出,CMPNN有更好的效果。

Conclusion

本研究通过增强bonds和atoms间的信息交互,从而提高了预测性能。


Questions


文章作者: Luyiyun
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Luyiyun !
评论
评论
  目录