# Copyright (c) 2022-2023 University of Houston
# Copyright (c) 2024 Kent State University CAE Networking Lab

# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

import json
import os
import os.path
import tarfile

import blinker

import fabrictestbed_extensions.fablib.slice
from fabrictestbed_extensions.fablib.fablib import FablibManager as fablib_manager
import ipywidgets as widgets
from IPython.core.getipython import get_ipython

from .config import CFG
from . import inputtable as IT


def get_local_slice_objs (exclude):
  ns = get_ipython().user_ns
  retl = []
  try:
    for x,v in ns.items():
      if isinstance(v, fabrictestbed_extensions.fablib.slice.Slice):
        continue
  except RuntimeError:
    # Fabric fouls introspection with horrible runtime locals() insertions
    pass

  for x,v in ns.items():
    if isinstance(v, fabrictestbed_extensions.fablib.slice.Slice):
      if v.get_name() not in exclude:
        retl.append(v)
  return retl

class LocalModelProvider():
  def __init__ (self, repomgr):
    self._repomgr = repomgr
    self.models = {}
    self.load()
    
  def load (self):
    for repo in self._repomgr._repoinfo.values():
      if not repo["location"]:
        continue

      for fname in os.listdir(repo["location"]):
        if fname[-5:] != ".bndl":
          continue

        with tarfile.open("%s/%s" % (repo["location"], fname)) as tf:
          mdata = json.loads(tf.extractfile("model.mdata").read())
          self.models["%s.%s" % (repo["name"], mdata["name"])] = mdata
        
  @property
  def names (self):
    return sorted(self.models.keys())

  def getData (self, name):
    return self.models[name]

  def getBundleLocation (self, name):
    # XXX: Don't name your repos with periods right now
    repo,mname = name.split(".", maxsplit = 1)
    return "%s/%s.bndl" % (self._repomgr._repoinfo[repo]["location"], mname)

    
class TgenModelConfigurationTab():
  def __init__ (self, manager):
    self.model = None
    self.repo = None
    self._manager = manager
    self._role_tables = {}

    self._apps = ["  - "]
    self._apps.extend(self._manager.model_provider.names)

    self._w_dd_application = widgets.Dropdown(options = self._apps, description = "Application:")
    self._w_dd_application.observe(self.observeApplication)
    self._w_btn_delete = widgets.Button(icon = "minus-circle", tooltip = "Remove Model",
                                        layout = widgets.Layout(width="50px"))
    self._w_btn_delete.on_click(self._removeTab)
    self._w_hbox_app = widgets.HBox([self._w_dd_application, self._w_btn_delete])
    self._w_vbox_tables = widgets.VBox([])
    self._w_vbox_tab = widgets.VBox([self._w_hbox_app, self._w_vbox_tables])

  @property
  def title (self):
    if not self.model:
      return ""
    return self.model["name"]

  def _removeTab (self, _):
    self._manager.removeTab(self)

  def observeApplication (self, change):
    if change['type'] == 'change' and change['name'] == 'value':
      self.model = self._manager.model_provider.getData(change['new'])
      self._manager.setTabTitle(self.model["name"])
      self.repo = change["new"].split(".")[0]
      self.buildModelTables()

  def buildServerTable (self, roledata):
    cols = []
    cols.append(IT.HeaderColumn(roledata["name"], layout = widgets.Layout(width="100px")))
    cols.append(IT.DropdownColumn("Node.Interface",
                                  values = sorted(self._manager.slice_interfaces, key=lambda x: x[0])))
    cols.append(IT.IntInputColumn("Listen Port", min = 1, max = 65535, default = roledata["listen_port"]))

    opts = IT.TableOptions(allow_add = True, allow_delete = True, allow_reorder = False,
                           row_numbers = False, add_on_right = True)
    tbl = IT.InputTable(cols, roledata, opts)
    return tbl

  def buildClientTable (self, roledata):
    cols = []
    cols.append(IT.HeaderColumn(roledata["name"], layout = widgets.Layout(width="100px")))
    cols.append(IT.DropdownColumn("Node",
                                  values = sorted([x.get_name() for x in self._manager.slice_obj.get_nodes()])))
    cols.append(IT.IntInputColumn("Start Offset (secs)", min = 0, default = 0))

    opts = IT.TableOptions(allow_add = True, allow_delete = True, allow_reorder = False,
                           row_numbers = False, add_on_right = True)
    tbl = IT.InputTable(cols, roledata, opts)
    return tbl

  def buildModelTables (self):
    tbllist = []
    for role in self.model["roles"]:
      if role["type"] == "server":
        tbl = self.buildServerTable(role)
      elif role["type"] == "client":
        tbl = self.buildClientTable(role)
      self._role_tables[role["name"]] = tbl
      tbllist.append(tbl.display())
    self._w_vbox_tables.children = tbllist
    
  def display (self):
    return self._w_vbox_tab

  def asDict (self):
    return {k:v.getAnswerData() for k,v in self._role_tables.items()}

  def fromDict (self, data):
    pass


