from fabrictestbed_extensions.fablib.fablib import FablibManager as fablib_manager
from ipaddress import ip_address, IPv4Address, IPv4Network
from datetime import datetime
from time import sleep
from IPython import get_ipython
from concurrent import futures
import json
import os
import tarfile
import base64
import requests

def setup(site=None, with_analyzer=False, verbose=False, do_install=False, conf_ifaces=False, avoid_sites=[]):
    slice_name = ''
    fablib = fablib_manager()
    if with_analyzer:
        slice_name = f"crease-knit11-{fablib.get_user_info()['uuid']}"
    else:
        raise Exception("Should not occur - this is a demo only for with_analyzer=True")
    host_cores = 2
    host_ram = 2
    host_disk = 10
    switch_disk = 10
    image = "default_ubuntu_22"
    switch_image = "crease_ubuntu_22"
    vslice = None
    build_slice = False
    if with_analyzer:
        try:
            vslice=fablib.get_crinkle_slice(name=slice_name)
            print("Retrieved existing slice")
            print("If slice creation did not previously finish, try rerunning with:")
            print("  do_install=True, if 'Finished installing needed tools' did not print or if the issue is p4c is not installed on nodes")
            print("  conf_ifaces=True, if there is no p4c error but the baseline pings fail")
        except:
            vslice = fablib.new_crinkle_slice(name=slice_name)
            print("Building new slice")
            build_slice = True
    else:
        try:
            vslice = fablib.get_slice(name=slice_name)
        except:
            vslice = fablib.new_slice(name=slice_name)
            build_slice = True

    if build_slice:
        if None == site:
            # Avoid sites with <= 2 hosts
            site = fablib.get_random_site(avoid=["RUTG", "MAX", "NEWY", "CIEN", "ATLA", "EDC", "SEAT", "EDUKY"],
                                          filter_function = lambda x: x["ptp_capable"] and x["hosts"]>1)

    hostnames = ['h1', 'h2']
    routernames = ['r1', 'r2', 'r3', 'r4', 'r5', 'r6', 'r7', 'r8', 'r9']
    ips = {
        'h1': [IPv4Address('11.0.1.11')],
        'h2': [IPv4Address('12.0.1.11')],
        'r1': [IPv4Address('10.0.1.1'), IPv4Address('10.0.1.1'), IPv4Address('11.0.1.1')],
        'r2': [IPv4Address('10.0.1.2'), IPv4Address('10.0.1.2'), IPv4Address('10.0.1.2')],
        'r3': [IPv4Address('10.0.1.3'), IPv4Address('10.0.1.3')],
        'r4': [IPv4Address('10.0.1.4'), IPv4Address('10.0.1.4'), IPv4Address('10.0.1.4')],
        'r5': [IPv4Address('10.0.1.5'), IPv4Address('10.0.1.5'), IPv4Address('12.0.1.1')],
        'r6': [IPv4Address('10.0.1.6'), IPv4Address('10.0.1.6'), IPv4Address('10.0.1.6')],
        'r7': [IPv4Address('10.0.1.7'), IPv4Address('10.0.1.7')],
        'r8': [IPv4Address('10.0.1.8'), IPv4Address('10.0.1.8'), IPv4Address('10.0.1.8')],
        'r9': [IPv4Address('10.0.1.9'), IPv4Address('10.0.1.9'), IPv4Address('10.0.1.9'), IPv4Address('10.0.1.9')],
    }
    subnets = {
        'h1': [IPv4Network('11.0.1.0/24')],
        'h2': [IPv4Network('12.0.1.0/24')],
        'r1': [IPv4Network('10.0.1.0/24'), IPv4Network('10.0.1.0/24'), IPv4Network('11.0.1.0/24')],
        'r2': [IPv4Network('10.0.1.0/24'), IPv4Network('10.0.1.0/24'), IPv4Network('10.0.1.0/24')],
        'r3': [IPv4Network('10.0.1.0/24'), IPv4Network('10.0.1.0/24')],
        'r4': [IPv4Network('10.0.1.0/24'), IPv4Network('10.0.1.0/24'), IPv4Network('10.0.1.0/24')],
        'r5': [IPv4Network('10.0.1.0/24'), IPv4Network('10.0.1.0/24'), IPv4Network('12.0.1.0/24')],
        'r6': [IPv4Network('10.0.1.0/24'), IPv4Network('10.0.1.0/24'), IPv4Network('10.0.1.0/24')],
        'r7': [IPv4Network('10.0.1.0/24'), IPv4Network('10.0.1.0/24')],
        'r8': [IPv4Network('10.0.1.0/24'), IPv4Network('10.0.1.0/24'), IPv4Network('10.0.1.0/24')],
        'r9': [IPv4Network('10.0.1.0/24'), IPv4Network('10.0.1.0/24'), IPv4Network('10.0.1.0/24'), IPv4Network('10.0.1.0/24')]
    }
    gateways = {
        'h1': [IPv4Address('11.0.1.1'), 'r1', 'r1p2'],
        'h2': [IPv4Address('12.0.1.1'), 'r5', 'r5p2']
    }
    if build_slice:
        err_dict = {}
        try:
            if with_analyzer:
                vslice.add_analyzer(site=site)
            # Create Hosts
            h1 = vslice.add_node(name="h1", site=site, cores=host_cores, ram=host_ram, disk=host_disk, image=image)
            h1_p0 = h1.add_component(model="NIC_Basic", name="h1p0").get_interfaces()[0]
            h1_p0.set_mode("manual")
            
            h2 = vslice.add_node(name="h2", site=site, cores=host_cores, ram=host_ram, disk=host_disk, image=image)
            h2_p0 = h2.add_component(model="NIC_Basic", name="h2p0").get_interfaces()[0]
            h2_p0.set_mode("manual")
            
            # Routers
            r1 = vslice.add_node(name="r1", cores=host_cores, ram=host_ram, disk=switch_disk, site=site, image=switch_image)
            r1_p0 = r1.add_component(model="NIC_Basic", name="r1p0").get_interfaces()[0]
            r1_p1 = r1.add_component(model="NIC_Basic", name="r1p1").get_interfaces()[0]
            r1_p2 = r1.add_component(model="NIC_Basic", name="r1p2").get_interfaces()[0]
            r1_p0.set_mode("manual")
            r1_p1.set_mode("manual")
            r1_p2.set_mode("manual")
        
            r2 = vslice.add_node(name="r2", cores=host_cores, ram=host_ram, disk=switch_disk, site=site, image=switch_image)
            r2_p0 = r2.add_component(model="NIC_Basic", name="r2p0").get_interfaces()[0]
            r2_p1 = r2.add_component(model="NIC_Basic", name="r2p1").get_interfaces()[0]
            r2_p2 = r2.add_component(model="NIC_Basic", name="r2p2").get_interfaces()[0]
            r2_p0.set_mode("manual")
            r2_p1.set_mode("manual")
            r2_p2.set_mode("manual")
        
            r3 = vslice.add_node(name="r3", cores=host_cores, ram=host_ram, disk=switch_disk, site=site, image=switch_image)
            r3_p0 = r3.add_component(model="NIC_Basic", name="r3p0").get_interfaces()[0]
            r3_p1 = r3.add_component(model="NIC_Basic", name="r3p1").get_interfaces()[0]
            r3_p0.set_mode("manual")
            r3_p1.set_mode("manual")
        
            r4 = vslice.add_node(name="r4", cores=host_cores, ram=host_ram, disk=switch_disk, site=site, image=switch_image)
            r4_p0 = r4.add_component(model="NIC_Basic", name="r4p0").get_interfaces()[0]
            r4_p1 = r4.add_component(model="NIC_Basic", name="r4p1").get_interfaces()[0]
            r4_p2 = r4.add_component(model="NIC_Basic", name="r4p2").get_interfaces()[0]
            r4_p0.set_mode("manual")
            r4_p1.set_mode("manual")
            r4_p2.set_mode("manual")
        
            r5 = vslice.add_node(name="r5", cores=host_cores, ram=host_ram, disk=switch_disk, site=site, image=switch_image)
            r5_p0 = r5.add_component(model="NIC_Basic", name="r5p0").get_interfaces()[0]
            r5_p1 = r5.add_component(model="NIC_Basic", name="r5p1").get_interfaces()[0]
            r5_p2 = r5.add_component(model="NIC_Basic", name="r5p2").get_interfaces()[0]
            r5_p0.set_mode("manual")
            r5_p1.set_mode("manual")
            r5_p2.set_mode("manual")
        
            r6 = vslice.add_node(name="r6", cores=host_cores, ram=host_ram, disk=switch_disk, site=site, image=switch_image)
            r6_p0 = r6.add_component(model="NIC_Basic", name="r6p0").get_interfaces()[0]
            r6_p1 = r6.add_component(model="NIC_Basic", name="r6p1").get_interfaces()[0]
            r6_p2 = r6.add_component(model="NIC_Basic", name="r6p2").get_interfaces()[0]
            r6_p0.set_mode("manual")
            r6_p1.set_mode("manual")
            r6_p2.set_mode("manual")
        
            r7 = vslice.add_node(name="r7", cores=host_cores, ram=host_ram, disk=switch_disk, site=site, image=switch_image)
            r7_p0 = r7.add_component(model="NIC_Basic", name="r7p0").get_interfaces()[0]
            r7_p1 = r7.add_component(model="NIC_Basic", name="r7p1").get_interfaces()[0]
            r7_p0.set_mode("manual")
            r7_p1.set_mode("manual")
        
            r8 = vslice.add_node(name="r8", cores=host_cores, ram=host_ram, disk=switch_disk, site=site, image=switch_image)
            r8_p0 = r8.add_component(model="NIC_Basic", name="r8p0").get_interfaces()[0]
            r8_p1 = r8.add_component(model="NIC_Basic", name="r8p1").get_interfaces()[0]
            r8_p2 = r8.add_component(model="NIC_Basic", name="r8p2").get_interfaces()[0]
            r8_p0.set_mode("manual")
            r8_p1.set_mode("manual")
            r8_p2.set_mode("manual")
            
            r9 = vslice.add_node(name="r9", cores=host_cores, ram=host_ram, disk=switch_disk, site=site, image=switch_image)
            r9_p0 = r9.add_component(model="NIC_Basic", name="r9p0").get_interfaces()[0]
            r9_p1 = r9.add_component(model="NIC_Basic", name="r9p1").get_interfaces()[0]
            r9_p2 = r9.add_component(model="NIC_Basic", name="r9p2").get_interfaces()[0]
            r9_p3 = r9.add_component(model="NIC_Basic", name="r9p3").get_interfaces()[0]
            r9_p0.set_mode("manual")
            r9_p1.set_mode("manual")
            r9_p2.set_mode("manual")
            r9_p3.set_mode("manual")
    
            # L2 Networks
            if with_analyzer:
                h1_r1 = vslice.add_monitored_l2network(name="h1_r1", interfaces=[h1_p0, r1_p2], cores=4, ram=2, disk=10)
                h2_r2 = vslice.add_monitored_l2network(name="h2_r5", interfaces=[h2_p0, r5_p2], cores=4, ram=2, disk=10)
                r1_r2 = vslice.add_monitored_l2network(name="r1_r2", interfaces=[r1_p0, r2_p1], cores=4, ram=2, disk=10)
                r2_r3 = vslice.add_monitored_l2network(name="r2_r3", interfaces=[r2_p0, r3_p1], cores=4, ram=2, disk=10)
                r3_r4 = vslice.add_monitored_l2network(name="r3_r4", interfaces=[r3_p0, r4_p1], cores=4, ram=2, disk=10)
                r4_r5 = vslice.add_monitored_l2network(name="r4_r5", interfaces=[r4_p0, r5_p1], cores=4, ram=2, disk=10)
                r5_r6 = vslice.add_monitored_l2network(name="r5_r6", interfaces=[r5_p0, r6_p1], cores=4, ram=2, disk=10)
                r6_r7 = vslice.add_monitored_l2network(name="r6_r7", interfaces=[r6_p0, r7_p1], cores=4, ram=2, disk=10)
                r7_r8 = vslice.add_monitored_l2network(name="r7_r8", interfaces=[r7_p0, r8_p1], cores=4, ram=2, disk=10)
                r8_r1 = vslice.add_monitored_l2network(name="r8_r1", interfaces=[r8_p0, r1_p1], cores=4, ram=2, disk=10)
                r2_r9 = vslice.add_monitored_l2network(name="r2_r9", interfaces=[r2_p2, r9_p0], cores=4, ram=2, disk=10)
                r4_r9 = vslice.add_monitored_l2network(name="r4_r9", interfaces=[r4_p2, r9_p1], cores=4, ram=2, disk=10)
                r6_r9 = vslice.add_monitored_l2network(name="r6_r9", interfaces=[r6_p2, r9_p2], cores=4, ram=2, disk=10)
                r8_r9 = vslice.add_monitored_l2network(name="r8_r9", interfaces=[r8_p2, r9_p3], cores=4, ram=2, disk=10)
            else:
                h1_r1 = vslice.add_l2network(name="h1_r1", interfaces=[h1_p0, r1_p2])
                h2_r2 = vslice.add_l2network(name="h2_r5", interfaces=[h2_p0, r5_p2])
                r1_r2 = vslice.add_l2network(name="r1_r2", interfaces=[r1_p0, r2_p1])
                r2_r3 = vslice.add_l2network(name="r2_r3", interfaces=[r2_p0, r3_p1])
                r3_r4 = vslice.add_l2network(name="r3_r4", interfaces=[r3_p0, r4_p1])
                r4_r5 = vslice.add_l2network(name="r4_r5", interfaces=[r4_p0, r5_p1])
                r5_r6 = vslice.add_l2network(name="r5_r6", interfaces=[r5_p0, r6_p1])
                r6_r7 = vslice.add_l2network(name="r6_r7", interfaces=[r6_p0, r7_p1])
                r7_r8 = vslice.add_l2network(name="r7_r8", interfaces=[r7_p0, r8_p1])
                r8_r1 = vslice.add_l2network(name="r8_r1", interfaces=[r8_p0, r1_p1])
                r2_r9 = vslice.add_l2network(name="r2_r9", interfaces=[r2_p2, r9_p0])
                r4_r9 = vslice.add_l2network(name="r4_r9", interfaces=[r4_p2, r9_p1])
                r6_r9 = vslice.add_l2network(name="r6_r9", interfaces=[r6_p2, r9_p2])
                r8_r9 = vslice.add_l2network(name="r8_r9", interfaces=[r8_p2, r9_p3])
            vslice.submit()
            did_err = False
            for entry in vslice.get_error_messages():
                err_dict[entry['sliver'].get_name()] = entry['notice']
                if entry['notice'] != '':
                    did_err = True
            if did_err:
                raise Exception("Reservation error occured during slice creation")
        except Exception as e:
            raise
        
    jobs = []
    hosts = {}
    routers = {}
    macs = {}
    host_subnets = [IPv4Network("11.0.1.0/24"), IPv4Network("12.0.1.0/24")]
    print("Retrieving Node Information")
    for hostname in hostnames:
        host = vslice.get_node(name=hostname)
        hosts[hostname] = host
        iface = host.get_component(name=f'{hostname}p0').get_interfaces()[0]
        macs[hostname] = [iface.get_mac()]
    for routername in routernames:
        router = vslice.get_node(name=routername)
        routers[routername] = router
        macs[routername] = []
        portnames = []
        if routername in ['r3', 'r7']:
            portnames = ['p0', 'p1']
        elif routername in ['r9']:
            portnames = ['p0', 'p1', 'p2', 'p3']
        else:
            portnames = ['p0', 'p1', 'p2']
        for portname in portnames:
            port = routername+portname
            macs[routername].append(router.get_component(name=port).get_interfaces()[0].get_mac())
    if build_slice or do_install:
        print('Installing needed tools')
        # install_switch = ('''sudo bash -c 'echo "2600:2701:5000:5001::c387:dfe2 download.opensuse.org" >> /etc/hosts'\n'''
        #                   'echo "deb http://download.opensuse.org/repositories/home:/p4lang/xUbuntu_${DISTRIB_RELEASE}/ /" | sudo tee /etc/apt/sources.list.d/home:p4lang.list\n'
        #                   'curl -fsSL https://download.opensuse.org/repositories/home:p4lang/xUbuntu_${DISTRIB_RELEASE}/Release.key | gpg --dearmor | sudo tee /etc/apt/trusted.gpg.d/home_p4lang.gpg > /dev/null\n'
        #                   'sudo apt-get update\n'
        #                   'sudo apt install -y p4lang-p4c net-tools python3-scapy\n'
        #                   'sudo sysctl net.ipv6.conf.all.forwarding=0\n')
        install_switch = 'sudo sysctl net.ipv6.conf.all.forwarding=0; sudo sysctl -w net.ipv4.ip_forward=0'
        install_host = 'sudo apt-get update; sudo apt install -y net-tools python3-scapy'
        jobs = []
        for routername, router in routers.items():
            jobs.append(router.execute_thread(install_switch))
        for _, host in hosts.items():
            jobs.append(host.execute_thread(install_host))
        ctr = 0
        ctr_max = len(jobs)
        for _ in futures.as_completed(jobs):
            ctr += 1
            print(f'{ctr}/{ctr_max} installs finished')
        print('Finished installing needed tools')
    print("Waiting for Node Info...", end="")
    futures.wait(jobs)
    print("Done")
    jobs = []
    if build_slice or conf_ifaces:
        print("Configuring Node Interfaces")
        for hostname, host in hosts.items():
            print(f"Configuring {hostname}...", end="")
            iface = host.get_component(name=f'{hostname}p0').get_interfaces()[0]
            iface.ip_addr_add(ips[hostname][0], subnets[hostname][0])
            gateway_router = routers[gateways[hostname][1]]
            gateway_port = gateways[hostname][2]
            dev_name = iface.get_device_name()
            route_commands = ""
            if verbose:
                print(f"sudo ip link set dev {dev_name} up")
            route_commands += f"sudo ip link set dev {dev_name} up; "
            if verbose:
                print(f'sudo ip route replace {subnets[hostname][0]} dev {dev_name};')
            route_commands += f'sudo ip route replace {subnets[hostname][0]} dev {dev_name}; '
            if verbose:
                print(f'sudo arp -s {gateways[hostname][0]} {gateway_router.get_component(name=gateway_port).get_interfaces()[0].get_mac()};')
            route_commands += f'sudo arp -s {gateways[hostname][0]} {gateway_router.get_component(name=gateway_port).get_interfaces()[0].get_mac()}; '
            for subnet in host_subnets:
                if subnet == subnets[hostname][0]:
                    continue
                if verbose:
                    print( f'sudo ip route replace {subnet} via {gateways[hostname][0]};')
                route_commands += f'sudo ip route replace {subnet} via {gateways[hostname][0]}; '
            if verbose:
                print( f'sudo ip route replace {subnets[hostname][0]} via {gateways[hostname][0]} dev {dev_name};')
            route_commands += f'sudo ip route replace {subnets[hostname][0]} via {gateways[hostname][0]} dev {dev_name}; '
            if verbose:
                print(f"sudo ip link set dev {dev_name} up")
            route_commands += f"sudo ip link set dev {dev_name} up; "
            jobs.append(host.execute_thread(route_commands))
            print("Done")
        for routername, router in routers.items():
            print(f"Configuring {routername}...", end="")
            route_commands = ""
            for iface in router.get_interfaces():
                dev_name = iface.get_device_name()
                command = f"sudo ip link set dev {dev_name} up; "
                if verbose:
                    print(command)
                route_commands += command
            jobs.append(router.execute_thread(route_commands))
            print("Done")
        futures.wait(jobs)
        if verbose:
            for job in jobs:
                print(job.result())
    
    slice_values = {}
    slice_values["slice"] = vslice
    slice_values["macs"] = macs
    slice_values["ips"] = ips
    slice_values["subnets"] = subnets
    slice_values["routers"] = routers
    slice_values["hosts"] = hosts
    return slice_values

