#!/usr/bin/env python3

from __future__ import annotations

import argparse
import logging
import os
from pathlib import Path
import shutil
import subprocess
import sys
import time

MESHTASTIC_HOST = "127.0.0.1"
MESHTASTIC_PORT = 4403
MESHTASTIC_TIMEOUT_SEC = 20
POLL_INTERVAL_SEC = 1.0
REQUEST_SETTLE_SEC = 0.5
MESH_APPLY_DELAY_SEC = 5.0
MESH_WRITEBACK_ENABLED = False
RFKILL_ROOT = Path("/sys/class/rfkill")
CONFIG_PROTO_PATH = Path("/var/lib/meshtasticd/.portduino/default/prefs/config.proto")

PROTO_WIRE_VARINT = 0
PROTO_WIRE_64BIT = 1
PROTO_WIRE_LENGTH_DELIMITED = 2
PROTO_WIRE_32BIT = 5
CONFIG_NETWORK_FIELD = 4
NETWORK_WIFI_ENABLED_FIELD = 1

LOG = logging.getLogger("wifisync")


def _state_label(enabled: bool) -> str:
  return "enabled" if enabled else "disabled"


def _find_meshtastic_interpreter() -> Path | None:
  candidates = []
  command_path = shutil.which("meshtastic")
  if command_path:
    candidates.append(Path(command_path))
  candidates.extend((Path("/usr/local/bin/meshtastic"), Path("/usr/bin/meshtastic")))

  seen = set()
  for script_path in candidates:
    if not script_path.is_file():
      continue
    resolved = script_path.resolve()
    if resolved in seen:
      continue
    seen.add(resolved)
    try:
      with resolved.open("r", encoding="utf-8") as handle:
        first_line = handle.readline().strip()
    except OSError:
      continue
    if not first_line.startswith("#!"):
      continue
    interpreter = first_line[2:].strip()
    if interpreter.startswith("/usr/bin/env "):
      env_target = interpreter.split(None, 1)[1]
      resolved_env = shutil.which(env_target)
      if resolved_env:
        return Path(resolved_env)
      continue
    resolved_interpreter = Path(interpreter.split()[0])
    if resolved_interpreter.exists():
      return resolved_interpreter
  return None


def _find_meshtastic_site_packages() -> Path | None:
  candidates = (
    Path("/opt/pipx/venvs/meshtastic/lib"),
    Path("/usr/local/pipx/venvs/meshtastic/lib"),
  )
  for lib_path in candidates:
    for site_packages in sorted(lib_path.glob("python*/site-packages")):
      if site_packages.is_dir():
        return site_packages
  return None


def _load_tcp_interface():
  try:
    from meshtastic.tcp_interface import TCPInterface
  except ModuleNotFoundError as exc:
    site_packages = _find_meshtastic_site_packages()
    if site_packages and str(site_packages) not in sys.path:
      sys.path.insert(0, str(site_packages))
      try:
        from meshtastic.tcp_interface import TCPInterface
      except ModuleNotFoundError:
        pass
      else:
        return TCPInterface
    interpreter = _find_meshtastic_interpreter()
    if interpreter and interpreter.resolve() != Path(sys.executable).resolve():
      os.execv(str(interpreter), [str(interpreter), *sys.argv])
    raise RuntimeError(
      "Meshtastic Python module not found. Install the Meshtastic Python package "
      "or the global meshtastic tool."
    ) from exc
  return TCPInterface


def _read_int(path: Path) -> int:
  return int(path.read_text(encoding="ascii").strip())


def _find_wifi_rfkill() -> Path | None:
  for rfkill_path in sorted(RFKILL_ROOT.glob("rfkill*")):
    try:
      kind = (rfkill_path / "type").read_text(encoding="ascii").strip()
    except OSError:
      continue
    if kind == "wlan":
      return rfkill_path
  return None


def _read_rfkill_state(rfkill_path: Path) -> tuple[bool, bool]:
  soft_blocked = _read_int(rfkill_path / "soft") == 1
  hard_blocked = _read_int(rfkill_path / "hard") == 1
  return (not soft_blocked, hard_blocked)


