Skip to content

Commit

Permalink
fix metadata issues, and add more debug logs
Browse files Browse the repository at this point in the history
  • Loading branch information
zain-sohail committed Jan 10, 2025
1 parent 8d9bcd5 commit ee666f7
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 58 deletions.
48 changes: 22 additions & 26 deletions src/sed/loader/flash/buffer_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
from pathlib import Path
import time

import dask.dataframe as dd
import pyarrow.parquet as pq
Expand All @@ -20,7 +21,7 @@

DF_TYP = ["electron", "timed"]

logger = setup_logging(__name__)
logger = setup_logging("flash_buffer_handler")


class BufferFilePaths:
Expand Down Expand Up @@ -134,16 +135,15 @@ def __init__(
def _schema_check(self, files: list[Path], expected_schema_set: set) -> None:
"""
Checks the schema of the Parquet files.
Raises:
ValueError: If the schema of the Parquet files does not match the configuration.
"""
logger.debug(f"Checking schema for {len(files)} files")
existing = [file for file in files if file.exists()]
parquet_schemas = [pq.read_schema(file) for file in existing]

for filename, schema in zip(existing, parquet_schemas):
schema_set = set(schema.names)
if schema_set != expected_schema_set:
logger.error(f"Schema mismatch in file: {filename}")
missing_in_parquet = expected_schema_set - schema_set
missing_in_config = schema_set - expected_schema_set

Expand All @@ -158,39 +158,35 @@ def _schema_check(self, files: list[Path], expected_schema_set: set) -> None:
f"{' '.join(errors)}. "
"Please check the configuration file or set force_recreate to True.",
)
logger.debug("Schema check passed successfully")

def _save_buffer_file(self, paths: dict[str, Path]) -> None:
"""
Creates the electron and timed buffer files from the raw H5 file.
First the dataframe is accessed and forward filled in the non-electron channels.
Then the data types are set. For the electron dataframe, all values not in the electron
channels are dropped. For the timed dataframe, only the train and pulse channels are taken
and it pulse resolved (no longer electron resolved). Both are saved as parquet files.
Args:
paths (dict[str, Path]): Dictionary containing the paths to the H5 and buffer files.
"""
"""Creates the electron and timed buffer files from the raw H5 file."""
logger.debug(f"Processing file: {paths['raw'].stem}")
start_time = time.time()

# Create a DataFrameCreator instance and the h5 file
# Create DataFrameCreator and get dataframe
df = DataFrameCreator(config_dataframe=self._config, h5_path=paths["raw"]).df

# forward fill all the non-electron channels
# Forward fill non-electron channels
logger.debug(f"Forward filling {len(self.fill_channels)} channels")
df[self.fill_channels] = df[self.fill_channels].ffill()

# Reset the index of the DataFrame and save both the electron and timed dataframes
# electron resolved dataframe
# Save electron dataframe
electron_channels = get_channels(self._config, "per_electron")
dtypes = get_dtypes(self._config, df.columns.values)
df.dropna(subset=electron_channels).astype(dtypes).reset_index().to_parquet(
paths["electron"],
)
electron_df = df.dropna(subset=electron_channels).astype(dtypes).reset_index()
logger.debug(f"Saving electron buffer with shape: {electron_df.shape}")
electron_df.to_parquet(paths["electron"])

# timed dataframe
# drop the electron channels and only take rows with the first electronId
# Save timed dataframe
df_timed = df.dropna(subset=electron_channels)[self.fill_channels].loc[:, :, 0]
dtypes = get_dtypes(self._config, df_timed.columns.values)
df_timed.astype(dtypes).reset_index().to_parquet(paths["timed"])
logger.debug(f"Processed {paths['raw'].stem}")
timed_df = df_timed.astype(dtypes).reset_index()
logger.debug(f"Saving timed buffer with shape: {timed_df.shape}")
timed_df.to_parquet(paths["timed"])

logger.debug(f"Processed {paths['raw'].stem} in {time.time() - start_time:.2f}s")

