#!/usr/bin/env python3
# Jupyter/CLI-friendly "Menu 1" flow:
# delete -> drain -> add IP list -> capture -> aggregate + AE predict
# - No interactive input
# - Preserves digest_messages.jsonl as JSON Lines (one dict per line)
# - Prints progress/results; does not return anything
# - IPs & duration provided via CLI args (argparse)

import sys, os, time, hashlib, ipaddress, json, logging, argparse
sys.path.append('/nix/store/qfv2ayja1zgrfbiy1nrd9f0b1y759h60-python3-3.11.6-env/lib/python3.11/site-packages')

import numpy as np
import pandas as pd

try:
    import bfrt_grpc.client as gc
except ImportError as e:
    raise RuntimeError(f"Failed to import bfrt_grpc.client: {e}")

import joblib
import tensorflow as tf

# ================== Logging ==================
logger = logging.getLogger('DigestMenu1')
logger.setLevel(logging.DEBUG)
for h in list(logger.handlers):
    logger.removeHandler(h)
nb_handler = logging.StreamHandler(sys.stdout)
nb_handler.setLevel(logging.INFO)
logger.addHandler(nb_handler)

# ================== AE assets ==================
AUTOENCODER_PATH = '/home/fabric/DDoS_Detection_wP4/model/autoencoder_ddos.h5'
SCALER_PATH      = '/home/fabric/DDoS_Detection_wP4/model/scaler_ddos.pkl'
THRESH_PATH      = '/home/fabric/DDoS_Detection_wP4/model/threshold.txt'

def _must_readable(path, label):
    if not os.path.exists(path):
        raise FileNotFoundError(f"{label} not found: {path} (cwd={os.getcwd()})")
    if not os.access(path, os.R_OK):
        raise PermissionError(f"{label} not readable: {path}")

def load_ae_assets(model_path=AUTOENCODER_PATH, scaler_path=SCALER_PATH, thresh_path=THRESH_PATH):
    _must_readable(model_path,  "Autoencoder (.h5)")
    _must_readable(scaler_path, "Scaler (.pkl)")
    _must_readable(thresh_path, "Threshold (.txt)")
    with open(thresh_path) as f:
        threshold = float(f.read().strip())
    scaler = joblib.load(scaler_path)
    model  = tf.keras.models.load_model(model_path, compile=False)
    return model, scaler, threshold

try:
    ae_model, scaler, THRESHOLD = load_ae_assets()
    logger.info("Loaded AE assets (TF=%s). Threshold=%.6f", tf.__version__, THRESHOLD)
except Exception as e:
    raise RuntimeError(f"FATAL: failed to load autoencoder assets: {e}")

# ================== Defaults (override via CLI) ==================
HOST_IP = "localhost"
DEFAULT_TABLE  = "SwitchIngress.control_digest"
DEFAULT_ACTION = "SwitchIngress.enable_digest"
LEARN_OBJ      = "pipe.SwitchIngressDeparser.digest_a"
DEFAULT_OUTDIR = "/home/fabric/DDoS_Detection_wP4/menu"

MAX_ENTRIES = 200_000
MAX_SEEN    = 500_000
LISTENER_TIMEOUT = 10          # stop if no new digests for this many seconds
PROGRESS_STEP_SECONDS = 5

# ================== Globals ==================
digest_entries = []
seen_entries = set()

# ================== Utilities ==================
def ipv4_to_int(ipv4_str):
    try:
        return int(ipaddress.IPv4Address(ipv4_str))
    except ValueError:
        raise ValueError(f"Invalid IPv4 address: {ipv4_str}")

def convert_ip(ip_int):
    return '.'.join(str((ip_int >> (8 * i)) & 0xFF) for i in reversed(range(4)))

def safe_int(value):
    if value is None:
        return 0
    if isinstance(value, bytes):
        return int.from_bytes(value, 'big')
    return int(value) if isinstance(value, (int,)) else value

# ================== gRPC ==================
def gc_connect(table_name: str):
    for bfrt_client_id in range(10):
        try:
            interface = gc.ClientInterface(
                grpc_addr=f"{HOST_IP}:50052",
                client_id=bfrt_client_id,
                device_id=0,
                num_tries=1,
            )
            break
        except Exception as e:
            if bfrt_client_id == 9:
                raise RuntimeError(f"Failed to connect to gRPC server on any client_id: {e}")
    dev_tgt = gc.Target(device_id=0, pipe_id=0xffff)
    try:
        bfrt_info = interface.bfrt_info_get()
        if bfrt_client_id == 0:
            interface.bind_pipeline_config("cNdAmlight")
        logger.info("Connected. P4 program: %s", bfrt_info.p4_name_get())
        try:
            bfrt_info.table_get(table_name)
        except Exception as e:
            logger.warning("Configured table '%s' not found: %s", table_name, e)
        return interface, dev_tgt, bfrt_info
    except Exception as e:
        raise RuntimeError(f"Failed to initialize P4 program: {e}")

