import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import argparse
import math

# #df = pd.read_csv("overall.csv")


# # counts, bins, patches = plt.hist(df["jitter"].abs(), weights=np.ones(len(df))/len(df), bins=100)
# # plt.xlabel('Value Range')
# # plt.ylabel('Probability (%)')

# # plt.plot(df["jitter"].multiply(1_000_000))
# # plt.xlabel("Packet Index")
# # plt.ylabel("Cumulative Latency [ns]")
# #print("Filtering")
# #filtered = df[df["ipg"] != 0]
# #print("Filtered")

# vals = []
# for item in df["ipg"]:
#     if item != 0:
#         vals.append(item*1_000_000)

# for item in vals:
#     if item == 0:
#         print("err")

# # bp = plt.boxplot(vals, patch_artist=True)
# # for patch in bp["boxes"]:
# #     patch.set(facecolor="cyan")
# # #plt.setp(bp["boxes"], color="blue")
# # plt.yscale('symlog')
# # plt.xticks([])
# # plt.ylabel("Packet Inter-Arrival Times [ns]")
# plt.hist(vals, weights=np.ones(len(vals))/len(vals), bins=100)
# plt.xlabel("IAT Deviation [ns]")
# plt.ylabel("Percentage")
# #plt.yscale('log')

def graph_cum_latency(dfs, main_colors, box_colors, color_offset, letters):
    
    run = 0
    for df in dfs:
        vals = df["CUM_LAT"].multiply(1_000).values
        plt.plot(vals, label=f"Run {letters[run]}", color=main_colors[run % len(main_colors) + color_offset])
        run += 1
    plt.xlabel("Packet Index")
    plt.ylabel("Cumulative Latency [us]")
    plt.legend()
    plt.show()

def graph_packet_dists(dfs, main_colors, box_colors, color_offset, letters):
    
    run = 0
    for df in dfs:
        vals = df["IDX_DELTA"].abs().values
        max_val = np.max(np.abs(vals))
        if max_val == 0:
            max_val = 1
        plt.hist(vals, weights=np.ones(len(vals))/len(vals), bins=max_val, alpha=0.5, label=f"Run {letters[run]}", color=main_colors[run % len(main_colors) + color_offset])
        run += 1
    plt.xlabel("Edit Script Distance")
    plt.ylabel("Percentage")
    plt.legend()
    plt.show()

def graph_packet_dists2(dfs, main_colors, box_colors, color_offset, letters):
    
    run = 0
    for df in dfs:
        vals = df["IDX_DELTA"].abs().values
        bp = plt.boxplot(vals, positions=[run+1], widths=0.5, patch_artist=True)
        plt.setp(bp["medians"], color="black")
        for patch in bp["boxes"]:
            patch.set(facecolor=box_colors[run % len(box_colors) + color_offset])
        run += 1
    plt.xlabel("Run")
    plt.ylabel("Distance")
    plt.xticks(range(1, run+1), [f"Run {letters[i]}" for i in range(run)])
    plt.legend()
    plt.show()

def graph_packet_dists3(dfs, main_colors, box_colors, color_offset, letters):
    
    run = 0
    for df in dfs:
        vals = df["IDX_DELTA"].abs().values
        plt.plot(vals, label=f"Run {letters[run]}", color=main_colors[run % len(main_colors) + color_offset])
        run += 1
    plt.xlabel("Packet Index")
    plt.ylabel("Edit Script Distance")
    plt.legend()
    plt.show()

def graph_packet_dists4(dfs, main_colors, box_colors, color_offset, letters):
    
    run = 0
    vals = {}
    for df in dfs:
        for item in df["IDX_DELTA"].abs().values:
            if item not in vals:
                vals[item] = 0
            vals[item] += 1
        plt.plot(vals.keys(), vals.values(), label=f"Run {letters[run]}", color=main_colors[run % len(main_colors) + color_offset], alpha=0.5)
        run += 1
    plt.xlabel("Edit Script Distance")
    plt.ylabel("Packets")
    plt.legend()
    plt.show()

