diff options
Diffstat (limited to 'modules/drug.py')
| -rw-r--r-- | modules/drug.py | 408 |
1 files changed, 408 insertions, 0 deletions
diff --git a/modules/drug.py b/modules/drug.py new file mode 100644 index 0000000..9b6c390 --- /dev/null +++ b/modules/drug.py @@ -0,0 +1,408 @@ +# modules/drugs.py +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd +import matplotlib.dates as mdates +from matplotlib.axes import Axes +from matplotlib.font_manager import FontProperties + +from .base import Frame, BigFrame, ModuleResult + +class SubstanceModel: + key: str = "" + supported_applications: set[str] = set() + unit_label: str = "" + + def simulate_bac( + self, + time_index: pd.DatetimeIndex, + events: pd.DataFrame, + profile: Dict[str, float], + params: Dict[str, float], + ) -> pd.Series: + raise NotImplementedError + +class EthanolOralModel(SubstanceModel): + key = "Ethanol" + supported_applications = {"oral"} + unit_label = "Blood concentration [g/L]" + + ETHANOL_DENSITY_G_PER_ML = 0.789 + + def simulate_bac( + self, + time_index: pd.DatetimeIndex, + events: pd.DataFrame, + profile: Dict[str, float], + params: Dict[str, float], + ) -> pd.Series: + weight_kg = float(profile.get("weight_kg", 70.0)) + r = float(profile.get("r", 0.6)) + r = max(0.1, min(0.9, r)) + + beta_permille_per_h = float(params.get("beta_permille_per_h", 0.15)) + t_half_abs_min = float(params.get("absorption_halftime_min", 20.0)) + + t_half_abs_h = max(1e-6, t_half_abs_min / 60.0) + ka = np.log(2.0) / t_half_abs_h # [1/h] + + ev = events.sort_values("date").reset_index(drop=True) + ev_idx = 0 + + gut_pools: List[float] = [] + + body_grams = 0.0 + + elim_g_per_h = beta_permille_per_h * (r * weight_kg) + + if len(time_index) >= 2: + dt_h = (time_index[1] - time_index[0]).total_seconds() / 3600.0 + else: + dt_h = 5.0 / 60.0 + + dt_h = max(1e-6, dt_h) + absorb_frac = 1.0 - np.exp(-ka * dt_h) + + out = np.zeros(len(time_index), dtype=float) + + for i, t in enumerate(time_index): + while ev_idx < len(ev) and pd.Timestamp(ev.loc[ev_idx, "date"]) <= t: + g = float(ev.loc[ev_idx, "grams_ethanol"]) + if g > 0: + gut_pools.append(g) + ev_idx += 1 + + absorbed_total = 0.0 + if gut_pools: + new_pools = [] + for rem in gut_pools: + absorbed = rem * absorb_frac + rem2 = rem - absorbed + absorbed_total += absorbed + if rem2 > 1e-6: + new_pools.append(rem2) + gut_pools = new_pools + + body_grams += absorbed_total + + body_grams = max(0.0, body_grams - elim_g_per_h * dt_h) + + out[i] = body_grams / (r * weight_kg) + + return pd.Series(out, index=time_index) + + +SUBSTANCE_MODELS: Dict[str, SubstanceModel] = { + "Ethanol": EthanolOralModel(), +} + +def _clean_str(x: Any) -> str: + s = str(x).strip() + if (s.startswith("'") and s.endswith("'")) or (s.startswith('"') and s.endswith('"')): + s = s[1:-1].strip() + return s + +def _auto_time_limits( + tmin: pd.Timestamp, + tmax: pd.Timestamp, + params: Optional[Dict[str, float]] = None, +) -> tuple[pd.Timestamp, pd.Timestamp]: + dt = tmax - tmin + if dt <= pd.Timedelta(0): + margin = pd.Timedelta(minutes=30) + else: + margin = dt * 0.05 + + x0 = tmin - margin + x1 = tmax + margin + + if params: + beta = float(params.get("beta_permille_per_h", 0.15)) + beta = max(1e-6, beta) + + t_half_abs_min = float(params.get("absorption_halftime_min", 20.0)) + t_half_abs_h = max(1e-6, t_half_abs_min / 60.0) + + tau_abs_h = t_half_abs_h / np.log(2.0) + + tail_h = max(6.0, 8.0 * tau_abs_h, 2.0 * (1.0 / beta)) + tail_h = min(tail_h, 24.0) + + x1_tail = tmax + pd.Timedelta(hours=tail_h) + + if x1_tail > x1: + x1 = x1_tail + + return x0, x1 + +def _extract_module_tuple(row: pd.Series, module_name: str) -> Optional[Tuple]: + mods = row.get("modules_list", []) or [] + params = row.get("params_list", []) or [] + try: + idx = list(mods).index(module_name) + except ValueError: + return None + return params[idx] if idx < len(params) else None + + +def _make_time_grid(start: pd.Timestamp, end: pd.Timestamp, target_points: int = 800) -> pd.DatetimeIndex: + if end <= start: + return pd.DatetimeIndex([start]) + + total_s = (end - start).total_seconds() + step_s = max(60.0, min(30 * 60.0, total_s / float(max(10, target_points)))) + step_min = int(max(1, round(step_s / 60.0))) + return pd.date_range(start=start, end=end, freq=f"{step_min}min") + +@dataclass +class TextFrame(Frame): + text: str + + def render(self, ax: Axes, mono_font: FontProperties) -> None: + ax.text(0, 1, self.text, va="top", ha="left", fontproperties=mono_font) + +@dataclass +class BACBigFrame(BigFrame): + model_name: str + times: pd.DatetimeIndex + xlim_start: pd.Timestamp + xlim_end: pd.Timestamp + unit_label: str + participants: List[str] + bac_by_person: Dict[str, pd.Series] + + def render(self, ax: Axes, mono_font: FontProperties) -> None: + ax.axis("on") + + locator = mdates.AutoDateLocator(minticks=3, maxticks=7) + formatter = mdates.ConciseDateFormatter(locator) + ax.xaxis.set_major_locator(locator) + ax.xaxis.set_major_formatter(formatter) + ax.xaxis.get_offset_text().set_visible(False) + + ax.set_xlim(self.xlim_start, self.xlim_end) + ax.set_ylabel(self.unit_label, fontproperties=mono_font) + + all_vals = [] + for p in self.participants: + y = self.bac_by_person[p].copy().replace([np.inf, -np.inf], np.nan).fillna(0.0) + + line, = ax.plot(self.times, y.values, label=p, linewidth=1) + ax.fill_between( + self.times, + y.values, + 0.0, + alpha=0.18, + color=line.get_color(), + zorder=line.get_zorder() - 1, + ) + + v = y.values + v = v[np.isfinite(v)] + if v.size: + all_vals.append(v) + + if all_vals: + vv = np.concatenate(all_vals) + vmax = float(np.nanmax(vv)) if vv.size else 0.0 + ax.set_ylim(0.0, 0.5 if vmax <= 0 else vmax * 1.10) + + ax.grid(True, alpha=0.2) + + leg = ax.legend(prop=mono_font, fontsize=7, loc="best", ncols=2) + if leg: + for t in leg.get_texts(): + t.set_fontproperties(mono_font) + + for tick in ax.get_xticklabels() + ax.get_yticklabels(): + tick.set_fontproperties(mono_font) + + ax.xaxis.label.set_fontproperties(mono_font) + ax.yaxis.label.set_fontproperties(mono_font) + +def _fmt_ethanol_totals_table(ev: pd.DataFrame) -> List[str]: + """ + Returns ONLY the table block (no markdown headings): + debitor | grams EtOH | weight kg | events + Expects columns: debitor, grams_ethanol, weight_kg + """ + if ev.empty: + return ["(no data)"] + + # grams total + events count + totals = ev.groupby("debitor")["grams_ethanol"].sum().reset_index().sort_values("debitor") + counts = ev.groupby("debitor").size().rename("events").reset_index() + + # last known weight per debitor (by date) + w_last = ( + ev.dropna(subset=["weight_kg"]) + .sort_values(["debitor", "date"]) + .groupby("debitor")["weight_kg"] + .last() + .rename("weight_kg") + .reset_index() + ) + + totals = totals.merge(counts, on="debitor", how="left").merge(w_last, on="debitor", how="left") + totals["events"] = totals["events"].fillna(0).astype(int) + + # formatting widths + name_list = totals["debitor"].astype(str).tolist() + name_w = max([len(x) for x in name_list] + [7]) + + header = f"{'debitor':<{name_w}} | {'EtOH [g]':>12} | {'weight [kg]':>9} | {'events':>6}" + sep = "-" * len(header) + + lines = [header, sep] + for _, r in totals.iterrows(): + deb = str(r["debitor"]) + grams = float(r["grams_ethanol"]) if pd.notna(r["grams_ethanol"]) else 0.0 + w = float(r["weight_kg"]) if pd.notna(r["weight_kg"]) else float("nan") + events = int(r["events"]) + + w_str = f"{w:>9.1f}" if np.isfinite(w) else f"{'n/a':>9}" + lines.append(f"{deb:<{name_w}} | {grams:>12.2f} | {w_str} | {events:>6d}") + + return lines + +class DrugModule: + name = "drug" + + def process(self, df: pd.DataFrame, context: Dict[str, Any]) -> ModuleResult: + want_pdf = bool(context.get("want_pdf", True)) + mono_font = context.get("mono_font") or FontProperties(family="DejaVu Sans Mono", size=8) + + profiles: Dict[str, Dict[str, float]] = context.get("drug_profiles", {}) or {} + + params = { + "beta_permille_per_h": float(context.get("drug_ethanol_beta_permille_per_h", 0.15)), + "absorption_halftime_min": float(context.get("drug_ethanol_absorption_halftime_min", 20.0)), + } + + work = df.copy() + work = work[pd.notna(work["date"])] + work["flag"] = work["group_flag"].astype(str).str.strip().str.upper() + work = work[work["flag"] == "U"].copy() + + warnings: List[str] = [] + events: List[Dict[str, Any]] = [] + + for _, row in work.iterrows(): + tup = _extract_module_tuple(row, "drug") + if not tup or len(tup) < 5: + warnings.append(f"Missing/invalid drug params at {row.get('date')} ({row.get('debitor')})") + continue + + substance = _clean_str(tup[0]) + application = _clean_str(tup[3]).lower() + + if substance != "Ethanol" or application != "oral": + warnings.append(f"Unsupported: ({substance}, {application}) at {row.get('date')} ({row.get('debitor')})") + continue + + try: + purity = float(tup[1]) + except Exception: + purity = 0.0 + + try: + amount_liters = float(tup[2]) + except Exception: + amount_liters = 0.0 + + try: + weight_kg = float(tup[4]) + except Exception: + weight_kg = float(context.get("drug_default_weight_kg", 70.0)) + + purity = float(np.clip(purity, 0.0, 1.0)) + amount_liters = max(0.0, amount_liters) + + if not np.isfinite(weight_kg) or weight_kg <= 0: + weight_kg = float(context.get("drug_default_weight_kg", 70.0)) + + purity = float(np.clip(purity, 0.0, 1.0)) + amount_liters = max(0.0, amount_liters) + + grams_ethanol = amount_liters * 1000.0 * EthanolOralModel.ETHANOL_DENSITY_G_PER_ML * purity + + events.append({ + "date": pd.Timestamp(row["date"]), + "debitor": str(row["debitor"]), + "substance": "Ethanol", + "application": "oral", + "purity": purity, + "amount_liters": amount_liters, + "grams_ethanol": grams_ethanol, + "weight_kg": weight_kg + }) + + if not events: + summary = ["# DrugModule", "", "(no supported Ethanol/oral events found)"] + if warnings: + summary += ["", "Warnings:"] + [f"- {w}" for w in warnings[:50]] + return ModuleResult(summary_text="\n".join(summary), frames=[], bigframes=[], pages=[]) + + ev = pd.DataFrame(events) + tmin, tmax = ev["date"].min(), ev["date"].max() + x0, x1 = _auto_time_limits(tmin, tmax, params=params) + + times = _make_time_grid(x0, x1, target_points=800) + + model = SUBSTANCE_MODELS["Ethanol"] + + participants = sorted(ev["debitor"].unique().tolist()) + bac_by_person: Dict[str, pd.Series] = {} + + for p in participants: + pe = ev[ev["debitor"] == p].copy() + + w_series = pe["weight_kg"].dropna().astype(float) + weight_kg = float(w_series.iloc[-1]) if len(w_series) else float(context.get("drug_default_weight_kg", 70.0)) + + profile = { + "weight_kg": weight_kg, + "r": float(context.get("drug_default_r", 0.6)), + } + + bac_by_person[p] = model.simulate_bac(times, pe[["date", "grams_ethanol"]], profile, params) + + summary_lines: List[str] = [] + summary_lines.append("# DrugModule") + summary_lines.append("") + summary_lines.append("## Totals") + summary_lines.append("") + summary_lines.extend(_fmt_ethanol_totals_table(ev)) + + summary_text = "\n".join(summary_lines) + + frames: List[Frame] = [] + bigframes: List[BigFrame] = [] + pages: List[plt.Figure] = [] + + if want_pdf: + table_lines = _fmt_ethanol_totals_table(ev) + frames.append(TextFrame( + title="Drug: Ethanol/oral totals", + text="\n".join(table_lines), + )) + + bigframes.append( + BACBigFrame( + title="EtOH pharmacokinetic analysis", + model_name="Ethanol/oral", + times=times, + xlim_start=x0, + xlim_end=x1, + unit_label=model.unit_label, + participants=participants, + bac_by_person=bac_by_person, + ) + ) + + return ModuleResult(summary_text=summary_text, frames=frames, bigframes=bigframes, pages=pages) + |
