Skip to content

Commit

Permalink
Merge branch 'dev' into feature-fuzzy-loss
Browse files Browse the repository at this point in the history
  • Loading branch information
sfluegel05 authored Dec 19, 2024
2 parents 62a49eb + 74f3ab9 commit 7639179
Show file tree
Hide file tree
Showing 87 changed files with 14,932 additions and 150 deletions.
22 changes: 22 additions & 0 deletions .github/workflows/export_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import json

from chebai.preprocessing.reader import (
CLS_TOKEN,
EMBEDDING_OFFSET,
MASK_TOKEN_INDEX,
PADDING_TOKEN_INDEX,
)

# Define the constants you want to export
# Any changes in the key names here should also follow the same change in verify_constants.yml code
constants = {
"EMBEDDING_OFFSET": EMBEDDING_OFFSET,
"CLS_TOKEN": CLS_TOKEN,
"PADDING_TOKEN_INDEX": PADDING_TOKEN_INDEX,
"MASK_TOKEN_INDEX": MASK_TOKEN_INDEX,
}

if __name__ == "__main__":
# Write constants to a JSON file
with open("constants.json", "w") as f:
json.dump(constants, f)
27 changes: 27 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: Unittests

on: [pull_request]

jobs:
build:

runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install --upgrade pip setuptools wheel
python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
python -m pip install -e .
- name: Display Python version
run: python -m unittest discover -s tests/unit
128 changes: 128 additions & 0 deletions .github/workflows/token_consistency.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
name: Check consistency of tokens.txt file

# Define the file paths under `paths` to trigger this check only when specific files are modified.
# This script will then execute checks only on files that have changed, rather than all files listed in `paths`.

# **Note** : To add a new token file for checks, include its path in:
# - `on` -> `push` and `pull_request` sections
# - `jobs` -> `check_tokens` -> `steps` -> Set global variable for multiple tokens.txt paths -> `TOKENS_FILES`

on:
push:
paths:
- "chebai/preprocessing/bin/smiles_token/tokens.txt"
- "chebai/preprocessing/bin/smiles_token_unlabeled/tokens.txt"
- "chebai/preprocessing/bin/selfies/tokens.txt"
- "chebai/preprocessing/bin/protein_token/tokens.txt"
- "chebai/preprocessing/bin/graph_properties/tokens.txt"
- "chebai/preprocessing/bin/graph/tokens.txt"
- "chebai/preprocessing/bin/deepsmiles_token/tokens.txt"
- "chebai/preprocessing/bin/protein_token_3_gram/tokens.txt"
pull_request:
paths:
- "chebai/preprocessing/bin/smiles_token/tokens.txt"
- "chebai/preprocessing/bin/smiles_token_unlabeled/tokens.txt"
- "chebai/preprocessing/bin/selfies/tokens.txt"
- "chebai/preprocessing/bin/protein_token/tokens.txt"
- "chebai/preprocessing/bin/graph_properties/tokens.txt"
- "chebai/preprocessing/bin/graph/tokens.txt"
- "chebai/preprocessing/bin/deepsmiles_token/tokens.txt"
- "chebai/preprocessing/bin/protein_token_3_gram/tokens.txt"

jobs:
check_tokens:
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v2

