Skip to content

Commit

Permalink
feat: add spconv model #6
Browse files Browse the repository at this point in the history
Same architecture that was used in first version of the draft
  • Loading branch information
jsappl committed May 22, 2024
1 parent 7ec7e41 commit c80eafb
Showing 1 changed file with 49 additions and 0 deletions.
49 changes: 49 additions & 0 deletions uibk/deep_preconditioning/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Define convolutional neural network architecture for preconditioning.
Classes:
PreconditionerNet: CNN returns lower triangular matrices for preconditioning.
"""

import spconv.pytorch as spconv
import torch
from torch import nn


class PreconditionerNet(nn.Module):
"""CNN returns preconditioner for conjugate gradient solver."""

def __init__(self) -> None:
"""Initialize the network architecture."""
super().__init__()

self.layers = spconv.SparseSequential(
spconv.SparseConv2d(1, 64, 1),
nn.PReLU(),
spconv.SparseConv2d(64, 256, 2, padding=(1, 0)),
nn.PReLU(),
spconv.SparseConv2d(256, 512, 2, padding=(1, 0)),
nn.PReLU(),
spconv.SparseConv2d(512, 256, 2, padding=(0, 1)),
nn.PReLU(),
spconv.SparseConv2d(256, 64, 2, padding=(0, 1)),
nn.PReLU(),
spconv.SparseConv2d(64, 1, 1),
)

def forward(self, input_: spconv.SparseConvTensor) -> spconv.SparseConvTensor:
"""Return the `L` part of the `L @ L.T` preconditioner for the conjugate gradient solver.
Args:
input_: Sparse batch tensor representing the linear system.
Returns:
Sparse batch tensor of lower triangular matrices.
"""
interim = self.layers(input_)

(filter, ) = torch.where(interim.indices[:, 1] < interim.indices[:, 2]) # (batch, row, col)
interim.features[filter] = 0 # make the matrix lower triangular

# TODO: Check diagonal, maybe enforce positive values?

return interim

0 comments on commit c80eafb

Please sign in to comment.