def _write_rfkill_enabled(rfkill_path: Path, enabled: bool) -> None:
  value = "0" if enabled else "1"
  with (rfkill_path / "soft").open("w", encoding="ascii") as handle:
    handle.write(value)


# Portduino persists config as protobuf wire format; only network.wifi_enabled is needed here.
def _read_varint(data: bytes, offset: int) -> tuple[int, int]:
  value = 0
  shift = 0
  while offset < len(data):
    byte = data[offset]
    offset += 1
    value |= (byte & 0x7F) << shift
    if byte < 0x80:
      return (value, offset)
    shift += 7
    if shift >= 64:
      break
  raise ValueError("Malformed protobuf varint.")


def _skip_field(data: bytes, offset: int, wire_type: int) -> int:
  if wire_type == PROTO_WIRE_VARINT:
    _, offset = _read_varint(data, offset)
    return offset
  if wire_type == PROTO_WIRE_64BIT:
    return offset + 8
  if wire_type == PROTO_WIRE_LENGTH_DELIMITED:
    length, offset = _read_varint(data, offset)
    return offset + length
  if wire_type == PROTO_WIRE_32BIT:
    return offset + 4
  raise ValueError(f"Unsupported protobuf wire type {wire_type}.")


def _read_length_delimited_field(data: bytes, field_number: int) -> bytes | None:
  offset = 0
  while offset < len(data):
    key, offset = _read_varint(data, offset)
    current_field = key >> 3
    wire_type = key & 0x7
    if current_field == field_number:
      if wire_type != PROTO_WIRE_LENGTH_DELIMITED:
        raise ValueError(f"Field {field_number} is not length-delimited.")
      length, offset = _read_varint(data, offset)
      return data[offset : offset + length]
    offset = _skip_field(data, offset, wire_type)
  return None


def _read_bool_field(data: bytes, field_number: int) -> bool | None:
  value = None
  offset = 0
  while offset < len(data):
    key, offset = _read_varint(data, offset)
    current_field = key >> 3
    wire_type = key & 0x7
    if current_field == field_number:
      if wire_type != PROTO_WIRE_VARINT:
        raise ValueError(f"Field {field_number} is not varint encoded.")
      raw_value, offset = _read_varint(data, offset)
      value = bool(raw_value)
    else:
      offset = _skip_field(data, offset, wire_type)
  return value


def _read_mesh_wifi_enabled(config_path: Path) -> tuple[bool, int]:
  config_mtime_ns = config_path.stat().st_mtime_ns
  config_bytes = config_path.read_bytes()
  network_config = _read_length_delimited_field(config_bytes, CONFIG_NETWORK_FIELD)
  if network_config is None:
    return (False, config_mtime_ns)
  wifi_enabled = _read_bool_field(network_config, NETWORK_WIFI_ENABLED_FIELD)
  return (bool(wifi_enabled), config_mtime_ns)


def _read_config_mtime_ns(config_path: Path) -> int:
  return config_path.stat().st_mtime_ns


def _endpoint_host(endpoint: str) -> str:
  endpoint = endpoint.strip()
  if endpoint.startswith("["):
    closing = endpoint.find("]")
    if closing != -1:
      return endpoint[1:closing]
  if ":" in endpoint:
    return endpoint.rsplit(":", 1)[0]
  return endpoint


def _has_active_remote_meshtastic_client() -> bool:
  ss_path = shutil.which("ss")
  if not ss_path:
    return False
  result = subprocess.run(
    [ss_path, "-Htn", "sport", "=", f":{MESHTASTIC_PORT}"],
    capture_output=True,
    check=False,
    text=True,
  )
  if result.returncode != 0:
    return False
  for line in result.stdout.splitlines():
    fields = line.split()
    if len(fields) < 5:
      continue
    state = fields[0]
    peer_host = _endpoint_host(fields[4])
    if state == "ESTAB" and peer_host not in ("127.0.0.1", "::1"):
      return True
  return False


