import sys
sys.path.append('/nix/store/qfv2ayja1zgrfbiy1nrd9f0b1y759h60-python3-3.11.6-env/lib/python3.11/site-packages')
import logging
import pandas as pd
import time
import bfrt_grpc.client as gc

# Logging setup
logger = logging.getLogger('DigestListener')
logger.setLevel(logging.DEBUG)
log_handler = logging.FileHandler('digest_listener.log', mode='w')
log_handler.setLevel(logging.DEBUG)
logger.addHandler(log_handler)

# Constants
HOST_IP = "localhost"
REGISTER_COUNTER = 'r_counter'

def convert_ip(ip_int):
    return '.'.join(str((ip_int >> (8 * i)) & 0xFF) for i in reversed(range(4)))

def gc_connect():
    for bfrt_client_id in range(10):
        try:
            interface = gc.ClientInterface(
                grpc_addr=f"{HOST_IP}:50052",
                client_id=bfrt_client_id,
                device_id=0,
                num_tries=1,
            )
            break
        except Exception as e:
            logger.error("Failed to connect to gRPC server (client_id %d): %s", bfrt_client_id, str(e))
            print(f"Failed to connect to gRPC server (client_id {bfrt_client_id}): {e}")
            sys.exit(1)
    dev_tgt = gc.Target(device_id=0, pipe_id=0xffff)
    try:
        bfrt_info = interface.bfrt_info_get()
        if bfrt_client_id == 0:
            interface.bind_pipeline_config("cNdAmlight")
        logger.debug("Connected to gRPC server with client_id %d, p4_name: %s", bfrt_client_id, bfrt_info.p4_name_get())
        print(f"P4 program: {bfrt_info.p4_name_get()}")
        return interface, dev_tgt, bfrt_info
    except Exception as e:
        logger.error("Failed to get bfrt_info or bind pipeline: %s", str(e))
        print(f"Failed to initialize P4 program: {e}")
        sys.exit(1)

def get_non_zero_indices(register, dev_tgt, bfrt_info):
    try:
        reg_name = f"SwitchIngress.{register}"
        table = bfrt_info.table_get(reg_name)
        data_iter = table.entry_get(dev_tgt, [], {"from_hw": True})
        non_zero_indices = []
        for data, key in data_iter:
            key_dict = key.to_dict()
            data_dict = data.to_dict()
            reg_key = f"SwitchIngress.{register}.f1"
            values = data_dict.get(reg_key, [])
            if any(v > 1 for v in values):
                index = key_dict['$REGISTER_INDEX']['value']
                non_zero_indices.append(index)
        return non_zero_indices
    except Exception as e:
        print(f"Error accessing register {register}: {e}")
        logger.error("Error accessing register %s: %s", register, str(e))
        return []

def get_register_values_batch(register, indices, dev_tgt, bfrt_info):
    try:
        reg_name = f"SwitchIngress.{register}"
        table = bfrt_info.table_get(reg_name)
        keys = [table.make_key([gc.KeyTuple('$REGISTER_INDEX', index)]) for index in indices]
        data_iter = table.entry_get(dev_tgt, keys, {"from_hw": True})
        values = {}
        for data, key in data_iter:
            key_dict = key.to_dict()
            data_dict = data.to_dict()
            index = key_dict['$REGISTER_INDEX']['value']
            reg_key = f"SwitchIngress.{register}.f1"
            value = data_dict.get(reg_key, [0, 0])[1]
            values[index] = value
        return values
    except Exception as e:
        print(f"Error reading register {register} for batched indices: {e}")
        logger.error("Error reading register %s for batched indices: %s", register, str(e))
        return {index: 0 for index in indices}

def inspect_registers():
    interface, dev_tgt, bfrt_info = gc_connect()
    print("Scanning non-zero entries in r_counter...\n")
    start_time = time.time()
    indices = get_non_zero_indices(REGISTER_COUNTER, dev_tgt, bfrt_info)
    print(f"Scanning r_counter took {time.time() - start_time:.3f} seconds")
    
    if not indices:
        print("No non-zero entries found.")
        return
    
    print(f"Processing {len(indices)} non-zero entries...")
    start_time = time.time()
    agg_data = {}
    BATCH_SIZE = 1000
    for i in range(0, len(indices), BATCH_SIZE):
        batch_indices = indices[i:i + BATCH_SIZE]
        dst_values = get_register_values_batch('r_dstAddr', batch_indices, dev_tgt, bfrt_info)
        counter_values = get_register_values_batch('r_counter', batch_indices, dev_tgt, bfrt_info)
        for index in batch_indices:
            dst = dst_values.get(index, 0)
            counter = counter_values.get(index, 0)
            if dst not in agg_data:
                agg_data[dst] = {'TotalCounter': 0, 'UniqueSrcIPs': 0}
            agg_data[dst]['TotalCounter'] += counter
            agg_data[dst]['UniqueSrcIPs'] += 1
    print(f"Fetching and aggregating register values took {time.time() - start_time:.3f} seconds")
    
    start_time = time.time()
    rows = [{'Dst IP': dst, 'TotalCounter': data['TotalCounter'], 'UniqueSrcIPs': data['UniqueSrcIPs']} for dst, data in agg_data.items()]
    df = pd.DataFrame(rows)
    summary_df = df.sort_values(by=['TotalCounter', 'Dst IP'], ascending=[False, True]).head(5)
    # summary_df_tail = df.sort_values(by=['TotalCounter', 'Dst IP'], ascending=[False, True]).tail(15)
    summary_df['Dst IP'] = summary_df['Dst IP'].apply(convert_ip)
    # summary_df_tail['Dst IP'] = summary_df_tail['Dst IP'].apply(convert_ip)

    print("=" * 80)
    print(f"Pandas processing and IP conversion took {time.time() - start_time:.3f} seconds")
    print("=" * 80)
    print("\nSummary: Top 5 Dst IPs by Total Counter and Unique Src IPs\n")
    print(summary_df.to_string(index=False))
    # print("\nSummary: Bottom 15 Dst IPs by Total Counter and Unique Src IPs\n")
    # print(summary_df_tail.to_string(index=False))
    print("=" * 80)
    
    interface.tear_down_stream()

if __name__ == '__main__':
    inspect_registers()