#!/usr/bin/env python3
"""
Claude Code session analyser.

Usage:
    analyse-session <session.jsonl>
    analyse-session <session.jsonl> --save
    analyse-session <session.jsonl> --save report.md
    analyse-session <session.jsonl> --no-turns   # skip per-turn table
"""

import argparse
import json
import sys
from collections import defaultdict
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional

# ---------------------------------------------------------------------------
# Pricing (per million tokens). Update these when Anthropic changes rates.
# ---------------------------------------------------------------------------
PRICING = {
    "claude-sonnet-4-6": {
        "input": 3.00,
        "cache_create": 3.75,
        "cache_read": 0.30,
        "output": 15.00,
    },
    "claude-opus-4-8": {
        "input": 5.00,
        "cache_create": 6.25,
        "cache_read": 0.5,
        "output": 25.00,
    },
}

FALLBACK_PRICING = {"input": 3.00, "cache_create": 3.75, "cache_read": 0.30, "output": 15.00}


# ---------------------------------------------------------------------------
# Parsing
# ---------------------------------------------------------------------------

def load_entries(path: Path) -> list[dict]:
    entries = []
    with path.open() as f:
        for line in f:
            line = line.strip()
            if line:
                try:
                    entries.append(json.loads(line))
                except json.JSONDecodeError:
                    pass
    return entries


def parse_session(entries: list[dict]) -> dict:
    seen_uuids: set[str] = set()
    turns: list[dict] = []
    tool_tally: dict[str, int] = defaultdict(int)
    model_totals: dict[str, dict] = defaultdict(lambda: defaultdict(int))
    opus_calls: list[dict] = []

    for entry in entries:
        if entry.get("type") != "assistant":
            continue

        uuid = entry.get("uuid", "")
        if uuid in seen_uuids:
            continue
        seen_uuids.add(uuid)

        msg = entry.get("message", {})
        usage = msg.get("usage", {})
        model = msg.get("model", "unknown")
        ts = entry.get("timestamp")

        tools_used = [
            c.get("name", "unknown")
            for c in msg.get("content", [])
            if c.get("type") == "tool_use"
        ]
        for t in tools_used:
            tool_tally[t] += 1

        in_tok = usage.get("input_tokens", 0)
        cc_tok = usage.get("cache_creation_input_tokens", 0)
        cr_tok = usage.get("cache_read_input_tokens", 0)
        out_tok = usage.get("output_tokens", 0)

        model_totals[model]["input"] += in_tok
        model_totals[model]["cache_create"] += cc_tok
        model_totals[model]["cache_read"] += cr_tok
        model_totals[model]["output"] += out_tok
        model_totals[model]["calls"] += 1

        turns.append(
            {
                "ts": ts,
                "model": model,
                "input": in_tok,
                "cache_create": cc_tok,
                "cache_read": cr_tok,
                "output": out_tok,
                "tools": tools_used,
                "stop_reason": msg.get("stop_reason", ""),
            }
        )

        # Advisor (Opus) sub-calls live in usage.iterations
        for it in usage.get("iterations", []):
            if it.get("type") == "advisor_message":
                advisor_model = it.get("model", "unknown")
                opus_calls.append(
                    {
                        "parent_uuid": uuid,
                        "model": advisor_model,
                        "input": it.get("input_tokens", 0),
                        "cache_create": it.get("cache_creation_input_tokens", 0),
                        "cache_read": it.get("cache_read_input_tokens", 0),
                        "output": it.get("output_tokens", 0),
                    }
                )

    # Deduplicate advisor calls (same parent uuid = one logical advisor call)
    seen_parent: set[str] = set()
    unique_opus: list[dict] = []
    for c in opus_calls:
        if c["parent_uuid"] not in seen_parent:
            seen_parent.add(c["parent_uuid"])
            unique_opus.append(c)
            m = c["model"]
            model_totals[m]["input"] += c["input"]
            model_totals[m]["cache_create"] += c["cache_create"]
            model_totals[m]["cache_read"] += c["cache_read"]
            model_totals[m]["output"] += c["output"]
            model_totals[m]["calls"] += 1

    turns.sort(key=lambda t: t["ts"] or "")

    timestamps = [t["ts"] for t in turns if t["ts"]]
    started_at = ended_at = duration_s = active_s = None
    if timestamps:
        dts = [datetime.fromisoformat(ts.replace("Z", "+00:00")) for ts in timestamps]
        started_at = dts[0]
        ended_at = dts[-1]
        duration_s = int((ended_at - started_at).total_seconds())

    # Active duration: total minus time spent waiting on the user.
    # Idle periods begin when an assistant turn reaches end_turn and end when
    # the next human-typed user message arrives (origin=human / promptSource=typed).
    idle_s = 0
    last_end_turn_ts = None
    for entry in entries:
        t = entry.get("type")
        ts = entry.get("timestamp")
        if not ts:
            continue
        if t == "assistant":
            msg = entry.get("message", {})
            if msg.get("stop_reason") == "end_turn":
                last_end_turn_ts = datetime.fromisoformat(ts.replace("Z", "+00:00"))
        elif t == "user" and last_end_turn_ts is not None:
            origin = entry.get("origin", {}).get("kind", "")
            src = entry.get("promptSource", "")
            if origin == "human" or src == "typed":
                human_ts = datetime.fromisoformat(ts.replace("Z", "+00:00"))
                idle_s += int((human_ts - last_end_turn_ts).total_seconds())
                last_end_turn_ts = None

    if duration_s is not None:
        active_s = duration_s - idle_s

    # First human-typed prompt (skip isMeta entries and slash-command wrappers)
    initial_prompt = None
    for entry in entries:
        if entry.get("type") != "user":
            continue
        if entry.get("isMeta"):
            continue
        origin = entry.get("origin", {}).get("kind", "")
        src = entry.get("promptSource", "")
        if origin != "human" and src != "typed":
            continue
        content = entry.get("message", {}).get("content", "")
        text = content if isinstance(content, str) else " ".join(
            b.get("text", "") for b in content if b.get("type") == "text"
        )
        text = text.strip()
        # Skip entries that are just harness-injected markup (ignores commands)
        #  if text and not text.startswith("<"):
        if text:
            initial_prompt = text
            break

    return {
        "turns": turns,
        "tool_tally": dict(tool_tally),
        "model_totals": dict(model_totals),
        "advisor_calls": unique_opus,
        "started_at": started_at,
        "ended_at": ended_at,
        "duration_s": duration_s,
        "active_s": active_s,
        "initial_prompt": initial_prompt,
    }


