from scapy.plist import PacketList
from scapy.utils import rdpcap
from scapy.compat import raw
import argparse
import pandas as pd
import enum
from tqdm import tqdm
from bisect import bisect_left

CREASE_PROT = 0x6587
    
class I(enum.IntEnum):
    TLR = 0
    TIME = 1
    IDX_DELTA = 2
    TIME_DELTA = 3
    IPG = 4
    CUM_LAT = 5

# For a new value new_value, compute the new count, new mean, the new M2.
# mean accumulates the mean of the entire dataset
# M2 aggregates the squared distance from the mean
# count aggregates the number of samples seen so far
def update(existing_aggregate, new_value):
    (count, mean, m2) = existing_aggregate
    count += 1
    delta = new_value - mean
    mean += delta / count
    delta2 = new_value - mean
    m2 += delta * delta2
    return (count, mean, m2)

# Retrieve the mean, variance and sample variance from an aggregate
def finalize(existing_aggregate):
    (count, mean, m2) = existing_aggregate
    if count < 2:
        return (float("nan"), float("nan"), float("nan"))
    else:
        (mean, variance, sample_variance) = (mean, m2 / count, m2 / (count - 1))
        return (mean, variance, sample_variance)

def shortest_edit_sequence(perm1_raw, perm2_raw):
    perm1 = list(perm1_raw)
    perm2 = list(perm2_raw)
    pos_in_perm1 = {c: i for i, c in enumerate(perm1)}
    perm2_indices = [pos_in_perm1[c] for c in perm2]

    n = len(perm2_indices)
    pile_tops = []
    predecessors = [-1] * n
    pile_indices = [0] * n

    for i, x in enumerate(tqdm(perm2_indices, desc="Finding LIS in perm2_indices")):
        pos = bisect_left(pile_tops, x)
        if pos == len(pile_tops):
            pile_tops.append(x)
        else:
            pile_tops[pos] = x
        pile_indices[i] = pos
        if pos > 0:
            for j in range(i - 1, -1, -1):
                if pile_indices[j] == pos - 1 and perm2_indices[j] < x:
                    predecessors[i] = j
                    break

    lis_len = len(pile_tops)
    lis_indices = []
    k = pile_indices.index(lis_len - 1)
    while k != -1:
        lis_indices.append(k)
        k = predecessors[k]
    lis_indices = set(lis_indices)

    # Simulate moves on a working copy so indices are always correct
    index_in_perm1 = {c: i for i, c in enumerate(perm1)}
    orig_index_in_perm1 = index_in_perm1.copy()
    moves = []
    target_pos = 0
    i = 0
    prev_lis_char = None
    lis_offset = 0
    for i, c in enumerate(tqdm(perm2, desc="Finding first edit script offset")):
        if i in lis_indices:
            prev_lis_char = c
            break
    for i, c in enumerate(tqdm(perm2, desc="Building edit script")):
        if i in lis_indices:
            prev_lis_char = c
            lis_offset = 0
            continue
        source_pos = index_in_perm1[c]
        target_pos = index_in_perm1[prev_lis_char] + lis_offset
        if target_pos > source_pos:
            for j in range(source_pos + 1, min(target_pos + 1, len(perm1))):
                index_in_perm1[perm1[j]] -= 1
        else:
            target_pos += 1
            for j in range(source_pos - 1, max(target_pos - 1, -1), -1):
                index_in_perm1[perm1[j]] += 1
        lis_offset += 1
        index_in_perm1[c] = target_pos
        perm1.pop(source_pos)
        perm1.insert(target_pos, c)
        moves.append((c, source_pos, target_pos, orig_index_in_perm1[c]))
    return moves

def apply_edit_script(perm, edit_script):
    perm = list(perm)
    for _, from_idx, to_idx, _ in tqdm(edit_script, desc="Applying edit script"):
        elem = perm.pop(from_idx)
        perm.insert(to_idx, elem)
    return perm

