#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""flash_samd11_openocd.py

Purpose: Use OpenOCD to erase / program / verify / reset an ATSAMD11 (e.g. Seeed XIAO SAMD11) and optionally open a 9600bps serial monitor.

Features:
    - Supports .bin (with configurable base address) and .hex files.
    - Optional full chip erase (enabled by default; disable via --no-chip-erase).
    - Configurable OpenOCD binary/interface/target script/adapter speed.
    - Modes: erase-only / program-only / verify-only / combined flow.
    - Optional serial monitor after programming.

Dependencies:
    - OpenOCD installed and in PATH (or specify via --openocd-bin).
    - Optional pyserial (auto-installed only if serial monitor is requested).
    - intelhex only needed if you plan to add bin->hex conversion logic (not required for direct bin/hex usage here).

Examples:
    python flash_samd11_openocd.py --file samd11.bin
    python flash_samd11_openocd.py --file samd11.bin --base-addr 0x00000000 --serial-port COM7
    python flash_samd11_openocd.py --file app.hex --no-chip-erase
    python flash_samd11_openocd.py --file samd11.bin --openocd-speed 4000 --verify-only
    python flash_samd11_openocd.py --file samd11.bin --erase-only

Exit codes:
    0 success
    1 argument / file error
    2 OpenOCD failed
    3 serial monitor failed to start (programming already succeeded)