# ---------------------------------------------------------------------------
# Cost helpers
# ---------------------------------------------------------------------------

def cost(model: str, kind: str, tokens: int) -> float:
    rates = PRICING.get(model, FALLBACK_PRICING)
    return (tokens / 1_000_000) * rates.get(kind, 0)


def model_cost(model: str, data: dict) -> float:
    return (
        cost(model, "input", data["input"])
        + cost(model, "cache_create", data["cache_create"])
        + cost(model, "cache_read", data["cache_read"])
        + cost(model, "output", data["output"])
    )


# ---------------------------------------------------------------------------
# Formatting
# ---------------------------------------------------------------------------

def fmt_duration(s: int) -> str:
    h, rem = divmod(s, 3600)
    m, sec = divmod(rem, 60)
    if h:
        return f"{h}h {m}m {sec}s"
    if m:
        return f"{m}m {sec}s"
    return f"{sec}s"


def fmt_ts(dt: Optional[datetime]) -> str:
    if dt is None:
        return "—"
    return dt.strftime("%Y-%m-%d %H:%M UTC")


def fmt_k(n: int) -> str:
    if n >= 1_000_000:
        return f"~{n // 1_000_000}M"
    if n >= 1_000:
        return f"~{round(n / 1_000)}K"
    return str(n)


def build_box_table(header: tuple, rows: list) -> str:
    """Render a box-drawing table with any number of columns."""
    ncols = len(header)
    all_rows = [header] + rows
    widths = [max(len(str(r[i])) for r in all_rows) + 2 for i in range(ncols)]

    def cell(text, w):
        return f" {str(text):<{w - 1}}"

    def hline(left, mid, right, fill="─"):
        return left + (mid).join(fill * w for w in widths) + right

    def row_line(row):
        return "│" + "│".join(cell(row[i], widths[i]) for i in range(ncols)) + "│"

    out = []
    out.append(hline("┌", "┬", "┐"))
    out.append(row_line(header))
    for row in rows:
        out.append(hline("├", "┼", "┤"))
        out.append(row_line(row))
    out.append(hline("└", "┴", "┘"))
    return "\n".join(out)


