#!/usr/bin/env python3
"""
PowerSearch Grid Edge Agent (MVP).

Design goals:
- Explicit opt-in resource sharing (policy.enabled defaults to false).
- Pull-based scheduling (agent polls the control-plane; no peer push).
- Signed, typed job manifests (Ed25519) and whitelisted handlers only.
- User-first throttling: avoid doing work during busy periods / quiet hours.
"""

from __future__ import annotations

import argparse
import base64
import hashlib
import ipaddress
import json
import os
import platform
import re
import shutil
import socket
import subprocess
import sys
import time
import urllib.robotparser
from datetime import datetime, timezone
from html.parser import HTMLParser
from typing import Any
from urllib.error import HTTPError, URLError
from urllib.parse import urljoin, urlparse
from urllib.request import HTTPRedirectHandler, Request, build_opener, urlopen

try:
    from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey  # type: ignore
except Exception:  # pragma: no cover
    Ed25519PublicKey = None

AGENT_VERSION = "0.2.6"
DEFAULT_USER_AGENT = f"PowerSearchGridEdgeAgent/{AGENT_VERSION}"

def public_key_fingerprint_sha256(public_key_b64: str) -> str:
    """Compute sha256 fingerprint of a raw Ed25519 public key (base64 encoded)."""
    s = str(public_key_b64 or "").strip()
    if not s:
        return ""
    try:
        raw = base64.b64decode(s.encode("utf-8"), validate=True)
    except Exception:
        return ""
    try:
        return hashlib.sha256(raw).hexdigest()
    except Exception:
        return ""


def now_utc_iso() -> str:
    return datetime.now(timezone.utc).isoformat()


def canonical_json(obj: Any) -> bytes:
    return json.dumps(obj, sort_keys=True, separators=(",", ":"), ensure_ascii=False).encode("utf-8")


