Skip to content

Commit

Permalink
[NN] Add MXNet impl for TAGCN module. (#799)
Browse files Browse the repository at this point in the history
* upd

* fig edgebatch edges

* add test

* trigger

* Update README.md for pytorch PinSage example.

Add noting that the PinSage model example under
example/pytorch/recommendation only work with Python 3.6+
as its dataset loader depends on stanfordnlp package
which work only with Python 3.6+.

* Provid a frame agnostic API to test nn modules on both CPU and CUDA side.

1. make dgl.nn.xxx frame agnostic
2. make test.backend include dgl.nn modules
3. modify test_edge_softmax of test/mxnet/test_nn.py and
    test/pytorch/test_nn.py work on both CPU and GPU

* Fix style

* Delete unused code

* Make agnostic test only related to tests/backend

1. clear all agnostic related code in dgl.nn
2. make test_graph_conv agnostic to cpu/gpu

* Fix code style

* fix

* doc

* Make all test code under tests.mxnet/pytorch.test_nn.py
work on both CPU and GPU.

* Fix syntex

* Remove rand

* Add TAGCN nn.module and example

* Now tagcn can run on CPU.

* Add unitest for TGConv

* Fix style

* For pubmed dataset, using --lr=0.005 can achieve better acc

* Fix style

* Fix some descriptions

* trigger

* Fix doc

* Add nn.TGConv and example

* Fix bug

* Update data in mxnet.tagcn test acc.

* Fix some comments and code

* delete useless code

* Fix namming

* Fix bug

* Fix bug

* Add test code for mxnet TAGCov

* Update some docs

* Fix some code

* Update docs dgl.nn.mxnet

* Update weight init

* Fix
  • Loading branch information
classicsong authored and yzh119 committed Aug 28, 2019
1 parent 14bffe9 commit e17add5
Show file tree
Hide file tree
Showing 9 changed files with 337 additions and 9 deletions.
4 changes: 4 additions & 0 deletions docs/source/api/python/nn.mxnet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ dgl.nn.mxnet.conv
:members: forward
:show-inheritance:

.. autoclass:: dgl.nn.mxnet.conv.TAGConv
:members: forward
:show-inheritance:

dgl.nn.mxnet.glob
-----------------

Expand Down
25 changes: 25 additions & 0 deletions examples/mxnet/tagcn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
Topology Adaptive Graph Convolutional networks (TAGCN)
============

- Paper link: [https://arxiv.org/abs/1710.10370](https://arxiv.org/abs/1710.10370)

Dependencies
------------
- MXNet nightly build
- requests

``bash
pip install mxnet --pre
pip install requests
``

Results
-------
Run with following (available dataset: "cora", "citeseer", "pubmed")
```bash
DGLBACKEND=mxnet python3 train.py --dataset cora --gpu 0 --self-loop
```

* cora: ~0.820 (paper: 0.833)
* citeseer: ~0.702 (paper: 0.714)
* pubmed: ~0.798 (paper: 0.811)
39 changes: 39 additions & 0 deletions examples/mxnet/tagcn/tagcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""TAGCN using DGL nn package
References:
- Topology Adaptive Graph Convolutional Networks
- Paper: https://arxiv.org/abs/1710.10370
"""
import mxnet as mx
from mxnet import gluon
import dgl
from dgl.nn.mxnet import TAGConv

class TAGCN(gluon.Block):
def __init__(self,
g,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout):
super(TAGCN, self).__init__()
self.g = g
self.layers = gluon.nn.Sequential()
# input layer
self.layers.add(TAGConv(in_feats, n_hidden, activation=activation))
# hidden layers
for i in range(n_layers - 1):
self.layers.add(TAGConv(n_hidden, n_hidden, activation=activation))
# output layer
self.layers.add(TAGConv(n_hidden, n_classes)) #activation=None
self.dropout = gluon.nn.Dropout(rate=dropout)

def forward(self, features):
h = features
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(self.g, h)
return h
126 changes: 126 additions & 0 deletions examples/mxnet/tagcn/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import argparse, time
import numpy as np
import mxnet as mx
from mxnet import gluon

from dgl import DGLGraph
from dgl.data import register_data_args, load_data

from tagcn import TAGCN

def evaluate(model, features, labels, mask):
pred = model(features).argmax(axis=1)
accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar()
return accuracy.asscalar()

def main(args):
# load and preprocess dataset
data = load_data(args)
features = mx.nd.array(data.features)
labels = mx.nd.array(data.labels)
train_mask = mx.nd.array(data.train_mask)
val_mask = mx.nd.array(data.val_mask)
test_mask = mx.nd.array(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
print("""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
train_mask.sum().asscalar(),
val_mask.sum().asscalar(),
test_mask.sum().asscalar()))