class SliceTgenConfigurator():
  def __init__ (self, repomgr):
    self.fab = fablib_manager(CFG.get("fabric.rc_path"))
    self.slices = self.fab.get_slices()
    # When both the testbed and locals() know about a slice, we prefer the testbed 
    self.slices.extend(get_local_slice_objs(exclude = [x.get_name() for x in self.slices]))
    self._repomgr = repomgr

    self.snames = ["  - "]
    self.snames.extend([x.get_name() for x in self.slices])

    self.slice_obj = None
    self.model_provider = LocalModelProvider(self._repomgr)
    self.slice_interfaces = []

    self._w_dd_slice = widgets.Dropdown(options = self.snames, description = "Slice:")
    self._w_dd_slice.observe(self.observeSlice)
    self._w_hbox_slice = widgets.HBox([self._w_dd_slice])

    self._w_tab_models = widgets.Tab()
    
    self._w_btn_model = widgets.Button(icon="plus", description = "Add Model")
    self._w_btn_model.on_click(self.buttonAddModel)

    self._w_btn_save = widgets.Button(icon = "floppy-disk", description = "Save")
    self._w_btn_save.on_click(self.buttonSave)

    self._w_hbox_buttons = widgets.HBox([self._w_btn_model, self._w_btn_save])

    self._w_vbox_outer = widgets.VBox([self._w_hbox_slice, self._w_tab_models, self._w_hbox_buttons])
    
    self._tabs = []
    self._tabdata = []

    self.loadSliceData()
    
    self.addTab()
    
  def observeSlice (self, change):
    if change['type'] == 'change' and change['name'] == 'value':
      # TODO: Some kind of save protection
      self.loadSliceData()

  def buttonAddModel (self, button):
    self.addTab()

  def buttonSave (self, button):
    outdict = {}
    for tab in self._tabdata:
      if not tab.title:
        continue

      outdict[tab.title] = tab.asDict()
      outdict[tab.title]["_repo"] = tab.repo

    slcdir = "%s/slcfgs" % (CFG.ddir)
    os.makedirs(slcdir, exist_ok = True)
    with open("%s/%s.json" % (slcdir, self._w_dd_slice.value), "w+") as scf:
      scf.write(json.dumps(outdict))

  def loadSliceData (self):
    if self._w_dd_slice.value == "  - ":
      return

    self.slice_obj = [x for x in self.slices if x.get_name() == self._w_dd_slice.value][0]
    for node in self.slice_obj.get_nodes():
      nname = node.get_name()
      for interface in node.get_interfaces():
        iface_str = "%s.%s (%s)" % (nname, interface.get_name().split("-")[1], interface.get_network().get_name())
        self.slice_interfaces.append((iface_str, interface.get_name()))

  def setTabTitle (self, title):
    self._w_tab_models.set_title(self._w_tab_models.selected_index, title)

  def rebuildTabTitles (self):
    self._w_tab_models.titles = [tab.title for tab in self._tabdata]

  def addTab (self):
    newtab = TgenModelConfigurationTab(self)
    self._tabdata.append(newtab)
    self._tabs.append(newtab.display())
    self._w_tab_models.children = self._tabs
    self._w_tab_models.selected_index = len(self._tabs) - 1

  def removeTab (self, tab):
    self._tabdata.remove(tab)
    self._tabs = [x.display() for x in self._tabdata]
    if not self._tabdata:
      self.addTab()
    else:
      self._w_tab_models.children = self._tabs
      self._w_tab_models.selected_index = len(self._tabs) - 1
    self.rebuildTabTitles()

  def display (self):
    return self._w_vbox_outer