"""

import argparse
import subprocess
import sys
import os
import shutil
import logging
import time
from typing import Optional, Tuple

logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
log = logging.getLogger('samd11_openocd')

DEFAULT_BASE_ADDR = 0x00000000

# Default script names (adjust if environment differs)
DEFAULT_INTERFACE = 'cmsis-dap'      
DEFAULT_TARGET_CFG = 'at91samdXX.cfg' 
DEFAULT_SPEED = 1000               


def which(cmd: str) -> Optional[str]:
    return shutil.which(cmd)


def try_install_openocd(auto: bool) -> Tuple[bool, str]:
    """Attempt to install OpenOCD using available package managers.

    Returns (success, message). Will only act if auto is True.
    Windows order: choco -> winget -> scoop
    Linux: apt-get -> dnf -> pacman
    macOS: brew -> port
    If installation appears successful, returns True.
    """
    if not auto:
        return False, "auto-install disabled"

    platform = sys.platform
    cmds = []  # list of (display, command list)
    if platform.startswith('win'):
        # Use choco / winget / scoop
        if which('choco'):
            cmds.append(('choco', ['choco', 'install', 'openocd', '-y']))
        if which('winget'):
            # winget install may require agreement; add silent switches
            cmds.append(('winget', ['winget', 'install', '--id', 'OpenOCD.OpenOCD', '--silent', '--accept-package-agreements', '--accept-source-agreements']))
        if which('scoop'):
            cmds.append(('scoop', ['scoop', 'install', 'openocd']))
    elif platform == 'darwin':
        if which('brew'):
            cmds.append(('brew', ['brew', 'install', 'open-ocd']))  # formula often named open-ocd
        if which('port'):
            cmds.append(('macports', ['sudo', 'port', 'install', 'openocd']))
    else:
        # Assume Linux / Unix
        if which('apt-get'):
            cmds.append(('apt', ['sudo', 'apt-get', 'update']))
            cmds.append(('apt', ['sudo', 'apt-get', 'install', '-y', 'openocd']))
        if which('dnf'):
            cmds.append(('dnf', ['sudo', 'dnf', 'install', '-y', 'openocd']))
        if which('pacman'):
            cmds.append(('pacman', ['sudo', 'pacman', '-Sy', '--noconfirm', 'openocd']))

    if not cmds:
        return False, 'No known package manager found for automatic OpenOCD installation.'

    for label, cmd in cmds:
        try:
            log.info('Attempting OpenOCD install via %s: %s', label, ' '.join(cmd))
            res = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
            if res.returncode == 0:
                # Re-check availability
                if which('openocd') or which('open-ocd'):
                    return True, f'Installed via {label}'
        except Exception as e:
            log.warning('Install attempt with %s failed: %s', label, e)
    return False, 'All install attempts failed.'

def build_openocd_base_args(args) -> list:
    return [
        args.openocd_bin,
        '-f', f'interface/{args.openocd_interface}.cfg',
        '-f', f'target/{args.openocd_target_cfg}'
    ]


def build_openocd_command_sequence(args) -> list:
    cmds = []
    # Adapter speed
    cmds.append(f'adapter speed {args.openocd_speed}')
    # Init and halt
    cmds.append('init')
    cmds.append('halt')

    # Erase
    if args.erase_only or (not args.verify_only and not args.no_chip_erase and not args.skip_erase):
    # Use at91samd command
        cmds.append('at91samd chip-erase')

    # Program
    if args.program_only or (not args.erase_only and not args.verify_only):
        if not args.no_program:
            ext = os.path.splitext(args.file)[1].lower()
            if ext == '.bin':
                cmds.append(f'program {args.file} verify {"reset" if args.reset_after else ""} 0x{args.base_addr:08X}'.strip())
            else:
                cmds.append(f'program {args.file} verify {"reset" if args.reset_after else ""}'.strip())
    # Verify-only (re-run program verify; OpenOCD program already verifies by default)
    if args.verify_only:
        ext = os.path.splitext(args.file)[1].lower()
        if ext == '.bin':
            cmds.append(f'program {args.file} verify 0x{args.base_addr:08X}')
        else:
            cmds.append(f'program {args.file} verify')

    # Add reset if not already included in program command
    if args.reset_after and all('reset' not in c for c in cmds):
        cmds.append('reset run')

    cmds.append('exit')
    return cmds


def run_openocd(args) -> int:
    """Run OpenOCD with coarse progress display.
    Stage weights (adjustable): erase 20%, program 60%, verify 15%, reset 5%.
    Skipped stages are renormalized across active ones."""
    base_args = build_openocd_base_args(args)
    seq = build_openocd_command_sequence(args)
    full_cmd = base_args[:]
    for c in seq:
        full_cmd.extend(['-c', c])
    log.info('OpenOCD command: %s', ' '.join(full_cmd))

    # Detect active stages
    use_erase = any('chip-erase' in s for s in seq)
    use_program = any(s.startswith('program ') for s in seq)
    use_verify = any(' verify' in s and s.startswith('program ') for s in seq)
    use_reset = any('reset' in s for s in seq)

    # Base weights
    weights = {
        'erase': 0.20,
        'program': 0.60,
        'verify': 0.15,
        'reset': 0.05,
    }
    # Remove skipped stages and normalize
    active = [k for k, flag in [('erase', use_erase), ('program', use_program), ('verify', use_verify), ('reset', use_reset)] if flag]
    total_weight = sum(weights[k] for k in active)
    for k in active:
        weights[k] = weights[k] / total_weight

    progress = 0.0
    finished = {k: False for k in active}

    def emit():
        pct = min(100.0, progress * 100.0)
        sys.stdout.write(f"\r[Progress] {pct:6.2f}%")
        sys.stdout.flush()

    try:
        proc = subprocess.Popen(full_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, universal_newlines=True)
    except FileNotFoundError:
        log.error('OpenOCD executable not found. Install it or set --openocd-bin with full path.')
        return 2

    emit()
    try:
        for line in proc.stdout:  # type: ignore
            line_strip = line.rstrip('\n')
            # Log original line
            if line_strip:
                log.info(line_strip)

            lower = line_strip.lower()
            # Detect stage completion
            if use_erase and not finished.get('erase') and 'chip erase' in lower and ('finished' in lower or 'done' in lower or 'success' in lower):
                progress += weights['erase']
                finished['erase'] = True
                emit()
            elif use_program and not finished.get('program') and ('programming finished' in lower or '** programming finished **' in lower or 'program completed' in lower):
                progress += weights['program']
                finished['program'] = True
                emit()
            elif use_verify and not finished.get('verify') and ('verified ok' in lower or 'verify successful' in lower):
                progress += weights['verify']
                finished['verify'] = True
                emit()
            elif use_reset and not finished.get('reset') and ('resetting target' in lower or 'reset run' in lower):
                progress += weights['reset']
                finished['reset'] = True
                emit()

        proc.wait()
    finally:
        # Force final 100%
        progress = 1.0
        emit()
        sys.stdout.write('\n')

    if proc.returncode != 0:
        log.error('OpenOCD failed, return code=%s', proc.returncode)
        return 2
    return 0


def auto_pick_serial_port() -> Optional[str]:
    try:
        import serial.tools.list_ports  # type: ignore
    except Exception:
        return None
    ports = list(serial.tools.list_ports.comports())
    if len(ports) == 1:
        return ports[0].device
    return None


def open_serial_monitor(port: str, baud: int = 9600):
    try:
        import serial  # type: ignore
    except ImportError:
        log.info('Installing pyserial for serial monitor...')
        subprocess.run([sys.executable, '-m', 'pip', 'install', 'pyserial'], check=True)
        import serial  # type: ignore
    try:
        with serial.Serial(port, baudrate=baud, timeout=0.2) as ser:
            log.info(f'[Serial {port} {baud}bps] Press Ctrl+C to exit')
            buf = bytearray()
            while True:
                data = ser.read(256)
                if data:
                    buf.extend(data)
                    while b'\n' in buf:
                        line, _, rest = buf.partition(b'\n')
                        buf = bytearray(rest)
                        try:
                            print(line.decode(errors='replace'))
                        except Exception:
                            print(line)
                time.sleep(0.02)
    except KeyboardInterrupt:
        log.info('Serial monitor stopped')
    except Exception as e:
        log.error(f'Serial monitor error: {e}')
        raise


def parse_args():
    p = argparse.ArgumentParser(description='OpenOCD flashing utility for ATSAMD11 (XIAO SAMD11)')
    p.add_argument('--file', required=True, help='Input .bin or .hex file to flash')
    p.add_argument('--base-addr', type=lambda x: int(x,0), default=DEFAULT_BASE_ADDR, help='Base address for .bin (default 0x00000000)')
    p.add_argument('--openocd-bin', default='openocd', help='OpenOCD executable name or absolute path')
    p.add_argument('--openocd-interface', default=DEFAULT_INTERFACE, help='Interface script (without path/.cfg)')
    p.add_argument('--openocd-target-cfg', default=DEFAULT_TARGET_CFG, help='Target cfg file name')
    p.add_argument('--openocd-speed', type=int, default=DEFAULT_SPEED, help='SWD/JTAG speed kHz (default 1000)')
    p.add_argument('--no-chip-erase', action='store_true', help='Skip explicit chip erase (program will erase required sectors)')
    p.add_argument('--skip-erase', action='store_true', help='Skip any erase stage completely (dangerous)')
    p.add_argument('--erase-only', action='store_true', help='Only erase then exit')
    p.add_argument('--program-only', action='store_true', help='Only program (erase as needed)')
    p.add_argument('--verify-only', action='store_true', help='Only verify image')
    p.add_argument('--no-program', action='store_true', help='Do not program (combine with erase-only)')
    p.add_argument('--reset-after', action='store_true', help='Reset and run after operations')
    p.add_argument('--serial-port', help='Open 9600bps serial monitor after programming (auto-pick if single port)')
    p.add_argument('--no-serial-auto', action='store_true', help='Disable auto serial pick (monitor only if --serial-port provided)')
    p.add_argument('--auto-install-openocd', action='store_true', help='Attempt to automatically install OpenOCD if not found')
    return p.parse_args()


def main():
    args = parse_args()

    if not os.path.exists(args.file):
        log.error('Input file does not exist: %s', args.file)
        sys.exit(1)

    if which(args.openocd_bin) is None and not os.path.isfile(args.openocd_bin):
        log.warning('OpenOCD not found at provided path/name: %s', args.openocd_bin)
        success, msg = try_install_openocd(args.auto_install_openocd)
        if success:
            log.info('Auto-install success: %s', msg)
        else:
            log.error('OpenOCD still not available. %s', msg)
            sys.exit(1)

    # Mode conflict check
    if sum([args.erase_only, args.verify_only, args.program_only]) > 1:
        log.error('Only one of erase-only / verify-only / program-only may be specified')
        sys.exit(1)

    # Logical flags (placeholder for future extension)
    if args.no_chip_erase:
        args.no_chip_erase = True

    rc = run_openocd(args)
    if rc != 0:
        sys.exit(rc)

    # Serial monitor
    if args.serial_port:
        sel_port = args.serial_port
    else:
        sel_port = None if args.no_serial_auto else auto_pick_serial_port()

    if sel_port:
        # Allow enumeration delay after reset
        if args.reset_after:
            time.sleep(1.0)
        try:
            open_serial_monitor(sel_port, 9600)
        except Exception:
            sys.exit(3)

    sys.exit(0)

if __name__ == '__main__':
    main()