# ================== Digest helpers ==================
def show_digest_entries(interface, dev_tgt, bfrt_info, table_name: str):
    try:
        table = bfrt_info.table_get(table_name)
        entries = []
        print("=" * 50)
        print(f"=== ENTRIES IN TABLE: {table_name} ===")
        for data, key in table.entry_get(dev_tgt, [], {"from_hw": True}):
            key_dict = key.to_dict()
            dst_ip_int = key_dict["hdr.inner_ipv4.dst_addr"]["value"]
            dst_ip = str(ipaddress.IPv4Address(dst_ip_int))
            entries.append((dst_ip, data.to_dict()))
            print(f"\tEntry key (dst IP): {dst_ip}")
        if not entries:
            print(f"No entries found in table: {table_name}")
        print("=" * 50)
        return entries
    except Exception as e:
        print(f"Error accessing table {table_name}: {e}")
        return []

def flush_digest_queue(
    interface,
    per_call_timeout=0.05,     # sec per poll
    idle_grace=0.25,           # continuous idle to declare 'drained'
    idle_streak_required=8,    # ~ idle_grace / per_call_timeout
    max_seconds=6.0,
):
    """
    Drain pending digest messages AFTER deleting table entries.
    Returns (drained_batches, idle_reached: bool)
    """
    drained = 0
    deadline = time.time() + max_seconds
    idle_streak = 0
    while time.time() < deadline:
        try:
            interface.digest_get(timeout=per_call_timeout)
            drained += 1
            idle_streak = 0
        except Exception:
            idle_streak += 1
            if idle_streak >= idle_streak_required:
                return drained, True
    return drained, False

