import os, sys, traceback
sys.path.append('/nix/store/qfv2ayja1zgrfbiy1nrd9f0b1y759h60-python3-3.11.6-env/lib/python3.11/site-packages')

import time
import select
import logging
import threading
import ipaddress
import pickle
import pandas as pd
import numpy as np
import hashlib
import bfrt_grpc.client as gc
import joblib
# from sklearn.ensemble import RandomForestClassifier
from tensorflow.keras.models import load_model
import tensorflow as tf


# ========== Logging Setup ==========
logger = logging.getLogger('DigestListener')
logger.setLevel(logging.DEBUG)
for h in list(logger.handlers):
    logger.removeHandler(h)
log_handler = logging.FileHandler('digest_listener.log', mode='w')
log_handler.setLevel(logging.DEBUG)
logger.addHandler(log_handler)


# ===== Autoencoder assets loader (minimal, fail-fast) =====

AUTOENCODER_PATH = '/home/fabric/DDoS_Detection_wP4/model/autoencoder_ddos.h5'
SCALER_PATH      = '/home/fabric/DDoS_Detection_wP4/model/scaler_ddos.pkl'
THRESH_PATH      = '/home/fabric/DDoS_Detection_wP4/model/threshold.txt'

def load_ae_assets(model_path=AUTOENCODER_PATH,
                   scaler_path=SCALER_PATH,
                   thresh_path=THRESH_PATH):
    import os
    import joblib
    import tensorflow as tf

    def must_readable(path, label):
        if not os.path.exists(path):
            raise FileNotFoundError(f"{label} not found: {path} (cwd={os.getcwd()})")
        if not os.access(path, os.R_OK):
            raise PermissionError(f"{label} not readable: {path}")

    must_readable(model_path,  "Autoencoder (.h5)")
    must_readable(scaler_path, "Scaler (.pkl)")
    must_readable(thresh_path, "Threshold (.txt)")

    with open(thresh_path) as f:
        threshold = float(f.read().strip())

    scaler = joblib.load(scaler_path)                           # needs scikit-learn
    model  = tf.keras.models.load_model(model_path, compile=False)  # h5py required

    return model, scaler, threshold

try:
    ae_model, scaler, THRESHOLD = load_ae_assets()
    logger.info("Loaded AE assets. TF=%s, threshold=%.6f", tf.__version__, THRESHOLD)
except Exception as e:
    logger.exception("Failed to load AE assets")
    print("FATAL: failed to load autoencoder assets:", e)
    sys.exit(1)


# ========== Load Random Forest Model ==========
num_features = [
    "avg_queue_occupancy", "std_queue_occupancy",
    "avg_packet_length", "std_packet_length",
    "avg_inter_arrival", "std_inter_arrival",
    "total_packets", "total_packet_length", "unique_src_addrs"
]



# ========== Constants ==========
HOST_IP = "localhost"
REGISTER_COUNTER = 'r_counter'
OTHER_REGISTERS = ['r_srcAddr', 'r_dstAddr']
VERBOSE = True
BATCH_SIZE = 1000
CACHE_FILE = 'non_zero_indices.pkl'
CACHE_REFRESH_INTERVAL = 600
last_cache_update = 0
TIMESTAMP_MAX = 2**32
LISTENER_TIMEOUT = 10
MAX_ENTRIES = 200000
MAX_SEEN = 500000
CAPTURE_DEFAULT_SECONDS = 30
PROGRESS_STEP_SECONDS = 5

# ========== Globals ==========
digest_entries = []
digest_listener_running = False
table_name = "SwitchIngress.control_digest"
action_name = "SwitchIngress.enable_digest"
menu_input_lock = threading.Lock()
seen_entries = set()

# ========== Utility Functions ==========
def log(msg, verbose=VERBOSE):
    if verbose:
        print(msg)

def safe_int(value):
    if value is None:
        return 0
    if isinstance(value, bytes):
        return int.from_bytes(value, 'big')
    return int(value) if isinstance(value, (int,)) else value

def silent_input(timeout=1):
    ready, _, _ = select.select([sys.stdin], [], [], timeout)
    if ready:
        return sys.stdin.readline().strip()
    return None

def drain_stdin():
    while True:
        ready, _, _ = select.select([sys.stdin], [], [], 0)
        if not ready:
            break
        try:
            sys.stdin.readline()
        except Exception:
            break