class MeshClient:
  def __init__(self, tcp_interface_cls) -> None:
    self._tcp_interface_cls = tcp_interface_cls
    self._interface = None

  def close(self) -> None:
    if self._interface is None:
      return
    try:
      self._interface.close()
    except Exception:
      pass
    self._interface = None

  def _node(self):
    if self._interface is None:
      interface = self._tcp_interface_cls(
        MESHTASTIC_HOST,
        timeout=MESHTASTIC_TIMEOUT_SEC,
        connectNow=False,
        noNodes=True,
      )
      try:
        interface.myConnect()
        interface.connect()
        interface.waitForConfig()
      except Exception:
        try:
          interface.close()
        except Exception:
          pass
        raise
      self._interface = interface
    return self._interface.localNode

  def connect(self) -> None:
    self._node()

  def set_wifi_enabled(self, enabled: bool) -> None:
    node = self._node()
    node.localConfig.network.wifi_enabled = enabled
    if enabled:
      node.writeConfig("network")
    else:
      node.beginSettingsTransaction()
      node.writeConfig("network")
      node.commitSettingsTransaction()
    time.sleep(REQUEST_SETTLE_SEC)


def _set_mesh_wifi_enabled(enabled: bool) -> None:
  mesh = MeshClient(_load_tcp_interface())
  try:
    mesh.set_wifi_enabled(enabled)
  finally:
    mesh.close()


def _run_mesh_write_helper(enabled: bool, debug: bool) -> None:
  command = [sys.executable, os.fspath(Path(__file__).resolve()), "--set-mesh-enabled"]
  command.append("true" if enabled else "false")
  if debug:
    command.append("--debug")
  subprocess.run(command, check=True)


def _apply_mesh_state_locally(rfkill_path: Path | None, mesh_enabled: bool) -> bool | None:
  if rfkill_path is None:
    LOG.info("No Wi-Fi rfkill device is available; mesh Wi-Fi stays local-config only.")
    return None

  local_enabled, hard_blocked = _read_rfkill_state(rfkill_path)
  if hard_blocked:
    LOG.warning(
      "Wi-Fi is hard blocked on %s; cannot apply mesh Wi-Fi %s.",
      rfkill_path.name,
      _state_label(mesh_enabled),
    )
    return local_enabled

  if local_enabled != mesh_enabled:
    _write_rfkill_enabled(rfkill_path, mesh_enabled)
    LOG.info(
      "Applied mesh Wi-Fi %s to %s.",
      _state_label(mesh_enabled),
      rfkill_path.name,
    )
    local_enabled, _ = _read_rfkill_state(rfkill_path)

  return local_enabled