def build_report(session_path: Path, data: dict, include_turns: bool = True) -> str:
    turns = data["turns"]
    tool_tally = data["tool_tally"]
    model_totals = data["model_totals"]
    advisor_calls = data["advisor_calls"]

    lines: list[str] = []
    a = lines.append

    a(f"# Claude Code Session Report")
    a(f"")
    a(f"**File:** `{session_path.name}`  ")
    if data.get("initial_prompt"):
        a(f"**Prompt:** {data['initial_prompt']}  ")
    a(f"**Started:** {fmt_ts(data['started_at'])}  ")
    a(f"**Ended:** {fmt_ts(data['ended_at'])}  ")
    if data["duration_s"] is not None:
        a(f"**Wall time:** {fmt_duration(data['duration_s'])}  ")
    if data["active_s"] is not None:
        a(f"**Active duration:** {fmt_duration(data['active_s'])} *(wall time minus user input wait)*  ")
    a(f"")

    # --- Per-model token tables ---
    grand_total_cost = 0.0

    for model, d in sorted(model_totals.items()):
        mc = model_cost(model, d)
        grand_total_cost += mc
        rates = PRICING.get(model, FALLBACK_PRICING)
        is_advisor = model in {c["model"] for c in advisor_calls}
        label = f"{model}" + (" *(advisor)*" if is_advisor else "")
        a(f"## {label}")
        a(f"")
        a(f"**API calls:** {d['calls']:,}  ")
        a(f"")
        token_rows = [
            ("Input (uncached)", d["input"], rates["input"]),
            ("Cache creation",   d["cache_create"], rates["cache_create"]),
            ("Cache read",       d["cache_read"], rates["cache_read"]),
            ("Output",           d["output"], rates["output"]),
        ]
        subtotal = sum((tok / 1_000_000) * rate for _, tok, rate in token_rows)
        tbl_rows = [
            (lbl, f"{tok:,}", f"${rate:.2f}/MTok", f"${(tok/1_000_000)*rate:.4f}")
            for lbl, tok, rate in token_rows
        ] + [("Subtotal", "", "", f"${subtotal:.4f}")]
        tbl = build_box_table(("Token type", "Count", "Rate", "Cost"), tbl_rows)
        a(f"```")
        a(tbl)
        a(f"```")
        a(f"")

    a(f"## Grand Total")
    a(f"")
    a(f"**~${grand_total_cost:.4f}** across all models  ")
    a(f"")

    # --- Caching efficiency comparison (only when both main and advisor models present) ---
    advisor_models = {c["model"] for c in advisor_calls}
    main_models = [m for m in model_totals if m not in advisor_models]
    if advisor_models and main_models:
        main_model = main_models[0]
        advisor_model = next(iter(advisor_models))
        main_rates = PRICING.get(main_model, FALLBACK_PRICING)
        adv_rates = PRICING.get(advisor_model, FALLBACK_PRICING)

        # Peak context size: last cache_read value for main model (grows monotonically)
        main_turns = [t for t in turns if t["model"] == main_model and t["cache_read"] > 0]
        main_ctx = main_turns[-1]["cache_read"] if main_turns else 0

        # Average input tokens per advisor call
        adv_ctx = (
            sum(c["input"] for c in advisor_calls) // len(advisor_calls)
            if advisor_calls else 0
        )

        main_rate = main_rates["cache_read"]
        adv_rate = adv_rates["input"]
        ratio = int(round(adv_rate / main_rate)) if main_rate else 0

        short_main = main_model.split("-")[1].capitalize()  # "Sonnet"
        short_adv = advisor_model.split("-")[1].capitalize()  # "Opus"

        tbl = build_box_table(
            ("", f"{short_main} (cached)", f"{short_adv} (uncached)"),
            [
                ("Context tokens", fmt_k(main_ctx), fmt_k(adv_ctx)),
                ("Rate paid", f"${main_rate:.2f}/MTok", f"${adv_rate:.2f}/MTok"),
                ("Effective cost ratio", "1×", f"{ratio}×"),
            ],
        )

        a(f"## Caching Efficiency")
        a(f"")
        a(f"```")
        a(tbl)
        a(f"```")
        a(f"")

    # --- Tool calls ---
    if tool_tally:
        total_tool_calls = sum(tool_tally.values())
        a(f"## Tool Calls")
        a(f"")
        tool_rows = [
            (tool, str(count), f"{(count/total_tool_calls)*100:.0f}%")
            for tool, count in sorted(tool_tally.items(), key=lambda x: -x[1])
        ]
        tbl = build_box_table(("Tool", "Calls", "%"), tool_rows)
        a(f"```")
        a(tbl)
        a(f"```")
        a(f"")

    # --- Per-turn breakdown ---
    if include_turns and turns:
        a(f"## Per-Turn Breakdown")
        a(f"")
        prev_ts = None
        turn_rows = []
        for i, t in enumerate(turns):
            if t["ts"]:
                dt = datetime.fromisoformat(t["ts"].replace("Z", "+00:00"))
                elapsed = f"+{int((dt - prev_ts).total_seconds())}s" if prev_ts else "start"
                prev_ts = dt
            else:
                elapsed = "?"
            action = ", ".join(t["tools"]) if t["tools"] else t["stop_reason"]
            turn_rows.append((
                str(i + 1), elapsed,
                f"{t['input']:,}", f"{t['cache_create']:,}", f"{t['cache_read']:,}", f"{t['output']:,}",
                action,
            ))
        tbl = build_box_table(("#", "+time", "In", "Cache create", "Cache read", "Out", "Action"), turn_rows)
        a(f"```")
        a(tbl)
        a(f"```")
        a(f"")

    return "\n".join(lines)