def ipv4_to_int(ipv4_str):
    try:
        return int(ipaddress.IPv4Address(ipv4_str))
    except ValueError:
        raise ValueError(f"Invalid IPv4 address: {ipv4_str}")

def convert_ip(ip_int):
    return '.'.join(str((ip_int >> (8 * i)) & 0xFF) for i in reversed(range(4)))

def flush_digest_queue(
    interface,
    per_call_timeout=0.05,     # how long each poll waits
    idle_grace=0.25,           # how long of continuous idle to declare “drained”
    idle_streak_required=8,    # number of consecutive timeouts to qualify as idle (= idle_grace/per_call_timeout)
    max_seconds=6.0,           # hard wall-clock cap (increase if backlog is massive)
    max_batches=None           # optional absolute cap on batches; None = unlimited (subject to max_seconds)
):
    """
    Drain pending digest messages AFTER deleting table entries.
    Returns (drained_count, idle_reached: bool).
    - We consider the queue 'drained' when we observe `idle_streak_required` consecutive timeouts.
    - We still stop no later than `max_seconds` to avoid hanging under continuous production.
    """
    drained = 0
    deadline = time.time() + max_seconds
    idle_streak = 0

    while time.time() < deadline:
        try:
            interface.digest_get(timeout=per_call_timeout)
            drained += 1
            idle_streak = 0
            if max_batches is not None and drained >= max_batches:
                return drained, False
        except Exception:
            idle_streak += 1
            if idle_streak >= idle_streak_required:
                # Seen enough consecutive timeouts: treat as drained/idle.
                return drained, True
    # Timed out overall
    return drained, False


# ========== gRPC Setup ==========
def gc_connect():
    global table_name, action_name
    for bfrt_client_id in range(10):
        try:
            interface = gc.ClientInterface(
                grpc_addr=f"{HOST_IP}:50052",
                client_id=bfrt_client_id,
                device_id=0,
                num_tries=1,
            )
            break
        except Exception as e:
            logger.error("Failed to connect to gRPC server (client_id %d): %s", bfrt_client_id, str(e))
            print("Failed to connect to gRPC server. Ensure bf_switchd is running and no other clients are connected.")
            sys.exit(1)
    dev_tgt = gc.Target(device_id=0, pipe_id=0xffff)
    try:
        bfrt_info = interface.bfrt_info_get()
        if bfrt_client_id == 0:
            interface.bind_pipeline_config("cNdAmlight")
        logger.debug("Connected to gRPC server with client_id %d, p4_name: %s", bfrt_client_id, bfrt_info.p4_name_get())
        print(f"P4 program: {bfrt_info.p4_name_get()}")
        try:
            bfrt_info.table_get(table_name)
            logger.debug("Table %s found", table_name)
            print(f"Table {table_name} found")
        except Exception as e:
            logger.error("Table %s not found: %s", table_name, str(e))
            print(f"Table {table_name} not found: {e}")
            table_name = input("Enter the correct table name from cNdAmlight.p4 (or press Enter to use 'SwitchIngress.control_digest'): ").strip() or table_name
            action_name = input("Enter the correct action name from cNdAmlight.p4 (or press Enter to use 'SwitchIngress.enable_digest'): ").strip() or action_name
        return interface, dev_tgt, bfrt_info
    except Exception as e:
        logger.error("Failed to get bfrt_info or bind pipeline: %s", str(e))
        print(f"Failed to initialize P4 program: {e}")
        sys.exit(1)

# ========== Digest Functions ==========
def show_digest_entries(interface, dev_tgt, bfrt_info, table_name=table_name):
    try:
        table = bfrt_info.table_get(table_name)
        entries = []
        print("=" * 50)
        print(f"=== ENTRIES IN TABLE: {table_name} ===")
        for data, key in table.entry_get(dev_tgt, [], {"from_hw": True}):
            key_dict = key.to_dict()
            dst_ip_int = key_dict["hdr.inner_ipv4.dst_addr"]["value"]
            dst_ip = str(ipaddress.IPv4Address(dst_ip_int))
            entries.append((dst_ip, data.to_dict()))
            print(f"\tEntry key (dst IP): {dst_ip}")
        if not entries:
            print(f"No entries found in table: {table_name}")
        print("=" * 50)
        print()
        logger.debug("Displayed %s entries: %s", table_name, entries)
        return entries
    except Exception as e:
        print(f"Error accessing table {table_name}: {e}")
        logger.error("Error accessing table %s: %s", table_name, str(e))
        return []