def graph_packet_dists5(dfs, main_colors, box_colors, color_offset, letters):
    
    run = 0
    vals = {}
    ctr = 0
    for row in dfs[0].itertuples():
        x  = row.TLR
        y = abs(row.IDX_DELTA)
        if x not in vals:
            vals[x] = [0] * (len(dfs) + 1)
            vals[x][0] = ctr
        vals[x][run+1] = y
        ctr += 1
    run += 1
    for df in dfs[1:]:
        ctr = 0
        for row in df.itertuples():
            x  = row.TLR
            y = abs(row.IDX_DELTA)
            # if x not in vals:
            #     vals[x] = [0] * (len(dfs) + 1)
            #     vals[x][0] = ctr
            vals[x][run+1] = y
            ctr += 1
        run += 1
    data = []
    for key, value in vals.items():
        data.append(value)
    sorted_data = sorted(data, key=lambda x: x[0])
    for run in range(len(dfs)):
        plt.plot([d[run+1] for d in sorted_data], label=f"Run {letters[run]}", color=main_colors[run % len(main_colors) + color_offset], alpha=0.5)
    plt.xlabel("Edit Script Distance")
    plt.ylabel("Packets")
    plt.legend()
    plt.show()

def graph_iat_deviation(dfs, main_colors, box_colors, color_offset, letters):
    run = 0
    fig = plt.figure(figsize=(7, 7))
    plt.rcParams['font.size'] = 10
    ax = fig.add_subplot(projection='3d')
    for df in dfs:
        vals = df["IPG"].multiply(1_000_000_000).values
        pow_min = -5
        pow_max = 5
        xbins = []
        xcenter = []
        xwidth = []
        for i in range(pow_min, 0):
            for j in range(2):
                if j == 0:
                    binV = -10 ** abs(i)
                    xC = i - 1
                else:
                    binV = (-10 ** abs(i)) // 2
                    xC = i -0.5
                xW = 0.5
                xbins.append(binV)
                xcenter.append(xC)
                xwidth.append(xW)
        xbins.append(-1)
        xbins.append(0)
        xcenter.append(-1.25 / 2)
        xwidth.append(1.25)
        xbins.append(1)
        xcenter.append(1.25 / 2)
        xwidth.append(1.25)
        for i in range(1, pow_max + 1):
            for j in range(2):
                if j == 1:
                    binV = 10 ** abs(i)
                    xC = i + 1
                else:
                    binV = (10 ** abs(i)) // 2
                    xC = i + 0.5
                xW = 0.5
                xbins.append(binV)
                xcenter.append(xC)
                xwidth.append(xW)
        # print(xbins)
        histvals, edges = np.histogram(vals, bins=xbins, weights=np.ones(len(vals))/len(vals))
        # print(len(xcenter), len(histvals), len(xwidth))
        # for i in range(len(xcenter)):
        #     print(f"{xcenter[i]} {histvals[i]} {xwidth[i]} {xbins[i]} {xbins[i+1]}")
        ax.bar(xcenter, align='center', width=xwidth, height=histvals, zs=run, zdir="y", label=f"Run {letters[run]}", color=main_colors[run % len(main_colors) + color_offset], alpha=0.5)
        run += 1
    ax.set_xlabel("Inter-Arrival Time Deviation [ns]")
    ax.set_ylabel("Run")
    ax.set_zlabel("Percentage")
    ax.set_yticks(range(run))
    ax.set_yticklabels([f"{letters[i]}" for i in range(run)])
    ax.set_xticks([i for i in range(-6, 7)])
    #ax.set_zticks([i/10 for i in range(0, 11)])
    ax.axes.set_xlim3d(-6, 6)
    xticks = []
    for i in range(-5, 1):
        xticks.append(f"$-10^{abs(i)}$")
    xticks.append(0)
    for i in range(0, 6):
        xticks.append(f"$10^{i}$")
    # print(xticks)
    ax.set_xticklabels(xticks)
    plt.savefig("/home/ubuntu/output/iat_graph.png")

