Skip to content

Commit

Permalink
Option to get train timestamps as Python datetime objects
Browse files Browse the repository at this point in the history
  • Loading branch information
takluyver committed Sep 9, 2024
1 parent 8d90317 commit 7fad10c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
17 changes: 14 additions & 3 deletions extra_data/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1605,7 +1605,7 @@ def train_info(self, train_id):
print('\tControls')
[print('\t-', d) for d in sorted(ctrl)] or print('\t-')

def train_timestamps(self, labelled=False):
def train_timestamps(self, labelled=False, *, pydatetime=False):
"""Get approximate timestamps for each train
Timestamps are stored and returned in UTC (not local time).
Expand All @@ -1614,8 +1614,10 @@ def train_timestamps(self, labelled=False):
(Not a Time).
If *labelled* is True, they are returned in a pandas series, labelled
with train IDs. If False (default), they are returned in a NumPy array
of the same length as data.train_ids.
with train IDs. If *pydatetime* is True, a list of Python datetime
objects (truncated to microseconds) is returned, the same length as
data.train_ids. Otherwise (by default), timestamps are returned as a
NumPy array with datetime64 dtype.
"""
arr = np.zeros(len(self.train_ids), dtype=np.uint64)
id_to_ix = {tid: i for (i, tid) in enumerate(self.train_ids)}
Expand Down Expand Up @@ -1643,6 +1645,15 @@ def train_timestamps(self, labelled=False):
if labelled:
import pandas as pd
return pd.Series(arr, index=self.train_ids).dt.tz_localize('UTC')
elif pydatetime:
from datetime import datetime, timezone
res = []
for npdt in arr:
pydt = npdt.astype('datetime64[ms]').item()
if pydt is not None: # Numpy NaT becomes None
pydt = pydt.replace(tzinfo=timezone.utc)
res.append(pydt)
return res
return arr

def run_metadata(self) -> dict:
Expand Down
15 changes: 12 additions & 3 deletions extra_data/tests/test_reader_mockdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,10 +743,12 @@ def test_train_timestamps(mock_scs_run):
assert np.all(np.diff(tss).astype(np.uint64) > 0)

# Convert numpy datetime64[ns] to Python datetime (dropping some precision)
dt0 = tss[0].astype('datetime64[ms]').item().replace(tzinfo=timezone.utc)
tss_l = run.train_timestamps(pydatetime=True)
assert len(tss_l) == len(run.train_ids)
now = datetime.now(timezone.utc)
assert dt0 > (now - timedelta(days=1)) # assuming tests take < 1 day to run
assert dt0 < now
assert tss_l[0] > (now - timedelta(days=1)) # assuming tests take < 1 day to run
assert tss_l[0] < now
assert tss_l[0].tzinfo is timezone.utc

tss_ser = run.train_timestamps(labelled=True)
assert isinstance(tss_ser, pd.Series)
Expand All @@ -764,6 +766,13 @@ def test_train_timestamps_nat(mock_fxe_control_data):
else:
assert not np.any(np.isnat(tss))

tss_l = f.train_timestamps(pydatetime=True)
assert len(tss_l) == len(f.train_ids)
if f.files[0].format_version == '0.5':
assert all(t is None for t in tss_l)
else:
assert not any(t is None for t in tss_l)


def test_union(mock_fxe_raw_run):
run = RunDirectory(mock_fxe_raw_run)
Expand Down

0 comments on commit 7fad10c

Please sign in to comment.