def show_digest_outcomes(interface, dev_tgt, bfrt_info, table_name=table_name, raw_top_n=10):
    global digest_entries, digest_listener_running, ae_model, scaler, THRESHOLD
    print("=" * 50)
    print("=== DIGEST OUTCOMES ===")
    print(f"Listener running: {digest_listener_running}")
    print(f"Processing {len(digest_entries)} digest entries for aggregation...")
    if not digest_entries:
        print("No digest entries received.")
        print("=" * 50)
        print()
        return

    processed_entries = []
    for parsed in digest_entries:
        dst_addr = parsed.get("dst_addr")
        src_addr = parsed.get("src_addr")
        try:
            dst_addr_str = str(ipaddress.IPv4Address(dst_addr)) if dst_addr is not None else None
            src_addr_str = str(ipaddress.IPv4Address(src_addr)) if src_addr is not None else None
        except Exception:
            continue
        queue_occupancy = safe_int(parsed.get("queue_occupancy", 0))
        packet_length = safe_int(parsed.get("packet_length", 0))
        ingress_timestamp = safe_int(parsed.get("ingress_timestamp", 0))
        processed_entries.append({
            "dst_addr": dst_addr_str,
            "src_addr": src_addr_str,
            "queue_occupancy": queue_occupancy,
            "packet_length": packet_length,
            "ingress_timestamp": ingress_timestamp
        })

    df = pd.DataFrame(processed_entries)
    unique_dst = df['dst_addr'].nunique() if not df.empty else 0

    if not df.empty:
        raw_view = (
            df.sort_values("ingress_timestamp", ascending=False)
              .loc[:, ["ingress_timestamp", "dst_addr", "src_addr", "queue_occupancy", "packet_length"]]
              .head(raw_top_n)
        )
        print(f"\nTop {min(raw_top_n, len(df))} NON-AGGREGATED digest rows (most recent first):")
        print(raw_view.to_string(index=False))
    else:
        print("\n(No rows to show in raw view)")
        print(f"Number of unique destination IPs: {unique_dst}")
        print("=" * 50)
        print()
        return

    # Ensure correct order for diff per destination
    df = df.sort_values(['dst_addr', 'ingress_timestamp'], ascending=[True, True]).copy()
    df['inter_arrival_time'] = df.groupby('dst_addr')['ingress_timestamp'].diff()

    # Drop negatives (wrap/reorder artifacts); keep first-in-group NaNs as 0
    neg_mask = df['inter_arrival_time'] < 0
    dropped = int(neg_mask.sum())
    df = df.loc[~neg_mask].copy()
    df['inter_arrival_time'] = df['inter_arrival_time'].fillna(0)

    # Aggregate
    agg_df = df.groupby('dst_addr').agg(
        avg_queue_occupancy=('queue_occupancy', 'mean'),
        std_queue_occupancy=('queue_occupancy', 'std'),
        avg_packet_length=('packet_length', 'mean'),
        std_packet_length=('packet_length', 'std'),
        avg_inter_arrival=('inter_arrival_time', 'mean'),
        std_inter_arrival=('inter_arrival_time', 'std'),
        total_packets=('dst_addr', 'size'),
        total_packet_length=('packet_length', 'sum'),
        unique_src_addrs=('src_addr', 'nunique')
    ).reset_index()
    for col in ['std_queue_occupancy', 'std_packet_length', 'std_inter_arrival']:
        agg_df[col] = agg_df[col].fillna(0)
    agg_df = agg_df.round(2)


    print("\n[DEBUG] ae_model loaded:", ae_model is not None)
    print("[DEBUG] scaler loaded:", scaler is not None)
    print("[DEBUG] THRESHOLD:", THRESHOLD)
    print("[DEBUG] agg_df shape:", agg_df.shape)

    # Predict using Autoencoder model
    if ae_model is not None and scaler is not None and not agg_df.empty:
        try:
            X = agg_df[num_features].fillna(0)

            # --- Apply same preprocessing ---
            # log1p transform for skewed cols
            skew_cols = ["avg_inter_arrival", "std_inter_arrival",
                        "total_packets", "total_packet_length"]
            for col in skew_cols:
                if col in X.columns:
                    X[col + "_log"] = np.log1p(X[col])
            X = X.drop(columns=[c for c in skew_cols if c in X.columns], errors="ignore")

            # Scale using training scaler
            X_scaled = scaler.transform(X)

            # Compute reconstruction and error
            recon = ae_model.predict(X_scaled)
            mse = np.mean(np.square(recon - X_scaled), axis=1)

            # Compare with threshold
            agg_df["reconstruction_error"] = mse
            agg_df["prediction"] = np.where(
                mse > THRESHOLD, "Attack flow", "Normal flow"
            )

            logger.debug("Generated autoencoder predictions for %d destination IPs", len(agg_df))

        except Exception as e:
            logger.error("Failed to make predictions with Autoencoder model: %s", str(e))
            print(f"Failed to make predictions: {e}")
            agg_df["prediction"] = "Prediction failed"
    else:
        if agg_df.empty:
            print("No rows left after dropping negatives.")
        if "prediction" not in agg_df.columns:
            agg_df["prediction"] = "Model not loaded"


    print("\nAggregated Data by Destination IP with Predictions:")
    if not agg_df.empty:
        print(agg_df.to_string(index=False))
    else:
        print("(empty)")
    print("\nEnd of Aggregated Data")
    print("=" * 50)
    print()

