Skip to content

Commit

Permalink
Project/dtdg learning (#26)
Browse files Browse the repository at this point in the history
add dtdg learning with pytorch backend.

Co-authored-by: Yuecai Zhu <[email protected]>
Co-authored-by: yuecazhu <[email protected]>
Co-authored-by: Leo <[email protected]>
  • Loading branch information
4 people authored Jan 28, 2022
1 parent 19be7c4 commit 22290a0
Show file tree
Hide file tree
Showing 36 changed files with 2,991 additions and 451 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ jobs:
# The type of runner that the job will run on
strategy:
matrix:
python-versions: [3.6, 3.7, 3.8, 3.9]
os: [ubuntu-18.04, macos-latest, windows-latest]
python-versions: [3.7, 3.8, 3.9]
os: [ubuntu-latest, macos-latest, windows-latest]
runs-on: ${{ matrix.os }}

# Steps represent a sequence of tasks that will be executed as part of the job
Expand Down
4 changes: 2 additions & 2 deletions docs/api.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# nineturn.dtdg.types module
::: nineturn.dtdg.types
#nineturn.core.config module
:::nineturn.core.config
80 changes: 80 additions & 0 deletions examples/torch/citation_prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging

from nineturn.core.backends import PYTORCH
from nineturn.core.config import set_backend
set_backend(PYTORCH)
from nineturn.dtdg.dataloader import ogb_dataset, supported_ogb_datasets
from nineturn.dtdg.models.encoder.implicitTimeEncoder.torch.staticGraphEncoder import GCN, GAT, SGCN, GraphSage
from nineturn.dtdg.models.decoder.torch.sequentialDecoder.rnnFamily import LSTM, GRU,RNN
from nineturn.dtdg.models.decoder.torch.simpleDecoder import MLP


def assembler(encoder, decoder):
return nn.Sequential(encoder,decoder).to(device)

"""
def loss_fn(predict, label):
return torch.sqrt(torch.mean(torch.abs(torch.log1p(predict) - torch.log1p(label))))
"""
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#----------------------------------------------------------------
#set up logger
this_logger = logging.getLogger('citation_predictoin_pipeline')
this_logger.setLevel(logging.INFO)
# create file handler which logs even debug messages
fh = logging.FileHandler('test2.log')
fh.setLevel(logging.DEBUG)
this_logger.addHandler(fh)
#--------------------------------------------------------
data_to_test = supported_ogb_datasets()[1]
this_graph = ogb_dataset(data_to_test)
n_snapshot = len(this_graph)
this_logger.info(f"number of snapshots: {n_snapshot}")
n_nodes = this_graph.dispatcher(n_snapshot -1).observation.num_nodes()
this_logger.info(f"number of nodes: {n_nodes}")
this_snapshot = this_graph.dispatcher(20)
in_dim = this_snapshot.num_node_features()
hidden_dim = 32
num_GNN_layers = 2
num_RNN_layers = 2
#gnn = GCN(num_GNN_layers, in_dim, hidden_dim, activation=F.leaky_relu,allow_zero_in_degree=True, dropout=0.2).to(device)
#gnn = SGCN(num_GNN_layers, in_dim, hidden_dim ,allow_zero_in_degree=True).to(device)
#gnn = GAT([1], in_dim, hidden_dim, activation=F.leaky_relu,allow_zero_in_degree=True).to(device)
gnn = GraphSage('gcn', in_dim, hidden_dim, activation=F.leaky_relu)
output_decoder = MLP(10, [10,20,10,5])
#decoder = LSTM( hidden_dim, 10,n_nodes,3,output_decoder, device)
#decoder = GRU( hidden_dim, 10,n_nodes,3,output_decoder, device)
decoder = RNN( hidden_dim, 10,n_nodes,3,output_decoder, device)
#this_model = LSTM( in_dim, 10,n_nodes,device).to(device)
this_model = assembler(gnn, decoder).to(device)
loss_fn = torch.nn.MSELoss().to(device)
optimizer = torch.optim.Adam(
[{"params": this_model.parameters()}], lr=1e-3
)
loss_list=[]
all_predictions=[]
for epoch in range(20):
this_model[1].reset_memory_state()
for t in range(5,n_snapshot-2):
this_model.train()
optimizer.zero_grad()
this_snapshot = this_graph.dispatcher(t)
next_snapshot = this_graph.dispatcher(t+1)
node_samples = torch.arange(this_snapshot.num_nodes())
predict = this_model.forward((this_snapshot, node_samples))
label = next_snapshot.node_feature()[:this_snapshot.num_nodes(), -1].float()
all_predictions.append(predict.squeeze().clone())
loss = loss_fn(predict.squeeze(), label)
loss.backward()
optimizer.step()
loss_list.append(loss.item())
print(loss_list[-1])
print(all_predictions[-1][:20])
print(label[:20])


135 changes: 135 additions & 0 deletions examples/torch/citation_prediction_batch_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim.lr_scheduler import ReduceLROnPlateau
import sys
import json
import logging
import logging.config
from nineturn.core.backends import PYTORCH
from nineturn.core.config import set_backend
set_backend(PYTORCH)
from nineturn.dtdg.types import BatchedSnapshot
from nineturn.dtdg.dataloader import ogb_dataset, supported_ogb_datasets
from nineturn.dtdg.models.encoder.implicitTimeEncoder.torch.staticGraphEncoder import GCN, GAT, SGCN, GraphSage
from nineturn.dtdg.models.decoder.torch.sequentialDecoder import LSTM
from nineturn.automl.torch.model_assembler import assembler
import dgl


#def loss_fn(predict, label):
# return torch.mean(torch.abs(predict - label))

loss_fn = torch.nn.MSELoss()

sampler = dgl.dataloading.MultiLayerFullNeighborSampler(3)


if __name__ == '__main__':
device = 'cpu'
#----------------------------------------------------------------
#set up logger
this_logger = logging.getLogger('citation_predictoin_pipeline')
this_logger.setLevel(logging.INFO)
# create file handler which logs even debug messages
fh = logging.FileHandler('node_wise_batch.log')
fh.setLevel(logging.DEBUG)
this_logger.addHandler(fh)
#--------------------------------------------------------

#-------------------------------------------------------
#load data
data_to_test = supported_ogb_datasets()[1]
this_graph = ogb_dataset(data_to_test)
#-------------------------------------------------------

#-------------------------------------------------------
#create learning model
n_snapshot = len(this_graph)
n_nodes = this_graph.dispatcher(n_snapshot -1).observation.num_nodes()
this_snapshot = this_graph.dispatcher(20)
in_dim = this_snapshot.num_node_features()
hidden_dim = 200
num_GNN_layers = 3
num_RNN_layers = 2
gnn = GCN(num_GNN_layers, in_dim, hidden_dim, activation=F.leaky_relu,allow_zero_in_degree=True, dropout=0.2).to(device)
#gnn = GAT([1], in_dim, hidden_dim, activation=F.leaky_relu,allow_zero_in_degree=True).to(device)
decoder = LSTM( hidden_dim, 20,n_nodes)
#this_model = LSTM( in_dim, 20,n_nodes,device).to(device)
this_model = assembler(gnn, decoder).to(device)
#----------------------------------------------------------

#---------------------------------------------------------
#configure training
optimizer = torch.optim.Adam(
[{"params": this_model.parameters()}], lr=1e-3
)
#scheduler = ReduceLROnPlateau(optimizer, 'min')
#---------------------------------------------------------

test_loss_list =[]
eval_loss_list = []
all_predictions = []
for epoch in range(400):
this_model[1].memory.reset_state()
#this_model[0].set_mini_batch(True)
this_model[1].set_mini_batch(True)
for t in range(2,n_snapshot-3):
this_snapshot = this_graph.dispatcher(t, True)
next_snapshot = this_graph.dispatcher(t+1, True)
node_samples = torch.arange(this_snapshot.num_nodes())
#------------------------
# batch creation
#------------------------
collator = dgl.dataloading.NodeCollator(this_snapshot.observation, node_samples, sampler)
dataloader = dgl.dataloading.NodeDataLoader(
this_snapshot.observation, node_samples, sampler,
batch_size=500,
shuffle=True,
drop_last=False,
num_workers=1)
for in_nodes, out_nodes, blocks in dataloader:
this_model.train()
optimizer.zero_grad()
sample = BatchedSnapshot(blocks, this_snapshot.node_feature()[in_nodes],this_snapshot.t)
#---------------------
_in = (sample, out_nodes)
predict = this_model.forward(_in).to(device)
label = next_snapshot.node_feature()[out_nodes, -1].float()
loss = loss_fn(predict.squeeze(), label).to(device)
loss.backward()
optimizer.step()

#---------------------------------------------------------
#turn model to inference mode
this_model[0].set_mini_batch(False)
this_model[1].set_mini_batch(False)
this_model.eval()
#---------------------------------------------------------

this_snapshot = this_graph.dispatcher(n_snapshot-3, True)
next_snapshot = this_graph.dispatcher(n_snapshot-2, True)
node_samples = torch.arange(this_snapshot.num_nodes())
predict = this_model.forward((this_snapshot, node_samples)).to(device)
label = next_snapshot.node_feature()[:this_snapshot.num_nodes(), -1].float()
loss = loss_fn(predict.squeeze(), label).to(device)
#scheduler.step(loss)
test_loss_list.append(loss.item())
this_snapshot = this_graph.dispatcher(n_snapshot-2, True)
next_snapshot = this_graph.dispatcher(n_snapshot-1, True)
node_samples = torch.arange(this_snapshot.num_nodes())
predict = this_model.forward((this_snapshot, node_samples)).to(device)
label = next_snapshot.node_feature()[:this_snapshot.num_nodes(), -1].float()
loss = loss_fn(predict.squeeze(), label).to(device)
eval_loss_list.append(loss.item())
print(test_loss_list[-1])
print(eval_loss_list[-1])
this_logger.info(predict.squeeze()[:20])
this_logger.info(label[:20])

this_logger.info(test_loss_list)
this_logger.info(eval_loss_list)


1 change: 1 addition & 0 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ lint:
mypy $(sources) tests

unittest:
poetry install -v
pytest

coverage:
Expand Down
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ nav:
- Installation: installation.md
- Usage: usage.md
- Modules:
- Utils: api.md
- nineturn.core.config: api.md
- Contributing: contributing.md
- Changelog: changelog.md
theme:
Expand Down
21 changes: 21 additions & 0 deletions nineturn/automl/torch/model_assembler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2022 The Nine Turn Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model assembler for pytorch."""
import torch.nn as nn


def assembler(encoder, decoder):
"""Combine the input encoder and decoder to a single model."""
return nn.Sequential(encoder, decoder)
4 changes: 1 addition & 3 deletions nineturn/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,4 @@
__email__ = '[email protected]'
__version__ = '0.0.0'

__all__ = ["utils"]

from nineturn.core.utils import *
import nineturn.core.config
36 changes: 36 additions & 0 deletions nineturn/core/backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Supporting backends.
This module lists the backends that we support.
Example:
>>> from nineturn.core import backends
>>> print(backends.supported_backends())
"""

from typing import List

TENSORFLOW = "tensorflow"
PYTORCH = "pytorch"


def supported_backends() -> List[str]:
"""A function to return the list of backends that Nine Turn supports.
Returns:
a list of supported banckend names in string
"""
return [TENSORFLOW, PYTORCH]
31 changes: 31 additions & 0 deletions nineturn/core/commonF.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2022 The Nine Turn Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dynamic import common functions based on backend."""
# flake8: noqa
# Dynamic import, no need for lint
from nineturn.core.backends import PYTORCH, TENSORFLOW
from nineturn.core.errors import BackendNotSupportedError
from nineturn.core.utils import _get_backend

this_backend = _get_backend()

if this_backend == TENSORFLOW:
from nineturn.core.tf_functions import _to_tensor as to_tensor

elif this_backend == PYTORCH:
from nineturn.core.torch_functions import _to_tensor as to_tensor

else:
raise BackendNotSupportedError("Backend %s not supported." % (this_backend))
Loading

0 comments on commit 22290a0

Please sign in to comment.