aboutsummaryrefslogtreecommitdiffstats
path: root/modules/drug.py
diff options
context:
space:
mode:
authorLeonard Kugis <leonard@kug.is>2025-12-23 03:50:58 +0100
committerLeonard Kugis <leonard@kug.is>2025-12-23 03:50:58 +0100
commit9e1f202b2101be32f2eb5b1a30f041f1e25553cd (patch)
treef6eefe7cbd02b52759877102fc63098d96e13807 /modules/drug.py
parentd342b3d00f64915685b486a68b7c7b3e2e47fde6 (diff)
downloadxembu-9e1f202b2101be32f2eb5b1a30f041f1e25553cd.tar.gz
Added drug module
Diffstat (limited to 'modules/drug.py')
-rw-r--r--modules/drug.py408
1 files changed, 408 insertions, 0 deletions
diff --git a/modules/drug.py b/modules/drug.py
new file mode 100644
index 0000000..9b6c390
--- /dev/null
+++ b/modules/drug.py
@@ -0,0 +1,408 @@
+# modules/drugs.py
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple
+
+import numpy as np
+import pandas as pd
+import matplotlib.dates as mdates
+from matplotlib.axes import Axes
+from matplotlib.font_manager import FontProperties
+
+from .base import Frame, BigFrame, ModuleResult
+
+class SubstanceModel:
+ key: str = ""
+ supported_applications: set[str] = set()
+ unit_label: str = ""
+
+ def simulate_bac(
+ self,
+ time_index: pd.DatetimeIndex,
+ events: pd.DataFrame,
+ profile: Dict[str, float],
+ params: Dict[str, float],
+ ) -> pd.Series:
+ raise NotImplementedError
+
+class EthanolOralModel(SubstanceModel):
+ key = "Ethanol"
+ supported_applications = {"oral"}
+ unit_label = "Blood concentration [g/L]"
+
+ ETHANOL_DENSITY_G_PER_ML = 0.789
+
+ def simulate_bac(
+ self,
+ time_index: pd.DatetimeIndex,
+ events: pd.DataFrame,
+ profile: Dict[str, float],
+ params: Dict[str, float],
+ ) -> pd.Series:
+ weight_kg = float(profile.get("weight_kg", 70.0))
+ r = float(profile.get("r", 0.6))
+ r = max(0.1, min(0.9, r))
+
+ beta_permille_per_h = float(params.get("beta_permille_per_h", 0.15))
+ t_half_abs_min = float(params.get("absorption_halftime_min", 20.0))
+
+ t_half_abs_h = max(1e-6, t_half_abs_min / 60.0)
+ ka = np.log(2.0) / t_half_abs_h # [1/h]
+
+ ev = events.sort_values("date").reset_index(drop=True)
+ ev_idx = 0
+
+ gut_pools: List[float] = []
+
+ body_grams = 0.0
+
+ elim_g_per_h = beta_permille_per_h * (r * weight_kg)
+
+ if len(time_index) >= 2:
+ dt_h = (time_index[1] - time_index[0]).total_seconds() / 3600.0
+ else:
+ dt_h = 5.0 / 60.0
+
+ dt_h = max(1e-6, dt_h)
+ absorb_frac = 1.0 - np.exp(-ka * dt_h)
+
+ out = np.zeros(len(time_index), dtype=float)
+
+ for i, t in enumerate(time_index):
+ while ev_idx < len(ev) and pd.Timestamp(ev.loc[ev_idx, "date"]) <= t:
+ g = float(ev.loc[ev_idx, "grams_ethanol"])
+ if g > 0:
+ gut_pools.append(g)
+ ev_idx += 1
+
+ absorbed_total = 0.0
+ if gut_pools:
+ new_pools = []
+ for rem in gut_pools:
+ absorbed = rem * absorb_frac
+ rem2 = rem - absorbed
+ absorbed_total += absorbed
+ if rem2 > 1e-6:
+ new_pools.append(rem2)
+ gut_pools = new_pools
+
+ body_grams += absorbed_total
+
+ body_grams = max(0.0, body_grams - elim_g_per_h * dt_h)
+
+ out[i] = body_grams / (r * weight_kg)
+
+ return pd.Series(out, index=time_index)
+
+
+SUBSTANCE_MODELS: Dict[str, SubstanceModel] = {
+ "Ethanol": EthanolOralModel(),
+}
+
+def _clean_str(x: Any) -> str:
+ s = str(x).strip()
+ if (s.startswith("'") and s.endswith("'")) or (s.startswith('"') and s.endswith('"')):
+ s = s[1:-1].strip()
+ return s
+
+def _auto_time_limits(
+ tmin: pd.Timestamp,
+ tmax: pd.Timestamp,
+ params: Optional[Dict[str, float]] = None,
+) -> tuple[pd.Timestamp, pd.Timestamp]:
+ dt = tmax - tmin
+ if dt <= pd.Timedelta(0):
+ margin = pd.Timedelta(minutes=30)
+ else:
+ margin = dt * 0.05
+
+ x0 = tmin - margin
+ x1 = tmax + margin
+
+ if params:
+ beta = float(params.get("beta_permille_per_h", 0.15))
+ beta = max(1e-6, beta)
+
+ t_half_abs_min = float(params.get("absorption_halftime_min", 20.0))
+ t_half_abs_h = max(1e-6, t_half_abs_min / 60.0)
+
+ tau_abs_h = t_half_abs_h / np.log(2.0)
+
+ tail_h = max(6.0, 8.0 * tau_abs_h, 2.0 * (1.0 / beta))
+ tail_h = min(tail_h, 24.0)
+
+ x1_tail = tmax + pd.Timedelta(hours=tail_h)
+
+ if x1_tail > x1:
+ x1 = x1_tail
+
+ return x0, x1
+
+def _extract_module_tuple(row: pd.Series, module_name: str) -> Optional[Tuple]:
+ mods = row.get("modules_list", []) or []
+ params = row.get("params_list", []) or []
+ try:
+ idx = list(mods).index(module_name)
+ except ValueError:
+ return None
+ return params[idx] if idx < len(params) else None
+
+
+def _make_time_grid(start: pd.Timestamp, end: pd.Timestamp, target_points: int = 800) -> pd.DatetimeIndex:
+ if end <= start:
+ return pd.DatetimeIndex([start])
+
+ total_s = (end - start).total_seconds()
+ step_s = max(60.0, min(30 * 60.0, total_s / float(max(10, target_points))))
+ step_min = int(max(1, round(step_s / 60.0)))
+ return pd.date_range(start=start, end=end, freq=f"{step_min}min")
+
+@dataclass
+class TextFrame(Frame):
+ text: str
+
+ def render(self, ax: Axes, mono_font: FontProperties) -> None:
+ ax.text(0, 1, self.text, va="top", ha="left", fontproperties=mono_font)
+
+@dataclass
+class BACBigFrame(BigFrame):
+ model_name: str
+ times: pd.DatetimeIndex
+ xlim_start: pd.Timestamp
+ xlim_end: pd.Timestamp
+ unit_label: str
+ participants: List[str]
+ bac_by_person: Dict[str, pd.Series]
+
+ def render(self, ax: Axes, mono_font: FontProperties) -> None:
+ ax.axis("on")
+
+ locator = mdates.AutoDateLocator(minticks=3, maxticks=7)
+ formatter = mdates.ConciseDateFormatter(locator)
+ ax.xaxis.set_major_locator(locator)
+ ax.xaxis.set_major_formatter(formatter)
+ ax.xaxis.get_offset_text().set_visible(False)
+
+ ax.set_xlim(self.xlim_start, self.xlim_end)
+ ax.set_ylabel(self.unit_label, fontproperties=mono_font)
+
+ all_vals = []
+ for p in self.participants:
+ y = self.bac_by_person[p].copy().replace([np.inf, -np.inf], np.nan).fillna(0.0)
+
+ line, = ax.plot(self.times, y.values, label=p, linewidth=1)
+ ax.fill_between(
+ self.times,
+ y.values,
+ 0.0,
+ alpha=0.18,
+ color=line.get_color(),
+ zorder=line.get_zorder() - 1,
+ )
+
+ v = y.values
+ v = v[np.isfinite(v)]
+ if v.size:
+ all_vals.append(v)
+
+ if all_vals:
+ vv = np.concatenate(all_vals)
+ vmax = float(np.nanmax(vv)) if vv.size else 0.0
+ ax.set_ylim(0.0, 0.5 if vmax <= 0 else vmax * 1.10)
+
+ ax.grid(True, alpha=0.2)
+
+ leg = ax.legend(prop=mono_font, fontsize=7, loc="best", ncols=2)
+ if leg:
+ for t in leg.get_texts():
+ t.set_fontproperties(mono_font)
+
+ for tick in ax.get_xticklabels() + ax.get_yticklabels():
+ tick.set_fontproperties(mono_font)
+
+ ax.xaxis.label.set_fontproperties(mono_font)
+ ax.yaxis.label.set_fontproperties(mono_font)
+
+def _fmt_ethanol_totals_table(ev: pd.DataFrame) -> List[str]:
+ """
+ Returns ONLY the table block (no markdown headings):
+ debitor | grams EtOH | weight kg | events
+ Expects columns: debitor, grams_ethanol, weight_kg
+ """
+ if ev.empty:
+ return ["(no data)"]
+
+ # grams total + events count
+ totals = ev.groupby("debitor")["grams_ethanol"].sum().reset_index().sort_values("debitor")
+ counts = ev.groupby("debitor").size().rename("events").reset_index()
+
+ # last known weight per debitor (by date)
+ w_last = (
+ ev.dropna(subset=["weight_kg"])
+ .sort_values(["debitor", "date"])
+ .groupby("debitor")["weight_kg"]
+ .last()
+ .rename("weight_kg")
+ .reset_index()
+ )
+
+ totals = totals.merge(counts, on="debitor", how="left").merge(w_last, on="debitor", how="left")
+ totals["events"] = totals["events"].fillna(0).astype(int)
+
+ # formatting widths
+ name_list = totals["debitor"].astype(str).tolist()
+ name_w = max([len(x) for x in name_list] + [7])
+
+ header = f"{'debitor':<{name_w}} | {'EtOH [g]':>12} | {'weight [kg]':>9} | {'events':>6}"
+ sep = "-" * len(header)
+
+ lines = [header, sep]
+ for _, r in totals.iterrows():
+ deb = str(r["debitor"])
+ grams = float(r["grams_ethanol"]) if pd.notna(r["grams_ethanol"]) else 0.0
+ w = float(r["weight_kg"]) if pd.notna(r["weight_kg"]) else float("nan")
+ events = int(r["events"])
+
+ w_str = f"{w:>9.1f}" if np.isfinite(w) else f"{'n/a':>9}"
+ lines.append(f"{deb:<{name_w}} | {grams:>12.2f} | {w_str} | {events:>6d}")
+
+ return lines
+
+class DrugModule:
+ name = "drug"
+
+ def process(self, df: pd.DataFrame, context: Dict[str, Any]) -> ModuleResult:
+ want_pdf = bool(context.get("want_pdf", True))
+ mono_font = context.get("mono_font") or FontProperties(family="DejaVu Sans Mono", size=8)
+
+ profiles: Dict[str, Dict[str, float]] = context.get("drug_profiles", {}) or {}
+
+ params = {
+ "beta_permille_per_h": float(context.get("drug_ethanol_beta_permille_per_h", 0.15)),
+ "absorption_halftime_min": float(context.get("drug_ethanol_absorption_halftime_min", 20.0)),
+ }
+
+ work = df.copy()
+ work = work[pd.notna(work["date"])]
+ work["flag"] = work["group_flag"].astype(str).str.strip().str.upper()
+ work = work[work["flag"] == "U"].copy()
+
+ warnings: List[str] = []
+ events: List[Dict[str, Any]] = []
+
+ for _, row in work.iterrows():
+ tup = _extract_module_tuple(row, "drug")
+ if not tup or len(tup) < 5:
+ warnings.append(f"Missing/invalid drug params at {row.get('date')} ({row.get('debitor')})")
+ continue
+
+ substance = _clean_str(tup[0])
+ application = _clean_str(tup[3]).lower()
+
+ if substance != "Ethanol" or application != "oral":
+ warnings.append(f"Unsupported: ({substance}, {application}) at {row.get('date')} ({row.get('debitor')})")
+ continue
+
+ try:
+ purity = float(tup[1])
+ except Exception:
+ purity = 0.0
+
+ try:
+ amount_liters = float(tup[2])
+ except Exception:
+ amount_liters = 0.0
+
+ try:
+ weight_kg = float(tup[4])
+ except Exception:
+ weight_kg = float(context.get("drug_default_weight_kg", 70.0))
+
+ purity = float(np.clip(purity, 0.0, 1.0))
+ amount_liters = max(0.0, amount_liters)
+
+ if not np.isfinite(weight_kg) or weight_kg <= 0:
+ weight_kg = float(context.get("drug_default_weight_kg", 70.0))
+
+ purity = float(np.clip(purity, 0.0, 1.0))
+ amount_liters = max(0.0, amount_liters)
+
+ grams_ethanol = amount_liters * 1000.0 * EthanolOralModel.ETHANOL_DENSITY_G_PER_ML * purity
+
+ events.append({
+ "date": pd.Timestamp(row["date"]),
+ "debitor": str(row["debitor"]),
+ "substance": "Ethanol",
+ "application": "oral",
+ "purity": purity,
+ "amount_liters": amount_liters,
+ "grams_ethanol": grams_ethanol,
+ "weight_kg": weight_kg
+ })
+
+ if not events:
+ summary = ["# DrugModule", "", "(no supported Ethanol/oral events found)"]
+ if warnings:
+ summary += ["", "Warnings:"] + [f"- {w}" for w in warnings[:50]]
+ return ModuleResult(summary_text="\n".join(summary), frames=[], bigframes=[], pages=[])
+
+ ev = pd.DataFrame(events)
+ tmin, tmax = ev["date"].min(), ev["date"].max()
+ x0, x1 = _auto_time_limits(tmin, tmax, params=params)
+
+ times = _make_time_grid(x0, x1, target_points=800)
+
+ model = SUBSTANCE_MODELS["Ethanol"]
+
+ participants = sorted(ev["debitor"].unique().tolist())
+ bac_by_person: Dict[str, pd.Series] = {}
+
+ for p in participants:
+ pe = ev[ev["debitor"] == p].copy()
+
+ w_series = pe["weight_kg"].dropna().astype(float)
+ weight_kg = float(w_series.iloc[-1]) if len(w_series) else float(context.get("drug_default_weight_kg", 70.0))
+
+ profile = {
+ "weight_kg": weight_kg,
+ "r": float(context.get("drug_default_r", 0.6)),
+ }
+
+ bac_by_person[p] = model.simulate_bac(times, pe[["date", "grams_ethanol"]], profile, params)
+
+ summary_lines: List[str] = []
+ summary_lines.append("# DrugModule")
+ summary_lines.append("")
+ summary_lines.append("## Totals")
+ summary_lines.append("")
+ summary_lines.extend(_fmt_ethanol_totals_table(ev))
+
+ summary_text = "\n".join(summary_lines)
+
+ frames: List[Frame] = []
+ bigframes: List[BigFrame] = []
+ pages: List[plt.Figure] = []
+
+ if want_pdf:
+ table_lines = _fmt_ethanol_totals_table(ev)
+ frames.append(TextFrame(
+ title="Drug: Ethanol/oral totals",
+ text="\n".join(table_lines),
+ ))
+
+ bigframes.append(
+ BACBigFrame(
+ title="EtOH pharmacokinetic analysis",
+ model_name="Ethanol/oral",
+ times=times,
+ xlim_start=x0,
+ xlim_end=x1,
+ unit_label=model.unit_label,
+ participants=participants,
+ bac_by_person=bac_by_person,
+ )
+ )
+
+ return ModuleResult(summary_text=summary_text, frames=frames, bigframes=bigframes, pages=pages)
+