def add_digest_entries(interface, dev_tgt, bfrt_info, table_name=table_name):
    global action_name
    with menu_input_lock:
        show_digest_entries(interface, dev_tgt, bfrt_info, table_name)
        try:
            ip_input = input("Enter destination IPv4 addresses (space-separated, e.g., 192.168.200.30 192.168.200.40): ").strip()
            if not ip_input:
                print("No IP addresses provided.")
                return
            if ip_input.lower() == 'clear':
                print("Use menu option 3 (Delete digest entries) to clear everything.")
                return

            try:
                table = bfrt_info.table_get(table_name)
            except Exception as e:
                print(f"Table {table_name} not found: {e}")
                table_name = input("Enter the correct table name from cNdAmlight.p4 (or press Enter to use 'SwitchIngress.control_digest'): ").strip() or table_name
                table = bfrt_info.table_get(table_name)

            dst_ips = ip_input.split()
            keys, datas = [], []
            for dst_ip in dst_ips:
                dst_ip_int = ipv4_to_int(dst_ip)
                keys.append(table.make_key([gc.KeyTuple("hdr.inner_ipv4.dst_addr", dst_ip_int)]))
                datas.append(table.make_data([], action_name))
            table.entry_add(dev_tgt, keys, datas)
            print(f"\nAdded {len(datas)} digest entries successfully.")
            show_digest_entries(interface, dev_tgt, bfrt_info, table_name)
        except Exception as e:
            print(f"Failed to add digest entries: {e}")
            logger.error("Failed to add digest entries: %s", str(e))

def delete_digest_entries(interface, dev_tgt, bfrt_info, table_name=table_name):
    global digest_entries, digest_listener_running, seen_entries
    with menu_input_lock:
        show_digest_entries(interface, dev_tgt, bfrt_info, table_name)
        try:
            ip_input = input(
                "Enter destination IPv4 addresses to delete (space-separated), "
                "or type 'clear' to delete ALL (does NOT clear registers): "
            ).strip()

            try:
                table = bfrt_info.table_get(table_name)
            except Exception as e:
                print(f"Table {table_name} not found: {e}")
                table_name = input(
                    "Enter the correct table name from cNdAmlight.p4 (or press Enter to use "
                    "'SwitchIngress.control_digest'): "
                ).strip() or table_name
                table = bfrt_info.table_get(table_name)

            if ip_input.lower() == 'clear':
                table.entry_del(dev_tgt, [])
                print(f"\nCleared ALL entries from table: {table_name}")
                flushed = flush_digest_queue(interface)
                print(f"Drained {flushed} residual digest batch(es) after delete.")
                digest_entries = []
                seen_entries = set()
                print("Reset in-memory buffers (digest_entries, seen_entries).")
                show_digest_entries(interface, dev_tgt, bfrt_info, table_name)
                return

            dst_ips = ip_input.split()
            if not dst_ips:
                print("No IP addresses provided.")
                return

            keys = [table.make_key([gc.KeyTuple("hdr.inner_ipv4.dst_addr", ipv4_to_int(dst_ip))]) for dst_ip in dst_ips]
            table.entry_del(dev_tgt, keys)
            print(f"\nDeleted {len(keys)} digest entries successfully.")
            flushed = flush_digest_queue(interface, idle_grace=0.3, per_call_timeout=0.05)
            if flushed:
                print(f"Drained {flushed} residual digest batch(es) after delete.")
            show_digest_entries(interface, dev_tgt, bfrt_info, table_name)

        except Exception as e:
            print(f"Failed to delete digest entries: {e}")
            logger.error("Failed to delete digest entries: %s", str(e))