# ---------------------------------------------------------------------------
# Terminal output (strip markdown formatting for readability)
# ---------------------------------------------------------------------------

def print_report(report: str) -> None:
    import re
    in_fence = False
    for line in report.splitlines():
        if line.strip() == "```":
            in_fence = not in_fence
            continue
        if in_fence:
            print(line)
            continue
        if line.startswith("# "):
            print("\n" + "=" * 60)
            print(line[2:].upper())
            print("=" * 60)
        elif line.startswith("## "):
            print("\n--- " + line[3:] + " ---")
        elif line.startswith("**") and line.endswith("  "):
            print(re.sub(r"\*\*(.+?)\*\*", r"\1", line).rstrip())
        elif line.startswith("|") and not line.startswith("|---"):
            cells = [c.strip().replace("`", "").replace("*", "") for c in line.strip("|").split("|")]
            print("  " + "  ".join(f"{c:<22}" for c in cells))
        elif line.startswith("|---"):
            pass
        else:
            print(re.sub(r"\*\*(.+?)\*\*", r"\1", line).replace("`", ""))


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

def main() -> None:
    parser = argparse.ArgumentParser(
        description="Analyse a Claude Code session JSONL file and report token spend."
    )
    parser.add_argument("file", type=Path, help="Path to the session .jsonl file")
    parser.add_argument(
        "--save",
        nargs="?",
        const=True,
        metavar="OUTPUT.md",
        help="Save report as markdown. Omit filename to auto-name from session file.",
    )
    parser.add_argument(
        "--no-turns",
        action="store_true",
        help="Omit the per-turn breakdown table (useful for long sessions).",
    )
    args = parser.parse_args()

    path = args.file.expanduser().resolve()
    if not path.exists():
        print(f"Error: file not found: {path}", file=sys.stderr)
        sys.exit(1)

    entries = load_entries(path)
    session = parse_session(entries)
    report = build_report(path, session, include_turns=not args.no_turns)

    print_report(report)

    if args.save:
        if args.save is True:
            out_path = path.with_suffix(".md")
        else:
            out_path = Path(args.save).expanduser().resolve()
        out_path.write_text(report)
        print(f"\nReport saved to: {out_path}")


if __name__ == "__main__":
    main()
