from __future__ import annotations
import json
from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, Any, Literal
import numpy as np
if TYPE_CHECKING:
from pathlib import Path
from numpy.typing import NDArray
from dynsight.trajectory import Insight, Trj
from tropea_clustering._internal.onion_smooth.first_classes import (
StateMulti,
StateUni,
)
import dynsight
from dynsight.logs import logger
UNIVAR_DIM = 2
[docs]
@dataclass(frozen=True)
class ClusterInsight:
"""Contains a clustering analysis.
Attributes:
labels: The labels assigned by the clustering algorithm.
"""
labels: NDArray[np.int64]
[docs]
def dump_to_json(self, file_path: Path) -> None:
"""Save the ClusterInsight to a JSON file and .npy file."""
npy_path = file_path.with_suffix(".npy")
np.save(npy_path, self.labels)
json_data = {
"labels_file": npy_path.name,
}
with file_path.open("w") as file:
json.dump(json_data, file, indent=4)
logger.log(
f"ClusterInsight saved to {file_path} and labels to {npy_path}."
)
[docs]
@classmethod
def load_from_json(
cls,
file_path: Path,
mmap_mode: Literal["r", "r+", "w+", "c"] | None = None,
) -> ClusterInsight:
"""Load the ClusterInsight object from JSON and associated .npy file.
Parameters:
file_path:
Path to the .json file.
mmap_mode:
If given, used as np.load(..., mmap_mode=mmap_mode) for memory
mapping.
Raises:
ValueError: if required keys are missing.
"""
with file_path.open("r") as file:
data = json.load(file)
labels_file = data.get("labels_file")
if not labels_file:
msg = "'labels_file' key not found in JSON file."
logger.log(msg)
raise ValueError(msg)
labels_path = file_path.with_name(labels_file)
labels = np.load(labels_path, mmap_mode=mmap_mode)
logger.log(
f"ClusterInsight loaded from {file_path}, "
f"labels from {labels_path}."
)
return cls(labels=labels)
[docs]
@dataclass(frozen=True)
class OnionInsight(ClusterInsight):
"""Contains an onion-clustering analysis.
Attributes:
labels: The labels assigned by the clustering algorithm.
state_list: List of the onion-clustering Gaussian states.
reshaped_data: The input data reshaped for onion-clustering.
meta: A dictionary containing the relevant parameters.
"""
state_list: list[StateUni] | list[StateMulti]
reshaped_data: NDArray[np.float64]
meta: dict[str, Any] = field(default_factory=dict)
[docs]
def dump_to_json(self, file_path: Path) -> None:
"""Save the OnionInsight to a JSON file and .npy file."""
# File paths
base = file_path.with_suffix("")
labels_path = base.with_name(base.name + "_labels.npy")
reshaped_path = base.with_name(base.name + "_reshaped.npy")
# Save large arrays
np.save(labels_path, self.labels)
np.save(reshaped_path, self.reshaped_data)
# Serialize state_list
serialized_states = []
for state in self.state_list:
state_dict = {}
for f in fields(state):
val = getattr(state, f.name)
state_dict[f.name] = (
val.tolist() if isinstance(val, np.ndarray) else val
)
serialized_states.append(state_dict)
# Compose JSON
data = {
"labels_file": labels_path.name,
"reshaped_data_file": reshaped_path.name,
"state_list": serialized_states,
"meta": self.meta,
}
with file_path.open("w") as file:
json.dump(data, file, indent=4)
logger.log(
f"OnionInsight saved to {file_path}, labels to {labels_path}, "
f"reshaped data to {reshaped_path}."
)
[docs]
@classmethod
def load_from_json(
cls,
file_path: Path,
mmap_mode: Literal["r", "r+", "w+", "c"] | None = None,
) -> OnionInsight:
"""Load the OnionInsight object from JSON and .npy files.
Parameters:
file_path:
Path to the .json file.
mmap_mode:
If given, used as np.load(..., mmap_mode=mmap_mode) for memory
mapping.
Raises:
ValueError: if required keys are missing.
"""
with file_path.open("r") as file:
data = json.load(file)
# Validate presence of keys
required_keys = ["labels_file", "reshaped_data_file", "state_list"]
for key in required_keys:
if key not in data:
msg = f"'{key}' key not found in JSON file."
logger.log(msg)
raise ValueError(msg)
raw_list = data["state_list"]
state_list = []
for entry in raw_list:
# Infer the correct class
if isinstance(entry.get("mean"), list):
state_cls = StateMulti
else:
state_cls = StateUni
# Convert lists back to ndarray
kwargs = {}
for k, v in entry.items():
kwargs[k] = np.array(v) if isinstance(v, list) else v
state_list.append(state_cls(**kwargs))
base_dir = file_path.parent
labels = np.load(base_dir / data["labels_file"], mmap_mode=mmap_mode)
reshaped = np.load(
base_dir / data["reshaped_data_file"], mmap_mode=mmap_mode
)
logger.log(f"OnionInsight loaded from {file_path}.")
return cls(
labels=labels,
reshaped_data=reshaped,
state_list=state_list,
meta=data.get("meta", {}),
)
[docs]
def plot_output(self, file_path: Path, data_insight: Insight) -> None:
"""Plot the overall onion clustering result."""
if data_insight.dataset.ndim == UNIVAR_DIM:
dynsight.onion.plot.plot_output_uni(
file_path,
self.reshaped_data,
data_insight.dataset.shape[0],
self.state_list,
)
else:
dynsight.onion.plot.plot_output_multi(
file_path,
data_insight.dataset,
self.state_list,
self.labels,
self.meta["delta_t"],
)
attr_dict = {"file_path": file_path}
logger.log(f"plot_output() with args {attr_dict}.")
[docs]
def plot_one_trj(
self,
file_path: Path,
data_insight: Insight,
particle_id: int,
) -> None:
"""Plot one particle's trajectory colored according to clustering."""
if data_insight.dataset.ndim == UNIVAR_DIM:
dynsight.onion.plot.plot_one_trj_uni(
file_path,
particle_id,
self.reshaped_data,
data_insight.dataset.shape[0],
self.labels,
)
else:
dynsight.onion.plot.plot_one_trj_multi(
file_path,
particle_id,
self.meta["delta_t"],
data_insight.dataset,
self.labels,
)
attr_dict = {"file_path": file_path, "particle_id": particle_id}
logger.log(f"plot_one_trj() with args {attr_dict}.")
[docs]
def plot_medoids(self, file_path: Path, data_insight: Insight) -> None:
"""Plot the average sequence of each onion cluster."""
if data_insight.dataset.ndim == UNIVAR_DIM:
dynsight.onion.plot.plot_medoids_uni(
file_path,
self.reshaped_data,
self.labels,
)
else:
dynsight.onion.plot.plot_medoids_multi(
file_path,
self.meta["delta_t"],
data_insight.dataset,
self.labels,
)
attr_dict = {"file_path": file_path}
logger.log(f"plot_medoids() with args {attr_dict}.")
[docs]
def plot_state_populations(
self,
file_path: Path,
data_insight: Insight,
) -> None:
"""Plot each state's population along the trajectory."""
dynsight.onion.plot.plot_state_populations(
file_path,
data_insight.dataset.shape[0],
self.meta["delta_t"],
self.labels,
)
attr_dict = {"file_path": file_path}
logger.log(f"plot_state_populations() with args {attr_dict}.")
[docs]
def plot_sankey(
self,
file_path: Path,
data_insight: Insight,
frame_list: list[int],
) -> None:
"""Plot the Sankey diagram of the onion clustering."""
dynsight.onion.plot.plot_sankey(
file_path,
self.labels,
data_insight.dataset.shape[0],
frame_list,
)
attr_dict = {"file_path": file_path, "frame_list": frame_list}
logger.log(f"plot_state_populations() with args {attr_dict}.")
[docs]
@dataclass(frozen=True)
class OnionSmoothInsight(ClusterInsight):
"""Contains a smooth onion-clustering analysis.
Attributes:
labels: The labels assigned by the clustering algorithm.
state_list: List of the onion-clustering Gaussian states.
meta: A dictionary containing the relevant parameters.
"""
state_list: list[StateUni] | list[StateMulti]
meta: dict[str, Any] = field(default_factory=dict)
[docs]
def dump_to_json(self, file_path: Path) -> None:
"""Save the OnionSmoothInsight object to JSON and .npy for labels."""
base = file_path.with_suffix("")
labels_path = base.with_name(base.name + "_labels.npy")
# Save labels to .npy
np.save(labels_path, self.labels)
# Serialize state_list
serialized_states = []
for state in self.state_list:
state_dict = {}
for f in fields(state):
val = getattr(state, f.name)
state_dict[f.name] = (
val.tolist() if isinstance(val, np.ndarray) else val
)
serialized_states.append(state_dict)
# Compose JSON
data = {
"labels_file": labels_path.name,
"state_list": serialized_states,
"meta": self.meta,
}
with file_path.open("w") as file:
json.dump(data, file, indent=4)
logger.log(
f"OnionSmoothInsight saved to {file_path}, "
f"labels to {labels_path}."
)
[docs]
@classmethod
def load_from_json(
cls,
file_path: Path,
mmap_mode: Literal["r", "r+", "w+", "c"] | None = None,
) -> OnionSmoothInsight:
"""Load the OnionSmoothInsight from JSON and associated .npy file.
Parameters:
file_path:
Path to the .json file.
mmap_mode:
If given, used as np.load(..., mmap_mode=mmap_mode) for memory
mapping.
Raises:
ValueError: if required keys are missing.
"""
with file_path.open("r") as file:
data = json.load(file)
if "labels_file" not in data or "state_list" not in data:
msg = "'labels_file' or 'state_list' key not found in JSON file."
logger.log(msg)
raise ValueError(msg)
raw_list = data["state_list"]
state_list = []
for entry in raw_list:
# Decide which class to use
if isinstance(entry.get("mean"), list):
state_cls = StateMulti
else:
state_cls = StateUni
# Rebuild kwargs (convert lists back to np.ndarrays)
kwargs = {}
for k, v in entry.items():
if isinstance(v, list):
kwargs[k] = np.array(v)
else:
kwargs[k] = v
state_list.append(state_cls(**kwargs))
labels_path = file_path.parent / data["labels_file"]
labels = np.load(labels_path, mmap_mode=mmap_mode)
logger.log(
f"OnionSmoothInsight loaded from {file_path}, "
f"labels from {labels_path}."
)
return cls(
labels=labels,
state_list=state_list,
meta=data.get("meta", {}),
)
[docs]
def plot_output(self, file_path: Path, data_insight: Insight) -> None:
"""Plot the overall onion clustering result."""
if data_insight.dataset.ndim == UNIVAR_DIM:
dynsight.onion.plot_smooth.plot_output_uni(
file_path,
data_insight.dataset,
self.state_list,
)
else:
dynsight.onion.plot_smooth.plot_output_multi(
file_path,
data_insight.dataset,
self.state_list,
self.labels,
)
attr_dict = {"file_path": file_path}
logger.log(f"plot_output() with args {attr_dict}.")
[docs]
def plot_one_trj(
self,
file_path: Path,
data_insight: Insight,
particle_id: int,
) -> None:
"""Plot one particle's trajectory colored according to clustering."""
if data_insight.dataset.ndim == UNIVAR_DIM:
dynsight.onion.plot_smooth.plot_one_trj_uni(
file_path,
particle_id,
data_insight.dataset,
self.labels,
)
else:
dynsight.onion.plot_smooth.plot_one_trj_multi(
file_path,
particle_id,
data_insight.dataset,
self.labels,
)
attr_dict = {"file_path": file_path, "particle_id": particle_id}
logger.log(f"plot_one_trj() with args {attr_dict}.")
[docs]
def plot_state_populations(
self,
file_path: Path,
) -> None:
"""Plot each state's population along the trajectory."""
dynsight.onion.plot_smooth.plot_state_populations(
file_path,
self.labels,
)
attr_dict = {"file_path": file_path}
logger.log(f"plot_state_populations() with args {attr_dict}.")
[docs]
def plot_sankey(
self,
file_path: Path,
frame_list: list[int],
) -> None:
"""Plot the Sankey diagram of the onion clustering."""
dynsight.onion.plot_smooth.plot_sankey(
file_path,
self.labels,
frame_list,
)
attr_dict = {"file_path": file_path, "frame_list": frame_list}
logger.log(f"plot_state_populations() with args {attr_dict}.")
[docs]
def dump_colored_trj(self, trj: Trj, file_path: Path) -> None:
"""Save an .xyz file with the clustering labels for each atom."""
trajslice = slice(None) if trj.trajslice is None else trj.trajslice
n_frames = len(trj.universe.trajectory[trajslice])
n_atoms = len(trj.universe.atoms)
lab_new = self.labels + 2
if self.labels.shape != (n_atoms, n_frames):
msg = (
f"Shape mismatch: Trj should have {self.labels.shape[0]} "
f"atoms, {self.labels.shape[1]} frames, but has {n_atoms} "
f"atoms, {n_frames} frames."
)
logger.log(msg)
raise ValueError(msg)
with file_path.open("w") as f:
for i, ts in enumerate(trj.universe.trajectory[trajslice]):
f.write(f"{n_atoms}\n")
f.write(f"Frame {i}\n")
for atom_idx in range(n_atoms):
label = str(lab_new[atom_idx, i])
x, y, z = ts.positions[atom_idx]
f.write(
f"{trj.universe.atoms[atom_idx].name} {x:.5f}"
f" {y:.5f} {z:.5f} {label}\n"
)
logger.log(f"Colored trj saved to {file_path}.")