- name: Get list of changed files
id: changed_files
run: |
git fetch origin dev
# Get the list of changed files compared to origin/dev and save them to a file
git diff --name-only origin/dev > changed_files.txt
# Print the names of changed files on separate lines
echo "Changed files:"
while read -r line; do
echo "Changed File name : $line"
done < changed_files.txt
- name: Set global variable for multiple tokens.txt paths
run: |
# All token files that needs to checked must be included here too, same as in `paths`.
TOKENS_FILES=(
"chebai/preprocessing/bin/smiles_token/tokens.txt"
"chebai/preprocessing/bin/smiles_token_unlabeled/tokens.txt"
"chebai/preprocessing/bin/selfies/tokens.txt"
"chebai/preprocessing/bin/protein_token/tokens.txt"
"chebai/preprocessing/bin/graph_properties/tokens.txt"
"chebai/preprocessing/bin/graph/tokens.txt"
"chebai/preprocessing/bin/deepsmiles_token/tokens.txt"
"chebai/preprocessing/bin/protein_token_3_gram/tokens.txt"
)
echo "TOKENS_FILES=${TOKENS_FILES[*]}" >> $GITHUB_ENV
- name: Process only changed tokens.txt files
run: |
# Convert the TOKENS_FILES environment variable into an array
TOKENS_FILES=(${TOKENS_FILES})
# Iterate over each token file path
for TOKENS_FILE_PATH in "${TOKENS_FILES[@]}"; do
# Check if the current token file path is in the list of changed files
if grep -q "$TOKENS_FILE_PATH" changed_files.txt; then
echo "----------------------- Processing $TOKENS_FILE_PATH -----------------------"
# Get previous tokens.txt version
git fetch origin dev
git diff origin/dev -- $TOKENS_FILE_PATH > tokens_diff.txt || echo "No previous tokens.txt found for $TOKENS_FILE_PATH"
# Check for deleted or added lines in tokens.txt
if [ -f tokens_diff.txt ]; then
# Check for deleted lines (lines starting with '-')
deleted_lines=$(grep '^-' tokens_diff.txt | grep -v '^---' | sed 's/^-//' || true)
if [ -n "$deleted_lines" ]; then
echo "Error: Lines have been deleted from $TOKENS_FILE_PATH."
echo -e "Deleted Lines: \n$deleted_lines"
exit 1
fi
# Check for added lines (lines starting with '+')
added_lines=$(grep '^+' tokens_diff.txt | grep -v '^+++' | sed 's/^+//' || true)
if [ -n "$added_lines" ]; then
# Count how many lines have been added
num_added_lines=$(echo "$added_lines" | wc -l)
# Get last `n` lines (equal to num_added_lines) of tokens.txt
last_lines=$(tail -n "$num_added_lines" $TOKENS_FILE_PATH)
# Check if the added lines are at the end of the file
if [ "$added_lines" != "$last_lines" ]; then
# Find lines that were added but not appended at the end of the file
non_appended_lines=$(diff <(echo "$added_lines") <(echo "$last_lines") | grep '^<' | sed 's/^< //')
echo "Error: New lines have been added to $TOKENS_FILE_PATH, but they are not at the end of the file."
echo -e "Added lines that are not at the end of the file: \n$non_appended_lines"
exit 1
fi
fi
if [ "$added_lines" == "" ]; then
echo "$TOKENS_FILE_PATH validation successful: No lines were deleted, and no new lines were added."
else
echo "$TOKENS_FILE_PATH validation successful: No lines were deleted, and new lines were correctly appended at the end."
fi
else
echo "No previous version of $TOKENS_FILE_PATH found."
fi
else
echo "$TOKENS_FILE_PATH was not changed, skipping."
fi
done
116 changes: 116 additions & 0 deletions .github/workflows/verify_constants.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
name: Verify Constants

# Define the file paths under `paths` to trigger this check only when specific files are modified.
# This script will then execute checks only on files that have changed, rather than all files listed in `paths`.

# **Note** : To add a new file for checks, include its path in:
# - `on` -> `push` and `pull_request` sections
# - `jobs` -> `verify-constants` -> `steps` -> Verify constants -> Add a new if else for your file, with check logic inside it.


on:
push:
paths:
- "chebai/preprocessing/reader.py"
pull_request:
paths:
- "chebai/preprocessing/reader.py"

jobs:
verify-constants:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: [
# Only use 3.10 as of now
# "3.9",
"3.10",
# "3.11"
]

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set PYTHONPATH
run: echo "PYTHONPATH=$PWD" >> $GITHUB_ENV

- name: Get list of changed files
id: changed_files
run: |
git fetch origin dev
# Get the list of changed files compared to origin/dev and save them to a file
git diff --name-only origin/dev > changed_files.txt
# Print the names of changed files on separate lines
echo "Changed files:"
while read -r line; do
echo "Changed File name : $line"
done < changed_files.txt
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
# Setting a fix version for torch due to an error with latest version (2.5.1)
# ImportError: cannot import name 'T_co' from 'torch.utils.data.dataset'
run: |
python -m pip install --upgrade pip
python -m pip install --upgrade pip setuptools wheel
python -m pip install torch==2.4.1 --index-url https://download.pytorch.org/whl/cpu
python -m pip install -e .
- name: Export constants
run: python .github/workflows/export_constants.py

