diff --git a/flyem_snapshot/inputs/rois.py b/flyem_snapshot/inputs/rois.py index 3539d54..ffb604e 100644 --- a/flyem_snapshot/inputs/rois.py +++ b/flyem_snapshot/inputs/rois.py @@ -206,6 +206,14 @@ def load_point_rois(cfg, point_df, roiset_names): assert isinstance(roi_ids, dict) roi_ids = _apply_roi_renames(point_df, roiset_name, roi_ids, roiset_cfg['rename-rois']) + + if '' in point_df[roiset_name].dtype.categories: + assert (u := roi_ids.get('')) == 0, \ + f"Non-zero '' label: {u}" + else: + assert '' not in roi_ids, \ + "Did not expect to see '' in roi_ids, since it is not in the categories." + roisets[roiset_name] = roi_ids _check_duplicate_rois(roisets) @@ -297,7 +305,7 @@ def _load_columns_for_roiset(roiset_name, roiset_cfg, point_df, dvid_cfg, proces if bool(roiset_cfg['labelmap']) != (roiset_cfg['source'] == 'dvid-labelmap'): raise RuntimeError("Please supply a labelmap for your ROIs IFF you selected source: dvid-labelmap") - roi_vol, roi_box, roi_ids = _load_roi_vol(roiset_name, roi_ids, roiset_cfg['labelmap'], dvid_cfg, processes) + roi_vol, roi_box, roi_ids = _load_roi_vol_from_cache_or_dvid(roiset_name, roi_ids, roiset_cfg['labelmap'], dvid_cfg, processes) extract_labels_from_volume(point_df, roi_vol, roi_box, 5, roi_ids, roiset_name, skip_index_check=False) return roi_ids @@ -326,7 +334,6 @@ def _load_roi_unions(roiset_name, src_roiset_name, src_lists, point_df): Returns: dict {roi_name: roi_label_int} for the new roiset. - By convention, "" (label 0) is not included in dict. Note: The sub-rois must all come from the same source roiset, and it @@ -367,14 +374,16 @@ def _load_roi_unions(roiset_name, src_roiset_name, src_lists, point_df): # Apply mapping to obtain the new ROIs (union of subrois) point_df[roiset_name] = point_df[src_roiset_name].map(union_mapping).fillna('') assert point_df[roiset_name].dtype == 'category' + point_df[roiset_name] = point_df[roiset_name].cat.remove_unused_categories() # Produce integers from names to create the _label column - roi_ids = {roi: i for i, roi in enumerate(union_mapping.dtype.categories)} + union_categories = point_df[roiset_name].dtype.categories + start = 0 if '' in union_categories else 1 + roi_ids = {roi: i for i, roi in enumerate(union_categories, start)} label_dtype = np.min_scalar_type(max(roi_ids.values())) point_df[f'{roiset_name}_label'] = point_df[roiset_name].map(roi_ids).astype(label_dtype) - assert roi_ids[''] == 0 - del roi_ids[''] + assert ('' not in roi_ids) or (roi_ids[''] == 0) return roi_ids @@ -388,6 +397,25 @@ def _load_roi_col(roiset_name, roi_ids, point_df): - primary_roi_label (integers) If it contains only one or the other, then we calculate the missing one using roi_ids. + + Args: + roiset_name: + string + + roi_ids: + Optional. + If given, must be dict of {name: int}. + The {'': 0} entry may be omitted, in which case we'll insert it if necessary. + No other ROI is permitted to have label 0. + If roi_ids is None, we'll create it. + + point_df: + DataFrame, modified IN-PLACE. + + Returns: + roi_ids + - Possibly created from scratch + - Possibly updated with '': 0 """ if not roi_ids and roiset_name not in point_df: raise RuntimeError( @@ -406,15 +434,32 @@ def _load_roi_col(roiset_name, roi_ids, point_df): eval(f'f"{roi_ids}"', None, {'x': x}): x # pylint: disable=eval-used for x in unique_ids if x != 0 } + if 0 in unique_ids: + roi_ids[''] = 0 if roiset_name in point_df.columns: point_df[roiset_name] = point_df[roiset_name].fillna("") - empirical_names = set(point_df[roiset_name].unique()) + + # If it's present, is always listed first. + empirical_names = sorted(point_df[roiset_name].unique()) + if '' in empirical_names: + empirical_names.remove('') + empirical_names = ['', *empirical_names] + if not roi_ids: - roi_ids = {n: i for i, n in enumerate(sorted(empirical_names), start=1)} - if '' in empirical_names and '' not in roi_ids: - roi_ids[''] = 1 + max(roi_ids.values()) - if unlisted_rois := empirical_names - set(roi_ids.keys()): + start = 0 if '' in empirical_names else 1 + roi_ids = {n: i for i, n in enumerate(empirical_names, start)} + + if (u := roi_ids.get('', 0)) != 0: + raise RuntimeError(f"roiset {roiset_name}: must have label 0, not {u}") + + if '' in empirical_names: + roi_ids[''] = 0 + + if zero_keys := {k for k,v in roi_ids.items() if v == 0 and k != ""}: + raise RuntimeError(f"roiset '{roiset_name}' uses ID 0 for ROIs other than '': {zero_keys}") + + if unlisted_rois := set(empirical_names) - set(roi_ids.keys()): raise RuntimeError( f"Your config for ROI column '{roiset_name}' explicitly lists ROI names, but that list is incomplete.\n" f"The following ROIs were found in the data but not in the config:\n" @@ -438,6 +483,8 @@ def _load_roi_col(roiset_name, roi_ids, point_df): if roiset_name in point_df.columns and not isinstance(point_df[roiset_name].dtype, pd.CategoricalDtype): point_df[roiset_name] = point_df[roiset_name].astype('category') + assert roi_ids.get('') == 0 or ('' not in point_df[roiset_name].dtype.categories) + if expected_cols <= {*point_df.columns}: # Necessary columns are already present return roi_ids @@ -457,7 +504,7 @@ def _load_roi_col(roiset_name, roi_ids, point_df): return roi_ids -def _load_roi_vol(roiset_name, roi_ids, roi_labelmap_name, dvid_cfg, processes): +def _load_roi_vol_from_cache_or_dvid(roiset_name, roi_ids, roi_labelmap_name, dvid_cfg, processes): """ Load an ROI volume, either from a list of DVID 'roi' instances, or from a single low-res DVID labelmap instance. @@ -495,6 +542,8 @@ def _load_roi_vol(roiset_name, roi_ids, roi_labelmap_name, dvid_cfg, processes): eval(f'f"{roi_ids}"', None, {'x': x}): x # pylint: disable=eval-used for x in unique_ids if x != 0 } + if 0 in unique_ids: + roi_ids[''] = 0 return roi_vol, roi_box, roi_ids @@ -547,14 +596,18 @@ def _apply_roi_renames(point_df, roiset_name, roi_ids, renames): def _check_duplicate_rois(roisets): - all_rois = pd.Series([ + """ + Raise an error if the same ROI name is present + in more than one roiset (except '') + """ + all_rois = [ k for d in roisets.values() for k in d.keys() if k != '' - ]) + ] - vc = all_rois.value_counts() + vc = pd.Series(all_rois).value_counts() duplicate_rois = vc[vc > 1].index.tolist() if duplicate_rois: raise RuntimeError(f"ROIs duplicated in multiple roisets: {duplicate_rois}") diff --git a/flyem_snapshot/outputs/neuprint/neuprint.py b/flyem_snapshot/outputs/neuprint/neuprint.py index 0ffe13b..b6c4a95 100644 --- a/flyem_snapshot/outputs/neuprint/neuprint.py +++ b/flyem_snapshot/outputs/neuprint/neuprint.py @@ -311,6 +311,15 @@ def export_neuprint(cfg, point_df, partner_df, element_tables, ann, body_sizes, point_df = point_df.loc[point_df['body'] != 0] partner_df = partner_df.loc[(partner_df['body_pre'] != 0) & (partner_df['body_post'] != 0)] + # We don't store neuprint properties for the "" ROI. + syn_roisets = copy.deepcopy(syn_roisets) + for roi_ids in syn_roisets.values(): + roi_ids.pop('', None) + + element_roisets = copy.deepcopy(element_roisets) + for roi_ids in element_roisets.values(): + roi_ids.pop('', None) + neuprint_ann = neuprint_segment_annotations(cfg, ann) point_df, partner_df = restrict_synapses_for_setting( @@ -347,7 +356,7 @@ def export_neuprint(cfg, point_df, partner_df, element_tables, ann, body_sizes, connectome = export_neuprint_segment_connections(cfg, partner_df) - # FIXME: It would be good to verify that there are no duplicated Element points (including Synapses) + # TODO: It would be good to verify that there are no duplicated Element IDs (including Synapses) export_neuprint_elementsets(cfg, element_tables, connectome) export_neuprint_elements(cfg, element_tables, element_roisets) export_neuprint_elements_closeto(element_tables)