def _run(args: argparse.Namespace) -> int:
  mesh_state: bool | None = None
  mesh_config_mtime_ns: int | None = None
  local_state: bool | None = None
  pending_mesh_state: bool | None = None
  pending_mesh_deadline: float | None = None
  pending_mesh_wait_logged = False
  rfkill_path: Path | None = None

  while True:
    current_rfkill = _find_wifi_rfkill()
    if current_rfkill != rfkill_path:
      rfkill_path = current_rfkill
      local_state = None
      if rfkill_path is None:
        LOG.info("No Wi-Fi rfkill device found.")
      else:
        LOG.info("Using Wi-Fi rfkill device %s.", rfkill_path.name)
        if mesh_state is not None:
          local_state = _apply_mesh_state_locally(rfkill_path, mesh_state)

    if mesh_state is None:
      mesh_state, mesh_config_mtime_ns = _read_mesh_wifi_enabled(args.config_path)
      LOG.info("Meshtastic Wi-Fi is %s.", _state_label(mesh_state))
      local_state = _apply_mesh_state_locally(rfkill_path, mesh_state)
      if args.once:
        return 0

    if rfkill_path is not None:
      current_local_state, hard_blocked = _read_rfkill_state(rfkill_path)
      if local_state is None:
        local_state = current_local_state
      elif current_local_state != local_state:
        previous_local_state = local_state
        local_state = current_local_state
        LOG.info(
          "Local Wi-Fi changed from %s to %s.",
          _state_label(previous_local_state),
          _state_label(local_state),
        )
        if hard_blocked:
          LOG.info("Ignoring local Wi-Fi change while hardware is hard blocked.")
        elif pending_mesh_state is not None:
          LOG.info("Skipping local Wi-Fi write-back while mesh apply is pending.")
        elif mesh_state is None:
          LOG.info("Skipping mesh update because Meshtastic Wi-Fi state is unavailable.")
        elif local_state != mesh_state and MESH_WRITEBACK_ENABLED:
          _run_mesh_write_helper(local_state, args.debug)
          mesh_state = local_state
          LOG.info(
            "Updated Meshtastic Wi-Fi to %s from local rfkill state.",
            _state_label(mesh_state),
          )
        elif local_state != mesh_state:
          LOG.info("Skipping local Wi-Fi write-back; Meshtastic writes are disabled.")

    try:
      current_mesh_mtime_ns = _read_config_mtime_ns(args.config_path)
    except OSError as exc:
      LOG.warning("Skipping Meshtastic config check: %s", exc)
    else:
      if current_mesh_mtime_ns != mesh_config_mtime_ns:
        try:
          current_mesh_state, mesh_config_mtime_ns = _read_mesh_wifi_enabled(args.config_path)
        except (OSError, ValueError) as exc:
          LOG.warning("Skipping unreadable Meshtastic config update: %s", exc)
        else:
          if current_mesh_state != mesh_state:
            previous_mesh_state = mesh_state
            mesh_state = current_mesh_state
            LOG.info(
              "Meshtastic Wi-Fi changed from %s to %s.",
              _state_label(previous_mesh_state),
              _state_label(mesh_state),
            )
            pending_mesh_state = mesh_state
            pending_mesh_deadline = time.monotonic() + args.mesh_apply_delay
            pending_mesh_wait_logged = False
            LOG.info(
              "Deferring mesh Wi-Fi %s apply for %.1fs.",
              _state_label(mesh_state),
              args.mesh_apply_delay,
            )

    if pending_mesh_state is not None and pending_mesh_deadline is not None:
      if time.monotonic() >= pending_mesh_deadline:
        if _has_active_remote_meshtastic_client():
          if not pending_mesh_wait_logged:
            LOG.info("Waiting for remote Meshtastic client disconnect before applying Wi-Fi.")
            pending_mesh_wait_logged = True
        else:
          local_state = _apply_mesh_state_locally(rfkill_path, pending_mesh_state)
          pending_mesh_state = None
          pending_mesh_deadline = None
          pending_mesh_wait_logged = False

    time.sleep(args.poll_interval)


def _parse_args() -> argparse.Namespace:
  parser = argparse.ArgumentParser(
    description="Keep Meshtastic Wi-Fi config and local rfkill state in sync.",
  )
  parser.add_argument(
    "--once",
    action="store_true",
    help="Apply mesh Wi-Fi config locally once, then exit.",
  )
  parser.add_argument(
    "--poll-interval",
    type=float,
    default=POLL_INTERVAL_SEC,
    help=f"Seconds between local rfkill polls (default: {POLL_INTERVAL_SEC}).",
  )
  parser.add_argument(
    "--mesh-apply-delay",
    type=float,
    default=MESH_APPLY_DELAY_SEC,
    help=(
      "Seconds to wait before applying runtime mesh Wi-Fi changes locally "
      f"(default: {MESH_APPLY_DELAY_SEC})."
    ),
  )
  parser.add_argument(
    "--config-path",
    type=Path,
    default=CONFIG_PROTO_PATH,
    help=f"Meshtastic config.proto path (default: {CONFIG_PROTO_PATH}).",
  )
  parser.add_argument(
    "--debug",
    action="store_true",
    help="Enable debug logging.",
  )
  parser.add_argument(
    "--set-mesh-enabled",
    choices=("true", "false"),
    help=argparse.SUPPRESS,
  )
  return parser.parse_args()


def main() -> int:
  args = _parse_args()
  logging.basicConfig(
    level=logging.DEBUG if args.debug else logging.INFO,
    format="%(asctime)s %(levelname)s %(message)s",
  )
  try:
    if args.set_mesh_enabled is not None:
      _set_mesh_wifi_enabled(args.set_mesh_enabled == "true")
      return 0
    return _run(args)
  except KeyboardInterrupt:
    LOG.info("Stopping.")
    return 0
  except Exception as exc:
    LOG.error("wifisync failed: %s", exc)
    return 1


if __name__ == "__main__":
  raise SystemExit(main())