- name: Load constants into environment variables
id: load_constants
# "E_" is appended as suffix to every constant, to protect overwriting other sys env variables with same name
run: |
constants=$(cat constants.json)
echo "$constants" | jq -r 'to_entries|map("E_\(.key)=\(.value|tostring)")|.[]' >> $GITHUB_ENV
- name: Print all environment variables
run: printenv

- name: Verify constants
run: |
file_name="chebai/preprocessing/reader.py"
if grep -q "$file_name" changed_files.txt; then
echo "----------------------- Checking file : $file_name ----------------------- "
# Define expected values for constants
exp_embedding_offset="10"
exp_cls_token="2"
exp_padding_token_index="0"
exp_mask_token_index="1"
# Debugging output to check environment variables
echo "Current Environment Variables:"
echo "E_EMBEDDING_OFFSET = $E_EMBEDDING_OFFSET"
echo "Expected: $exp_embedding_offset"
# Verify constants match expected values
if [ "$E_EMBEDDING_OFFSET" != "$exp_embedding_offset" ]; then
echo "EMBEDDING_OFFSET ($E_EMBEDDING_OFFSET) does not match expected value ($exp_embedding_offset)!"
exit 1
fi
if [ "$E_CLS_TOKEN" != "$exp_cls_token" ]; then
echo "CLS_TOKEN ($E_CLS_TOKEN) does not match expected value ($exp_cls_token)!"
exit 1
fi
if [ "$E_PADDING_TOKEN_INDEX" != "$exp_padding_token_index" ]; then
echo "PADDING_TOKEN_INDEX ($E_PADDING_TOKEN_INDEX) does not match expected value ($exp_padding_token_index)!"
exit 1
fi
if [ "$E_MASK_TOKEN_INDEX" != "$exp_mask_token_index" ]; then
echo "MASK_TOKEN_INDEX ($E_MASK_TOKEN_INDEX) does not match expected value ($exp_mask_token_index)!"
exit 1
fi
else
echo "$file_name not found in changed_files.txt; skipping check."
fi
8 changes: 4 additions & 4 deletions chebai/loss/bce_weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
and self.data_extractor is not None
and all(
os.path.exists(
os.path.join(self.data_extractor.processed_dir_main, raw_file)
os.path.join(self.data_extractor.processed_dir_main, file_name)
)
for raw_file in self.data_extractor.raw_file_names
for file_name in self.data_extractor.processed_main_file_names
)
and self.pos_weight is None
):
Expand All @@ -65,12 +65,12 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
open(
os.path.join(
self.data_extractor.processed_dir_main,
raw_file_name,
file_name,
),
"rb",
)
)
for raw_file_name in self.data_extractor.raw_file_names
for file_name in self.data_extractor.processed_main_file_names
]
)
value_counts = []
Expand Down
13 changes: 11 additions & 2 deletions chebai/models/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Union, Iterable

import torch
from lightning.pytorch.core.module import LightningModule
Expand Down Expand Up @@ -41,12 +41,21 @@ def __init__(
test_metrics: Optional[torch.nn.Module] = None,
pass_loss_kwargs: bool = True,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
exclude_hyperparameter_logging: Optional[Iterable[str]] = None,
**kwargs,
):
super().__init__()
if exclude_hyperparameter_logging is None:
exclude_hyperparameter_logging = tuple()
self.criterion = criterion
self.save_hyperparameters(
ignore=["criterion", "train_metrics", "val_metrics", "test_metrics"]
ignore=[
"criterion",
"train_metrics",
"val_metrics",
"test_metrics",
*exclude_hyperparameter_logging,
]
)
self.out_dim = out_dim
if optimizer_kwargs:
Expand Down
Loading

0 comments on commit 7639179

Please sign in to comment.