def lin_dist(
        pktlist_a: PacketList,
        pktlist_b: PacketList,
        output_file: str,
        replayer: int = -1,
        starta: float = 0,
        startb: float = 0):
    a_indices = {}
    b_indices = {}
    agg_a0 = (0, 0, 0)
    agg_a1 = (0, 0, 0)
    agg_b0 = (0, 0, 0)
    agg_b1 = (0, 0, 0)
    m = len(pktlist_a)
    n = len(pktlist_b)
    r = {}
    q = {}
    ap = 0
    bp = 0
    a_common = []
    b_common = []
    u_a = []
    u_b = []
    a = []
    b = []
    hex0 = b'\x00\x00'.hex()
    hex1 = b'\x00\x01'.hex()
    for pkt in tqdm(pktlist_a, desc="Reading PCAP 1"):
        if int.from_bytes(raw(pkt)[-2:], 'big') == CREASE_PROT:
            tlr = raw(pkt)[-14:-6].hex()
            if replayer == 0 and tlr[-4:] == hex0:
                a.append((tlr, pkt.time))
            elif replayer == 1 and tlr[-4:] == hex1:
                a.append((tlr, pkt.time))
            elif replayer == -1:
                a.append((tlr, pkt.time))
    for pkt in tqdm(pktlist_b, desc="Reading PCAP 2"):
        if int.from_bytes(raw(pkt)[-2:], 'big') == CREASE_PROT:
            tlr = raw(pkt)[-14:-6].hex()
            if replayer == 0 and tlr[-4:] == hex0:
                b.append((tlr, pkt.time))
            elif replayer == 1 and tlr[-4:] == hex1:
                b.append((tlr, pkt.time))
            elif replayer == -1:
                b.append((tlr, pkt.time))
    for i, c in enumerate(tqdm(a, desc="First PCAP 1 Loop")):
        tlr = c[0]
        if starta == 0 and i == 0:
            starta = c[1]
        if tlr in a_indices:
            print(f"Warning: Duplicate TLR {tlr} found in a at index {i}, skipping")
            continue
        a_indices[tlr] = i
        if tlr[-4:] == hex0:
            agg_a0 = update(agg_a0, i)
        elif tlr[-4:] == hex1:
            agg_a1 = update(agg_a1, i)
    for i, c in enumerate(tqdm(b, desc="PCAP 2 Loop")):
        tlr = c[0]
        if startb == 0 and i == 0:
            startb = c[1]
        if tlr in b_indices:
            print(f"Warning: Duplicate TLR {tlr} found in b at index {i}, skipping")
            continue
        b_indices[tlr] = i
        if tlr in a_indices:
            r[tlr] = [0, bp]
            b_common.append([tlr, c[1], 0, 0, 0, 0])
            bp += 1
        else:
            u_b.append(tlr)
        if tlr[-4:] == hex0:
            agg_b0 = update(agg_b0, i)
        elif tlr[-4:] == hex1:
            agg_b1 = update(agg_b1, i)
    for i, c in enumerate(tqdm(a, desc="Second PCAP 1 Loop")):
        tlr = c[0]
        if tlr in b_indices:
            r[tlr][0] = ap
            bp = r[tlr][1]
            b_common[bp][I.IDX_DELTA] = ap - bp
            b_common[bp][I.TIME_DELTA] = c[1] - b_common[bp][I.TIME]
            if ap == 0 and bp == 0:
                b_common[bp][I.IPG] = 0
            elif ap == 0:
                b_common[bp][I.IPG] = 0 - (b_common[bp][I.TIME] - b_common[bp-1][I.TIME])
            elif bp == 0:
                b_common[bp][I.IPG] = (c[1] - a_common[ap-1][1])
            else:
                b_common[bp][I.IPG] = (c[1] - a_common[ap-1][1]) - (b_common[bp][I.TIME] - b_common[bp-1][I.TIME])
            b_common[bp][I.CUM_LAT] = (c[1] - starta) - (b_common[bp][I.TIME] - startb)
            ap += 1
            a_common.append([tlr, c[1]])
        else:
            u_a.append(tlr)

    b_tlrs = [item[I.TLR] for item in b_common]
    a_tlrs = [item[I.TLR] for item in a_common]
    moves = shortest_edit_sequence(b_tlrs, a_tlrs)
    # c = apply_edit_script(b_common, moves)
    # mismatches = 0
    # for i, item_c in enumerate(c):
    #     item_a = a_common[i]
    #     if item_c[0] != item_a[0]:
    #         mismatches += 1
    # print(f"Total mismatches found: {mismatches}")
    print(f"Total moves: {len(moves)}")
    for move in tqdm(moves, desc="Updating deltas"):
        b_common[move[3]][I.IDX_DELTA] = move[2] - move[1]

    print(f"Unique in A: {len(u_a)}/{m}")
    print(f"Unique in B: {len(u_b)}/{n}")
    print(f"Common: {len(a_common)}")
    print("Writing stats to txt file...", end="", flush=True)
    agg_a0 = finalize(agg_a0)
    agg_a1 = finalize(agg_a1)
    agg_b0 = finalize(agg_b0)
    agg_b1 = finalize(agg_b1)
    with open(output_file + ".txt", "w") as f:
        f.write(f"Unique in A: {len(u_a)}/{m}\n")
        f.write(f"Unique in B: {len(u_b)}/{n}\n")
        f.write(f"Common: {len(a_common)}\n")
        f.write(f"Total moves: {len(moves)}\n")
        # f.write(f"Total mismatches found: {mismatches}\n")
        f.write(f"Aggregate A0: {agg_a0}\n")
        f.write(f"Aggregate A1: {agg_a1}\n")
        f.write(f"Aggregate B0: {agg_b0}\n")
        f.write(f"Aggregate B1: {agg_b1}\n")
        for i in range(len(u_a)):
            f.write(f"Unique in A: {u_a[i]} | {a_indices[u_a[i]]}\n")
        for i in range(len(u_b)):
            f.write(f"Unique in B: {u_b[i]} | {b_indices[u_b[i]]}\n")

    print(f"Writing results to CSV...", end="", flush=True)
    df_overall = pd.DataFrame(b_common, columns=[I.TLR.name, I.TIME.name, I.IDX_DELTA.name, I.TIME_DELTA.name, I.IPG.name, I.CUM_LAT.name])
    df_overall.to_csv(output_file + ".csv", index=True)
    print("Done")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Compare two PCAP files for consistency.")
    parser.add_argument("pcap1", type=str, help="Path to the first PCAP file")
    parser.add_argument("output", type=str, help="Name of the output file (without extension)")
    parser.add_argument("-i", nargs='+', type=str, help="Paths to the other PCAP file(s) to compare against the first one.")
    parser.add_argument("-r", type=int, default=-1, help="Replayer ID")
    parser.add_argument("-t", nargs='+', type=float, help="Start times")
    args = parser.parse_args()

    print(f"Reading packets from PCAP file 1...", end="", flush=True)
    skip1 = ["runsL2/run13.pcap"]
    a = rdpcap(args.pcap1)
    if args.pcap1 in skip1:
        print(f"Skipping a packet in {args.pcap1} as it is in the skip list.")
        a = a[2:]
    print(f"Done", flush=True)
    if not args.i or not args.i[0]:
        print("No second PCAP file provided, exiting.")
        exit(1)
    start1 = 0
    start2 = 0
    if args.t:
        start1 = args.t[0]
    for i in range(len(args.i)):
        pcap2 = args.i[i]
        print(f"Comparing with PCAP file {i + 2} ({pcap2})...")
        if not pcap2:
            print("No second PCAP file provided, exiting.")
            exit(1)
        print(f"Reading packets from PCAP file {i + 2}...", end="", flush=True)
        b = rdpcap(pcap2)
        print(f"Done", flush=True)
        if pcap2 in skip1:
            print(f"Skipping a packet in {pcap2} as it is in the skip list.")
            b = b[2:]
        start2 = args.t[i+1] if args.t else 0
        lin_dist(a, b, args.output+f"_{i + 1}", args.r, start1, start2)