def start_switches(routers):
    jobs = []
    for routername, router in routers.items():
        port_sequence = ""
        port_num = 0
        p0 = router.get_component(name=routername+"p0").get_interfaces()[0]
        p1 = router.get_component(name=routername+"p1").get_interfaces()[0]
        if routername in ['r3', 'r7']:
            port_sequence = f'-i 0@{p0.get_device_name()} -i 1@{p1.get_device_name()}'
        elif routername in ['r9']:
            p2 = router.get_component(name=routername+"p2").get_interfaces()[0]
            p3 = router.get_component(name=routername+"p3").get_interfaces()[0]
            port_sequence = f'-i 0@{p0.get_device_name()} -i 1@{p1.get_device_name()} -i 2@{p2.get_device_name()} -i 3@{p3.get_device_name()}'
        else:
            p2 = router.get_component(name=routername+"p2").get_interfaces()[0]
            port_sequence = f'-i 0@{p0.get_device_name()} -i 1@{p1.get_device_name()} -i 2@{p2.get_device_name()}'
        router.upload_file("switches/router-1.p4", "router-1.p4")
        router.upload_file("switches/router-2.p4", "router-2.p4")
        router.execute("p4c --target bmv2 --arch v1model router-1.p4")
        router.execute("p4c --target bmv2 --arch v1model router-2.p4")
        jobs.append(router.execute_thread(f'sudo simple_switch {port_sequence} router-1.json --log-file ~/switch.log --log-flush -- --enable-swap --disable-ra-broadcast &'))
        print(f'Starting {routername}')
    sleep(5)