From c80eafbb8e1c8eeba43dcc548d98892ba4ac70d9 Mon Sep 17 00:00:00 2001 From: Johannes Sappl Date: Wed, 22 May 2024 18:47:26 +0200 Subject: [PATCH] feat: add spconv model #6 Same architecture that was used in first version of the draft --- uibk/deep_preconditioning/model.py | 49 ++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 uibk/deep_preconditioning/model.py diff --git a/uibk/deep_preconditioning/model.py b/uibk/deep_preconditioning/model.py new file mode 100644 index 0000000..4204214 --- /dev/null +++ b/uibk/deep_preconditioning/model.py @@ -0,0 +1,49 @@ +"""Define convolutional neural network architecture for preconditioning. + +Classes: + PreconditionerNet: CNN returns lower triangular matrices for preconditioning. +""" + +import spconv.pytorch as spconv +import torch +from torch import nn + + +class PreconditionerNet(nn.Module): + """CNN returns preconditioner for conjugate gradient solver.""" + + def __init__(self) -> None: + """Initialize the network architecture.""" + super().__init__() + + self.layers = spconv.SparseSequential( + spconv.SparseConv2d(1, 64, 1), + nn.PReLU(), + spconv.SparseConv2d(64, 256, 2, padding=(1, 0)), + nn.PReLU(), + spconv.SparseConv2d(256, 512, 2, padding=(1, 0)), + nn.PReLU(), + spconv.SparseConv2d(512, 256, 2, padding=(0, 1)), + nn.PReLU(), + spconv.SparseConv2d(256, 64, 2, padding=(0, 1)), + nn.PReLU(), + spconv.SparseConv2d(64, 1, 1), + ) + + def forward(self, input_: spconv.SparseConvTensor) -> spconv.SparseConvTensor: + """Return the `L` part of the `L @ L.T` preconditioner for the conjugate gradient solver. + + Args: + input_: Sparse batch tensor representing the linear system. + + Returns: + Sparse batch tensor of lower triangular matrices. + """ + interim = self.layers(input_) + + (filter, ) = torch.where(interim.indices[:, 1] < interim.indices[:, 2]) # (batch, row, col) + interim.features[filter] = 0 # make the matrix lower triangular + + # TODO: Check diagonal, maybe enforce positive values? + + return interim