if args.gpu < 0:
cuda = False
ctx = mx.cpu(0)
else:
cuda = True
ctx = mx.gpu(args.gpu)

features = features.as_in_context(ctx)
labels = labels.as_in_context(ctx)
train_mask = train_mask.as_in_context(ctx)
val_mask = val_mask.as_in_context(ctx)
test_mask = test_mask.as_in_context(ctx)

# graph preprocess and calculate normalization factor
g = data.graph
# add self loop
if args.self_loop:
g.remove_edges_from(g.selfloop_edges())
g.add_edges_from(zip(g.nodes(), g.nodes()))
g = DGLGraph(g)

# create TAGCN model
model = TAGCN(g,
in_feats,
args.n_hidden,
n_classes,
args.n_layers,
mx.nd.relu,
args.dropout)

model.initialize(ctx=ctx)
n_train_samples = train_mask.sum().asscalar()
loss_fcn = gluon.loss.SoftmaxCELoss()

# use optimizer
print(model.collect_params())
trainer = gluon.Trainer(model.collect_params(), 'adam',
{'learning_rate': args.lr, 'wd': args.weight_decay})

# initialize graph
dur = []
for epoch in range(args.n_epochs):
if epoch >= 3:
t0 = time.time()
# forward
with mx.autograd.record():
pred = model(features)
loss = loss_fcn(pred, labels, mx.nd.expand_dims(train_mask, 1))
loss = loss.sum() / n_train_samples

loss.backward()
trainer.step(batch_size=1)

if epoch >= 3:
loss.asscalar()
dur.append(time.time() - t0)
acc = evaluate(model, features, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}". format(
epoch, np.mean(dur), loss.asscalar(), acc, n_edges / np.mean(dur) / 1000))

print()
acc = evaluate(model, features, labels, val_mask)
print("Test accuracy {:.2%}".format(acc))

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='TAGCN')
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-2,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden tagcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden tagcn layers")
parser.add_argument("--weight-decay", type=float, default=5e-4,
help="Weight for L2 loss")
parser.add_argument("--self-loop", action='store_true',
help="graph self-loop (default=False)")
parser.set_defaults(self_loop=False)
args = parser.parse_args()
print(args)

main(args)
7 changes: 3 additions & 4 deletions examples/pytorch/tagcn/tagcn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""GCN using DGL nn package
"""TAGCN using DGL nn package
References:
- Semi-Supervised Classification with Graph Convolutional Networks
- Paper: https://arxiv.org/abs/1609.02907
- Code: https://github.com/tkipf/gcn
- Topology Adaptive Graph Convolutional Networks
- Paper: https://arxiv.org/abs/1710.10370
"""
import torch
import torch.nn as nn
Expand Down
101 changes: 98 additions & 3 deletions python/dgl/nn/mxnet/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from . import utils
from ... import function as fn

__all__ = ['GraphConv', 'RelGraphConv']
__all__ = ['GraphConv', 'TAGConv', 'RelGraphConv']

