# modules/drugs.py from __future__ import annotations from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple import numpy as np import pandas as pd import matplotlib.dates as mdates from matplotlib.axes import Axes from matplotlib.font_manager import FontProperties from .base import Frame, BigFrame, ModuleResult class SubstanceModel: key: str = "" supported_applications: set[str] = set() unit_label: str = "" def simulate_bac( self, time_index: pd.DatetimeIndex, events: pd.DataFrame, profile: Dict[str, float], params: Dict[str, float], ) -> pd.Series: raise NotImplementedError class EthanolOralModel(SubstanceModel): key = "Ethanol" supported_applications = {"oral"} unit_label = "Blood concentration [g/L]" ETHANOL_DENSITY_G_PER_ML = 0.789 def simulate_bac( self, time_index: pd.DatetimeIndex, events: pd.DataFrame, profile: Dict[str, float], params: Dict[str, float], ) -> pd.Series: weight_kg = float(profile.get("weight_kg", 70.0)) r = float(profile.get("r", 0.6)) r = max(0.1, min(0.9, r)) beta_permille_per_h = float(params.get("beta_permille_per_h", 0.15)) t_half_abs_min = float(params.get("absorption_halftime_min", 20.0)) t_half_abs_h = max(1e-6, t_half_abs_min / 60.0) ka = np.log(2.0) / t_half_abs_h # [1/h] ev = events.sort_values("date").reset_index(drop=True) ev_idx = 0 gut_pools: List[float] = [] body_grams = 0.0 elim_g_per_h = beta_permille_per_h * (r * weight_kg) if len(time_index) >= 2: dt_h = (time_index[1] - time_index[0]).total_seconds() / 3600.0 else: dt_h = 5.0 / 60.0 dt_h = max(1e-6, dt_h) absorb_frac = 1.0 - np.exp(-ka * dt_h) out = np.zeros(len(time_index), dtype=float) for i, t in enumerate(time_index): while ev_idx < len(ev) and pd.Timestamp(ev.loc[ev_idx, "date"]) <= t: g = float(ev.loc[ev_idx, "grams_ethanol"]) if g > 0: gut_pools.append(g) ev_idx += 1 absorbed_total = 0.0 if gut_pools: new_pools = [] for rem in gut_pools: absorbed = rem * absorb_frac rem2 = rem - absorbed absorbed_total += absorbed if rem2 > 1e-6: new_pools.append(rem2) gut_pools = new_pools body_grams += absorbed_total body_grams = max(0.0, body_grams - elim_g_per_h * dt_h) out[i] = body_grams / (r * weight_kg) return pd.Series(out, index=time_index) SUBSTANCE_MODELS: Dict[str, SubstanceModel] = { "Ethanol": EthanolOralModel(), } def _clean_str(x: Any) -> str: s = str(x).strip() if (s.startswith("'") and s.endswith("'")) or (s.startswith('"') and s.endswith('"')): s = s[1:-1].strip() return s def _auto_time_limits( tmin: pd.Timestamp, tmax: pd.Timestamp, params: Optional[Dict[str, float]] = None, ) -> tuple[pd.Timestamp, pd.Timestamp]: dt = tmax - tmin if dt <= pd.Timedelta(0): margin = pd.Timedelta(minutes=30) else: margin = dt * 0.05 x0 = tmin - margin x1 = tmax + margin if params: beta = float(params.get("beta_permille_per_h", 0.15)) beta = max(1e-6, beta) t_half_abs_min = float(params.get("absorption_halftime_min", 20.0)) t_half_abs_h = max(1e-6, t_half_abs_min / 60.0) tau_abs_h = t_half_abs_h / np.log(2.0) tail_h = max(6.0, 8.0 * tau_abs_h, 2.0 * (1.0 / beta)) tail_h = min(tail_h, 24.0) x1_tail = tmax + pd.Timedelta(hours=tail_h) if x1_tail > x1: x1 = x1_tail return x0, x1 def _extract_module_tuple(row: pd.Series, module_name: str) -> Optional[Tuple]: mods = row.get("modules_list", []) or [] params = row.get("params_list", []) or [] try: idx = list(mods).index(module_name) except ValueError: return None return params[idx] if idx < len(params) else None def _make_time_grid(start: pd.Timestamp, end: pd.Timestamp, target_points: int = 800) -> pd.DatetimeIndex: if end <= start: return pd.DatetimeIndex([start]) total_s = (end - start).total_seconds() step_s = max(60.0, min(30 * 60.0, total_s / float(max(10, target_points)))) step_min = int(max(1, round(step_s / 60.0))) return pd.date_range(start=start, end=end, freq=f"{step_min}min") @dataclass class TextFrame(Frame): text: str def render(self, ax: Axes, mono_font: FontProperties) -> None: ax.text(0, 1, self.text, va="top", ha="left", fontproperties=mono_font) @dataclass class BACBigFrame(BigFrame): model_name: str times: pd.DatetimeIndex xlim_start: pd.Timestamp xlim_end: pd.Timestamp unit_label: str participants: List[str] bac_by_person: Dict[str, pd.Series] def render(self, ax: Axes, mono_font: FontProperties) -> None: ax.axis("on") locator = mdates.AutoDateLocator(minticks=3, maxticks=7) formatter = mdates.ConciseDateFormatter(locator) ax.xaxis.set_major_locator(locator) ax.xaxis.set_major_formatter(formatter) ax.xaxis.get_offset_text().set_visible(False) ax.set_xlim(self.xlim_start, self.xlim_end) ax.set_ylabel(self.unit_label, fontproperties=mono_font) all_vals = [] for p in self.participants: y = self.bac_by_person[p].copy().replace([np.inf, -np.inf], np.nan).fillna(0.0) line, = ax.plot(self.times, y.values, label=p, linewidth=1) ax.fill_between( self.times, y.values, 0.0, alpha=0.18, color=line.get_color(), zorder=line.get_zorder() - 1, ) v = y.values v = v[np.isfinite(v)] if v.size: all_vals.append(v) if all_vals: vv = np.concatenate(all_vals) vmax = float(np.nanmax(vv)) if vv.size else 0.0 ax.set_ylim(0.0, 0.5 if vmax <= 0 else vmax * 1.10) ax.grid(True, alpha=0.2) leg = ax.legend(prop=mono_font, fontsize=7, loc="best", ncols=2) if leg: for t in leg.get_texts(): t.set_fontproperties(mono_font) for tick in ax.get_xticklabels() + ax.get_yticklabels(): tick.set_fontproperties(mono_font) ax.xaxis.label.set_fontproperties(mono_font) ax.yaxis.label.set_fontproperties(mono_font) def _fmt_ethanol_totals_table(ev: pd.DataFrame) -> List[str]: """ Returns ONLY the table block (no markdown headings): debitor | grams EtOH | weight kg | events Expects columns: debitor, grams_ethanol, weight_kg """ if ev.empty: return ["(no data)"] # grams total + events count totals = ev.groupby("debitor")["grams_ethanol"].sum().reset_index().sort_values("debitor") counts = ev.groupby("debitor").size().rename("events").reset_index() # last known weight per debitor (by date) w_last = ( ev.dropna(subset=["weight_kg"]) .sort_values(["debitor", "date"]) .groupby("debitor")["weight_kg"] .last() .rename("weight_kg") .reset_index() ) totals = totals.merge(counts, on="debitor", how="left").merge(w_last, on="debitor", how="left") totals["events"] = totals["events"].fillna(0).astype(int) # formatting widths name_list = totals["debitor"].astype(str).tolist() name_w = max([len(x) for x in name_list] + [7]) header = f"{'debitor':<{name_w}} | {'EtOH [g]':>12} | {'weight [kg]':>9} | {'events':>6}" sep = "-" * len(header) lines = [header, sep] for _, r in totals.iterrows(): deb = str(r["debitor"]) grams = float(r["grams_ethanol"]) if pd.notna(r["grams_ethanol"]) else 0.0 w = float(r["weight_kg"]) if pd.notna(r["weight_kg"]) else float("nan") events = int(r["events"]) w_str = f"{w:>9.1f}" if np.isfinite(w) else f"{'n/a':>9}" lines.append(f"{deb:<{name_w}} | {grams:>12.2f} | {w_str} | {events:>6d}") return lines class DrugModule: name = "drug" def process(self, df: pd.DataFrame, context: Dict[str, Any]) -> ModuleResult: want_pdf = bool(context.get("want_pdf", True)) mono_font = context.get("mono_font") or FontProperties(family="DejaVu Sans Mono", size=8) profiles: Dict[str, Dict[str, float]] = context.get("drug_profiles", {}) or {} params = { "beta_permille_per_h": float(context.get("drug_ethanol_beta_permille_per_h", 0.15)), "absorption_halftime_min": float(context.get("drug_ethanol_absorption_halftime_min", 20.0)), } work = df.copy() work = work[pd.notna(work["date"])] work["flag"] = work["group_flag"].astype(str).str.strip().str.upper() work = work[work["flag"] == "U"].copy() warnings: List[str] = [] events: List[Dict[str, Any]] = [] for _, row in work.iterrows(): tup = _extract_module_tuple(row, "drug") if not tup or len(tup) < 5: warnings.append(f"Missing/invalid drug params at {row.get('date')} ({row.get('debitor')})") continue substance = _clean_str(tup[0]) application = _clean_str(tup[3]).lower() if substance != "Ethanol" or application != "oral": warnings.append(f"Unsupported: ({substance}, {application}) at {row.get('date')} ({row.get('debitor')})") continue try: purity = float(tup[1]) except Exception: purity = 0.0 try: amount_liters = float(tup[2]) except Exception: amount_liters = 0.0 try: weight_kg = float(tup[4]) except Exception: weight_kg = float(context.get("drug_default_weight_kg", 70.0)) purity = float(np.clip(purity, 0.0, 1.0)) amount_liters = max(0.0, amount_liters) if not np.isfinite(weight_kg) or weight_kg <= 0: weight_kg = float(context.get("drug_default_weight_kg", 70.0)) purity = float(np.clip(purity, 0.0, 1.0)) amount_liters = max(0.0, amount_liters) grams_ethanol = amount_liters * 1000.0 * EthanolOralModel.ETHANOL_DENSITY_G_PER_ML * purity events.append({ "date": pd.Timestamp(row["date"]), "debitor": str(row["debitor"]), "substance": "Ethanol", "application": "oral", "purity": purity, "amount_liters": amount_liters, "grams_ethanol": grams_ethanol, "weight_kg": weight_kg }) if not events: summary = ["# DrugModule", "", "(no supported Ethanol/oral events found)"] if warnings: summary += ["", "Warnings:"] + [f"- {w}" for w in warnings[:50]] return ModuleResult(summary_text="\n".join(summary), frames=[], bigframes=[], pages=[]) ev = pd.DataFrame(events) tmin, tmax = ev["date"].min(), ev["date"].max() x0, x1 = _auto_time_limits(tmin, tmax, params=params) times = _make_time_grid(x0, x1, target_points=800) model = SUBSTANCE_MODELS["Ethanol"] participants = sorted(ev["debitor"].unique().tolist()) bac_by_person: Dict[str, pd.Series] = {} for p in participants: pe = ev[ev["debitor"] == p].copy() w_series = pe["weight_kg"].dropna().astype(float) weight_kg = float(w_series.iloc[-1]) if len(w_series) else float(context.get("drug_default_weight_kg", 70.0)) profile = { "weight_kg": weight_kg, "r": float(context.get("drug_default_r", 0.6)), } bac_by_person[p] = model.simulate_bac(times, pe[["date", "grams_ethanol"]], profile, params) summary_lines: List[str] = [] summary_lines.append("# DrugModule") summary_lines.append("") summary_lines.append("## Totals") summary_lines.append("") summary_lines.extend(_fmt_ethanol_totals_table(ev)) summary_text = "\n".join(summary_lines) frames: List[Frame] = [] bigframes: List[BigFrame] = [] pages: List[plt.Figure] = [] if want_pdf: table_lines = _fmt_ethanol_totals_table(ev) frames.append(TextFrame( title="Drug: Ethanol/oral totals", text="\n".join(table_lines), )) bigframes.append( BACBigFrame( title="EtOH pharmacokinetic analysis", model_name="Ethanol/oral", times=times, xlim_start=x0, xlim_end=x1, unit_label=model.unit_label, participants=participants, bac_by_person=bac_by_person, ) ) return ModuleResult(summary_text=summary_text, frames=frames, bigframes=bigframes, pages=pages)