def graph_lat_deviation(dfs, main_colors, box_colors, color_offset, letters):
    run = 0
    fig = plt.figure(figsize=(7, 7))
    plt.rcParams['font.size'] = 10
    ax = fig.add_subplot(projection='3d')
    for df in dfs:
        vals = df["CUM_LAT"].multiply(1_000_000_000).values
        pow_min = -10
        pow_max = 10
        xbins = []
        xcenter = []
        xwidth = []
        for i in range(pow_min, 0):
            for j in range(2):
                if j == 0:
                    binV = -10 ** abs(i)
                    xC = i - 1
                else:
                    binV = (-10 ** abs(i)) // 2
                    xC = i -0.5
                xW = 0.5
                xbins.append(binV)
                xcenter.append(xC)
                xwidth.append(xW)
        xbins.append(-1)
        xbins.append(0)
        xcenter.append(-1.25 / 2)
        xwidth.append(1.25)
        xbins.append(1)
        xcenter.append(1.25 / 2)
        xwidth.append(1.25)
        for i in range(1, pow_max + 1):
            for j in range(2):
                if j == 1:
                    binV = 10 ** abs(i)
                    xC = i + 1
                else:
                    binV = (10 ** abs(i)) // 2
                    xC = i + 0.5
                xW = 0.5
                xbins.append(binV)
                xcenter.append(xC)
                xwidth.append(xW)
        # print(xbins)
        histvals, edges = np.histogram(vals, bins=xbins, weights=np.ones(len(vals))/len(vals))
        # print(len(xcenter), len(histvals), len(xwidth))
        # for i in range(len(xcenter)):
        #     print(f"{xcenter[i]} {histvals[i]} {xwidth[i]} {xbins[i]} {xbins[i+1]}")
        ax.bar(xcenter, align='center', width=xwidth, height=histvals, zs=run, zdir="y", label=f"Run {letters[run]}", color=main_colors[run % len(main_colors) + color_offset], alpha=0.5)
        run += 1
    ax.set_xlabel("Inter-Arrival Time Deviation [ns]")
    ax.set_ylabel("Run")
    ax.set_zlabel("Percentage")
    ax.set_yticks(range(run))
    ax.set_yticklabels([f"{letters[i]}" for i in range(run)])
    ax.set_xticks([i for i in range(-7, 8)])
    #ax.set_zticks([i/10 for i in range(0, 11)])
    ax.axes.set_xlim3d(-7, 7)
    xticks = []
    for i in range(-6, 1):
        xticks.append(f"$-10^{abs(i)}$")
    xticks.append(0)
    for i in range(0, 7):
        xticks.append(f"$10^{i}$")
    # print(xticks)
    ax.set_xticklabels(xticks)
    plt.savefig("/home/ubuntu/output/latency_graph.png")

def graph_iat_box(dfs, main_colors, box_colors, color_offset, letters):
    
    run = 0
    for df in dfs:
        vals = df["IPG"].multiply(1_000_000_000).values
        bp = plt.boxplot(vals, positions=[run+1], widths=0.5, patch_artist=True)
        plt.setp(bp["medians"], color="black")
        for patch in bp["boxes"]:
            patch.set(facecolor=box_colors[run % len(box_colors) + color_offset])
        run += 1
    plt.xlabel("Run")
    plt.ylabel("Inter-Arrival Time [ns]")
    plt.xticks(range(1, run+1), [f"{letters[i]}" for i in range(run)])
    plt.yscale('symlog')
    plt.show()

def metric_o(dfs, main_colors, box_colors, color_offset, letters):
    run = 0
    retvals = []
    for df in dfs:
        vals = []
        total = 0
        for item in df["IDX_DELTA"]:
            vals.append(abs(item))
            total += abs(item)
        print(f"Run {run+1} O: {total/((len(vals) * (len(vals) - 1)) / 2)}")
        retvals.append(total/((len(vals) * (len(vals) - 1)) / 2))
        run += 1
    return retvals

def metric_i(dfs, main_colors, box_colors, color_offset, letters):
    run = 0
    retvals = []
    for df in dfs:
        vals = []
        total = 0
        for item in df["IPG"]:
            vals.append(abs(item))
            total += abs(item)
        dt = df["TIME"].iloc[-1] - df["TIME"].iloc[0]
        print(f"Run {run+1} I: {total/(2 * dt)}")
        retvals.append(total/(2 * dt))
        run += 1
    return retvals

def metric_l(dfs, main_colors, box_colors, color_offset, letters):
    run = 0
    retvals = []
    for df in dfs:
        vals = []
        total = 0
        for item in df["CUM_LAT"]:
            vals.append(abs(item))
            total += abs(item)
        dt = df["TIME"].iloc[-1] - df["TIME"].iloc[0]
        print(f"Run {run+1} L: {total/(len(vals) * dt)}")
        retvals.append(total/(len(vals) * dt))
        run += 1
    return retvals

