-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add dtdg learning with pytorch backend. Co-authored-by: Yuecai Zhu <[email protected]> Co-authored-by: yuecazhu <[email protected]> Co-authored-by: Leo <[email protected]>
- Loading branch information
1 parent
19be7c4
commit 22290a0
Showing
36 changed files
with
2,991 additions
and
451 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
# nineturn.dtdg.types module | ||
::: nineturn.dtdg.types | ||
#nineturn.core.config module | ||
:::nineturn.core.config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import logging | ||
|
||
from nineturn.core.backends import PYTORCH | ||
from nineturn.core.config import set_backend | ||
set_backend(PYTORCH) | ||
from nineturn.dtdg.dataloader import ogb_dataset, supported_ogb_datasets | ||
from nineturn.dtdg.models.encoder.implicitTimeEncoder.torch.staticGraphEncoder import GCN, GAT, SGCN, GraphSage | ||
from nineturn.dtdg.models.decoder.torch.sequentialDecoder.rnnFamily import LSTM, GRU,RNN | ||
from nineturn.dtdg.models.decoder.torch.simpleDecoder import MLP | ||
|
||
|
||
def assembler(encoder, decoder): | ||
return nn.Sequential(encoder,decoder).to(device) | ||
|
||
""" | ||
def loss_fn(predict, label): | ||
return torch.sqrt(torch.mean(torch.abs(torch.log1p(predict) - torch.log1p(label)))) | ||
""" | ||
if __name__ == '__main__': | ||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
#---------------------------------------------------------------- | ||
#set up logger | ||
this_logger = logging.getLogger('citation_predictoin_pipeline') | ||
this_logger.setLevel(logging.INFO) | ||
# create file handler which logs even debug messages | ||
fh = logging.FileHandler('test2.log') | ||
fh.setLevel(logging.DEBUG) | ||
this_logger.addHandler(fh) | ||
#-------------------------------------------------------- | ||
data_to_test = supported_ogb_datasets()[1] | ||
this_graph = ogb_dataset(data_to_test) | ||
n_snapshot = len(this_graph) | ||
this_logger.info(f"number of snapshots: {n_snapshot}") | ||
n_nodes = this_graph.dispatcher(n_snapshot -1).observation.num_nodes() | ||
this_logger.info(f"number of nodes: {n_nodes}") | ||
this_snapshot = this_graph.dispatcher(20) | ||
in_dim = this_snapshot.num_node_features() | ||
hidden_dim = 32 | ||
num_GNN_layers = 2 | ||
num_RNN_layers = 2 | ||
#gnn = GCN(num_GNN_layers, in_dim, hidden_dim, activation=F.leaky_relu,allow_zero_in_degree=True, dropout=0.2).to(device) | ||
#gnn = SGCN(num_GNN_layers, in_dim, hidden_dim ,allow_zero_in_degree=True).to(device) | ||
#gnn = GAT([1], in_dim, hidden_dim, activation=F.leaky_relu,allow_zero_in_degree=True).to(device) | ||
gnn = GraphSage('gcn', in_dim, hidden_dim, activation=F.leaky_relu) | ||
output_decoder = MLP(10, [10,20,10,5]) | ||
#decoder = LSTM( hidden_dim, 10,n_nodes,3,output_decoder, device) | ||
#decoder = GRU( hidden_dim, 10,n_nodes,3,output_decoder, device) | ||
decoder = RNN( hidden_dim, 10,n_nodes,3,output_decoder, device) | ||
#this_model = LSTM( in_dim, 10,n_nodes,device).to(device) | ||
this_model = assembler(gnn, decoder).to(device) | ||
loss_fn = torch.nn.MSELoss().to(device) | ||
optimizer = torch.optim.Adam( | ||
[{"params": this_model.parameters()}], lr=1e-3 | ||
) | ||
loss_list=[] | ||
all_predictions=[] | ||
for epoch in range(20): | ||
this_model[1].reset_memory_state() | ||
for t in range(5,n_snapshot-2): | ||
this_model.train() | ||
optimizer.zero_grad() | ||
this_snapshot = this_graph.dispatcher(t) | ||
next_snapshot = this_graph.dispatcher(t+1) | ||
node_samples = torch.arange(this_snapshot.num_nodes()) | ||
predict = this_model.forward((this_snapshot, node_samples)) | ||
label = next_snapshot.node_feature()[:this_snapshot.num_nodes(), -1].float() | ||
all_predictions.append(predict.squeeze().clone()) | ||
loss = loss_fn(predict.squeeze(), label) | ||
loss.backward() | ||
optimizer.step() | ||
loss_list.append(loss.item()) | ||
print(loss_list[-1]) | ||
print(all_predictions[-1][:20]) | ||
print(label[:20]) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from torch.optim.lr_scheduler import ReduceLROnPlateau | ||
import sys | ||
import json | ||
import logging | ||
import logging.config | ||
from nineturn.core.backends import PYTORCH | ||
from nineturn.core.config import set_backend | ||
set_backend(PYTORCH) | ||
from nineturn.dtdg.types import BatchedSnapshot | ||
from nineturn.dtdg.dataloader import ogb_dataset, supported_ogb_datasets | ||
from nineturn.dtdg.models.encoder.implicitTimeEncoder.torch.staticGraphEncoder import GCN, GAT, SGCN, GraphSage | ||
from nineturn.dtdg.models.decoder.torch.sequentialDecoder import LSTM | ||
from nineturn.automl.torch.model_assembler import assembler | ||
import dgl | ||
|
||
|
||
#def loss_fn(predict, label): | ||
# return torch.mean(torch.abs(predict - label)) | ||
|
||
loss_fn = torch.nn.MSELoss() | ||
|
||
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(3) | ||
|
||
|
||
if __name__ == '__main__': | ||
device = 'cpu' | ||
#---------------------------------------------------------------- | ||
#set up logger | ||
this_logger = logging.getLogger('citation_predictoin_pipeline') | ||
this_logger.setLevel(logging.INFO) | ||
# create file handler which logs even debug messages | ||
fh = logging.FileHandler('node_wise_batch.log') | ||
fh.setLevel(logging.DEBUG) | ||
this_logger.addHandler(fh) | ||
#-------------------------------------------------------- | ||
|
||
#------------------------------------------------------- | ||
#load data | ||
data_to_test = supported_ogb_datasets()[1] | ||
this_graph = ogb_dataset(data_to_test) | ||
#------------------------------------------------------- | ||
|
||
#------------------------------------------------------- | ||
#create learning model | ||
n_snapshot = len(this_graph) | ||
n_nodes = this_graph.dispatcher(n_snapshot -1).observation.num_nodes() | ||
this_snapshot = this_graph.dispatcher(20) | ||
in_dim = this_snapshot.num_node_features() | ||
hidden_dim = 200 | ||
num_GNN_layers = 3 | ||
num_RNN_layers = 2 | ||
gnn = GCN(num_GNN_layers, in_dim, hidden_dim, activation=F.leaky_relu,allow_zero_in_degree=True, dropout=0.2).to(device) | ||
#gnn = GAT([1], in_dim, hidden_dim, activation=F.leaky_relu,allow_zero_in_degree=True).to(device) | ||
decoder = LSTM( hidden_dim, 20,n_nodes) | ||
#this_model = LSTM( in_dim, 20,n_nodes,device).to(device) | ||
this_model = assembler(gnn, decoder).to(device) | ||
#---------------------------------------------------------- | ||
|
||
#--------------------------------------------------------- | ||
#configure training | ||
optimizer = torch.optim.Adam( | ||
[{"params": this_model.parameters()}], lr=1e-3 | ||
) | ||
#scheduler = ReduceLROnPlateau(optimizer, 'min') | ||
#--------------------------------------------------------- | ||
|
||
test_loss_list =[] | ||
eval_loss_list = [] | ||
all_predictions = [] | ||
for epoch in range(400): | ||
this_model[1].memory.reset_state() | ||
#this_model[0].set_mini_batch(True) | ||
this_model[1].set_mini_batch(True) | ||
for t in range(2,n_snapshot-3): | ||
this_snapshot = this_graph.dispatcher(t, True) | ||
next_snapshot = this_graph.dispatcher(t+1, True) | ||
node_samples = torch.arange(this_snapshot.num_nodes()) | ||
#------------------------ | ||
# batch creation | ||
#------------------------ | ||
collator = dgl.dataloading.NodeCollator(this_snapshot.observation, node_samples, sampler) | ||
dataloader = dgl.dataloading.NodeDataLoader( | ||
this_snapshot.observation, node_samples, sampler, | ||
batch_size=500, | ||
shuffle=True, | ||
drop_last=False, | ||
num_workers=1) | ||
for in_nodes, out_nodes, blocks in dataloader: | ||
this_model.train() | ||
optimizer.zero_grad() | ||
sample = BatchedSnapshot(blocks, this_snapshot.node_feature()[in_nodes],this_snapshot.t) | ||
#--------------------- | ||
_in = (sample, out_nodes) | ||
predict = this_model.forward(_in).to(device) | ||
label = next_snapshot.node_feature()[out_nodes, -1].float() | ||
loss = loss_fn(predict.squeeze(), label).to(device) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
#--------------------------------------------------------- | ||
#turn model to inference mode | ||
this_model[0].set_mini_batch(False) | ||
this_model[1].set_mini_batch(False) | ||
this_model.eval() | ||
#--------------------------------------------------------- | ||
|
||
this_snapshot = this_graph.dispatcher(n_snapshot-3, True) | ||
next_snapshot = this_graph.dispatcher(n_snapshot-2, True) | ||
node_samples = torch.arange(this_snapshot.num_nodes()) | ||
predict = this_model.forward((this_snapshot, node_samples)).to(device) | ||
label = next_snapshot.node_feature()[:this_snapshot.num_nodes(), -1].float() | ||
loss = loss_fn(predict.squeeze(), label).to(device) | ||
#scheduler.step(loss) | ||
test_loss_list.append(loss.item()) | ||
this_snapshot = this_graph.dispatcher(n_snapshot-2, True) | ||
next_snapshot = this_graph.dispatcher(n_snapshot-1, True) | ||
node_samples = torch.arange(this_snapshot.num_nodes()) | ||
predict = this_model.forward((this_snapshot, node_samples)).to(device) | ||
label = next_snapshot.node_feature()[:this_snapshot.num_nodes(), -1].float() | ||
loss = loss_fn(predict.squeeze(), label).to(device) | ||
eval_loss_list.append(loss.item()) | ||
print(test_loss_list[-1]) | ||
print(eval_loss_list[-1]) | ||
this_logger.info(predict.squeeze()[:20]) | ||
this_logger.info(label[:20]) | ||
|
||
this_logger.info(test_loss_list) | ||
this_logger.info(eval_loss_list) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ lint: | |
mypy $(sources) tests | ||
|
||
unittest: | ||
poetry install -v | ||
pytest | ||
|
||
coverage: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Copyright 2022 The Nine Turn Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Model assembler for pytorch.""" | ||
import torch.nn as nn | ||
|
||
|
||
def assembler(encoder, decoder): | ||
"""Combine the input encoder and decoder to a single model.""" | ||
return nn.Sequential(encoder, decoder) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,4 @@ | |
__email__ = '[email protected]' | ||
__version__ = '0.0.0' | ||
|
||
__all__ = ["utils"] | ||
|
||
from nineturn.core.utils import * | ||
import nineturn.core.config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Supporting backends. | ||
This module lists the backends that we support. | ||
Example: | ||
>>> from nineturn.core import backends | ||
>>> print(backends.supported_backends()) | ||
""" | ||
|
||
from typing import List | ||
|
||
TENSORFLOW = "tensorflow" | ||
PYTORCH = "pytorch" | ||
|
||
|
||
def supported_backends() -> List[str]: | ||
"""A function to return the list of backends that Nine Turn supports. | ||
Returns: | ||
a list of supported banckend names in string | ||
""" | ||
return [TENSORFLOW, PYTORCH] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Copyright 2022 The Nine Turn Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Dynamic import common functions based on backend.""" | ||
# flake8: noqa | ||
# Dynamic import, no need for lint | ||
from nineturn.core.backends import PYTORCH, TENSORFLOW | ||
from nineturn.core.errors import BackendNotSupportedError | ||
from nineturn.core.utils import _get_backend | ||
|
||
this_backend = _get_backend() | ||
|
||
if this_backend == TENSORFLOW: | ||
from nineturn.core.tf_functions import _to_tensor as to_tensor | ||
|
||
elif this_backend == PYTORCH: | ||
from nineturn.core.torch_functions import _to_tensor as to_tensor | ||
|
||
else: | ||
raise BackendNotSupportedError("Backend %s not supported." % (this_backend)) |
Oops, something went wrong.