Source code for funROI.parcels

import os
import json
from typing import Union, Tuple, Optional
from nibabel.nifti1 import Nifti1Image
from nilearn.image import load_img, math_img
from pathlib import Path
import numpy as np
from . import get_analysis_output_folder
from .utils import ensure_paths
import pandas as pd

_get_parcels_folder = lambda: get_analysis_output_folder() / "parcels"


[docs] class ParcelsConfig(dict): """ Configuration for parcels. :param parcels_path: Path to the parcels image. :type parcels_path: Union[str, Path] :param labels_path: Path to the labels file. The labels file can be a JSON file mapping numerical labels to label names, or a text file with one label name per line. :type labels_path: Optional[Union[str, Path]] """ @ensure_paths("parcels_path", "labels_path") def __init__( self, parcels_path: Union[str, Path], labels_path: Optional[Union[str, Path]] = None, ): self.parcels_path = parcels_path self.labels_path = labels_path dict.__init__(self, parcels_path=parcels_path, labels_path=labels_path) def __repr__(self): return ( f"ParcelsConfig(parcels_path={self.parcels_path}, " f"labels_path={self.labels_path})" ) def __eq__(self, other): if not isinstance(other, ParcelsConfig): return False return ( self.parcels_path == other.parcels_path and self.labels_path == other.labels_path )
[docs] @staticmethod def from_analysis_output( name: str, smoothing_kernel_size: int, overlap_thr_vox: float, overlap_thr_roi: float, min_voxel_size: int, use_spm_smooth: bool = True, ): """ Create a ParcelsConfig object from the analysis output folder. """ parcels_path = ( _get_parcels_folder() / f"parcels-{name}" / f"parcels-{name}_sm-{smoothing_kernel_size}_spmsmooth-{use_spm_smooth}_voxthres-{overlap_thr_vox}_roithres-{overlap_thr_roi}_sz-{min_voxel_size}.nii.gz" ) if os.path.exists( _get_parcels_folder() / f"parcels-{name}" / f"parcels-{name}_sm-{smoothing_kernel_size}_spmsmooth-{use_spm_smooth}_voxthres-{overlap_thr_vox}_roithres-{overlap_thr_roi}_sz-{min_voxel_size}.json" ): labels_path = ( _get_parcels_folder() / f"parcels-{name}" / f"parcels-{name}_sm-{smoothing_kernel_size}_spmsmooth-{use_spm_smooth}_voxthres-{overlap_thr_vox}_roithres-{overlap_thr_roi}_sz-{min_voxel_size}.json" ) else: labels_path = None return ParcelsConfig(parcels_path, labels_path)
def get_parcels( parcels: Union[str, ParcelsConfig] ) -> Tuple[Nifti1Image, dict]: """ Get parcels image and labels. """ if isinstance(parcels, str): parcels_img, label_dict = _get_saved_parcels(parcels) if parcels_img is None: parcels_img, label_dict = _get_external_parcels( ParcelsConfig(parcels_path=parcels) ) else: parcels_img, label_dict = _get_external_parcels(parcels) return parcels_img, label_dict def _get_saved_parcels(parcels_label: str) -> Tuple[Nifti1Image, dict]: """ Get parcels image and labels from a saved parcels file. """ parcels_path = ( _get_parcels_folder() / f"parcels-{parcels_label}_mask.nii.gz" ) parcels_labels_path = None return _get_external_parcels( ParcelsConfig( parcels_path=parcels_path, labels_path=parcels_labels_path ) ) def _get_external_parcels(parcels: ParcelsConfig) -> Tuple[Nifti1Image, dict]: """ Get parcels image and labels from externally specified paths. """ if parcels.parcels_path is None or not parcels.parcels_path.exists(): return None, None parcels_img = load_img(parcels.parcels_path) parcels_img = math_img("np.round(img)", img=parcels_img) if parcels.labels_path is not None and parcels.labels_path.exists(): if parcels.labels_path.name.endswith("json"): # If JSON file, label dict is a dictionary from numerical labels to # label names label_dict = json.load(open(parcels.labels_path)) label_dict = {int(k): v for k, v in label_dict.items()} elif parcels.labels_path.name.endswith("txt"): # If txt file, one label name per line label_dict = {} with open(parcels.labels_path, "r") as f: for i, line in enumerate(f): label_dict[i + 1] = line.strip() else: # Default: no text labels label_dict = {} for label in np.unique(parcels_img.get_fdata()): if label != 0: label_dict[int(label)] = int(label) return parcels_img, label_dict def label_parcel( parcels_img: Nifti1Image, label_dict: dict, label: int ) -> Tuple[Nifti1Image, str]: """ Label a parcel. """ if label not in label_dict: raise ValueError(f"Label {label} not found in label dictionary.") label_name = label_dict[label] return math_img("img == {}".format(label), img=parcels_img), label_name def merge_parcels( parcels_img: Nifti1Image, label_dict: dict, label1: Union[int, str], label2: Union[int, str], new_label: Optional[str] = None, ) -> Tuple[Nifti1Image, dict]: """ Merge two parcels. """ if new_label in label_dict.values(): raise ValueError( f"New label {new_label} already exists in label dictionary." ) if isinstance(label1, str): label1 = {v: k for k, v in label_dict.items()}[label1] if isinstance(label2, str): label2 = {v: k for k, v in label_dict.items()}[label2] parcels_data = _merge_parcels(parcels_img.get_fdata(), label1, label2) parcels_img = Nifti1Image( parcels_data, parcels_img.affine, parcels_img.header ) label_dict.pop(label1, None) label_dict.pop(label2, None) if new_label: label_dict[new_label] = new_label return parcels_img, label_dict def _merge_parcels(data: np.ndarray, x: int, y: int) -> np.ndarray: if len(data.shape) != 3: raise ValueError("Data must be 3D.") if x == y: return data neighbors26 = np.zeros((26, data.shape[0], data.shape[1], data.shape[2])) ni = 0 for dx in range(-1, 2): for dy in range(-1, 2): for dz in range(-1, 2): if dx == 0 and dy == 0 and dz == 0: continue neighbors26[ni] = np.roll(data, dx, axis=0) neighbors26[ni] = np.roll(neighbors26[ni], dy, axis=1) neighbors26[ni] = np.roll(neighbors26[ni], dz, axis=2) ni += 1 mask = ( np.all(np.isin(neighbors26, [0, x, y]), axis=0) & np.any(neighbors26 == x, axis=0) & np.any(neighbors26 == y, axis=0) ) data[mask] = x data[data == y] = x return data def save_parcels(parcels_img: Nifti1Image, label_dict: dict, name: str): """ Save parcels image and labels. """ parcels_path = _get_parcels_folder() / f"{name}.nii.gz" parcels_labels_path = _get_parcels_folder() / f"{name}.json" parcels_img.to_filename(parcels_path) with open(parcels_labels_path, "w") as f: json.dump(label_dict, f)