Skip to content

Commit

Permalink
generalized extract()
Browse files Browse the repository at this point in the history
moved extraction start and stop time outside stats container and only provide the extraction start, i.e. validity start. The returning of the caller is a dict with key the validity start and items the StatisticsContainer.
  • Loading branch information
TjarkMiener committed Jul 17, 2024
1 parent 49afd88 commit 452ec1b
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 73 deletions.
65 changes: 34 additions & 31 deletions src/ctapipe/calib/camera/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

class StatisticsExtractor(TelescopeComponent):
"""
Base component to handle the extraction of the statistics from a dl1 table
containing charges, peak times and/or charge variances (images).
Base component to handle the extraction of the statistics from a table
containing e.g. charges, peak times and/or charge variances (images).
"""

chunk_size = Int(
Expand All @@ -31,13 +31,15 @@ class StatisticsExtractor(TelescopeComponent):

def __call__(
self,
dl1_table,
table,
masked_pixels_of_sample=None,
chunk_shift=None,
col_name="image",
) -> list:
"""
Divides the input DL1 table into overlapping or non-overlapping chunks of size `chunk_size`
Divide table into chunks and extract the statistical values.
This function divides the input table into overlapping or non-overlapping chunks of size `chunk_size`
and call the relevant function of the particular extractor to extract the statistical values.
The chunks are generated in a way that ensures they do not overflow the bounds of the table.
- If `chunk_shift` is None, extraction chunks will not overlap, but the last chunk is ensured to be
Expand All @@ -48,56 +50,61 @@ def __call__(
Parameters
----------
dl1_table : astropy.table.Table
DL1 table with images of shape (n_images, n_channels, n_pix)
table : astropy.table.Table
table with images of shape (n_images, n_channels, n_pix)
and timestamps of shape (n_images, ) stored in an astropy Table
masked_pixels_of_sample : ndarray, optional
boolean array of masked pixels of shape (n_pix, ) that are not available for processing
chunk_shift : int, optional
number of samples to shift between the start of consecutive extraction chunks
col_name : string
column name in the DL1 table
column name in the table
Returns
-------
List StatisticsContainer:
List of extracted statistics and extraction chunks
"""

# Check if the statistics of the dl1 table is sufficient to extract at least one chunk.
if len(dl1_table) < self.chunk_size:
# Check if the statistics of the table is sufficient to extract at least one chunk.
if len(table) < self.chunk_size:
raise ValueError(
f"The length of the provided table ({len(table)}) is insufficient to meet the required statistics for a single extraction chunk of size ({self.chunk_size})."
)
# Check if the chunk_shift is smaller than the chunk_size
if chunk_shift is None and chunk_shift > self.chunk_size:
raise ValueError(
f"The length of the provided DL1 table ({len(dl1_table)}) is insufficient to meet the required statistics for a single extraction chunk of size ({self.chunk_size})."
f"The chunk_shift ({chunk_shift}) must be smaller than the chunk_size ({self.chunk_size})."
)

# Function to split the dl1 table into appropriated chunks
def _get_chunks(dl1_table, chunk_shift):
# Function to split the table into appropriated chunks
def _get_chunks(table, chunk_shift):
# Calculate the range step: Use chunk_shift if provided, otherwise use chunk_size
step = chunk_shift or self.chunk_size

# Generate chunks that do not overflow
for i in range(0, len(dl1_table) - self.chunk_size + 1, step):
yield dl1_table[i : i + self.chunk_size]
for i in range(0, len(table) - self.chunk_size + 1, step):
yield table[i : i + self.chunk_size]

# If chunk_shift is None, ensure the last chunk is of size chunk_size, if needed
if chunk_shift is None and len(dl1_table) % self.chunk_size != 0:
yield dl1_table[-self.chunk_size :]
if chunk_shift is None and len(table) % self.chunk_size != 0:
yield table[-self.chunk_size :]

# Get the chunks of the dl1 table
dl1_chunks = _get_chunks(dl1_table, chunk_shift)
# Get the chunks of the table
chunks = _get_chunks(table, chunk_shift)

# Calculate the statistics from a chunk of images
stats_list = [
self.extract(
chunk[col_name].data, chunk["time_mono"], masked_pixels_of_sample
chunk_stats = {
chunk["time_mono"][0]: self.extract(
chunk[col_name].data, masked_pixels_of_sample
)
for chunk in dl1_chunks
]
for chunk in chunks
}

return stats_list
return chunk_stats

@abstractmethod
def extract(self, images, times, masked_pixels_of_sample) -> StatisticsContainer:
def extract(self, images, masked_pixels_of_sample) -> StatisticsContainer:
pass


Expand All @@ -106,7 +113,7 @@ class PlainExtractor(StatisticsExtractor):
Extract the statistics from a chunk of images using numpy functions
"""

def extract(self, images, times, masked_pixels_of_sample) -> StatisticsContainer:
def extract(self, images, masked_pixels_of_sample) -> StatisticsContainer:
# Mask broken pixels
masked_images = np.ma.array(images, mask=masked_pixels_of_sample)

Expand All @@ -116,8 +123,6 @@ def extract(self, images, times, masked_pixels_of_sample) -> StatisticsContainer
pixel_std = np.ma.std(masked_images, axis=0)

return StatisticsContainer(
extraction_start=times[0],
extraction_stop=times[-1],
mean=pixel_mean.filled(np.nan),
median=pixel_median.filled(np.nan),
std=pixel_std.filled(np.nan),
Expand All @@ -138,7 +143,7 @@ class SigmaClippingExtractor(StatisticsExtractor):
help="Number of iterations for the sigma clipping outlier removal",
).tag(config=True)

def extract(self, images, times, masked_pixels_of_sample) -> StatisticsContainer:
def extract(self, images, masked_pixels_of_sample) -> StatisticsContainer:
# Mask broken pixels
masked_images = np.ma.array(images, mask=masked_pixels_of_sample)

Expand All @@ -152,8 +157,6 @@ def extract(self, images, times, masked_pixels_of_sample) -> StatisticsContainer
)

return StatisticsContainer(
extraction_start=times[0],
extraction_stop=times[-1],
mean=pixel_mean,
median=pixel_median,
std=pixel_std,
Expand Down
77 changes: 37 additions & 40 deletions src/ctapipe/calib/camera/tests/test_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,48 +16,49 @@ def test_extractors(example_subarray):
times = Time(
np.linspace(60117.911, 60117.9258, num=5000), scale="tai", format="mjd"
)
ped_dl1_data = np.random.normal(2.0, 5.0, size=(5000, 2, 1855))
ff_charge_dl1_data = np.random.normal(77.0, 10.0, size=(5000, 2, 1855))
ff_time_dl1_data = np.random.normal(18.0, 5.0, size=(5000, 2, 1855))
# Create dl1 tables
ped_dl1_table = Table(
[times, ped_dl1_data],
ped_data = np.random.normal(2.0, 5.0, size=(5000, 2, 1855))
charge_data = np.random.normal(77.0, 10.0, size=(5000, 2, 1855))
time_data = np.random.normal(18.0, 5.0, size=(5000, 2, 1855))
# Create tables
ped_table = Table(
[times, ped_data],
names=("time_mono", "image"),
)
ff_charge_dl1_table = Table(
[times, ff_charge_dl1_data],
charge_table = Table(
[times, charge_data],
names=("time_mono", "image"),
)
ff_time_dl1_table = Table(
[times, ff_time_dl1_data],
time_table = Table(
[times, time_data],
names=("time_mono", "peak_time"),
)
# Initialize the extractors
ped_extractor = SigmaClippingExtractor(subarray=example_subarray, chunk_size=2500)
chunk_size = 2500
ped_extractor = SigmaClippingExtractor(
subarray=example_subarray, chunk_size=chunk_size
)
ff_charge_extractor = SigmaClippingExtractor(
subarray=example_subarray, chunk_size=2500
subarray=example_subarray, chunk_size=chunk_size
)
ff_time_extractor = PlainExtractor(subarray=example_subarray, chunk_size=2500)
ff_time_extractor = PlainExtractor(subarray=example_subarray, chunk_size=chunk_size)

# Extract the statistical values
ped_stats_list = ped_extractor(dl1_table=ped_dl1_table)
ff_charge_stats_list = ff_charge_extractor(dl1_table=ff_charge_dl1_table)
ff_time_stats_list = ff_time_extractor(
dl1_table=ff_time_dl1_table, col_name="peak_time"
)
ped_stats = ped_extractor(table=ped_table)
charge_stats = ff_charge_extractor(table=charge_table)
time_stats = ff_time_extractor(table=time_table, col_name="peak_time")
# Check if the calculated statistical values are reasonable
# for a camera with two gain channels
assert not np.any(np.abs(ped_stats_list[0].mean - 2.0) > 1.5)
assert not np.any(np.abs(ff_charge_stats_list[0].mean - 77.0) > 1.5)
assert not np.any(np.abs(ff_time_stats_list[0].mean - 18.0) > 1.5)
assert not np.any(np.abs(ped_stats[times[0]].mean - 2.0) > 1.5)
assert not np.any(np.abs(charge_stats[times[0]].mean - 77.0) > 1.5)
assert not np.any(np.abs(time_stats[times[0]].mean - 18.0) > 1.5)

assert not np.any(np.abs(ped_stats_list[1].median - 2.0) > 1.5)
assert not np.any(np.abs(ff_charge_stats_list[1].median - 77.0) > 1.5)
assert not np.any(np.abs(ff_time_stats_list[1].median - 18.0) > 1.5)
assert not np.any(np.abs(ped_stats[times[chunk_size]].median - 2.0) > 1.5)
assert not np.any(np.abs(charge_stats[times[chunk_size]].median - 77.0) > 1.5)
assert not np.any(np.abs(time_stats[times[chunk_size]].median - 18.0) > 1.5)

assert not np.any(np.abs(ped_stats_list[0].std - 5.0) > 1.5)
assert not np.any(np.abs(ff_charge_stats_list[0].std - 10.0) > 1.5)
assert not np.any(np.abs(ff_time_stats_list[0].std - 5.0) > 1.5)
assert not np.any(np.abs(ped_stats[times[0]].std - 5.0) > 1.5)
assert not np.any(np.abs(charge_stats[times[0]].std - 10.0) > 1.5)
assert not np.any(np.abs(time_stats[times[0]].std - 5.0) > 1.5)


def test_check_chunk_shift(example_subarray):
Expand All @@ -67,22 +68,18 @@ def test_check_chunk_shift(example_subarray):
times = Time(
np.linspace(60117.911, 60117.9258, num=5500), scale="tai", format="mjd"
)
ff_dl1_data = np.random.normal(77.0, 10.0, size=(5500, 2, 1855))
# Create dl1 table
ff_dl1_table = Table(
[times, ff_dl1_data],
charge_data = np.random.normal(77.0, 10.0, size=(5500, 2, 1855))
# Create table
charge_table = Table(
[times, charge_data],
names=("time_mono", "image"),
)
# Initialize the extractor
ff_charge_extractor = SigmaClippingExtractor(
subarray=example_subarray, chunk_size=2500
)
extractor = SigmaClippingExtractor(subarray=example_subarray, chunk_size=2500)
# Extract the statistical values
stats_list = ff_charge_extractor(dl1_table=ff_dl1_table)
stats_list_chunk_shift = ff_charge_extractor(
dl1_table=ff_dl1_table, chunk_shift=2000
)
chunk_stats = extractor(table=charge_table)
chunk_stats_shift = extractor(table=charge_table, chunk_shift=2000)
# Check if three chunks are used for the extraction as the last chunk overflows
assert len(stats_list) == 3
assert len(chunk_stats) == 3
# Check if two chunks are used for the extraction as the last chunk is dropped
assert len(stats_list_chunk_shift) == 2
assert len(chunk_stats_shift) == 2
2 changes: 0 additions & 2 deletions src/ctapipe/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,8 +415,6 @@ class MorphologyContainer(Container):
class StatisticsContainer(Container):
"""Store descriptive statistics of a chunk of images"""

extraction_start = Field(NAN_TIME, "start of the extraction chunk")
extraction_stop = Field(NAN_TIME, "stop of the extraction chunk")
mean = Field(
None,
"mean of a pixel-wise quantity for each channel"
Expand Down

0 comments on commit 452ec1b

Please sign in to comment.