diff --git a/JGNN/src/examples/classification/LogisticRegression.java b/JGNN/src/examples/classification/LogisticRegression.java index 54a0b8e2..83571724 100644 --- a/JGNN/src/examples/classification/LogisticRegression.java +++ b/JGNN/src/examples/classification/LogisticRegression.java @@ -4,6 +4,7 @@ import mklab.JGNN.adhoc.ModelBuilder; import mklab.JGNN.adhoc.ModelTraining; import mklab.JGNN.adhoc.datasets.Citeseer; +import mklab.JGNN.adhoc.train.SampleClassification; import mklab.JGNN.core.Matrix; import mklab.JGNN.nn.Model; import mklab.JGNN.nn.loss.Accuracy; @@ -44,7 +45,7 @@ public static void main(String[] args) { long tic = System.currentTimeMillis(); - Model model = new ModelTraining() + Model model = new SampleClassification() .setOptimizer(new GradientDescent(0.01)) .setEpochs(600) .setNumBatches(10) diff --git a/JGNN/src/examples/classification/MLP.java b/JGNN/src/examples/classification/MLP.java index 093f2fbf..55fd55f2 100644 --- a/JGNN/src/examples/classification/MLP.java +++ b/JGNN/src/examples/classification/MLP.java @@ -4,6 +4,7 @@ import mklab.JGNN.adhoc.ModelBuilder; import mklab.JGNN.adhoc.ModelTraining; import mklab.JGNN.adhoc.datasets.Citeseer; +import mklab.JGNN.adhoc.train.SampleClassification; import mklab.JGNN.core.Matrix; import mklab.JGNN.nn.Model; import mklab.JGNN.core.Slice; @@ -11,6 +12,7 @@ import mklab.JGNN.nn.initializers.XavierNormal; import mklab.JGNN.nn.loss.Accuracy; import mklab.JGNN.nn.loss.BinaryCrossEntropy; +import mklab.JGNN.nn.loss.report.VerboseLoss; import mklab.JGNN.nn.optimizers.Adam; /** @@ -42,20 +44,24 @@ public static void main(String[] args) { Slice nodeIds = dataset.samples().getSlice().shuffle(100); - long tic = System.currentTimeMillis(); - Model model = new ModelTraining() + Slice nodes = dataset.samples().getSlice().shuffle(100); + ModelTraining trainer = new SampleClassification() + .setFeatures(dataset.features()) + .setOutputs(dataset.labels()) + .setTrainingSamples(nodes.range(0, 0.6)) + .setValidationSamples(nodes.range(0.6, 0.8)) .setOptimizer(new Adam(0.01)) .setEpochs(3000) .setPatience(300) .setNumBatches(20) .setParallelizedStochasticGradientDescent(true) .setLoss(new BinaryCrossEntropy()) - .setVerbose(true) - .setValidationLoss(new Accuracy()) - .train(new XavierNormal().apply(modelBuilder.getModel()), - dataset.features(), - dataset.labels(), - nodeIds.range(0, 0.7), nodeIds.range(0.7, 0.8)); + .setValidationLoss(new VerboseLoss(new Accuracy())); + + long tic = System.currentTimeMillis(); + Model model = modelBuilder.getModel() + .init(new XavierNormal()) + .train(trainer); long toc = System.currentTimeMillis(); double acc = 0; diff --git a/JGNN/src/examples/graphClassification/SortPooling.java b/JGNN/src/examples/graphClassification/SortPooling.java index 2dfa8ae5..b9f23067 100644 --- a/JGNN/src/examples/graphClassification/SortPooling.java +++ b/JGNN/src/examples/graphClassification/SortPooling.java @@ -3,14 +3,18 @@ import java.util.Arrays; import mklab.JGNN.adhoc.ModelBuilder; +import mklab.JGNN.adhoc.ModelTraining; import mklab.JGNN.adhoc.parsers.LayeredBuilder; +import mklab.JGNN.adhoc.train.AGFTraining; import mklab.JGNN.core.Matrix; import mklab.JGNN.core.Tensor; import mklab.JGNN.core.ThreadPool; import mklab.JGNN.nn.Loss; import mklab.JGNN.nn.Model; import mklab.JGNN.nn.initializers.XavierNormal; +import mklab.JGNN.nn.loss.Accuracy; import mklab.JGNN.nn.loss.CategoricalCrossEntropy; +import mklab.JGNN.nn.loss.report.VerboseLoss; import mklab.JGNN.nn.optimizers.Adam; import mklab.JGNN.nn.optimizers.BatchOptimizer; @@ -45,40 +49,30 @@ public static void main(String[] args){ TrajectoryData dtrain = new TrajectoryData(8000); TrajectoryData dtest = new TrajectoryData(2000); - Model model = builder.getModel().init(new XavierNormal()); - BatchOptimizer optimizer = new BatchOptimizer(new Adam(0.01)); - Loss loss = new CategoricalCrossEntropy(); - for(int epoch=0; epoch<600; epoch++) { - // gradient update over all graphs - for(int graphId=0; graphId