# Copyright (c) 2024  Kent State University CAE-NetLab

# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

import cProfile
import json
import logging
import math
import os
import os.path
import pstats
import random
import threading

import blinker
import ipywidgets as widgets

from fabrictestbed_extensions.fablib.fablib import FablibManager as fablib_manager

from .config import CFG
from . import progress

def getRolesAndNodes (slmap, roles, sobj, progress = None):
  # Map roles to IPs with slice binding info
  # Also throw in a list of unique nodes to deploy to

  # slmap from slicecfg json
  # roles from model["roles"]
  # sobj is actual Slice object

  rmap = {}

  # If you ask fablib to do this through slice.get_interface, it will take
  # an eternity because it ssh's to all the nodes instead of checking the fim
  node_intf_map = {}
  for nd in sobj.get_nodes():
    for intf in nd.get_interfaces():
      node_intf_map[intf.get_name()] = intf

  # Unfortunately every time you acquire a Node "proxy" from
  # fablib it's a unique object, so 100 proxies for the same
  # node can exist which means we can't use a set().
  nodes = {}

  for role in roles:
    if progress:
      progress.advance()
    nlist = slmap[role["name"]]
    if role["type"] == "client":
      rmap[role["name"]] = []
      for nodeinfo in nlist:
        if progress:
          progress.advance()
        node = sobj.get_node(nodeinfo["Node"])
        nodes["%s.%s" % (node.get_name(), node.get_reservation_id())] = node
        rmap[role["name"]].append(str(node.get_interfaces()[0].get_ip_addr()))
    elif role["type"] == "server":
      rmap[role["name"]] = []
      for nodeinfo in nlist:
        if progress:
          progress.advance()
        node_intf = node_intf_map[nodeinfo["Node.Interface"]]
        rmap[role["name"]].append(str(node_intf.get_ip_addr()))
        node = node_intf.get_node()
        nodes["%s.%s" % (node.get_name(), node.get_reservation_id())] = node
  return rmap, list(nodes.values())

#def makextic (*args, **kwargs):
#  with cProfile.Profile() as profile:
#    _makextic_core(*args, **kwargs)
#    pstats.Stats(profile).strip_dirs().sort_stats("cumtime").print_stats()

def makextic (fabmgr, model_provider, slicename, out, duration = 14400, progress = None):
  sinfo = fabmgr.get_slice(slicename)
  scfg = json.loads(open("%s/slcfgs/%s.json" % (CFG.ddir, slicename), "r").read())
  
  xtic = []
  for k,v in scfg.items():
    if progress:
      progress.advance()
    model = model_provider.getData("%s.%s" % (v["_repo"], k))
    rmap,deploynodes = getRolesAndNodes(v, model["roles"], sinfo, progress)
    for appname,appdata in model["appmap"].items():
      if progress:
        progress.advance()
      client_ips = []
      for crole in appdata["clients"]:
        client_ips.extend(rmap[crole])
      server_ips = []
      for srole in appdata["servers"]:
        server_ips.extend(rmap[srole])
      
      cinfo = appdata["conn_info"]
      if cinfo["type"] == "continuous":
        num_conns = math.ceil((duration - cinfo["start-offset"])/cinfo["restart-interval"])
        for conn in range(num_conns):
          start = (cinfo["restart-interval"] * conn) + cinfo["start-offset"]
          for cip in client_ips:
            if progress: progress.advance()
            sip = random.choice(server_ips)
            xtic.append({"start": start, "cli_ip" : cip, "srv_ip" : sip, "app" : appname})
      elif cinfo["type"] == "periodic":
        pmin = cinfo["period"] - cinfo["period-variance"]
        pmax = cinfo["period"] + cinfo["period-variance"]
        for cip in client_ips:
          curtime = cinfo["start-offset"]
          while curtime < duration:
            if progress: progress.advance()
            xtic.append({"start":curtime, "cli_ip" : cip, "srv_ip": random.choice(server_ips), "app" : appname})
            curtime += random.uniform(pmin,pmax)
  xtic.sort(key = lambda x: x["start"])
  out["xtic"] = xtic
  out["nodes"] = deploynodes

