Skip to content

Commit

Permalink
Merge branch 'rvankoert:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
rvankoert authored Oct 26, 2023
2 parents e1b853c + 26aa3d9 commit 7ebd7ef
Show file tree
Hide file tree
Showing 8 changed files with 222 additions and 221 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,6 @@ GUNICORN_ACCESSLOG # Default: "-": Access log settings.
```bash
LOGHI_MODEL_PATH # Path to the model.
LOGHI_CHARLIST_PATH # Path to the character list.
LOGHI_MODEL_CHANNELS # Number of channels in the model.
LOGHI_BATCH_SIZE # Default: "256": Batch size for processing.
LOGHI_OUTPUT_PATH # Directory where predictions are saved.
LOGHI_MAX_QUEUE_SIZE # Default: "10000": Maximum size of the processing queue.
Expand Down
43 changes: 43 additions & 0 deletions src/api/app_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@

# > Standard library
import logging
from multiprocessing import Process, Manager
import os
from typing import Tuple

# > Local dependencies
from batch_predictor import batch_prediction_worker
from image_preparator import image_preparation_worker

# > Third-party dependencies
from flask import request

Expand Down Expand Up @@ -143,3 +148,41 @@ def get_env_variable(var_name: str, default_value: str = None) -> str:

logger.debug(f"Environment variable {var_name} set to {value}")
return value


def start_processes(batch_size: int, max_queue_size: int, model_path: str,
charlist_path: str, output_path: str, gpus: str):
logger = logging.getLogger(__name__)

# Create a thread-safe Queue
logger.info("Initializing request queue")
manager = Manager()
request_queue = manager.JoinableQueue(maxsize=max_queue_size//2)

# Max size of prepared queue is half of the max size of request queue
# expressed in number of batches
max_prepared_queue_size = max_queue_size // 2 // batch_size
prepared_queue = manager.JoinableQueue(maxsize=max_prepared_queue_size)

# Start the image preparation process
logger.info("Starting image preparation process")
preparation_process = Process(
target=image_preparation_worker,
args=(batch_size, request_queue,
prepared_queue, model_path),
name="Image Preparation Process")
preparation_process.daemon = True
preparation_process.start()

# Start the batch prediction process
logger.info("Starting batch prediction process")
prediction_process = Process(
target=batch_prediction_worker,
args=(prepared_queue, model_path,
charlist_path, output_path,
gpus),
name="Batch Prediction Process")
prediction_process.daemon = True
prediction_process.start()

return request_queue, preparation_process, prediction_process
133 changes: 24 additions & 109 deletions src/api/batch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# > Standard library
import logging
import multiprocessing
from multiprocessing.queues import Empty
import os
import sys
from typing import Callable, List, Tuple
Expand All @@ -17,12 +16,10 @@
from tensorflow.keras import mixed_precision


def batch_prediction_worker(batch_size: int,
prepared_queue: multiprocessing.JoinableQueue,
def batch_prediction_worker(prepared_queue: multiprocessing.JoinableQueue,
model_path: str,
charlist_path: str,
output_path: str,
num_channels: int,
gpus: str = '0'):
"""
Worker process for batch prediction on images.
Expand All @@ -34,8 +31,6 @@ def batch_prediction_worker(batch_size: int,
Parameters
----------
batch_size : int
Number of images to process in a batch.
prepared_queue : multiprocessing.JoinableQueue
Queue from which preprocessed images are fetched.
model_path : str
Expand All @@ -44,10 +39,6 @@ def batch_prediction_worker(batch_size: int,
Path to the character list file.
output_path : str
Path where predictions should be saved.
num_channels : int
Number of channels desired for the input images (e.g., 1 for grayscale,
3 for RGB). This is used to verify that the preparation uses the
correct format.
gpus : str, optional
IDs of GPUs to be used (comma-separated). Default is '0'.
Expand Down Expand Up @@ -109,85 +100,38 @@ def batch_prediction_worker(batch_size: int,

try:
with strategy.scope():
model, utils = create_model(
model_path, charlist_path, num_channels)
model, utils = create_model(model_path, charlist_path)
logger.info("Model created and utilities initialized")
except Exception as e:
logger.error(e)
logger.error("Error creating model. Exiting...")
return

TIMEOUT_DURATION = 1
MAX_WAIT_COUNT = 3

total_predictions = 0

try:
while True:
# The goal is to accumulate a batch of images for processing.
# However, if there's a delay in receiving images, we don't want to
# wait indefinitely. So, we'll use a combination of timeouts and
# counters to decide when to process whatever images we have, even
# if we don't have a full batch.

batch_images = []
wait_count = 0
batch_images, batch_groups, batch_identifiers = \
prepared_queue.get()
logger.debug(f"Retrieved batch of size {len(batch_images)} from "
"prepared_queue")

logger.debug("Waiting to accumulate images for processing")

# Wait until we have enough images for a batch
while len(batch_images) < batch_size:
logger.debug(
f"Waiting for {batch_size - len(batch_images)} more images"
" to be available")

try:
# Wait for TIMEOUT_DURATION seconds for an image to be
# available in the queue
prepared_data = prepared_queue.get(
timeout=TIMEOUT_DURATION)
batch_images.append(prepared_data)
prepared_queue.task_done()
wait_count = 0
except Empty:
wait_count += 1
logger.debug("Time without new images: "
f"{wait_count * TIMEOUT_DURATION} seconds")

# If we've waited more than the maximum allowed time
# (MAX_WAIT_COUNT * TIMEOUT_DURATION) and we have some
# images in the batch, then process those images.
if wait_count > MAX_WAIT_COUNT and len(batch_images) > 0:

# Grab any remaining images in the queue up to the
# batch size
while not prepared_queue.empty()\
and len(batch_images) < batch_size:
prepared_data = prepared_queue.get()
batch_images.append(prepared_data)
prepared_queue.task_done()

# Reset the wait_count and break out of the loop
wait_count = 0
break

logger.info(
f"Retrieved batch of size {len(batch_images)}")
logger.debug(
f"There are {prepared_queue.qsize()} images waiting on "
"prediction")
batch_info = list(zip(batch_groups, batch_identifiers))

# Here, make the batch prediction
# TODO: if OOM, split the batch into halves and try again for each
# half
try:
predictions = batch_predict(
model, batch_images, utils, decode_batch_predictions,
output_path, normalize_confidence)
model, batch_images, batch_info, utils,
decode_batch_predictions, output_path,
normalize_confidence)
except Exception as e:
logger.error(e)
logger.error("Error making predictions. Skipping batch.")
logger.error("Failed batch:")
for image in batch_images:
logger.error(image[2])
for id in batch_identifiers:
logger.error(id)
predictions = []

# Update the total number of predictions made
Expand All @@ -201,7 +145,7 @@ def batch_prediction_worker(batch_size: int,
f"Made {len(predictions)} predictions")
logger.info(f"Total predictions: {total_predictions}")
logger.info(
f"{prepared_queue.qsize()} images waiting on prediction")
f"{prepared_queue.qsize()} batches waiting on prediction")

# Clear the batch images to free up memory
logger.debug("Clearing batch images and predictions")
Expand All @@ -215,8 +159,7 @@ def batch_prediction_worker(batch_size: int,


def create_model(model_path: str,
charlist_path: str,
num_channels: int) -> Tuple[tf.keras.Model, object]:
charlist_path: str) -> Tuple[tf.keras.Model, object]:
"""
Load a pre-trained model and create utility methods.
Expand All @@ -226,9 +169,6 @@ def create_model(model_path: str,
Path to the pre-trained model file.
charlist_path : str
Path to the character list file.
num_channels : int
Number of channels desired for the input images (e.g., 1 for grayscale,
3 for RGB).
Returns
-------
Expand All @@ -244,6 +184,7 @@ def create_model(model_path: str,
- Logs various messages regarding the model and utility initialization.
"""

from custom_layers import ResidualBlock
from model import CERMetric, WERMetric, CTCLoss
from utils import Utils

Expand All @@ -254,21 +195,16 @@ def create_model(model_path: str,
'CERMetric': CERMetric,
'WERMetric': WERMetric,
'CTCLoss': CTCLoss,
'ResidualBlock': ResidualBlock
})
logger.debug("Custom objects registered")

logger.info("Loading model...")
model = tf.keras.saving.load_model(model_path)
logger.info("Model loaded successfully")

model_channels = model.input_shape[3]
if model_channels != num_channels:
raise ValueError(
f"Model expects {model_channels} channels, but {num_channels} "
"were provided")

if logger.isEnabledFor(logging.DEBUG):
logger.debug(model.summary())
model.summary()

with open(charlist_path) as file:
charlist = list(char for char in file.read())
Expand All @@ -279,7 +215,8 @@ def create_model(model_path: str,


def batch_predict(model: tf.keras.Model,
batch: List[Tuple[tf.Tensor, str, str]],
images: List[Tuple[tf.Tensor, str, str]],
batch_info: List[Tuple[str, str]],
utils: object,
decoder: Callable,
output_path: str,
Expand Down Expand Up @@ -316,37 +253,15 @@ def batch_predict(model: tf.keras.Model,

logger = logging.getLogger(__name__)

logger.debug(f"Initial batch size: {len(batch)}")
logger.debug(f"Initial batch size: {len(images)}")

# Unpack the batch
images, groups, identifiers = map(list, zip(*batch))

# Determine the maximum width of the images in the batch
max_width = 0
for image in images:
if image.shape[0] > max_width:
max_width = image.shape[0]
logger.debug(f"Determined max width: {max_width}")

# Pad the images to the maximum width
for i in range(len(images)):
images[i] = tf.image.resize_with_pad(images[i], max_width, 64)
images = tf.stack(images)

logger.debug(f"Batch shape after padding: {images.shape}")

batch = tf.convert_to_tensor(images)
logger.debug("Converted batch to tensor")
groups, identifiers = zip(*batch_info)

logger.info("Making predictions...")
encoded_predictions = model(images)
encoded_predictions = model.predict_on_batch(images)
logger.debug("Predictions made")

# Clear the session to free up memory
logger.debug("Clearing session...")
tf.keras.backend.clear_session()
logger.debug("Session cleared")

logger.debug("Decoding predictions...")
decoded_predictions = decoder(encoded_predictions, utils)[0]
logger.debug("Predictions decoded")
Expand Down
Loading

0 comments on commit 7ebd7ef

Please sign in to comment.