From e17add560279b87bc5c5cd49679dbdd854cbfc0f Mon Sep 17 00:00:00 2001 From: "xiang song(charlie.song)" Date: Wed, 28 Aug 2019 13:19:17 +0800 Subject: [PATCH] [NN] Add MXNet impl for TAGCN module. (#799) * upd * fig edgebatch edges * add test * trigger * Update README.md for pytorch PinSage example. Add noting that the PinSage model example under example/pytorch/recommendation only work with Python 3.6+ as its dataset loader depends on stanfordnlp package which work only with Python 3.6+. * Provid a frame agnostic API to test nn modules on both CPU and CUDA side. 1. make dgl.nn.xxx frame agnostic 2. make test.backend include dgl.nn modules 3. modify test_edge_softmax of test/mxnet/test_nn.py and test/pytorch/test_nn.py work on both CPU and GPU * Fix style * Delete unused code * Make agnostic test only related to tests/backend 1. clear all agnostic related code in dgl.nn 2. make test_graph_conv agnostic to cpu/gpu * Fix code style * fix * doc * Make all test code under tests.mxnet/pytorch.test_nn.py work on both CPU and GPU. * Fix syntex * Remove rand * Add TAGCN nn.module and example * Now tagcn can run on CPU. * Add unitest for TGConv * Fix style * For pubmed dataset, using --lr=0.005 can achieve better acc * Fix style * Fix some descriptions * trigger * Fix doc * Add nn.TGConv and example * Fix bug * Update data in mxnet.tagcn test acc. * Fix some comments and code * delete useless code * Fix namming * Fix bug * Fix bug * Add test code for mxnet TAGCov * Update some docs * Fix some code * Update docs dgl.nn.mxnet * Update weight init * Fix --- docs/source/api/python/nn.mxnet.rst | 4 + examples/mxnet/tagcn/README.md | 25 ++++++ examples/mxnet/tagcn/tagcn.py | 39 +++++++++ examples/mxnet/tagcn/train.py | 126 ++++++++++++++++++++++++++++ examples/pytorch/tagcn/tagcn.py | 7 +- python/dgl/nn/mxnet/conv.py | 101 +++++++++++++++++++++- python/dgl/nn/pytorch/conv.py | 3 +- tests/mxnet/test_nn.py | 40 +++++++++ tests/pytorch/test_nn.py | 1 + 9 files changed, 337 insertions(+), 9 deletions(-) create mode 100644 examples/mxnet/tagcn/README.md create mode 100644 examples/mxnet/tagcn/tagcn.py create mode 100644 examples/mxnet/tagcn/train.py diff --git a/docs/source/api/python/nn.mxnet.rst b/docs/source/api/python/nn.mxnet.rst index fba29c80c025..8f9727294489 100644 --- a/docs/source/api/python/nn.mxnet.rst +++ b/docs/source/api/python/nn.mxnet.rst @@ -16,6 +16,10 @@ dgl.nn.mxnet.conv :members: forward :show-inheritance: +.. autoclass:: dgl.nn.mxnet.conv.TAGConv + :members: forward + :show-inheritance: + dgl.nn.mxnet.glob ----------------- diff --git a/examples/mxnet/tagcn/README.md b/examples/mxnet/tagcn/README.md new file mode 100644 index 000000000000..9eaac4df587f --- /dev/null +++ b/examples/mxnet/tagcn/README.md @@ -0,0 +1,25 @@ +Topology Adaptive Graph Convolutional networks (TAGCN) +============ + +- Paper link: [https://arxiv.org/abs/1710.10370](https://arxiv.org/abs/1710.10370) + +Dependencies +------------ +- MXNet nightly build +- requests + +``bash +pip install mxnet --pre +pip install requests +`` + +Results +------- +Run with following (available dataset: "cora", "citeseer", "pubmed") +```bash +DGLBACKEND=mxnet python3 train.py --dataset cora --gpu 0 --self-loop +``` + +* cora: ~0.820 (paper: 0.833) +* citeseer: ~0.702 (paper: 0.714) +* pubmed: ~0.798 (paper: 0.811) \ No newline at end of file diff --git a/examples/mxnet/tagcn/tagcn.py b/examples/mxnet/tagcn/tagcn.py new file mode 100644 index 000000000000..c88970bd05cd --- /dev/null +++ b/examples/mxnet/tagcn/tagcn.py @@ -0,0 +1,39 @@ +"""TAGCN using DGL nn package + +References: +- Topology Adaptive Graph Convolutional Networks +- Paper: https://arxiv.org/abs/1710.10370 +""" +import mxnet as mx +from mxnet import gluon +import dgl +from dgl.nn.mxnet import TAGConv + +class TAGCN(gluon.Block): + def __init__(self, + g, + in_feats, + n_hidden, + n_classes, + n_layers, + activation, + dropout): + super(TAGCN, self).__init__() + self.g = g + self.layers = gluon.nn.Sequential() + # input layer + self.layers.add(TAGConv(in_feats, n_hidden, activation=activation)) + # hidden layers + for i in range(n_layers - 1): + self.layers.add(TAGConv(n_hidden, n_hidden, activation=activation)) + # output layer + self.layers.add(TAGConv(n_hidden, n_classes)) #activation=None + self.dropout = gluon.nn.Dropout(rate=dropout) + + def forward(self, features): + h = features + for i, layer in enumerate(self.layers): + if i != 0: + h = self.dropout(h) + h = layer(self.g, h) + return h diff --git a/examples/mxnet/tagcn/train.py b/examples/mxnet/tagcn/train.py new file mode 100644 index 000000000000..85295fbf92bb --- /dev/null +++ b/examples/mxnet/tagcn/train.py @@ -0,0 +1,126 @@ +import argparse, time +import numpy as np +import mxnet as mx +from mxnet import gluon + +from dgl import DGLGraph +from dgl.data import register_data_args, load_data + +from tagcn import TAGCN + +def evaluate(model, features, labels, mask): + pred = model(features).argmax(axis=1) + accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar() + return accuracy.asscalar() + +def main(args): + # load and preprocess dataset + data = load_data(args) + features = mx.nd.array(data.features) + labels = mx.nd.array(data.labels) + train_mask = mx.nd.array(data.train_mask) + val_mask = mx.nd.array(data.val_mask) + test_mask = mx.nd.array(data.test_mask) + in_feats = features.shape[1] + n_classes = data.num_labels + n_edges = data.graph.number_of_edges() + print("""----Data statistics------' + #Edges %d + #Classes %d + #Train samples %d + #Val samples %d + #Test samples %d""" % + (n_edges, n_classes, + train_mask.sum().asscalar(), + val_mask.sum().asscalar(), + test_mask.sum().asscalar())) + + if args.gpu < 0: + cuda = False + ctx = mx.cpu(0) + else: + cuda = True + ctx = mx.gpu(args.gpu) + + features = features.as_in_context(ctx) + labels = labels.as_in_context(ctx) + train_mask = train_mask.as_in_context(ctx) + val_mask = val_mask.as_in_context(ctx) + test_mask = test_mask.as_in_context(ctx) + + # graph preprocess and calculate normalization factor + g = data.graph + # add self loop + if args.self_loop: + g.remove_edges_from(g.selfloop_edges()) + g.add_edges_from(zip(g.nodes(), g.nodes())) + g = DGLGraph(g) + + # create TAGCN model + model = TAGCN(g, + in_feats, + args.n_hidden, + n_classes, + args.n_layers, + mx.nd.relu, + args.dropout) + + model.initialize(ctx=ctx) + n_train_samples = train_mask.sum().asscalar() + loss_fcn = gluon.loss.SoftmaxCELoss() + + # use optimizer + print(model.collect_params()) + trainer = gluon.Trainer(model.collect_params(), 'adam', + {'learning_rate': args.lr, 'wd': args.weight_decay}) + + # initialize graph + dur = [] + for epoch in range(args.n_epochs): + if epoch >= 3: + t0 = time.time() + # forward + with mx.autograd.record(): + pred = model(features) + loss = loss_fcn(pred, labels, mx.nd.expand_dims(train_mask, 1)) + loss = loss.sum() / n_train_samples + + loss.backward() + trainer.step(batch_size=1) + + if epoch >= 3: + loss.asscalar() + dur.append(time.time() - t0) + acc = evaluate(model, features, labels, val_mask) + print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " + "ETputs(KTEPS) {:.2f}". format( + epoch, np.mean(dur), loss.asscalar(), acc, n_edges / np.mean(dur) / 1000)) + + print() + acc = evaluate(model, features, labels, val_mask) + print("Test accuracy {:.2%}".format(acc)) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='TAGCN') + register_data_args(parser) + parser.add_argument("--dropout", type=float, default=0.5, + help="dropout probability") + parser.add_argument("--gpu", type=int, default=-1, + help="gpu") + parser.add_argument("--lr", type=float, default=1e-2, + help="learning rate") + parser.add_argument("--n-epochs", type=int, default=200, + help="number of training epochs") + parser.add_argument("--n-hidden", type=int, default=16, + help="number of hidden tagcn units") + parser.add_argument("--n-layers", type=int, default=1, + help="number of hidden tagcn layers") + parser.add_argument("--weight-decay", type=float, default=5e-4, + help="Weight for L2 loss") + parser.add_argument("--self-loop", action='store_true', + help="graph self-loop (default=False)") + parser.set_defaults(self_loop=False) + args = parser.parse_args() + print(args) + + main(args) diff --git a/examples/pytorch/tagcn/tagcn.py b/examples/pytorch/tagcn/tagcn.py index 804b91daee1e..e35ef12697c7 100644 --- a/examples/pytorch/tagcn/tagcn.py +++ b/examples/pytorch/tagcn/tagcn.py @@ -1,9 +1,8 @@ -"""GCN using DGL nn package +"""TAGCN using DGL nn package References: -- Semi-Supervised Classification with Graph Convolutional Networks -- Paper: https://arxiv.org/abs/1609.02907 -- Code: https://github.com/tkipf/gcn +- Topology Adaptive Graph Convolutional Networks +- Paper: https://arxiv.org/abs/1710.10370 """ import torch import torch.nn as nn diff --git a/python/dgl/nn/mxnet/conv.py b/python/dgl/nn/mxnet/conv.py index f792af5331f5..15293d5fe068 100644 --- a/python/dgl/nn/mxnet/conv.py +++ b/python/dgl/nn/mxnet/conv.py @@ -9,7 +9,7 @@ from . import utils from ... import function as fn -__all__ = ['GraphConv', 'RelGraphConv'] +__all__ = ['GraphConv', 'TAGConv', 'RelGraphConv'] class GraphConv(gluon.Block): r"""Apply graph convolution over an input signal. @@ -74,7 +74,7 @@ def __init__(self, with self.name_scope(): self.weight = self.params.get('weight', shape=(in_feats, out_feats), - init=mx.init.Xavier()) + init=mx.init.Xavier(magnitude=math.sqrt(2.0))) if bias: self.bias = self.params.get('bias', shape=(out_feats,), init=mx.init.Zero()) @@ -108,7 +108,7 @@ def forward(self, graph, feat): graph = graph.local_var() if self._norm: degs = graph.in_degrees().astype('float32') - norm = mx.nd.power(degs, -0.5) + norm = mx.nd.power(mx.nd.clip(degs, a_min=1, a_max=float("inf")), -0.5) shp = norm.shape + (1,) * (feat.ndim - 1) norm = norm.reshape(shp).as_in_context(feat.context) feat = feat * norm @@ -147,6 +147,101 @@ def __repr__(self): summary += '\n)' return summary +class TAGConv(gluon.Block): + r"""Apply Topology Adaptive Graph Convolutional Network + + .. math:: + \mathbf{X}^{\prime} = \sum_{k=0}^K \mathbf{D}^{-1/2} \mathbf{A} + \mathbf{D}^{-1/2}\mathbf{X} \mathbf{\Theta}_{k}, + + where :math:`\mathbf{A}` denotes the adjacency matrix and + :math:`D_{ii} = \sum_{j=0} A_{ij}` its diagonal degree matrix. + + Parameters + ---------- + in_feats : int + Number of input features. + out_feats : int + Number of output features. + k: int, optional + Number of hops :math: `k`. (default: 2) + bias: bool, optional + If True, adds a learnable bias to the output. Default: ``True``. + activation: callable activation function/layer or None, optional + If not None, applies an activation function to the updated node features. + Default: ``None``. + + Attributes + ---------- + lin : mxnet.gluon.parameter.Parameter + The learnable weight tensor. + bias : mxnet.gluon.parameter.Parameter + The learnable bias tensor. + """ + def __init__(self, + in_feats, + out_feats, + k=2, + bias=True, + activation=None): + super(TAGConv, self).__init__() + self.out_feats = out_feats + self.k = k + self.bias = bias + self.activation = activation + self.in_feats = in_feats + + self.lin = self.params.get( + 'weight', shape=(self.in_feats * (self.k + 1), self.out_feats), + init=mx.init.Xavier(magnitude=math.sqrt(2.0))) + if self.bias: + self.h_bias = self.params.get('bias', shape=(out_feats,), + init=mx.init.Zero()) + + def forward(self, graph, feat): + r"""Compute graph convolution + + Parameters + ---------- + graph : DGLGraph + The graph. + feat : mxnet.NDArray + The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` + is size of input feature, :math:`N` is the number of nodes. + + Returns + ------- + mxnet.NDArray + The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` + is size of output feature. + """ + graph = graph.local_var() + + degs = graph.in_degrees().astype('float32') + norm = mx.nd.power(mx.nd.clip(degs, a_min=1, a_max=float("inf")), -0.5) + shp = norm.shape + (1,) * (feat.ndim - 1) + norm = norm.reshape(shp).as_in_context(feat.context) + + rst = feat + for _ in range(self.k): + rst = rst * norm + graph.ndata['h'] = rst + + graph.update_all(fn.copy_src(src='h', out='m'), + fn.sum(msg='m', out='h')) + rst = graph.ndata['h'] + rst = rst * norm + feat = mx.nd.concat(feat, rst, dim=-1) + + rst = mx.nd.dot(feat, self.lin.data(feat.context)) + if self.bias is not None: + rst = rst + self.h_bias.data(rst.context) + + if self.activation is not None: + rst = self.activation(rst) + + return rst + class RelGraphConv(gluon.Block): r"""Relational graph convolution layer. diff --git a/python/dgl/nn/pytorch/conv.py b/python/dgl/nn/pytorch/conv.py index 0d335cce529f..6a3f9bf9c575 100644 --- a/python/dgl/nn/pytorch/conv.py +++ b/python/dgl/nn/pytorch/conv.py @@ -171,7 +171,6 @@ def extra_repr(self): summary += ', activation={_activation}' return summary.format(**self.__dict__) - class GATConv(nn.Module): r"""Apply `Graph Attention Network `__ over an input signal. @@ -305,7 +304,7 @@ class TAGConv(nn.Module): out_feats : int Output feature size. k: int, optional - Number of hops :math: `k`. (default: 3) + Number of hops :math: `k`. (default: 2) bias: bool, optional If True, adds a learnable bias to the output. Default: ``True``. activation: callable activation function/layer or None, optional diff --git a/tests/mxnet/test_nn.py b/tests/mxnet/test_nn.py index bd2b088156bd..5292bfbab959 100644 --- a/tests/mxnet/test_nn.py +++ b/tests/mxnet/test_nn.py @@ -72,6 +72,46 @@ def test_graph_conv(): assert "h" in g.ndata check_close(g.ndata['h'], 2 * F.ones((3, 1))) +def _S2AXWb(A, N, X, W, b): + X1 = X * N + X1 = mx.nd.dot(A, X1.reshape(X1.shape[0], -1)) + X1 = X1 * N + X2 = X1 * N + X2 = mx.nd.dot(A, X2.reshape(X2.shape[0], -1)) + X2 = X2 * N + X = mx.nd.concat(X, X1, X2, dim=-1) + Y = mx.nd.dot(X, W) + + return Y + b + +def test_tagconv(): + g = dgl.DGLGraph(nx.path_graph(3)) + ctx = F.ctx() + adj = g.adjacency_matrix(ctx=ctx) + norm = mx.nd.power(g.in_degrees().astype('float32'), -0.5) + + conv = nn.TAGConv(5, 2, bias=True) + conv.initialize(ctx=ctx) + print(conv) + + # test#1: basic + h0 = F.ones((3, 5)) + h1 = conv(g, h0) + assert len(g.ndata) == 0 + assert len(g.edata) == 0 + shp = norm.shape + (1,) * (h0.ndim - 1) + norm = norm.reshape(shp).as_in_context(h0.context) + + assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.data(ctx), conv.h_bias.data(ctx))) + + conv = nn.TAGConv(5, 2) + conv.initialize(ctx=ctx) + + # test#2: basic + h0 = F.ones((3, 5)) + h1 = conv(g, h0) + assert h1.shape[-1] == 2 + def test_set2set(): g = dgl.DGLGraph(nx.path_graph(10)) ctx = F.ctx() diff --git a/tests/pytorch/test_nn.py b/tests/pytorch/test_nn.py index 7c9ae57a4e7b..3f61782aa02e 100644 --- a/tests/pytorch/test_nn.py +++ b/tests/pytorch/test_nn.py @@ -105,6 +105,7 @@ def test_tagconv(): conv = nn.TAGConv(5, 2) if F.gpu_ctx(): conv = conv.to(ctx) + # test#2: basic h0 = F.ones((3, 5)) h1 = conv(g, h0)