# ========== Start-Listener Flow (delete -> drain -> settle -> add -> capture) ==========
def start_listener_for_ips_live(interface, dev_tgt, bfrt_info, dst_ip_list,
                                duration_seconds=CAPTURE_DEFAULT_SECONDS,
                                table_name="SwitchIngress.control_digest"):
    """
    Per-run boundary:
      1) Delete ALL digest entries to stop new digests
      2) Drain residual queue (hard-capped), then short settle
      3) Add provided IP list
      4) Timed listener (bounded waits, overload protection)
    """
    global digest_listener_running, action_name, digest_entries, seen_entries

    if digest_listener_running:
        print("Listener already running. Stop it first.")
        return

    # Resolve table
    try:
        table = bfrt_info.table_get(table_name)
    except Exception as e:
        print(f"Table {table_name} not found: {e}")
        return

    # 1) Delete all digest entries FIRST
    try:
        table.entry_del(dev_tgt, [])
        print(f"Cleared all entries in {table_name}")
    except Exception as e:
        print(f"Failed to clear table {table_name}: {e}")
        return

    # 2) Drain residual digests (hard-capped) and short settle
    drained, idle = flush_digest_queue(
        interface,
        per_call_timeout=0.05,
        idle_grace=0.25,
        idle_streak_required=8,  # 8 * 0.05s ~= 0.40s of continuous idle
        max_seconds=6.0,         # bump this if you still see spillover
        max_batches=None         # let time, not a count cap, decide
    )
    print(f"Drained {drained} residual digest batch(es) after delete. Idle reached: {idle}")
    time.sleep(0.15)  # brief settle


    # 3) Validate and add IPs
    keys, datas, added = [], [], []
    try:
        for ip_str in dst_ip_list:
            ip_str = ip_str.strip()
            if not ip_str:
                continue
            ip_int = ipv4_to_int(ip_str)
            keys.append(table.make_key([gc.KeyTuple("hdr.inner_ipv4.dst_addr", ip_int)]))
            datas.append(table.make_data([], action_name))
            added.append(ip_str)
    except Exception as e:
        print(f"Invalid IP in list or action error: {e}")
        return

    try:
        if keys:
            table.entry_add(dev_tgt, keys, datas)
            print(f"Added {len(keys)} digest entries: {', '.join(added)}")
        else:
            print("No valid IPs provided.")
            return
    except Exception as e:
        print(f"Failed adding entries: {e}")
        return

    # 4) Confirm and run timed listener
    show_digest_entries(interface, dev_tgt, bfrt_info, table_name)
    _run_live_listener(interface, duration_seconds)
    show_digest_outcomes(interface, dev_tgt, bfrt_info, table_name=table_name)