def graph_consistency(m_us, m_os, m_is, m_ls):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    img = ax.scatter(m_is, m_os, m_ls, c=m_us, cmap='viridis', s=50)
    ax.set_ylabel('Ordering')
    ax.set_xlabel('Inter-Arrival Time')
    ax.set_zlabel('Latency')
    #ax.axes.set_xlim3d(0, 1)
    #ax.axes.set_ylim3d(0, 1)
    #ax.axes.set_zlim3d(0, 1)
    cbar = plt.colorbar(img, pad=0.1, shrink=0.5)
    cbar.set_label('Uniqueness')
    plt.show()

def distance_metrics(dfs):
    dms = []
    for df in dfs:
        dm = []
        avg = df["IDX_DELTA"].mean()
        abs_avg = df["IDX_DELTA"].abs().mean()
        dm.append(avg)
        dm.append(df["IDX_DELTA"].std())
        dm.append(abs_avg)
        dm.append(df["IDX_DELTA"].abs().std())
        dm.append(df["IDX_DELTA"].min())
        dm.append(df["IDX_DELTA"].abs().min())
        dm.append(df["IDX_DELTA"].max())
        dm.append(df["IDX_DELTA"].abs().max())
        dms.append(dm)

    return dms


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Visualize packet data from a CSV file.")
    parser.add_argument("-c", help="Color offset", type=int, default=0)
    parser.add_argument("-i", nargs='+', action="append", type=str, help="Path to the CSV files containing packet data.")
    args = parser.parse_args()
    dfs = []
    box_colors = ["lightblue", "lightorange", "lightgray", "lightpurple", "cyan", "lightpink", "olivedrab", "lightgray"]
    main_colors = ["blue", "orange", "gray", "purple", "cyan", "pink", "olive", "gray"]
    letters = "BCDEFGHIJKLMNOPQRSTUVWXYZ"

    for csv in args.i[0]:
        df = pd.read_csv(csv)
        dfs.append(df)
    print(f"Loaded {len(dfs)} CSV files.")
    plt.rcParams['figure.figsize'] = (7, 6)
    plt.rcParams['font.size'] = 12

    graph_iat_deviation(dfs, main_colors, box_colors, args.c, letters)
    graph_lat_deviation(dfs, main_colors, box_colors, args.c, letters)

    #m_us = [0] * len(dfs)
    m_us = [(1052594 + 1053824 - 2 * 1052594) / (1053824 + 1052594),
            (1053824 + 1053824 - 2 * 1053824) / (1053824 + 1053824),
            (1053586 + 1053824 - 2 * 1053586) / (1053824 + 1053586),
            (1053619 + 1053824 - 2 * 1053619) / (1053824 + 1053619)]
    m_us = [0] * len(dfs)
    for i in range(len(m_us)):
        print(f"Run {i+1} U: {m_us[i]}")
    m_os = metric_o(dfs, main_colors, box_colors, args.c, letters)
    m_is = metric_i(dfs, main_colors, box_colors, args.c, letters)
    m_ls = metric_l(dfs, main_colors, box_colors, args.c, letters)
    if m_os[0] > 0:
        dms = distance_metrics(dfs)
        print("Distance Metrics:")
        for dm in dms:
            print(f"Avg: {dm[0]:.2f}, Std: {dm[1]:.2f}, Abs Avg: {dm[2]:.2f}, Abs Std: {dm[3]:.2f}, Min: {dm[4]}, Abs Min: {dm[5]}, Max: {dm[6]}, Abs Max: {dm[7]}")
    for i in range(len(dfs)):
        k = 1 - math.sqrt(m_us[i] ** 2 + m_os[i] ** 2 + m_is[i] ** 2 + m_ls[i] ** 2) / 2
        print(f"Run {i+1} Consistency: {k * 100:.5f}%")

    for i in range(len(dfs)):
        row0 = dfs[i].iloc[0]
        rowN = dfs[i].iloc[-1]
        total_time = rowN["TIME"] - row0["TIME"]
        print(f"Run {i+1} total time: {total_time}s, packets: {len(dfs[i])}, pps: {len(dfs[i])/total_time}")


    run = 1
    for df in dfs:
        total = 0
        c_iat = 0
        c_lat = 0
        for val in df["IPG"]:
            total += 1
            if abs(val) < 0.00000001:
                c_iat += 1
        print(f"Run {run} IAT Within +- 10 ns: {c_iat/total * 100:.2f}%")
        run += 1
