from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import json
import warnings
import numpy as np
import pandas as pd
from nilearn.surface import SurfaceImage, load_surf_mesh
from .._surface import (
SURFACE_HEMIS,
SURFACE_PARTS,
flatten_image_data,
load_surface_numeric_data,
save_surface_image,
)
from ..contrast import _get_contrast_runs, _get_surface_contrast_path
from ..first_level.nilearn import _find_surface_mesh_paths, _smooth_surface_array
from ..froi import _create_p_map_mask
from ..settings import (
get_analysis_output_folder,
get_bids_data_folder,
get_bids_preprocessed_folder,
get_bids_preprocessed_folder_relative,
)
from ..utils import validate_arguments
[docs]
class SurfaceParcelsGenerator:
"""Generate group surface parcels from first-level surface contrast maps."""
def __init__(
self,
parcels_name: str,
space: str = "fsLR32k",
smoothing_kernel_size: Optional[Union[float, List[float]]] = 8,
overlap_thr_vox: Optional[float] = 0.1,
min_voxel_size: Optional[int] = 0,
overlap_thr_roi: Optional[float] = 0,
):
self.parcels_name = parcels_name
self.space = space
self.smoothing_kernel_size = smoothing_kernel_size
self.overlap_thr_vox = overlap_thr_vox
self.min_voxel_size = min_voxel_size
self.overlap_thr_roi = overlap_thr_roi
self.configs = []
self._data = []
self.mesh = None
self.mesh_paths = None
self.hemi_sizes = None
self.hemi_slices = None
self.overlap_map = None
self.parcels = None
self.parcel_info = None
[docs]
@validate_arguments(
p_threshold_type={"none", "bonferroni", "fdr", "n", "percent"},
conjunction_type={"min", "max", "sum", "prod", "and", "or"},
)
def add_subjects(
self,
subjects: List[str],
task: str,
contrasts: List[str],
p_threshold_type: str,
p_threshold_value: float = 0.05,
conjunction_type: Optional[str] = "and",
):
existing_subjects = []
for config in self.configs:
existing_subjects.extend(config["subjects"])
subjects_redundant = set(subjects).intersection(existing_subjects)
if subjects_redundant:
raise ValueError(
f"Subjects {subjects_redundant} are already added."
)
config = {
"task": task,
"contrasts": contrasts,
"threshold_type": p_threshold_type,
"threshold_value": p_threshold_value,
"conjunction_type": conjunction_type,
"space": self.space,
}
new_data = []
for subject in subjects:
subject_runs = self._get_subject_runs(subject, task, contrasts)
subject_mesh, subject_mesh_paths = self._get_subject_meshes(subject)
self._set_or_validate_mesh(subject_mesh, subject_mesh_paths)
subject_data = []
for run_label in self._get_orthogonalized_labels(subject_runs):
run_p_maps = []
for contrast in contrasts:
hemi_data = self._get_surface_contrast_data(
subject, task, run_label, contrast, "p"
)
run_p_maps.append(
np.concatenate(
[hemi_data["L"], hemi_data["R"]], axis=0
)
)
data = np.asarray(run_p_maps, dtype=float)[:, None, :]
mask = _create_p_map_mask(
data,
conjunction_type,
p_threshold_type,
p_threshold_value,
)
subject_data.append(np.asarray(mask).reshape(-1).astype(float))
if len(subject_data) == 0:
raise ValueError(
f"No data found for subject {subject} and task {task}."
)
new_data.append(np.asarray(subject_data, dtype=float))
self._data.extend(new_data)
self.configs.append({"subjects": subjects, "surface": config})
@classmethod
def _run_fast(
cls,
parcels_name: str,
space: str = "fsLR32k",
smoothing_kernel_size: Optional[Union[float, List[float]]] = 8,
overlap_thr_vox: Optional[float] = 0.1,
) -> "SurfaceParcelsGenerator":
parcel_gen = cls(
parcels_name=parcels_name,
space=space,
smoothing_kernel_size=smoothing_kernel_size,
overlap_thr_vox=overlap_thr_vox,
)
# Run parcel generation with a pre-calculated overlapping map.
# Use this method only if you are sure what you are doing.
base = cls._get_analysis_parcels_folder(parcels_name)
config_path = base / f"parcels-{parcels_name}_config.json"
if not config_path.exists():
raise FileNotFoundError(
"Surface parcels config not found. Please double-check the "
"parcels name and configurations."
)
with open(config_path, "r") as f:
config = json.load(f)
mesh_paths = config.get("mesh_paths", {})
if not all(hemi in mesh_paths for hemi in SURFACE_HEMIS):
raise ValueError(
"Surface parcels config does not include both hemisphere meshes."
)
parcel_gen.mesh_paths = {
hemi: Path(mesh_paths[hemi]) for hemi in SURFACE_HEMIS
}
parcel_gen.mesh = {
"left": load_surf_mesh(parcel_gen.mesh_paths["L"]),
"right": load_surf_mesh(parcel_gen.mesh_paths["R"]),
}
parcel_gen.hemi_sizes = {
"left": int(len(np.asarray(parcel_gen.mesh["left"].coordinates))),
"right": int(len(np.asarray(parcel_gen.mesh["right"].coordinates))),
}
left_size = parcel_gen.hemi_sizes["left"]
parcel_gen.hemi_slices = {
"left": slice(0, left_size),
"right": slice(left_size, left_size + parcel_gen.hemi_sizes["right"]),
}
parcel_info_path = (
base
/ (
f"parcels-{parcels_name}_space-{space}"
f"_sm-{smoothing_kernel_size}_voxthres-{overlap_thr_vox}_info.csv"
)
)
parcels_paths = {
hemi: (
base
/ (
f"parcels-{parcels_name}_space-{space}"
f"_sm-{smoothing_kernel_size}_voxthres-{overlap_thr_vox}"
f"_roithres-0_sz-0_hemi-{hemi}.func.gii"
)
)
for hemi in SURFACE_HEMIS
}
if parcel_info_path.exists() and all(
path.exists() for path in parcels_paths.values()
):
parcel_gen.parcel_info = pd.read_csv(parcel_info_path)
parcel_gen.parcels = np.concatenate(
[
load_surface_numeric_data(parcels_paths["L"]).reshape(-1),
load_surface_numeric_data(parcels_paths["R"]).reshape(-1),
]
)
return parcel_gen
overlap_paths = {
hemi: (
base
/ f"parcels-{parcels_name}_space-{space}_overlap_hemi-{hemi}.func.gii"
)
for hemi in SURFACE_HEMIS
}
if not all(path.exists() for path in overlap_paths.values()):
raise FileNotFoundError(
"Overlapping map not found. Please double-check the parcels "
"name and configurations."
)
parcel_gen.overlap_map = np.concatenate(
[
load_surface_numeric_data(overlap_paths["L"]).reshape(-1),
load_surface_numeric_data(overlap_paths["R"]).reshape(-1),
]
)
_, parcel_gen.parcels = cls._run(
[parcel_gen.overlap_map],
parcel_gen.mesh,
parcel_gen.hemi_slices,
smoothing_kernel_size,
overlap_thr_vox,
)
parcel_info_data = []
for parcel in np.unique(parcel_gen.parcels):
if parcel == 0:
continue
parcel_mask = parcel_gen.parcels == parcel
parcel_size = int(np.sum(parcel_mask))
parcel_info_data.append([int(parcel), parcel_size])
parcel_gen.parcel_info = pd.DataFrame(
parcel_info_data, columns=["id", "size"]
)
parcel_gen._save()
return parcel_gen
[docs]
def run(self) -> SurfaceImage:
if self.mesh is None:
raise RuntimeError("No surface mesh loaded. Add subjects first.")
if self.parcel_info is None:
binary_masks = [np.mean(dat, axis=0) > 0.5 for dat in self._data]
self.overlap_map, self.parcels = self._run(
binary_masks,
self.mesh,
self.hemi_slices,
self.smoothing_kernel_size,
self.overlap_thr_vox,
)
parcel_info_data = []
for parcel in np.unique(self.parcels):
if parcel == 0:
continue
parcel_mask = self.parcels == parcel
parcel_size = int(np.sum(parcel_mask))
subject_coverage = np.zeros(len(self._data))
for subject_i, data in enumerate(self._data):
subject_coverage[subject_i] = (
self._harmonic_mean(
np.sum(data[:, parcel_mask], axis=1)
)
> 0
)
parcel_info_data.append(
[int(parcel), parcel_size, float(np.mean(subject_coverage))]
)
self.parcel_info = pd.DataFrame(
parcel_info_data, columns=["id", "size", "roi_overlap"]
)
self._save()
if self.min_voxel_size != 0 or self.overlap_thr_roi != 0:
self.parcels = self._filter(
self.parcels,
self.parcel_info,
self.overlap_thr_roi,
self.min_voxel_size,
)
self._save()
return self._to_surface_image(self.parcels)
[docs]
def filter(
self,
overlap_thr_roi: Optional[float] = 0,
min_voxel_size: Optional[int] = 0,
) -> SurfaceImage:
if self.parcels is None:
raise RuntimeError(
"No parcels to filter. Run the parcels generation first."
)
if overlap_thr_roi != 0 and overlap_thr_roi <= self.overlap_thr_roi:
warnings.warn(
"The new overlap_thr_roi is lower than the current setup. "
"The filtering will not be applied."
)
overlap_thr_roi = 0.0
if min_voxel_size != 0 and min_voxel_size <= self.min_voxel_size:
warnings.warn(
"The new min_voxel_size is lower than the current setup. "
"The filtering will not be applied."
)
min_voxel_size = 0
if overlap_thr_roi != 0 or min_voxel_size != 0:
self.parcels = self._filter(
self.parcels,
self.parcel_info,
overlap_thr_roi,
min_voxel_size,
)
if overlap_thr_roi != 0:
self.overlap_thr_roi = overlap_thr_roi
if min_voxel_size != 0:
self.min_voxel_size = min_voxel_size
self._save()
return self._to_surface_image(self.parcels)
def _get_subject_runs(
self, subject: str, task: str, contrasts: List[str]
) -> List[str]:
runs = None
for contrast in contrasts:
runs_i = _get_contrast_runs(subject, task, contrast)
if runs is None:
runs = runs_i
else:
runs = list(set(runs) & set(runs_i))
runs = sorted(runs or [])
if len(runs) == 0:
raise ValueError(
f"No surface contrast runs found for subject {subject} and "
f"task {task}."
)
return runs
@staticmethod
def _get_orthogonalized_labels(runs: List[str]) -> List[str]:
labels = []
for run in runs:
if len(runs) == 2:
runs_ = runs.copy()
runs_.remove(run)
labels.append(runs_[0])
else:
labels.append(f"orth{run}")
return labels
@staticmethod
def _get_derivatives_root() -> Path:
try:
bids_data_folder = Path(get_bids_data_folder())
derivatives_folder = get_bids_preprocessed_folder_relative()
if derivatives_folder == ".":
return bids_data_folder
return bids_data_folder / derivatives_folder
except (ValueError, RuntimeError):
return Path(get_bids_preprocessed_folder())
def _get_subject_meshes(
self, subject: str
) -> Tuple[Dict[str, object], Dict[str, Path]]:
derivatives_root = self._get_derivatives_root()
mesh_paths = _find_surface_mesh_paths(
derivatives_root, subject, self.space
)
if mesh_paths is None:
raise FileNotFoundError(
f"Could not find surface meshes for subject {subject} in "
f"space '{self.space}'."
)
return (
{
"left": load_surf_mesh(mesh_paths["L"]),
"right": load_surf_mesh(mesh_paths["R"]),
},
mesh_paths,
)
def _set_or_validate_mesh(
self, subject_mesh: Dict[str, object], mesh_paths: Dict[str, Path]
) -> None:
hemi_sizes = {
hemi: int(len(np.asarray(subject_mesh[hemi].coordinates)))
for hemi in ["left", "right"]
}
if self.mesh is None:
self.mesh = subject_mesh
self.mesh_paths = mesh_paths
self.hemi_sizes = hemi_sizes
left_size = hemi_sizes["left"]
self.hemi_slices = {
"left": slice(0, left_size),
"right": slice(left_size, left_size + hemi_sizes["right"]),
}
return
for hemi in ["left", "right"]:
reference_mesh = self.mesh[hemi]
candidate_mesh = subject_mesh[hemi]
if len(reference_mesh.coordinates) != len(candidate_mesh.coordinates):
raise ValueError(
"All surface meshes must have the same number of vertices. "
f"Mismatch found in hemisphere {hemi}."
)
if not np.array_equal(reference_mesh.faces, candidate_mesh.faces):
raise ValueError(
"All surface meshes must share the same topology. "
f"Mismatch found in hemisphere {hemi}."
)
def _get_surface_contrast_data(
self,
subject: str,
task: str,
run_label: str,
contrast: str,
image_type: str,
) -> Dict[str, np.ndarray]:
hemi_data = {}
for hemi in SURFACE_HEMIS:
path = _get_surface_contrast_path(
subject, task, run_label, contrast, image_type, hemi
)
if not path.exists():
raise FileNotFoundError(
f"Surface contrast file not found: {path}"
)
hemi_data[hemi] = load_surface_numeric_data(path).reshape(-1)
if image_type == "p":
for hemi in SURFACE_HEMIS:
hemi_data[hemi] = hemi_data[hemi].astype(float, copy=True)
hemi_data[hemi][hemi_data[hemi] == 0] = np.nan
expected_sizes = {
"L": self.hemi_sizes["left"] if self.hemi_sizes else None,
"R": self.hemi_sizes["right"] if self.hemi_sizes else None,
}
for hemi in SURFACE_HEMIS:
expected_size = expected_sizes[hemi]
if expected_size is not None and len(hemi_data[hemi]) != expected_size:
raise ValueError(
"All surface contrast maps must match the shared mesh "
f"vertex count. Hemisphere {hemi} does not match."
)
return hemi_data
@classmethod
def _run(
cls,
binary_masks: List[np.ndarray],
mesh: Dict[str, object],
hemi_slices: Dict[str, slice],
smoothing_kernel_size: Union[float, List[float]],
overlap_thr_vox: float,
) -> Tuple[np.ndarray, np.ndarray]:
overlap_map = np.mean(binary_masks, axis=0)
smoothed_map = overlap_map.copy().astype(float)
fwhm = smoothing_kernel_size
if isinstance(fwhm, (list, tuple)):
if len(fwhm) != 1:
raise ValueError(
"Surface smoothing expects a scalar kernel size."
)
fwhm = fwhm[0]
for hemi in ["left", "right"]:
hemi_slice = hemi_slices[hemi]
smoothed_map[hemi_slice] = _smooth_surface_array(
smoothed_map[hemi_slice],
mesh[hemi],
fwhm,
)
smoothed_map[smoothed_map < overlap_thr_vox] = np.nan
parcels = np.zeros_like(overlap_map, dtype=int)
label_offset = 0
for hemi in ["left", "right"]:
hemi_slice = hemi_slices[hemi]
hemi_labels = cls._watershed_surface(
smoothed_map[hemi_slice],
mesh[hemi],
)
nonzero = hemi_labels > 0
hemi_labels = hemi_labels.astype(int)
hemi_labels[nonzero] += label_offset
parcels[hemi_slice] = hemi_labels
label_offset = int(np.max(parcels))
return overlap_map, parcels
@classmethod
def _watershed_surface(cls, values: np.ndarray, mesh) -> np.ndarray:
values = np.asarray(values, dtype=float).reshape(-1)
adjacency = cls._mesh_adjacency(mesh)
labels = np.zeros(values.shape[0], dtype=int)
valid = np.flatnonzero(~np.isnan(values))
if valid.size == 0:
return labels
order = valid[np.argsort(values[valid], kind="stable")[::-1]]
next_label = 1
for vertex in order:
neighbor_labels = np.unique(
labels[adjacency[vertex]][labels[adjacency[vertex]] > 0]
)
if neighbor_labels.size == 0:
labels[vertex] = next_label
next_label += 1
elif neighbor_labels.size == 1:
labels[vertex] = int(neighbor_labels[0])
return labels
@staticmethod
def _mesh_adjacency(mesh) -> List[np.ndarray]:
n_vertices = len(mesh.coordinates)
neighbors = [set() for _ in range(n_vertices)]
for face in np.asarray(mesh.faces, dtype=int):
a, b, c = face.tolist()
neighbors[a].update([b, c])
neighbors[b].update([a, c])
neighbors[c].update([a, b])
return [
np.asarray(sorted(vertex_neighbors), dtype=int)
for vertex_neighbors in neighbors
]
@classmethod
def _filter(
cls,
parcels: np.ndarray,
parcel_info: pd.DataFrame,
overlap_thr_roi: float,
min_voxel_size: int,
) -> np.ndarray:
filtered_parcels = parcels.copy()
unique_parcels = np.unique(parcels)
for parcel in unique_parcels:
if parcel == 0:
continue
parcel_mask = parcels == parcel
if (
parcel_info.loc[
parcel_info["id"] == parcel, "roi_overlap"
].values[0]
< overlap_thr_roi
):
filtered_parcels[parcel_mask] = 0
if (
parcel_info.loc[parcel_info["id"] == parcel, "size"].values[0]
< min_voxel_size
):
filtered_parcels[parcel_mask] = 0
return filtered_parcels
def _to_surface_image(self, data: np.ndarray) -> SurfaceImage:
return SurfaceImage(
mesh=self.mesh,
data={
"left": np.asarray(data[self.hemi_slices["left"]]),
"right": np.asarray(data[self.hemi_slices["right"]]),
},
)
@staticmethod
def _harmonic_mean(data: np.ndarray) -> float:
data = np.asarray(data).flatten().astype(float)
data = data[~np.isnan(data)]
if data.size == 0:
return 0.0
if np.any(data <= 0):
return 0.0
return data.size / np.sum(1.0 / data)
@staticmethod
def _get_analysis_parcels_folder(parcels_name: str) -> Path:
return get_analysis_output_folder() / "parcels" / f"parcels-{parcels_name}"
def _output_stem(self) -> str:
return (
f"parcels-{self.parcels_name}_space-{self.space}"
f"_sm-{self.smoothing_kernel_size}"
f"_voxthres-{self.overlap_thr_vox}"
f"_roithres-{self.overlap_thr_roi}"
f"_sz-{self.min_voxel_size}"
)
def _save(self) -> None:
base = self._get_analysis_parcels_folder(self.parcels_name)
base.mkdir(parents=True, exist_ok=True)
config_path = base / f"parcels-{self.parcels_name}_config.json"
if not config_path.exists():
with open(config_path, "w") as f:
json.dump(
{
"configs": self.configs,
"space": self.space,
"mesh_paths": {
hemi: str(path)
for hemi, path in (self.mesh_paths or {}).items()
},
},
f,
)
overlap_stem = (
f"parcels-{self.parcels_name}_space-{self.space}_overlap"
)
if self.overlap_map is not None:
save_surface_image(
self._to_surface_image(self.overlap_map),
{
hemi: base / f"{overlap_stem}_hemi-{hemi}.func.gii"
for hemi in SURFACE_HEMIS
},
)
parcel_info_path = (
base
/ (
f"parcels-{self.parcels_name}_space-{self.space}"
f"_sm-{self.smoothing_kernel_size}"
f"_voxthres-{self.overlap_thr_vox}_info.csv"
)
)
self.parcel_info.to_csv(parcel_info_path, index=False)
stem = self._output_stem()
save_surface_image(
self._to_surface_image(self.parcels),
{
hemi: base / f"{stem}_hemi-{hemi}.func.gii"
for hemi in SURFACE_HEMIS
},
)