def _read_json(path: str) -> dict[str, Any]:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def _write_json(path: str, obj: Any) -> None:
    tmp = f"{path}.tmp"
    with open(tmp, "w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=2, sort_keys=True)
        f.write("\n")
    os.replace(tmp, path)


def http_json(
    url: str,
    *,
    method: str = "GET",
    token: str | None = None,
    payload: dict[str, Any] | None = None,
    timeout_s: float = 15.0,
) -> dict[str, Any]:
    headers = {
        "Accept": "application/json",
        "User-Agent": DEFAULT_USER_AGENT,
    }
    body: bytes | None = None
    if payload is not None:
        body = json.dumps(payload).encode("utf-8")
        headers["Content-Type"] = "application/json"
    if token:
        headers["Authorization"] = f"Bearer {token}"

    req = Request(url, data=body, headers=headers, method=method.upper())
    try:
        with urlopen(req, timeout=timeout_s) as resp:
            raw = resp.read()
        return json.loads(raw.decode("utf-8"))
    except HTTPError as exc:
        try:
            data = exc.read().decode("utf-8", errors="ignore")
        except Exception:
            data = ""
        raise RuntimeError(f"http {exc.code} {exc.reason}: {data[:200]}".strip())
    except (URLError, TimeoutError) as exc:
        raise RuntimeError(f"network error: {exc}")


def _safe_int(x: Any, default: int) -> int:
    try:
        return int(x)
    except Exception:
        return default


def policy_defaults() -> dict[str, Any]:
    return {
        "enabled": False,
        "cpu_max_percent": 20,
        "ram_max_gb": 4,
        "gpu_max_percent": 15,
        "disk_max_gb": 50,
        # Safety: refuse new work when disk is tight (prevents stuck jobs + corrupted writes).
        # Measured against the user's home filesystem.
        "disk_min_free_gb": 2,
        "network_upload_mbps": 3,
        "network_download_mbps": 10,
        "idle_only": True,
        "plugged_in_only": False,
        "quiet_hours": ["08:00-18:00"],
        "allowed_job_types": ["health_check", "crawl_url", "ollama_chat"],
        # Crawl safety (defense-in-depth): allowlist targets and block private IPs by default.
        # If empty, defaults to the control-plane host (base_url).
        "crawl_allowlist_domains": [],
        "allow_private_crawl_ips": False,
        "crawl_max_redirects": 3,
        "max_concurrent_jobs": 1,
        "reserve_cores_for_user": 2,
        "reserve_ram_gb_for_user": 4,
        "thermal_throttle_enabled": True,
        "emergency_stop_enabled": True,
        "local_user_priority": True,
    }


def validate_policy(policy: dict[str, Any]) -> tuple[bool, str]:
    if not isinstance(policy, dict):
        return False, "policy must be an object"
    try:
        cpu = int(policy.get("cpu_max_percent", 0))
        ram = int(policy.get("ram_max_gb", 0))
        disk = int(policy.get("disk_max_gb", 0))
        disk_min = int(policy.get("disk_min_free_gb", 0))
        if cpu < 0 or cpu > 95:
            return False, "cpu_max_percent must be 0..95"
        if ram < 0 or ram > 512:
            return False, "ram_max_gb must be 0..512"
        if disk < 0 or disk > 4096:
            return False, "disk_max_gb must be 0..4096"
        if disk_min < 0 or disk_min > 4096:
            return False, "disk_min_free_gb must be 0..4096"
    except Exception:
        return False, "invalid numeric policy fields"
    return True, ""


def _minutes_local() -> int:
    t = time.localtime()
    return int(t.tm_hour) * 60 + int(t.tm_min)


def _hhmm_to_minutes(raw: str) -> int | None:
    m = re.match(r"^\s*(\d{1,2}):(\d{2})\s*$", str(raw or ""))
    if not m:
        return None
    hh = int(m.group(1))
    mm = int(m.group(2))
    if hh < 0 or hh > 23 or mm < 0 or mm > 59:
        return None
    return hh * 60 + mm


def in_quiet_hours(policy: dict[str, Any]) -> bool:
    ranges = policy.get("quiet_hours") or []
    if not isinstance(ranges, list) or not ranges:
        return False
    now_m = _minutes_local()
    for r in ranges:
        if not isinstance(r, str) or "-" not in r:
            continue
        start_s, end_s = r.split("-", 1)
        start = _hhmm_to_minutes(start_s)
        end = _hhmm_to_minutes(end_s)
        if start is None or end is None:
            continue
        if start <= end:
            if start <= now_m <= end:
                return True
        else:
            # Wrap-around, e.g. 22:00-06:00
            if now_m >= start or now_m <= end:
                return True
    return False


def is_system_idle(policy: dict[str, Any]) -> bool:
    try:
        load1 = float(os.getloadavg()[0])
        cpu = os.cpu_count() or 1
        reserve = max(0, _safe_int(policy.get("reserve_cores_for_user"), 0))
        usable = max(1, cpu - reserve)
        # Heuristic: allow background work when load is well below usable cores.
        return load1 < max(0.3, usable * 0.25)
    except Exception:
        return True


def _meminfo_bytes() -> dict[str, int]:
    out = {"mem_total": 0, "mem_available": 0}
    try:
        with open("/proc/meminfo", "r", encoding="utf-8") as f:
            for line in f:
                if line.startswith("MemTotal:"):
                    out["mem_total"] = int(line.split()[1]) * 1024
                elif line.startswith("MemAvailable:"):
                    out["mem_available"] = int(line.split()[1]) * 1024
    except Exception:
        return out
    return out


def has_enough_free_ram(policy: dict[str, Any]) -> bool:
    """Best-effort RAM reserve guard (Linux-only today).

    If we can read MemAvailable, avoid taking work when doing so would violate
    `reserve_ram_gb_for_user`. On platforms without `/proc/meminfo`, we treat
    RAM availability as unknown and do not block work.
    """
    reserve_gb = 0.0
    try:
        reserve_gb = float(policy.get("reserve_ram_gb_for_user") or 0.0)
    except Exception:
        reserve_gb = 0.0
    if reserve_gb <= 0.0:
        return True
    mem = _meminfo_bytes()
    avail = int(mem.get("mem_available") or 0)
    if avail <= 0:
        return True
    reserve_bytes = int(reserve_gb * (1024**3))
    return avail >= max(0, reserve_bytes)


def has_enough_free_disk(policy: dict[str, Any]) -> bool:
    """Best-effort disk reserve guard.

    If the home filesystem has less than `disk_min_free_gb` free, we block new
    work (protects small nodes from running out of space).
    """
    min_free_gb = 0
    try:
        min_free_gb = int(policy.get("disk_min_free_gb", 0) or 0)
    except Exception:
        min_free_gb = 0
    if min_free_gb <= 0:
        return True
    try:
        disk = shutil.disk_usage(os.path.expanduser("~"))
    except Exception:
        return True
    try:
        free = int(disk.free)
    except Exception:
        free = 0
    reserve_bytes = int(min_free_gb) * (1024**3)
    return free >= max(0, reserve_bytes)

def on_ac_power() -> bool:
    """Best-effort "plugged in" check.

    If we cannot determine power state, default to True (do not block work).
    """
    try:
        if sys.platform.startswith("linux") and os.path.isdir("/sys/class/power_supply"):
            base = "/sys/class/power_supply"
            found_mains = False
            saw_explicit_offline = False
            for name in os.listdir(base):
                d = os.path.join(base, name)
                if not os.path.isdir(d):
                    continue
                tpath = os.path.join(d, "type")
                opath = os.path.join(d, "online")
                try:
                    with open(tpath, "r", encoding="utf-8") as f:
                        t = f.read().strip().lower()
                except Exception:
                    t = ""
                if t not in {"mains", "ac"}:
                    continue
                found_mains = True
                try:
                    with open(opath, "r", encoding="utf-8") as f:
                        online = f.read().strip()
                except Exception:
                    online = ""
                if online == "1":
                    return True
                if online == "0":
                    saw_explicit_offline = True
                    continue
                # Unknown power state: don't block work.
                return True
            # If we found at least one mains/AC supply but none are online, treat as battery.
            if found_mains and saw_explicit_offline:
                return False
            return True
    except Exception:
        return True
    try:
        if sys.platform == "darwin":
            p = subprocess.run(["pmset", "-g", "batt"], capture_output=True, text=True, timeout=2.0)
            out = (p.stdout or "") + "\n" + (p.stderr or "")
            if "AC Power" in out:
                return True
            if "Battery Power" in out:
                return False
    except Exception:
        return True
    return True

def cpu_temp_c() -> float | None:
    """Best-effort CPU temperature (Celsius), Linux-only."""
    if not sys.platform.startswith("linux"):
        return None
    base = "/sys/class/thermal"
    if not os.path.isdir(base):
        return None
    temps: list[float] = []
    try:
        for name in os.listdir(base):
            if not name.startswith("thermal_zone"):
                continue
            tpath = os.path.join(base, name, "temp")
            try:
                raw = open(tpath, "r", encoding="utf-8").read().strip()
                if not raw:
                    continue
                val = float(raw)
                # Most kernels expose millidegrees C.
                if val > 1000.0:
                    val = val / 1000.0
                if 0.0 < val < 140.0:
                    temps.append(val)
            except Exception:
                continue
    except Exception:
        return None
    if not temps:
        return None
    return max(temps)


def build_heartbeat_payload(policy: dict[str, Any]) -> dict[str, Any]:
    disk = shutil.disk_usage(os.path.expanduser("~"))
    mem = _meminfo_bytes()
    load = None
    try:
        load = list(os.getloadavg())
    except Exception:
        load = None
    return {
        "ok": True,
        "agent_version": AGENT_VERSION,
        "platform": platform.platform(),
        "python": platform.python_version(),
        "time_utc": now_utc_iso(),
        "policy_enabled": bool(policy.get("enabled")),
        "cpu_count": os.cpu_count() or 1,
        "loadavg": load,
        "mem_total_bytes": mem.get("mem_total", 0),
        "mem_available_bytes": mem.get("mem_available", 0),
        "home_total_bytes": int(disk.total),
        "home_free_bytes": int(disk.free),
    }


def compute_work_blockers(
    policy: dict[str, Any],
    *,
    effective_enabled: bool,
    emergency_stop: bool,
    allowed_types_set: set[str],
    key_pin_mismatch: bool,
) -> list[str]:
    blockers: list[str] = []

    if emergency_stop and bool(policy.get("emergency_stop_enabled", True)):
        blockers.append("emergency_stop")

    if key_pin_mismatch:
        blockers.append("key_pin_mismatch")

    if not bool(policy.get("enabled")):
        blockers.append("policy_disabled")

    if not effective_enabled:
        return blockers

    if not allowed_types_set:
        blockers.append("no_allowed_job_types")

    if bool(policy.get("idle_only")) and not is_system_idle(policy):
        blockers.append("not_idle")

    if not has_enough_free_ram(policy):
        blockers.append("low_ram")

    if not has_enough_free_disk(policy):
        blockers.append("low_disk")

    if bool(policy.get("plugged_in_only")) and not on_ac_power():
        blockers.append("on_battery")

    if bool(policy.get("thermal_throttle_enabled")):
        temp = cpu_temp_c()
        if temp is not None and temp >= 85.0:
            blockers.append("thermal")

    if in_quiet_hours(policy):
        blockers.append("quiet_hours")

    return blockers


def _normalize_ollama_base_url(raw: str) -> str:
    s = str(raw or "").strip().rstrip("/")
    if not s:
        s = "http://127.0.0.1:11434"
    if not s.startswith(("http://", "https://")):
        s = "http://" + s
    return s


def ollama_chat(ollama_base_url: str, *, model: str, prompt: str, timeout_s: float) -> dict[str, Any]:
    model = str(model or "").strip()
    prompt = str(prompt or "").strip()
    if not model:
        raise RuntimeError("ollama_model not configured on this node")
    if not prompt:
        raise RuntimeError("missing prompt")
    if len(prompt) > 8000:
        raise RuntimeError("prompt too long")
    base = _normalize_ollama_base_url(ollama_base_url)
    url = base + "/api/chat"
    payload = {
        "model": model,
        "messages": [{"role": "user", "content": prompt}],
        "stream": False,
    }
    started = time.time()
    data = http_json(url, method="POST", payload=payload, timeout_s=timeout_s)
    latency_s = max(0.0, time.time() - started)
    if not isinstance(data, dict):
        return {"ok": True, "message": {"role": "assistant", "content": str(data)}, "latency_s": latency_s}
    out = dict(data)
    out["ok"] = True
    out["latency_s"] = latency_s
    out["agent_time_utc"] = now_utc_iso()
    out.setdefault("model", model)
    return out


class _TextExtractor(HTMLParser):
    def __init__(self) -> None:
        super().__init__(convert_charrefs=True)
        self._in_script = False
        self._in_style = False
        self._in_title = False
        self._title: list[str] = []
        self._parts: list[str] = []

    def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
        t = (tag or "").lower()
        if t in {"script", "noscript"}:
            self._in_script = True
        if t == "style":
            self._in_style = True
        if t == "title":
            self._in_title = True

    def handle_endtag(self, tag: str) -> None:
        t = (tag or "").lower()
        if t in {"script", "noscript"}:
            self._in_script = False
        if t == "style":
            self._in_style = False
        if t == "title":
            self._in_title = False
        if t in {"p", "br", "div", "li", "h1", "h2", "h3", "h4", "h5", "h6"}:
            self._parts.append("\n")

    def handle_data(self, data: str) -> None:
        if self._in_script or self._in_style:
            return
        if self._in_title:
            self._title.append(data)
            return
        self._parts.append(data)

    def title(self) -> str:
        return re.sub(r"\s+", " ", " ".join(self._title)).strip()

    def text(self) -> str:
        raw = " ".join(self._parts)
        raw = raw.replace("\u00a0", " ")
        raw = re.sub(r"[ \t\r\f\v]+", " ", raw)
        raw = re.sub(r"\n{2,}", "\n", raw)
        return raw.strip()


class _NoRedirect(HTTPRedirectHandler):
    def redirect_request(self, req: Request, fp: Any, code: int, msg: str, headers: Any, newurl: str) -> Any:  # type: ignore[override]
        raise HTTPError(req.full_url, code, msg, headers, fp)


_NO_REDIRECT_OPENER = build_opener(_NoRedirect())


def _policy_domain_allowlist(policy: dict[str, Any], *, base_url: str) -> set[str]:
    raw = policy.get("crawl_allowlist_domains")
    parts: list[str] = []
    if isinstance(raw, str):
        parts = [p for p in re.split(r"[\s,]+", raw) if p]
    elif isinstance(raw, list):
        parts = [str(p) for p in raw if isinstance(p, str)]
    allowed = {p.strip().lower() for p in parts if str(p or "").strip()}

    if not allowed:
        try:
            host = str(urlparse(base_url or "").hostname or "").strip().lower()
        except Exception:
            host = ""
        if host:
            allowed.add(host)
    return allowed


def _host_allowed(host: str, allowlist: set[str]) -> bool:
    h = str(host or "").strip().lower()
    if not h:
        return False
    if "*" in allowlist:
        return True
    for a in allowlist:
        if not a:
            continue
        if h == a or h.endswith("." + a):
            return True
    return False


def _resolved_ips(host: str) -> list[str]:
    h = str(host or "").strip()
    if not h:
        return []
    try:
        ipaddress.ip_address(h)
        return [h]
    except ValueError:
        pass
    try:
        infos = socket.getaddrinfo(h, None, type=socket.SOCK_STREAM)
    except Exception:
        return []
    ips: list[str] = []
    for info in infos:
        sockaddr = info[4]
        if not sockaddr:
            continue
        ip = str(sockaddr[0]).strip()
        if ip and ip not in ips:
            ips.append(ip)
    return ips


def _ssrf_guard(host: str, *, allow_private: bool) -> None:
    ips = _resolved_ips(host)
    if not ips:
        raise RuntimeError("dns resolution failed")
    for ip in ips:
        try:
            addr = ipaddress.ip_address(str(ip))
        except ValueError:
            continue
        if addr.is_loopback or addr.is_link_local:
            raise RuntimeError("host resolves to loopback/link-local ip")
        if addr.is_private:
            if not allow_private:
                raise RuntimeError("host resolves to private ip (set allow_private_crawl_ips=true)")
            continue
        if not addr.is_global:
            raise RuntimeError("host resolves to non-global ip")


def _validate_crawl_url(url: str, *, policy: dict[str, Any], base_url: str) -> str:
    try:
        parsed = urlparse(url or "")
    except Exception:
        raise RuntimeError("invalid url")
    if parsed.scheme not in {"http", "https"}:
        raise RuntimeError("only http/https urls are allowed")
    if parsed.username or parsed.password:
        raise RuntimeError("userinfo in url is not allowed")
    host = str(parsed.hostname or "").strip().lower()
    if not host:
        raise RuntimeError("missing host")

    allowlist = _policy_domain_allowlist(policy, base_url=base_url)
    if not _host_allowed(host, allowlist):
        raise RuntimeError("host not allowed by crawl_allowlist_domains")

    allow_private = bool(policy.get("allow_private_crawl_ips"))
    _ssrf_guard(host, allow_private=allow_private)
    return str(url)


def _open_with_redirects(
    url: str,
    *,
    headers: dict[str, str],
    timeout_s: float,
    max_redirects: int,
    policy: dict[str, Any],
    base_url: str,
):
    cur = str(url or "").strip()
    seen: set[str] = set()
    for _ in range(max(0, int(max_redirects)) + 1):
        if not cur or cur in seen:
            raise RuntimeError("redirect loop")
        seen.add(cur)
        _validate_crawl_url(cur, policy=policy, base_url=base_url)
        req = Request(cur, headers=headers)
        try:
            return _NO_REDIRECT_OPENER.open(req, timeout=timeout_s)
        except HTTPError as exc:
            if int(getattr(exc, "code", 0) or 0) in {301, 302, 303, 307, 308}:
                loc = str(exc.headers.get("Location") or "").strip()
                if not loc:
                    raise RuntimeError("redirect without Location header")
                cur = urljoin(cur, loc)
                continue
            raise
    raise RuntimeError("too many redirects")


def crawl_url(
    url: str,
    *,
    max_bytes: int,
    timeout_s: float,
    respect_robots: bool,
    policy: dict[str, Any],
    base_url: str,
) -> dict[str, Any]:
    _validate_crawl_url(url, policy=policy, base_url=base_url)
    parsed = urlparse(url)

    ua = DEFAULT_USER_AGENT
    if respect_robots:
        rp = urllib.robotparser.RobotFileParser()
        robots_url = urljoin(f"{parsed.scheme}://{parsed.netloc}", "/robots.txt")
        _validate_crawl_url(robots_url, policy=policy, base_url=base_url)
        try:
            # Fetch robots.txt without following redirects automatically (SSRF defense).
            with _open_with_redirects(
                robots_url,
                headers={"User-Agent": ua, "Accept": "text/plain, */*;q=0.1"},
                timeout_s=min(8.0, float(timeout_s)),
                max_redirects=min(2, _safe_int(policy.get("crawl_max_redirects"), 3)),
                policy=policy,
                base_url=base_url,
            ) as resp2:
                raw2 = resp2.read(128 * 1024)
            text2 = raw2.decode("utf-8", errors="replace")
            rp.parse(text2.splitlines())
            if not rp.can_fetch(ua, url):
                raise RuntimeError("blocked by robots.txt")
        except RuntimeError:
            raise
        except Exception:
            # If robots fails to load, be conservative and do not crawl.
            raise RuntimeError("robots.txt unavailable (refusing crawl)")

    headers = {"User-Agent": ua, "Accept": "text/html, text/plain;q=0.9, */*;q=0.1"}
    max_redirects = _safe_int(policy.get("crawl_max_redirects"), 3)
    started = time.time()
    final_url = url
    with _open_with_redirects(
        url,
        headers=headers,
        timeout_s=timeout_s,
        max_redirects=max_redirects,
        policy=policy,
        base_url=base_url,
    ) as resp:
        final_url = str(resp.geturl() or url).strip() or url
        ctype = str(resp.headers.get("Content-Type") or "").lower()
        raw = b""
        while True:
            chunk = resp.read(min(64 * 1024, max(1, max_bytes - len(raw))))
            if not chunk:
                break
            raw += chunk
            if len(raw) >= max_bytes:
                break
    latency_s = max(0.0, time.time() - started)

    # Content-type safety gate (best-effort).
    if ctype:
        ctype_main = ctype.split(";", 1)[0].strip()
        if any(ctype_main.startswith(p) for p in ("image/", "audio/", "video/")):
            raise RuntimeError(f"unsupported content-type: {ctype_main}")
        if any(tok in ctype_main for tok in ("application/pdf", "application/zip", "application/x-gzip", "application/gzip")):
            raise RuntimeError(f"unsupported content-type: {ctype_main}")
    if b"\x00" in raw:
        raise RuntimeError("binary content (null byte) refused")

    # Best-effort decode.
    text = ""
    for enc in ("utf-8", "utf-16", "latin-1"):
        try:
            text = raw.decode(enc)
            break
        except Exception:
            continue
    if not text:
        raise RuntimeError("decode failed")

    title = ""
    content = ""
    if "text/html" in ctype or "<html" in text.lower():
        parser = _TextExtractor()
        parser.feed(text)
        title = parser.title()
        content = parser.text()
    else:
        content = re.sub(r"\s+", " ", text).strip()

    if not content:
        raise RuntimeError("empty content")

    content_hash = hashlib.sha256(content.encode("utf-8", errors="ignore")).hexdigest()
    return {
        "ok": True,
        "url": final_url,
        "title": title,
        "content": content,
        "content_hash": content_hash,
        "content_type": ctype,
        "fetched_utc": now_utc_iso(),
        "latency_s": latency_s,
        "bytes": len(raw),
    }


def verify_job_signature(manifest: dict[str, Any], signature_b64: str, public_key_b64: str) -> None:
    if Ed25519PublicKey is None:
        raise RuntimeError("cryptography missing: cannot verify job signatures")
    try:
        sig = base64.b64decode(signature_b64.encode("utf-8"), validate=True)
        pub = base64.b64decode(public_key_b64.encode("utf-8"), validate=True)
    except Exception:
        raise RuntimeError("invalid signature/public key encoding")
    try:
        key = Ed25519PublicKey.from_public_bytes(pub)
        key.verify(sig, canonical_json(manifest))
    except Exception:
        raise RuntimeError("job signature verification failed")


def run_job(
    manifest: dict[str, Any],
    policy: dict[str, Any],
    *,
    base_url: str,
    ollama_base_url: str,
    ollama_model: str,
) -> dict[str, Any]:
    job_type = str(manifest.get("job_type") or "").strip()
    input_obj = manifest.get("input") if isinstance(manifest.get("input"), dict) else {}
    constraints = manifest.get("constraints") if isinstance(manifest.get("constraints"), dict) else {}

    if job_type == "health_check":
        hb = build_heartbeat_payload(policy)
        if ollama_model:
            hb["ollama_model"] = str(ollama_model)
        return {"ok": True, "kind": "health_check", "time_utc": now_utc_iso(), "heartbeat": hb}
    if job_type == "crawl_url":
        url = str(input_obj.get("url") or "").strip()
        if not url:
            raise RuntimeError("missing url")
        max_bytes = max(32 * 1024, min(_safe_int(constraints.get("max_bytes"), 700_000), 2_000_000))
        timeout_s = max(3.0, min(float(constraints.get("timeout_s") or 18.0), 60.0))
        respect_robots = bool(constraints.get("respect_robots", True))
        return crawl_url(
            url,
            max_bytes=max_bytes,
            timeout_s=timeout_s,
            respect_robots=respect_robots,
            policy=policy,
            base_url=base_url,
        )
    if job_type == "ollama_chat":
        prompt = str(input_obj.get("prompt") or "").strip()
        requested_model = str(input_obj.get("model") or "").strip()
        model_to_use = requested_model or ollama_model
        if requested_model and ollama_model and requested_model != ollama_model:
            raise RuntimeError("requested model not available on this node")
        timeout_s = float(constraints.get("timeout_s") or manifest.get("timeout_s") or 90.0)
        timeout_s = max(5.0, min(timeout_s, 600.0))
        return ollama_chat(ollama_base_url, model=model_to_use, prompt=prompt, timeout_s=timeout_s)
    raise RuntimeError(f"unsupported job_type: {job_type}")


def main() -> int:
    ap = argparse.ArgumentParser()
    ap.add_argument("--config", required=True, help="Path to agent.json")
    ap.add_argument("--once", action="store_true", help="Run one heartbeat + poll cycle and exit")
    ap.add_argument("--doctor", action="store_true", help="Run local diagnostics and exit")
    args = ap.parse_args()

    cfg = _read_json(args.config)
    cfg_dir = os.path.dirname(os.path.abspath(os.path.expanduser(args.config)))
    emergency_stop_path = os.path.join(cfg_dir, "EMERGENCY_STOP")
    base_url = str(cfg.get("base_url") or "").strip().rstrip("/")
    node_id = str(cfg.get("node_id") or "").strip()
    token = str(cfg.get("node_token") or "").strip()
    # Forward-compatible merge: new defaults should take effect even if the
    # on-disk config was written by an older agent/installer.
    policy = policy_defaults()
    policy_raw = cfg.get("policy") if isinstance(cfg.get("policy"), dict) else {}
    if isinstance(policy_raw, dict):
        try:
            policy.update(policy_raw)
        except Exception:
            pass
    ollama_base_url = str(cfg.get("ollama_base_url") or "").strip()
    ollama_model = str(cfg.get("ollama_model") or "").strip()
    public_key_b64 = str(cfg.get("signing_public_key_b64") or "").strip()
    expected_pubkey_fpr = str(cfg.get("signing_public_key_fingerprint_sha256_expected") or "").strip().lower()
    ok, err = validate_policy(policy)
    if not ok:
        print(f"[agent] invalid policy: {err}", file=sys.stderr)
        return 2
    if not base_url or not node_id or not token:
        print("[agent] config missing base_url/node_id/node_token", file=sys.stderr)
        return 2

    hb_url = urljoin(base_url + "/", f"api/grid/nodes/{node_id}/heartbeat")
    poll_url = urljoin(base_url + "/", "api/grid/jobs/poll")
    key_url = urljoin(base_url + "/", "api/grid/public-key")

    last_hb = 0.0
    hb_interval_s = 35.0
    poll_interval_s = 8.0

    last_emergency_stop: bool | None = None

    def ensure_public_key() -> str:
        nonlocal public_key_b64
        if public_key_b64:
            fpr = public_key_fingerprint_sha256(public_key_b64)
            if expected_pubkey_fpr and (not fpr or fpr != expected_pubkey_fpr):
                raise RuntimeError("public key fingerprint mismatch (pinned key)")
            return public_key_b64
        data = http_json(key_url, token=token, timeout_s=10.0)
        pk = str(data.get("public_key_b64") or "").strip()
        if not pk:
            raise RuntimeError("missing public key from control-plane")
        fpr = public_key_fingerprint_sha256(pk)
        if expected_pubkey_fpr and (not fpr or fpr != expected_pubkey_fpr):
            raise RuntimeError("public key fingerprint mismatch (pinned key)")
        public_key_b64 = pk
        cfg["signing_public_key_b64"] = pk
        if fpr:
            cfg["signing_public_key_fingerprint_sha256"] = fpr
        _write_json(args.config, cfg)
        return pk

    def run_doctor() -> int:
        problems: list[str] = []
        warnings: list[str] = []

        print("PowerSearch Grid Edge Agent — doctor")
        print(f"- agent_version: {AGENT_VERSION}")
        print(f"- python: {platform.python_version()} ({sys.executable})")
        print(f"- platform: {platform.platform()}")
        print(f"- config: {os.path.abspath(os.path.expanduser(args.config))}")
        print(f"- base_url: {base_url}")
        print(f"- node_id: {node_id}")
        print(f"- node_token: {'present' if bool(token) else 'missing'}")

        pub_fpr = public_key_fingerprint_sha256(public_key_b64)
        if public_key_b64:
            print(f"- signing_public_key_fingerprint_sha256: {pub_fpr or 'unavailable'}")
        else:
            print("- signing_public_key_fingerprint_sha256: missing (will fetch on first poll)")
        if expected_pubkey_fpr:
            print(f"- pinned_public_key_fingerprint_sha256_expected: {expected_pubkey_fpr}")
            if not pub_fpr:
                problems.append("pinned key set but current public key fingerprint is unavailable/invalid")
            elif pub_fpr != expected_pubkey_fpr:
                problems.append("pinned key mismatch (expected != current)")

        if Ed25519PublicKey is None:
            problems.append("cryptography missing (cannot verify job signatures)")

        emergency_stop = False
        try:
            emergency_stop = os.path.exists(emergency_stop_path)
        except Exception:
            emergency_stop = False
        print(f"- emergency_stop: {'active' if emergency_stop else 'not set'} ({emergency_stop_path})")

        print(f"- policy.enabled: {bool(policy.get('enabled'))}")
        print(f"- policy.allowed_job_types: {policy.get('allowed_job_types')}")
        print(f"- policy.crawl_allowlist_domains: {policy.get('crawl_allowlist_domains')}")

        # Control-plane connectivity (read-only).
        try:
            d = http_json(urljoin(base_url + '/', 'health'), timeout_s=10.0)
            ok_flag = bool(d.get('ok') is True or str(d.get('status') or '').strip().lower() == 'ok')
            print(f"- control_plane.health: {'ok' if ok_flag else 'warn'}")
            if not ok_flag:
                warnings.append("control-plane /health did not return ok")
        except Exception as exc:
            print(f"- control_plane.health: error ({exc})")
            problems.append("control-plane /health unreachable")

        try:
            d = http_json(urljoin(base_url + '/', 'api/grid/status'), timeout_s=12.0)
            ok_flag = bool(d.get('ok') is True)
            nodes_online = d.get('nodes_online')
            nodes_total = d.get('nodes_total')
            print(f"- control_plane.grid_status: {'ok' if ok_flag else 'warn'} (nodes {nodes_online}/{nodes_total})")
            if not ok_flag:
                warnings.append("control-plane /api/grid/status did not return ok")
        except Exception as exc:
            print(f"- control_plane.grid_status: error ({exc})")
            problems.append("control-plane /api/grid/status unreachable")

        try:
            d = http_json(urljoin(base_url + '/', f'api/grid/nodes/{node_id}/status_public'), timeout_s=12.0)
            ok_flag = bool(d.get('ok') is True)
            state = str(d.get('state') or '').strip() or 'unknown'
            connected = d.get('connected')
            last_seen = str(d.get('last_seen_utc') or '').strip()
            print(f"- node.status_public: {'ok' if ok_flag else 'warn'} (state={state} connected={connected} last_seen={last_seen or 'n/a'})")
        except Exception as exc:
            print(f"- node.status_public: error ({exc})")
            warnings.append("node status_public not reachable (node may not be registered yet)")

        # Ollama connectivity (best-effort).
        base = _normalize_ollama_base_url(ollama_base_url)
        try:
            d = http_json(base + '/api/tags', timeout_s=4.0)
            models = d.get('models')
            count = len(models) if isinstance(models, list) else 0
            print(f"- ollama: ok ({base} · {count} model(s))")
        except Exception as exc:
            print(f"- ollama: unavailable ({base} · {exc})")
            if ollama_model:
                warnings.append("ollama_model set but Ollama is unreachable")

        # systemd --user status (Linux best-effort).
        if sys.platform.startswith('linux') and shutil.which('systemctl'):
            try:
                p = subprocess.run(['systemctl', '--user', 'is-active', 'powersearch-grid-agent'], capture_output=True, text=True, timeout=2.5)
                state = (p.stdout or p.stderr or '').strip() or 'unknown'
                print(f"- systemd_user.service_active: {state}")
            except Exception as exc:
                print(f"- systemd_user.service_active: error ({exc})")

        if problems:
            print("\nProblems:")
            for x in problems[:12]:
                print(f"- {x}")
        if warnings:
            print("\nWarnings:")
            for x in warnings[:12]:
                print(f"- {x}")

        if problems:
            print("\nResult: FAIL")
            return 1
        if warnings:
            print("\nResult: WARN")
            return 0
        print("\nResult: OK")
        return 0

    if args.doctor:
        return run_doctor()

    while True:
        now = time.time()

        emergency_stop = False
        if bool(policy.get("emergency_stop_enabled", True)):
            try:
                emergency_stop = os.path.exists(emergency_stop_path)
            except Exception:
                emergency_stop = False
        if last_emergency_stop is None or emergency_stop != last_emergency_stop:
            if emergency_stop:
                print(f"[agent] emergency stop active (found {emergency_stop_path})")
            elif last_emergency_stop is not None:
                print("[agent] emergency stop cleared")
            last_emergency_stop = emergency_stop

        effective_enabled = bool(policy.get("enabled")) and (not emergency_stop)
        enabled = bool(effective_enabled)
        allowed_types = policy.get("allowed_job_types") or []
        allowed_types_set = {str(x) for x in allowed_types if isinstance(x, str)}

        key_pin_mismatch = False
        if expected_pubkey_fpr:
            fpr = public_key_fingerprint_sha256(public_key_b64)
            if not fpr:
                # If we have no cached key, attempt to fetch it before declaring a mismatch.
                if enabled:
                    try:
                        public_key_b64 = ensure_public_key()
                        fpr = public_key_fingerprint_sha256(public_key_b64)
                    except Exception:
                        fpr = ""
                if not fpr:
                    key_pin_mismatch = True
            elif fpr != expected_pubkey_fpr:
                key_pin_mismatch = True

        blockers = compute_work_blockers(
            policy,
            effective_enabled=enabled,
            emergency_stop=emergency_stop,
            allowed_types_set=allowed_types_set,
            key_pin_mismatch=key_pin_mismatch,
        )
        can_work = enabled and (not blockers)

        if now - last_hb >= hb_interval_s:
            try:
                payload = build_heartbeat_payload(policy)
                payload["policy_enabled"] = bool(effective_enabled)
                payload["emergency_stop"] = bool(emergency_stop)
                payload["work_allowed"] = bool(can_work)
                payload["work_blockers"] = blockers
                fpr = public_key_fingerprint_sha256(public_key_b64)
                if fpr:
                    payload["signing_public_key_fingerprint_sha256"] = fpr
                if ollama_model:
                    payload["ollama_model"] = ollama_model
                http_json(hb_url, method="POST", token=token, payload=payload, timeout_s=8.5)
                last_hb = now
                print(f"[agent] heartbeat ok {payload.get('time_utc')}")
            except Exception as exc:
                print(f"[agent] heartbeat error: {exc}", file=sys.stderr)

        if can_work:
            try:
                pk = ensure_public_key()
                caps = build_heartbeat_payload(policy)
                caps["policy_enabled"] = bool(effective_enabled)
                caps["emergency_stop"] = bool(emergency_stop)
                caps["work_allowed"] = bool(can_work)
                caps["work_blockers"] = blockers
                fpr = public_key_fingerprint_sha256(public_key_b64)
                if fpr:
                    caps["signing_public_key_fingerprint_sha256"] = fpr
                if ollama_model:
                    caps["ollama_model"] = ollama_model
                resp = http_json(
                    poll_url,
                    method="POST",
                    token=token,
                    payload={"node_id": node_id, "capabilities": caps, "policy": policy},
                    timeout_s=15.0,
                )
                job = resp.get("job") if isinstance(resp, dict) else None
                if job and isinstance(job, dict):
                    manifest = job.get("manifest") if isinstance(job.get("manifest"), dict) else None
                    sig = str(job.get("signature_b64") or "").strip()
                    if not manifest or not sig:
                        raise RuntimeError("malformed job offer")
                    verify_job_signature(manifest, sig, pk)
                    job_type = str(manifest.get("job_type") or "").strip()
                    job_id = str(manifest.get("job_id") or "").strip()
                    assigned = str(manifest.get("assigned_node_id") or "").strip()
                    if assigned and assigned != node_id:
                        raise RuntimeError("job assigned to different node")
                    if job_type not in allowed_types_set:
                        raise RuntimeError(f"job_type not allowed by policy: {job_type}")

                    print(f"[agent] running job {job_id} ({job_type})")
                    result_ok = True
                    result_obj: dict[str, Any] = {}
                    err_s = ""
                    try:
                        result_obj = run_job(
                            manifest,
                            policy,
                            base_url=base_url,
                            ollama_base_url=ollama_base_url,
                            ollama_model=ollama_model,
                        )
                    except Exception as exc:
                        result_ok = False
                        err_s = str(exc)
                        result_obj = {"ok": False, "error": err_s}

                    try:
                        submit_url = urljoin(base_url + "/", f"api/grid/jobs/{job_id}/result")
                        http_json(
                            submit_url,
                            method="POST",
                            token=token,
                            payload={"node_id": node_id, "ok": result_ok, "result": result_obj, "error": err_s},
                            timeout_s=30.0,
                        )
                        print(f"[agent] submitted job {job_id} ok={result_ok}")
                    except Exception as exc:
                        msg = str(exc)
                        # If an operator cancels a job while we're working, the control-plane rejects the result
                        # with a 409. Treat this as a normal outcome (not an agent error).
                        if msg.startswith("http 409") and ("status=canceled" in msg or "status=cancelled" in msg or "job not active" in msg):
                            print(f"[agent] job {job_id} result rejected (likely canceled): {msg}")
                        else:
                            print(f"[agent] submit error: {exc}", file=sys.stderr)
                else:
                    # No job available.
                    pass
            except Exception as exc:
                print(f"[agent] poll error: {exc}", file=sys.stderr)

        if args.once:
            break
        time.sleep(poll_interval_s if enabled else max(15.0, poll_interval_s))
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
