Skip to content

Commit

Permalink
fixed parallelization issues (the
Browse files Browse the repository at this point in the history
variable setter needed to be different per thread)
  • Loading branch information
maniospas committed Aug 27, 2024
1 parent b6bd5fc commit afb2110
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 51 deletions.
4 changes: 2 additions & 2 deletions JGNN/src/examples/graphClassification/SortPooling.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ public static void main(String[] args){
.setEpochs(300)
.setOptimizer(new Adam(0.001))
.setLoss(new CategoricalCrossEntropy())
//.setNumBatches(10)
//.setParallelizedStochasticGradientDescent(true)
.setNumBatches(10)
.setParallelizedStochasticGradientDescent(true)
.setValidationLoss(new VerboseLoss(new CategoricalCrossEntropy(), new Accuracy()));

Model model = builder.getModel()
Expand Down
8 changes: 4 additions & 4 deletions JGNN/src/main/java/mklab/JGNN/adhoc/ModelTraining.java
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public ModelTraining setValidationLoss(Loss loss) {
*
* @param optimizer The desired optimizer.
* @return <code>this</code> model training instance.
* @see #train(Model, Matrix, Matrix, Slice, Slice)
* @see #train(Model)
*/
public ModelTraining setOptimizer(Optimizer optimizer) {
if (optimizer instanceof BatchOptimizer)
Expand Down Expand Up @@ -259,7 +259,7 @@ public Model train(Model model) {
Runnable batchCode = new Runnable() {
@Override
public void run() {
for (BatchData batchData : getBatchData(batchId, epochId))
for (BatchData batchData : getBatchData(batchId, epochId))
model.train(loss, optimizer, batchData.getInputs(), batchData.getOutputs());
if (stochasticGradientDescent)
optimizer.updateAll();
Expand All @@ -275,7 +275,8 @@ public void run() {
ThreadPool.getInstance().waitForConclusion();
if (!stochasticGradientDescent)
optimizer.updateAll();

loss.onEndEpoch();

Memory.scope().enter();
double totalLoss = 0;
List<BatchData> allValidationData = getValidationData(epoch);
Expand All @@ -296,7 +297,6 @@ public void run() {

if (verbose)
System.out.println("Epoch " + epoch + " with loss " + totalLoss);
loss.onEndEpoch();
validLoss.onEndEpoch();
currentPatience -= 1;
if (currentPatience == 0)
Expand Down
2 changes: 1 addition & 1 deletion JGNN/src/main/java/mklab/JGNN/core/Matrix.java
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ protected boolean isMatching(Tensor other) {
@Override
public String describe() {
return getClass().getSimpleName() + " (" + (rowName == null ? "" : (rowName + " ")) + rows + ","
+ (colName == null ? "" : (" " + colName + " ")) + cols + ")" + " extending "+super.describe();
+ (colName == null ? "" : (" " + colName + " ")) + cols + ")";// + " extending "+super.describe();
}

/**
Expand Down
2 changes: 2 additions & 0 deletions JGNN/src/main/java/mklab/JGNN/core/ThreadPool.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ public void run() {
int threadId = getUnusedId();
if (threadId == -1)
throw new RuntimeException("Tried to instantiate thread without an available id");
//System.out.println("Starting thread #"+threadId);
threadIds.put(Thread.currentThread(), threadId);
usedIds.add(threadId);
}
Expand All @@ -67,6 +68,7 @@ public void run() {
int threadId = getCurrentThreadId();
threadIds.remove(this);
usedIds.remove(threadId);
//System.out.println("Ending thread #"+threadId);
}
}
};
Expand Down
111 changes: 73 additions & 38 deletions JGNN/src/main/java/mklab/JGNN/nn/NNOperation.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,24 @@ protected static class ThreadData {
public Tensor lastOutput;
public Tensor tapeError;
public int countTapeSources;
private int isLocked = -1;
private int threadId;
public int getThreadId() {
return threadId;
}
public synchronized void lock() {
if(isLocked!=-1)
throw new RuntimeException("Locked by thread #"+isLocked);
threadId = ThreadPool.getCurrentThreadId();;
isLocked = threadId;
}
public synchronized Tensor unlock() {
if(isLocked!=ThreadPool.getCurrentThreadId())
throw new RuntimeException("Trying to unlock a different thread");
Tensor ret = lastOutput;
isLocked = -1;
return ret;
}
}

private HashMap<Integer, ThreadData> data = new HashMap<Integer, ThreadData>();
Expand All @@ -51,11 +69,8 @@ protected ThreadData data() {
ThreadData ret;
synchronized (data) {
ret = data.get(threadId);
}
if (ret == null) {
synchronized (data) {
if (ret == null)
data.put(threadId, ret = new ThreadData());
}
}
return ret;
}
Expand All @@ -69,17 +84,18 @@ public String getDescription() {
}

/**
* Retrieves an concise description of the operation that shows metadata and
* Retrieves a concise description of the operation that shows metadata and
* potential data descriptions processed by the current thread.
*
* @return A <code>String</code> description.
* @see #setDescription(String)
* @see #view()
*/
public String describe() {
ThreadData data = data();
return this.getClass() + ": " + (description != null ? description : ("#" + this.hashCode())) + " = "
+ (data().lastOutput != null ? data().lastOutput.describe() : "NA")
+ (isConstant() ? " (constant)" : "");
+ (data.lastOutput != null ? data.lastOutput.describe() : "NA")
+ (isConstant() ? " (constant)" : "")+" in thread #"+data.getThreadId();
}

/**
Expand Down Expand Up @@ -183,9 +199,11 @@ public double getNonLinearity(int inputId, double inputMass, double outputNonLin

public final void clearPrediction() {
ThreadData data = data();
if (data.lastOutput == null)
return;
data.lastOutput = null;
synchronized(data) {
if (data.lastOutput == null)
return;
data.lastOutput = null;
}
for (NNOperation input : inputs)
input.clearPrediction();
}
Expand Down Expand Up @@ -217,8 +235,9 @@ protected boolean isInputNeededForDerivative(int inputId) {
public final Tensor runPrediction() {
try {
ThreadData data = data();
data.lock();
if (data.lastOutput != null)
return data.lastOutput;
return data.unlock();
ArrayList<Tensor> lastInputs = new ArrayList<Tensor>(inputs.size());
for (NNOperation input : inputs)
lastInputs.add(input.runPrediction());
Expand All @@ -230,9 +249,11 @@ public final Tensor runPrediction() {
* inputs.get(inputId).data().lastOutput = null;
*/
if (debugging) {
System.out.println("Predicting " + describe() + " for inputs:");
for (Tensor input : lastInputs)
System.out.println("\t" + input.describe());
synchronized(System.err) {
System.out.println("Thread "+data.getThreadId()+" Predicting " + describe() + " for inputs:");
for (Tensor input : lastInputs)
System.out.println("\t" + input.describe());
}
/*
* if(data()!=data) System.out.println(data+" -> "+data());
*/
Expand All @@ -252,13 +273,16 @@ public final Tensor runPrediction() {
constantCache = data.lastOutput;
if (debugging)
System.out.println("\t=> " + describe());
return data.lastOutput;
return data.unlock();
} catch (Exception e) {
System.err.println(e.toString());
System.err.println("During the forward pass of " + describe() + " with the following inputs:");
for (NNOperation input : inputs)
System.err.println("\t" + input.describe());
e.printStackTrace();
synchronized(System.err) {
System.err.println(e.toString());
System.err.println("In thread #"+ThreadPool.getCurrentThreadId());
System.err.println("During the forward pass of " + describe() + " with the following inputs:");
for (NNOperation input : inputs)
System.err.println("\t" + input.describe());
e.printStackTrace();
}
System.exit(1);
return null;
}
Expand Down Expand Up @@ -298,18 +322,24 @@ final void backpropagate(Optimizer optimizer, Tensor error) {
if (!inputs.get(i).isConstant())
inputs.get(i).backpropagate(optimizer, partial(i, lastInputs, data.lastOutput, data.tapeError));
trainParameters(optimizer, data.tapeError);
if (debugging)
System.out.println(
"Finished backpropagation on " + describe() + " on thread " + ThreadPool.getCurrentThreadId());
if (debugging) {
synchronized(System.err) {
System.out.println(
"Finished backpropagation on " + describe() + " on thread " + ThreadPool.getCurrentThreadId());
}
}
data.tapeError = null;
} catch (Exception e) {
System.err.println(e.toString());
System.err.println("During the backward pass of " + describe() + " with derivative:");
System.err.println("\t " + (error == null ? "null" : error.describe()));
System.err.println("and the following inputs:");
for (NNOperation input : inputs)
System.err.println("\t" + input.describe());
e.printStackTrace();
synchronized(System.err) {
System.err.println(e.toString());
System.err.println("In thread #"+ThreadPool.getCurrentThreadId());
System.err.println("During the backward pass of " + describe() + " with derivative:");
System.err.println("\t " + (error == null ? "null" : error.describe()));
System.err.println("and the following inputs:");
for (NNOperation input : inputs)
System.err.println("\t" + input.describe());
e.printStackTrace();
}
System.exit(1);
}
}
Expand Down Expand Up @@ -368,10 +398,11 @@ public String getSimpleDescription() {
}

public Tensor runPredictionAndAutosize() {
ThreadData data = data();
try {
ThreadData data = data();
data.lock();
if (data.lastOutput != null)
return data.lastOutput;
return data.unlock();
ArrayList<Tensor> lastInputs = new ArrayList<Tensor>(inputs.size());
for (NNOperation input : inputs)
lastInputs.add(input.runPredictionAndAutosize());
Expand All @@ -398,13 +429,17 @@ public Tensor runPredictionAndAutosize() {
constantCache = data.lastOutput;
if (debugging)
System.out.println("\t=> " + describe());
return data.lastOutput;
return data.unlock();
} catch (Exception e) {
System.err.println(e.toString());
System.err.println("During the forward pass of " + describe() + " with the following inputs:");
for (NNOperation input : inputs)
System.err.println("\t" + input.describe());
e.printStackTrace();
data.unlock();
synchronized(System.err) {
System.err.println(e.toString());
System.err.println("In thread #"+ThreadPool.getCurrentThreadId());
System.err.println("During the forward pass of " + describe() + " with the following inputs:");
for (NNOperation input : inputs)
System.err.println("\t" + input.describe());
e.printStackTrace();
}
System.exit(1);
return null;
}
Expand Down
34 changes: 28 additions & 6 deletions JGNN/src/main/java/mklab/JGNN/nn/inputs/Variable.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,52 @@

import mklab.JGNN.nn.NNOperation;
import mklab.JGNN.nn.Optimizer;

import java.util.HashMap;
import java.util.List;

import mklab.JGNN.core.Tensor;
import mklab.JGNN.core.ThreadPool;

/**
* Implements a {@link NNOperation} that represents {@link mklab.JGNN.nn.Model} inputs.
* Its values can be set using the {@link #setTo(Tensor)} method.
*
* @author Emmanouil Krasanakis
*/
public class Variable extends Parameter {
public class Variable extends NNOperation {
private HashMap<Integer, Tensor> threadData = new HashMap<Integer, Tensor>();
public Variable() {
super(null);
}

@Override
protected void trainParameters(Optimizer optimizer, Tensor error) {
}

public void setTo(Tensor value) {
synchronized(threadData) {
threadData.put(ThreadPool.getCurrentThreadId(), value);
}
}
@Override
protected Tensor forward(List<Tensor> inputs) {
Tensor ret;
synchronized(threadData) {
ret = threadData.get(ThreadPool.getCurrentThreadId());
}
return ret;
}
@Override
protected Tensor partial(int inputId, List<Tensor> inputs, Tensor output, Tensor error) {
return null;
}

@Override
public boolean isConstant() {
return true;
return false;
}
@Override
public boolean isCachable() {
return false;
}
public void setTo(Tensor value) {
this.tensor = value;
}
}

0 comments on commit afb2110

Please sign in to comment.