def _save_buffer_files(self, force_recreate: bool, debug: bool) -> None:
"""
Expand Down Expand Up @@ -263,7 +259,6 @@ def _get_dataframes(self) -> None:
config=self._config,
)
self.metadata.update(meta)

def process_and_load_dataframe(
self,
h5_paths: list[Path],
Expand Down Expand Up @@ -310,3 +305,4 @@ def process_and_load_dataframe(
self._get_dataframes()

return self.df["electron"], self.df["timed"]

20 changes: 15 additions & 5 deletions src/sed/loader/flash/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

from sed.loader.flash.utils import get_channels
from sed.loader.flash.utils import InvalidFileError
from sed.core.logging import setup_logging

logger = setup_logging("flash_dataframe_creator")


class DataFrameCreator:
Expand All @@ -34,6 +37,7 @@ def __init__(self, config_dataframe: dict, h5_path: Path) -> None:
config_dataframe (dict): The configuration dictionary with only the dataframe key.
h5_path (Path): Path to the h5 file.
"""
logger.debug(f"Initializing DataFrameCreator for file: {h5_path}")
self.h5_file = h5py.File(h5_path, "r")
self.multi_index = get_channels(index=True)
self._config = config_dataframe
Expand Down Expand Up @@ -76,6 +80,7 @@ def get_dataset_array(
tuple[pd.Index, np.ndarray | h5py.Dataset]: A tuple containing the train ID
pd.Index and the channel's data.
"""
logger.debug(f"Getting dataset array for channel: {channel}")
# Get the data from the necessary h5 file and channel
index_key, dataset_key = self.get_index_dataset_key(channel)

Expand All @@ -85,6 +90,7 @@ def get_dataset_array(
if slice_:
slice_index = self._config["channels"][channel].get("slice", None)
if slice_index is not None:
logger.debug(f"Slicing dataset with index: {slice_index}")
dataset = np.take(dataset, slice_index, axis=1)
# If np_array is size zero, fill with NaNs, fill it with NaN values
# of the same shape as index
Expand Down Expand Up @@ -291,10 +297,14 @@ def df(self) -> pd.DataFrame:
Returns:
pd.DataFrame: The combined pandas DataFrame.
"""

logger.debug("Creating combined DataFrame")
self.validate_channel_keys()
# been tested with merge, join and concat
# concat offers best performance, almost 3 times faster

df = pd.concat((self.df_electron, self.df_pulse, self.df_train), axis=1).sort_index()
# all the negative pulse values are dropped as they are invalid
return df[df.index.get_level_values("pulseId") >= 0]
logger.debug(f"Created DataFrame with shape: {df.shape}")

# Filter negative pulse values
df = df[df.index.get_level_values("pulseId") >= 0]
logger.debug(f"Filtered DataFrame shape: {df.shape}")

return df
40 changes: 21 additions & 19 deletions src/sed/loader/flash/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import warnings

import requests
from sed.core.logging import setup_logging

logger = setup_logging("flash_metadata_retriever")


class MetadataRetriever:
Expand All @@ -15,19 +18,19 @@ class MetadataRetriever:
on beamtime and run IDs.
"""

def __init__(self, metadata_config: dict, scicat_token: str = None) -> None:
def __init__(self, metadata_config: dict, token: str = None) -> None:
"""
Initializes the MetadataRetriever class.
Args:
metadata_config (dict): Takes a dict containing
at least url, and optionally token for the scicat instance.
scicat_token (str, optional): The token to use for fetching metadata.
token (str, optional): The token to use for fetching metadata.
"""
self.token = metadata_config.get("scicat_token", None)
if scicat_token:
self.token = scicat_token
self.url = metadata_config.get("scicat_url", None)
self.token = metadata_config.get("token", None)
if token:
self.token = token
self.url = metadata_config.get("archiver_url", None)

if not self.token or not self.url:
raise ValueError("No URL or token provided for fetching metadata from scicat.")
Expand All @@ -36,7 +39,7 @@ def __init__(self, metadata_config: dict, scicat_token: str = None) -> None:
"Content-Type": "application/json",
"Accept": "application/json",
}
self.token = metadata_config["scicat_token"]
self.token = metadata_config["token"]

def get_metadata(
self,
Expand All @@ -59,19 +62,18 @@ def get_metadata(
Raises:
Exception: If the request to retrieve metadata fails.
"""
# If metadata is not provided, initialize it as an empty dictionary
logger.debug(f"Fetching metadata for beamtime {beamtime_id}, runs: {runs}")

if metadata is None:
metadata = {}

# Iterate over the list of runs
for run in runs:
pid = f"{beamtime_id}/{run}"
# Retrieve metadata for each run and update the overall metadata dictionary
logger.debug(f"Retrieving metadata for PID: {pid}")
metadata_run = self._get_metadata_per_run(pid)
metadata.update(
metadata_run,
) # TODO: Not correct for multiple runs
metadata.update(metadata_run)

logger.debug(f"Retrieved metadata with {len(metadata)} entries")
return metadata

def _get_metadata_per_run(self, pid: str) -> dict:
Expand All @@ -91,26 +93,26 @@ def _get_metadata_per_run(self, pid: str) -> dict:
headers2["Authorization"] = f"Bearer {self.token}"

try:
logger.debug(f"Attempting to fetch metadata with new URL format for PID: {pid}")
dataset_response = requests.get(
self._create_new_dataset_url(pid),
headers=headers2,
timeout=10,
)
dataset_response.raise_for_status()
# Check if response is an empty object because wrong url for older implementation

if not dataset_response.content:
logger.debug("Empty response, trying old URL format")
dataset_response = requests.get(
self._create_old_dataset_url(pid),
headers=headers2,
timeout=10,
)
# If the dataset request is successful, return the retrieved metadata
# as a JSON object
return dataset_response.json()

except requests.exceptions.RequestException as exception:
# If the request fails, raise warning
print(warnings.warn(f"Failed to retrieve metadata for PID {pid}: {str(exception)}"))
return {} # Return an empty dictionary for this run
logger.warning(f"Failed to retrieve metadata for PID {pid}: {str(exception)}")
return {}

def _create_old_dataset_url(self, pid: str) -> str:
return "{burl}/{url}/%2F{npid}".format(
Expand Down
16 changes: 8 additions & 8 deletions tests/loader/flash/test_flash_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def mock_requests(requests_mock) -> None:
# Test cases for MetadataRetriever
def test_get_metadata(mock_requests: None) -> None: # noqa: ARG001
metadata_config = {
"scicat_url": "https://example.com",
"scicat_token": "fake_token",
"archiver_url": "https://example.com",
"token": "fake_token",
}
retriever = MetadataRetriever(metadata_config)
metadata = retriever.get_metadata("11013410", ["43878"])
Expand All @@ -27,8 +27,8 @@ def test_get_metadata(mock_requests: None) -> None: # noqa: ARG001

def test_get_metadata_with_existing_metadata(mock_requests: None) -> None: # noqa: ARG001
metadata_config = {
"scicat_url": "https://example.com",
"scicat_token": "fake_token",
"archiver_url": "https://example.com",
"token": "fake_token",
}
retriever = MetadataRetriever(metadata_config)
existing_metadata = {"existing": "metadata"}
Expand All @@ -39,8 +39,8 @@ def test_get_metadata_with_existing_metadata(mock_requests: None) -> None: # no

def test_get_metadata_per_run(mock_requests: None) -> None: # noqa: ARG001
metadata_config = {
"scicat_url": "https://example.com",
"scicat_token": "fake_token",
"archiver_url": "https://example.com",
"token": "fake_token",
}
retriever = MetadataRetriever(metadata_config)
metadata = retriever._get_metadata_per_run("11013410/43878")
Expand All @@ -50,8 +50,8 @@ def test_get_metadata_per_run(mock_requests: None) -> None: # noqa: ARG001

def test_create_dataset_url_by_PID() -> None:
metadata_config = {
"scicat_url": "https://example.com",
"scicat_token": "fake_token",
"archiver_url": "https://example.com",
"token": "fake_token",
}
retriever = MetadataRetriever(metadata_config)
# Assuming the dataset follows the new format
Expand Down

0 comments on commit ee666f7

Please sign in to comment.