class GraphConv(gluon.Block):
r"""Apply graph convolution over an input signal.
Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(self,

with self.name_scope():
self.weight = self.params.get('weight', shape=(in_feats, out_feats),
init=mx.init.Xavier())
init=mx.init.Xavier(magnitude=math.sqrt(2.0)))
if bias:
self.bias = self.params.get('bias', shape=(out_feats,),
init=mx.init.Zero())
Expand Down Expand Up @@ -108,7 +108,7 @@ def forward(self, graph, feat):
graph = graph.local_var()
if self._norm:
degs = graph.in_degrees().astype('float32')
norm = mx.nd.power(degs, -0.5)
norm = mx.nd.power(mx.nd.clip(degs, a_min=1, a_max=float("inf")), -0.5)
shp = norm.shape + (1,) * (feat.ndim - 1)
norm = norm.reshape(shp).as_in_context(feat.context)
feat = feat * norm
Expand Down Expand Up @@ -147,6 +147,101 @@ def __repr__(self):
summary += '\n)'
return summary

class TAGConv(gluon.Block):
r"""Apply Topology Adaptive Graph Convolutional Network
.. math::
\mathbf{X}^{\prime} = \sum_{k=0}^K \mathbf{D}^{-1/2} \mathbf{A}
\mathbf{D}^{-1/2}\mathbf{X} \mathbf{\Theta}_{k},
where :math:`\mathbf{A}` denotes the adjacency matrix and
:math:`D_{ii} = \sum_{j=0} A_{ij}` its diagonal degree matrix.
Parameters
----------
in_feats : int
Number of input features.
out_feats : int
Number of output features.
k: int, optional
Number of hops :math: `k`. (default: 2)
bias: bool, optional
If True, adds a learnable bias to the output. Default: ``True``.
activation: callable activation function/layer or None, optional
If not None, applies an activation function to the updated node features.
Default: ``None``.
Attributes
----------
lin : mxnet.gluon.parameter.Parameter
The learnable weight tensor.
bias : mxnet.gluon.parameter.Parameter
The learnable bias tensor.
"""
def __init__(self,
in_feats,
out_feats,
k=2,
bias=True,
activation=None):
super(TAGConv, self).__init__()
self.out_feats = out_feats
self.k = k
self.bias = bias
self.activation = activation
self.in_feats = in_feats

self.lin = self.params.get(
'weight', shape=(self.in_feats * (self.k + 1), self.out_feats),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)))
if self.bias:
self.h_bias = self.params.get('bias', shape=(out_feats,),
init=mx.init.Zero())

def forward(self, graph, feat):
r"""Compute graph convolution
Parameters
----------
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
Returns
-------
mxnet.NDArray
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature.
"""
graph = graph.local_var()

degs = graph.in_degrees().astype('float32')
norm = mx.nd.power(mx.nd.clip(degs, a_min=1, a_max=float("inf")), -0.5)
shp = norm.shape + (1,) * (feat.ndim - 1)
norm = norm.reshape(shp).as_in_context(feat.context)

rst = feat
for _ in range(self.k):
rst = rst * norm
graph.ndata['h'] = rst

graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
rst = graph.ndata['h']
rst = rst * norm
feat = mx.nd.concat(feat, rst, dim=-1)

rst = mx.nd.dot(feat, self.lin.data(feat.context))
if self.bias is not None:
rst = rst + self.h_bias.data(rst.context)

if self.activation is not None:
rst = self.activation(rst)

return rst

class RelGraphConv(gluon.Block):
r"""Relational graph convolution layer.
Expand Down
3 changes: 1 addition & 2 deletions python/dgl/nn/pytorch/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ def extra_repr(self):
summary += ', activation={_activation}'
return summary.format(**self.__dict__)


class GATConv(nn.Module):
r"""Apply `Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>`__
over an input signal.
Expand Down Expand Up @@ -305,7 +304,7 @@ class TAGConv(nn.Module):
out_feats : int
Output feature size.
k: int, optional
Number of hops :math: `k`. (default: 3)
Number of hops :math: `k`. (default: 2)
bias: bool, optional
If True, adds a learnable bias to the output. Default: ``True``.
activation: callable activation function/layer or None, optional
Expand Down
Loading

0 comments on commit e17add5

Please sign in to comment.