def _run_live_listener(interface, duration_seconds):
    """Bounded, safe listener for live traffic."""
    global digest_entries, digest_listener_running, seen_entries

    digest_entries = []
    seen_entries = set()
    digest_listener_running = True

    try:
        bfrt_info = interface.bfrt_info_get("cNdAmlight")
        digest_a = bfrt_info.learn_get("pipe.SwitchIngressDeparser.digest_a")
    except Exception as e:
        print(f"Could not init digest object: {e}")
        digest_listener_running = False
        return

    stop_evt = threading.Event()
    print(f"Listening for {duration_seconds} seconds. Press 'm' to stop early.")

    def _stop_on_m():
        while not stop_evt.is_set():
            key = silent_input(timeout=0.5)
            if key and key.lower() == "m":
                stop_evt.set()
                break

    t = threading.Thread(target=_stop_on_m, daemon=True)
    t.start()

    start_t = time.time()
    end_t = start_t + duration_seconds
    next_report = start_t + PROGRESS_STEP_SECONDS
    last_msg_t = start_t

    try:
        while not stop_evt.is_set() and time.time() < end_t:
            now = time.time()
            try:
                digest = interface.digest_get(timeout=0.8)  # bounded wait
                data_list = digest_a.make_data_list(digest)
            except Exception:
                data_list = None

            if data_list:
                last_msg_t = now
                for data in data_list:
                    entry = data.to_dict()
                    entry_hash = hashlib.sha256(str(sorted(entry.items())).encode()).hexdigest()
                    if entry_hash in seen_entries:
                        continue
                    seen_entries.add(entry_hash)
                    if len(seen_entries) > MAX_SEEN:
                        seen_entries.clear()

                    digest_entries.append(entry)
                    if len(digest_entries) > MAX_ENTRIES:
                        drop = max(1, len(digest_entries) // 10)
                        digest_entries = digest_entries[drop:]
                        print(f"High load: dropped {drop} oldest entries to stay responsive.")

            if now >= next_report:
                print(f"[{int(now - start_t)}s] entries: {len(digest_entries)}")
                next_report += PROGRESS_STEP_SECONDS

            if now - last_msg_t > LISTENER_TIMEOUT:
                print("No new digest for 10 seconds. Stopping listener.")
                stop_evt.set()

    finally:
        stop_evt.set()
        try:
            t.join(timeout=1.0)
        except Exception:
            pass
        drain_stdin()
        digest_listener_running = False
        print(f"[done] Collected {len(digest_entries)} entries.")

# ========== Register Functions (unchanged) ==========
def get_non_zero_indices(register, dev_tgt, bfrt_info):
    try:
        reg_name = f"SwitchIngress.{register}"
        table = bfrt_info.table_get(reg_name)
        data_iter = table.entry_get(dev_tgt, [], {"from_hw": True})
        non_zero_indices = []
        for data, key in data_iter:
            key_dict = key.to_dict()
            data_dict = data.to_dict()
            reg_key = f"SwitchIngress.{register}.f1"
            values = data_dict.get(reg_key, [])
            if any(v > 1 for v in values):
                index = key_dict['$REGISTER_INDEX']['value']
                non_zero_indices.append(index)
        return non_zero_indices
    except Exception as e:
        print(f"Error accessing register {register}: {e}")
        logger.error("Error accessing register %s: %s", register, str(e))
        return []

def get_register_values_batch(register, indices, dev_tgt, bfrt_info):
    try:
        reg_name = f"SwitchIngress.{register}"
        table = bfrt_info.table_get(reg_name)
        keys = [table.make_key([gc.KeyTuple('$REGISTER_INDEX', index)]) for index in indices]
        data_iter = table.entry_get(dev_tgt, keys, {"from_hw": True})
        values = {}
        for data, key in data_iter:
            key_dict = key.to_dict()
            data_dict = data.to_dict()
            index = key_dict['$REGISTER_INDEX']['value']
            reg_key = f"SwitchIngress.{register}.f1"
            value = data_dict.get(reg_key, [0, 0])[1]
            values[index] = value
        return values
    except Exception as e:
        print(f"Error reading register {register} for batched indices: {e}")
        logger.error("Error reading register %s for batched indices: %s", register, str(e))
        return {index: 0 for index in indices}

def clear_registers(interface, dev_tgt, bfrt_info):
    all_registers = [REGISTER_COUNTER] + OTHER_REGISTERS
    for register in all_registers:
        try:
            reg_name = f"SwitchIngress.{register}"
            table = bfrt_info.table_get(reg_name)
            table.entry_del(dev_tgt, [])
            log(f"Cleared register: {register}")
            logger.debug("Cleared register: %s", register)
        except Exception as e:
            print(f"Failed to clear register {register}: {e}")
            logger.error("Failed to clear register %s: %s", register, str(e))
    print("All registers cleared.")

def inspect_registers(interface, dev_tgt, bfrt_info):
    with menu_input_lock:
        print("Scanning non-zero entries in r_counter...\n")
        start_time = time.time()
        indices = get_non_zero_indices(REGISTER_COUNTER, dev_tgt, bfrt_info)
        print(f"Scanning r_counter took {time.time() - start_time:.3f} seconds")
        if not indices:
            print("No non-zero entries found.")
            return
        print(f"Processing {len(indices)} non-zero entries...")
        start_time = time.time()
        agg_data = {}
        for i in range(0, len(indices), BATCH_SIZE):
            batch_indices = indices[i:i + BATCH_SIZE]
            dst_values = get_register_values_batch('r_dstAddr', batch_indices, dev_tgt, bfrt_info)
            counter_values = get_register_values_batch('r_counter', batch_indices, dev_tgt, bfrt_info)
            for index in batch_indices:
                dst = dst_values.get(index, 0)
                counter = counter_values.get(index, 0)
                if dst not in agg_data:
                    agg_data[dst] = {'TotalCounter': 0, 'UniqueSrcIPs': 0}
                agg_data[dst]['TotalCounter'] += counter
                agg_data[dst]['UniqueSrcIPs'] += 1
        print(f"Fetching and aggregating register values took {time.time() - start_time:.3f} seconds")
        start_time = time.time()
        rows = [{'Dst IP': dst, 'TotalCounter': data['TotalCounter'], 'UniqueSrcIPs': data['UniqueSrcIPs']} for dst, data in agg_data.items()]
        df = pd.DataFrame(rows)
        n = 10
        summary_df = df.sort_values(by=['TotalCounter', 'Dst IP'], ascending=[False, True]).head(n)
        # summary_df_tail = df.sort_values(by=['TotalCounter', 'Dst IP'], ascending=[False, True]).tail(n)
        summary_df['Dst IP'] = summary_df['Dst IP'].apply(convert_ip)
        # summary_df_tail['Dst IP'] = summary_df_tail['Dst IP'].apply(convert_ip)
        print(f"Pandas processing and IP conversion took {time.time() - start_time:.3f} seconds")
        print(f"\nSummary: Top {n} Dst IPs by Total Counter and Unique Src IPs\n")
        print(summary_df.to_string(index=False))
        # print(f"\nSummary: Bottom {n} Dst IPs by Total Counter and Unique Src IPs\n")
        # print(summary_df_tail.to_string(index=False))

# ========== Menu ==========
def menu():
    interface, dev_tgt, bfrt_info = gc_connect()
    print("\nWelcome to the P4 Digest and Register Menu!")

    while True:
        with menu_input_lock:
            print("\n=== MAIN MENU ===")
            # print("1. Add digest entries")
            print("1. Start listener (delete -> drain -> add IP list -> capture)")
            print("2. Delete digest entries")
            print("3. Show digest table entries")
            print("4. Show digest outcomes")
            print("5. Inspect registers")
            print("6. Clear registers")
            print("0. Exit")
            choice = input("Choose an option: ")

        # if choice == '1':
        #     add_digest_entries(interface, dev_tgt, bfrt_info)
        if choice == '1':
            with menu_input_lock:
                ip_str = input("Enter destination IPv4 addresses (space-separated): ").strip()
                dur_in = input(f"Capture duration in seconds (default {CAPTURE_DEFAULT_SECONDS}): ").strip()
            duration = int(dur_in) if dur_in.isdigit() else CAPTURE_DEFAULT_SECONDS
            if not ip_str:
                print("No IPs provided.")
            else:
                dst_list = [p for p in ip_str.split() if p.strip()]
                start_listener_for_ips_live(interface, dev_tgt, bfrt_info, dst_list, duration_seconds=duration)
        elif choice == '2':
            delete_digest_entries(interface, dev_tgt, bfrt_info)
        elif choice == '3':
            show_digest_entries(interface, dev_tgt, bfrt_info)
        elif choice == '4':
            show_digest_outcomes(interface, dev_tgt, bfrt_info)
        elif choice == '5':
            inspect_registers(interface, dev_tgt, bfrt_info)
        elif choice == '6':
            clear_registers(interface, dev_tgt, bfrt_info)
        elif choice == '0':
            print("Exiting program.")
            try:
                interface.tear_down_stream()
            except Exception:
                pass
            break
        else:
            print("Invalid option. Try again.")

# ========== Entry Point ==========
if __name__ == '__main__':
    os.system('cls' if os.name == 'nt' else 'clear')
    menu()