class DeploymentManager():
  def __init__ (self, ttm):
    self.fab = fablib_manager(CFG.get("fabric.rc_path"))
    self._ttm = ttm
    self.slices = []
    self.snames = []
    self._dnode_data = {}

    self._lock_slicechange = False
    self._w_dd_slice = widgets.Dropdown(options = self.snames, description = "Slice:")
    self._w_dd_slice.observe(self._observeSlice, names = ["value"])

    self._fillSlices()

    self._w_button_refresh = widgets.Button(icon = "refresh", layout = widgets.Layout(width="50px"), tooltip = "Refresh Slices")
    self._w_button_refresh.on_click(self._refreshSlices)

    self._w_hbox_slice = widgets.HBox([self._w_dd_slice, self._w_button_refresh])
    self._w_outer_box = widgets.VBox([self._w_hbox_slice])

    # Widget set for no slice config
    self._noslc_template = "<i class=\"fa fa-exclamation-circle\"></i>&nbsp;No slice binding found for slice %(slice)s.  Please" \
                           " bind applcation models before deploying."
    self._w_noslc_html = widgets.HTML()
    self._w_noslc_hbox = widgets.HBox([self._w_noslc_html])

    # widget set for known slices with configs
    self._w_generate_progress = progress.EndlessProgress()
    self._w_generate_label = widgets.Label(value="Traffic model configuration: ")
    self._w_generate_button = widgets.Button(description = "Generate")
    self._w_generate_button.on_click(self._generateTrafficConfig)
    self._w_box_generate = widgets.HBox([self._w_generate_label, self._w_generate_progress.display(), self._w_generate_button])

    # Container for deployable nodes
    self._w_box_deploy = widgets.VBox([])
    self._w_deploy_button = widgets.Button(description = "Deploy")
    self._w_deploy_button.on_click(self._doDeploy)

    self._sig_status_update = blinker.signal("status.update")

  def display (self):
    return self._w_outer_box

  def _observeSlice (self, change):
    if self._lock_slicechange:
      return
    self._buildSliceDeploy()

  def _doDeploy (self, change):
    return

  def _generateTrafficConfig (self, button):
    self._w_generate_progress.reset()
    sname = self._w_dd_slice.value
    staging_path = "%s/%s" % (CFG.get("general.staging_dir"), sname)
    os.makedirs(staging_path, exist_ok = True)

    out = {}
    t1 = threading.Thread(target = makextic,
                          args = (self.fab, self._ttm._slicecfg.model_provider, sname, out),
                          kwargs = {"progress" : self._w_generate_progress})
    t1.start()
    t1.join()

    with open("%s/xtic.json" % (staging_path), "w+") as xf:
      xf.write(json.dumps(out["xtic"]))

    self._w_generate_progress.done()
    self._sig_status_update.send(level = logging.DEBUG, message = str(out["nodes"]))

    if not out["nodes"]:
      self._sig_status_update.send(level = logging.WARNING, message = "No nodes found for deployment")
      return
    
    dnode_boxes = [widgets.HBox([self._w_deploy_button], layout = widgets.Layout(align_content="center"))]
    self._dnode_data = {}
    for dnode in out["nodes"]:
      pbar = widgets.IntProgress(min = 0, max = 100, bar_style= "info", orientation = "horizontal")
      label = widgets.Label(value = dnode.get_name())
      self._dnode_data[dnode.get_name()] = pbar
      dnode_boxes.append(widgets.HBox([label, pbar]))
    self._w_box_deploy.children = dnode_boxes

  def _buildSliceDeploy (self):
    sname = self._w_dd_slice.value
    if sname == "  - ":
      self._w_outer_box.children = [self._w_hbox_slice]
      self._w_box_deploy.children = []
      return

    scfgpath = "%s/slcfgs/%s.json" % (CFG.ddir, sname)
    if not os.path.exists(scfgpath):
      self._w_noslc_html.value = self._noslc_template % ({"slice" : sname})
      self._w_outer_box.children = [self._w_hbox_slice, self._w_noslc_hbox]
      self._w_box_deploy.children = []
      return

    self._w_generate_progress.reset()
    self._w_outer_box.children = [self._w_hbox_slice, self._w_box_generate, self._w_box_deploy]

  def _fillSlices (self):
    self.slices = self.fab.get_slices()
    self.snames = ["  - "]
    self.snames.extend([x.get_name() for x in self.slices])
    self._w_dd_slice.options = self.snames

  def _refreshSlices (self, change):
    self._lock_slicechange = True
    oldvalue = self._w_dd_slice.value
    self._fillSlices()
    self._w_dd_slice.value = oldvalue
    self._lock_slicechange = False