def _run_live_listener(interface, duration_seconds, jsonl_file, learn_obj=LEARN_OBJ):
    """
    - polls for up to duration_seconds
    - stops early if no new messages arrive for LISTENER_TIMEOUT seconds
    - appends each raw digest dict to jsonl_file (one JSON per line)
    - mirrors console progress prints
    """
    global digest_entries, seen_entries
    digest_entries = []
    seen_entries = set()

    try:
        bfrt_info = interface.bfrt_info_get("cNdAmlight")
        digest_a = bfrt_info.learn_get(learn_obj)
    except Exception as e:
        print(f"Could not init digest object: {e}")
        return

    # ensure file exists/emptied
    os.makedirs(os.path.dirname(jsonl_file), exist_ok=True)
    with open(jsonl_file, 'w') as _tmp:
        pass

    start_t = time.time()
    end_t   = start_t + duration_seconds
    next_report = start_t + PROGRESS_STEP_SECONDS
    last_msg_t  = start_t

    print(f"Listening for {duration_seconds} seconds…")
    while time.time() < end_t:
        now = time.time()
        try:
            digest = interface.digest_get(timeout=0.8)  # bounded wait
            data_list = digest_a.make_data_list(digest)
        except Exception:
            data_list = None

        if data_list:
            last_msg_t = now
            with open(jsonl_file, 'a') as f:
                for data in data_list:
                    entry = data.to_dict()  # raw dict, preserved
                    entry_hash = hashlib.sha256(str(sorted(entry.items())).encode()).hexdigest()
                    if entry_hash in seen_entries:
                        continue
                    seen_entries.add(entry_hash)
                    if len(seen_entries) > MAX_SEEN:
                        seen_entries.clear()
                    digest_entries.append(entry)
                    if len(digest_entries) > MAX_ENTRIES:
                        drop = max(1, len(digest_entries) // 10)
                        digest_entries = digest_entries[drop:]
                    json.dump(entry, f)
                    f.write('\n')  # JSON Lines

        if now >= next_report:
            print(f"[{int(now - start_t)}s] entries: {len(digest_entries)}")
            next_report += PROGRESS_STEP_SECONDS

        if now - last_msg_t > LISTENER_TIMEOUT:
            print(f"No new digest for {LISTENER_TIMEOUT} seconds. Stopping.")
            break

    print(f"[done] Collected {len(digest_entries)} entries. Saved to {jsonl_file}")

def _aggregate_and_predict():
    """
    Build per-dst aggregates, run AE predictions, and print results.
    Uses global digest_entries, ae_model, scaler, THRESHOLD.
    Prints only; returns nothing.
    """
    if not digest_entries:
        print("No digest entries received.")
        return

    processed_entries = []
    for parsed in digest_entries:
        dst_addr = parsed.get("dst_addr")
        src_addr = parsed.get("src_addr")
        try:
            dst_addr_str = str(ipaddress.IPv4Address(dst_addr)) if dst_addr is not None else None
            src_addr_str = str(ipaddress.IPv4Address(src_addr)) if src_addr is not None else None
        except Exception:
            continue
        queue_occupancy  = safe_int(parsed.get("queue_occupancy", 0))
        packet_length    = safe_int(parsed.get("packet_length", 0))
        ingress_timestamp= safe_int(parsed.get("ingress_timestamp", 0))
        processed_entries.append({
            "dst_addr": dst_addr_str,
            "src_addr": src_addr_str,
            "queue_occupancy": queue_occupancy,
            "packet_length": packet_length,
            "ingress_timestamp": ingress_timestamp
        })

    df = pd.DataFrame(processed_entries)
    if df.empty:
        print("(No valid rows after parsing)")
        return

    # Raw view (recent first) — FAST
    raw_top_n = 5
    cols = ["ingress_timestamp", "dst_addr", "src_addr", "queue_occupancy", "packet_length"]
    try:
        raw_view = df.nlargest(raw_top_n, "ingress_timestamp")[cols]
    except ValueError:
        # In case ingress_timestamp has NaNs or non-numeric entries (shouldn’t, but be safe)
        df_num = df.copy()
        df_num["ingress_timestamp"] = pd.to_numeric(df_num["ingress_timestamp"], errors="coerce").fillna(0)
        raw_view = df_num.nlargest(raw_top_n, "ingress_timestamp")[cols]
    
    print(f"\nTop {min(raw_top_n, len(df))} NON-AGGREGATED digest rows (most recent first):", flush=True)
    for r in raw_view.itertuples(index=False):
        print(r, flush=True)


    # Inter-arrival per destination
    df = df.sort_values(['dst_addr', 'ingress_timestamp'], ascending=[True, True]).copy()
    df['inter_arrival_time'] = df.groupby('dst_addr')['ingress_timestamp'].diff()
    neg_mask = df['inter_arrival_time'] < 0
    df = df.loc[~neg_mask].copy()
    df['inter_arrival_time'] = df['inter_arrival_time'].fillna(0)

    # Aggregate
    agg_df = df.groupby('dst_addr').agg(
        avg_queue_occupancy=('queue_occupancy', 'mean'),
        std_queue_occupancy=('queue_occupancy', 'std'),
        avg_packet_length=('packet_length', 'mean'),
        std_packet_length=('packet_length', 'std'),
        avg_inter_arrival=('inter_arrival_time', 'mean'),
        std_inter_arrival=('inter_arrival_time', 'std'),
        total_packets=('dst_addr', 'size'),
        total_packet_length=('packet_length', 'sum'),
        unique_src_addrs=('src_addr', 'nunique')
    ).reset_index()
    for col in ['std_queue_occupancy', 'std_packet_length', 'std_inter_arrival']:
        agg_df[col] = agg_df[col].fillna(0)
    agg_df = agg_df.round(2)

    # === Apply same preprocessing as training ===
    num_features = [
        "avg_queue_occupancy", "std_queue_occupancy",
        "avg_packet_length", "std_packet_length",
        "avg_inter_arrival", "std_inter_arrival",
        "total_packets", "total_packet_length", "unique_src_addrs"
    ]
    if not agg_df.empty:
        X = agg_df[num_features].fillna(0).copy()
        # log1p transform for skewed columns used in training
        skew_cols = ["avg_inter_arrival", "std_inter_arrival", "total_packets", "total_packet_length"]
        for col in skew_cols:
            if col in X.columns:
                X[col + "_log"] = np.log1p(X[col])
        X = X.drop(columns=[c for c in skew_cols if c in X.columns], errors="ignore")

        # Scale and AE reconstruction error
        X_scaled = scaler.transform(X)
        recon = ae_model.predict(X_scaled, verbose=0)
        mse = np.mean(np.square(recon - X_scaled), axis=1)

        agg_df["reconstruction_error"] = mse
        agg_df["prediction"] = np.where(mse > THRESHOLD, "Attack flow", "Normal flow")

    print("\nAggregated Data by Destination IP with Predictions:")
    if not agg_df.empty:
        print(agg_df.to_string(index=False))
    else:
        print("(empty)")
    print("\nEnd of Aggregated Data\n" + "=" * 50)

# ================== Main "Menu 1" Runner (prints only) ==================
def run_menu1(ip_addresses, duration_seconds=30,
              target_dir=DEFAULT_OUTDIR, table_name=DEFAULT_TABLE, action_name=DEFAULT_ACTION):
    """
    Non-interactive Menu 1:
      1) Connect
      2) Delete all digest entries
      3) Drain residual digests (bounded)
      4) Add provided IP list
      5) Timed listener and save JSONL
      6) Aggregate + predict outcomes and print results
    """
    os.makedirs(target_dir, exist_ok=True)
    jsonl_file = os.path.join(target_dir, 'digest_messages.jsonl')

    try:
        interface, dev_tgt, bfrt_info = gc_connect(table_name)
    except Exception as e:
        print(e)
        return

    # Resolve table
    try:
        table = bfrt_info.table_get(table_name)
    except Exception as e:
        print(f"Table {table_name} not found: {e}")
        try:
            interface.tear_down_stream()
        except Exception:
            pass
        return

    # 1) Delete all entries
    try:
        table.entry_del(dev_tgt, [])
        print(f"Cleared all entries in {table_name}")
    except Exception as e:
        print(f"Failed to clear table {table_name}: {e}")
        try:
            interface.tear_down_stream()
        except Exception:
            pass
        return

    # 2) Drain residual digests
    drained, idle = flush_digest_queue(
        interface,
        per_call_timeout=0.05,
        idle_grace=0.25,
        idle_streak_required=8,
        max_seconds=6.0
    )
    print(f"Drained {drained} residual digest batch(es) after delete. Idle reached: {idle}")
    time.sleep(0.15)

    # 3) Add provided IP list
    keys, datas, added = [], [], []
    for ip_str in ip_addresses or []:
        try:
            ip_int = ipv4_to_int(ip_str.strip())
            keys.append(table.make_key([gc.KeyTuple("hdr.inner_ipv4.dst_addr", ip_int)]))
            datas.append(table.make_data([], action_name))
            added.append(ip_str.strip())
        except Exception as e:
            print(f"Invalid IP '{ip_str}': {e}")

    if keys:
        try:
            table.entry_add(dev_tgt, keys, datas)
            print(f"Added {len(keys)} digest entries: {', '.join(added)}")
        except Exception as e:
            print(f"Failed adding entries: {e}")
    else:
        print("No valid IPs provided — skipping add.")

    # Show table
    show_digest_entries(interface, dev_tgt, bfrt_info, table_name)

    # 4) Timed listener (bounded)
    _run_live_listener(interface, duration_seconds, jsonl_file, learn_obj=LEARN_OBJ)

    # 5) Outcomes (aggregate + AE predict)
    _aggregate_and_predict()

    # Cleanly tear down stream
    try:
        interface.tear_down_stream()
    except Exception:
        pass

    print(f"\nAll done. Digest messages saved to {jsonl_file}")

# ================== CLI ==================
def parse_args():
    p = argparse.ArgumentParser(description="Menu-1 digest runner (non-interactive)")
    p.add_argument("--ips", nargs="*", default=[],
                   help="Destination IPv4 addresses (space-separated). Example: --ips 192.168.200.30 192.168.200.40")
    p.add_argument("--duration", type=int, default=30,
                   help="Capture duration in seconds (default: 30)")
    p.add_argument("--outdir", default=DEFAULT_OUTDIR,
                   help=f"Output directory for digest_messages.jsonl (default: {DEFAULT_OUTDIR})")
    p.add_argument("--table", default=DEFAULT_TABLE,
                   help=f"Digest table name (default: {DEFAULT_TABLE})")
    p.add_argument("--action", default=DEFAULT_ACTION,
                   help=f"Digest action name (default: {DEFAULT_ACTION})")
    return p.parse_args()

if __name__ == "__main__":
    args = parse_args()
    # Normalize/strip IPs
    ip_list = [ip.strip() for ip in (args.ips or []) if ip.strip()]
    run_menu1(
        ip_addresses=ip_list,
        duration_seconds=args.duration,
        target_dir=args.outdir,
        table_name=args.table,
        action_name=args.action
    )
