#!/usr/bin/env python
# Copyright (c) 2024  Kent State University CAE Networking Lab

import argparse
import json
import os.path
import statistics
import sys
import tarfile


def parse_args (args):
  parser = argparse.ArgumentParser()
  parser.add_argument("-d", "--dir", dest="model_dir")
  parser.add_argument("-n", "--short-name", dest="name", default = None)
  opts = parser.parse_args(args)
  return opts

def validate_dir (opts):
  if not os.path.isdir(opts.model_dir):
    raise ValueError("Model directory does not exist")
  necessary_files = ["app_model.json", "metadata.json", "net_xtic.json", "dnsmap.json"]
  for fname in necessary_files:
    if not os.path.exists("%s/%s" % (opts.model_dir, fname)):
      raise ValueError("Model dir file '%s' does not exist" % (fname))

def summarize_model_metadata (acdata):
  mdata = {}
  mdata["proto"] = acdata["l4_proto"]
  mdata["port"] = acdata["port_number"]
  mdata["mtype"] = acdata["num_requests_per_conn"]["mtype"]
  mdata["conn_mean_duration"] = acdata["num_requests_per_conn"]["mean"] * acdata["inter_request_times"]["mean"]
  mdata["conn_max_duration"] = acdata["num_requests_per_conn"]["max"] * acdata["inter_request_times"]["max"]
  return mdata

def parse (opts):
  amodel = json.loads(open("%s/app_model.json" % (opts.model_dir), "r").read())
  mdata = json.loads(open("%s/metadata.json" % (opts.model_dir), "r").read())
  xdata = json.loads(open("%s/net_xtic.json" % (opts.model_dir), "r").read())

  # We can only hope multiple model types aren't used for the same file...
  # That didn't age well.
  conns = {}
  for k,v in amodel.items():
    ac = v["conn_models"]["allconns"]
    conns[v["app_name"]] = {"mdata" : summarize_model_metadata(ac), 
                            "clients" : set(), "servers" : set(), "starts" : []}

  for obj in xdata:
    conns[obj["app"]]["clients"].add(obj["cli_ip"])
    conns[obj["app"]]["servers"].add(obj["srv_ip"])
    conns[obj["app"]]["starts"].append(obj["start"])

  for conn in conns.values():
    conn["clients"] = list(conn["clients"])
    conn["servers"] = list(conn["servers"])

  outdata = {}
  outdata["conns"] = conns
  return outdata

def analyze_starts (starts):
  if len(starts) == 1:
    sdata = {}
    sdata["min"] = None
    sdata["max"] = None 
    sdata["mean"] = None
    sdata["median"] = None
    return sdata
    
  ias = []
  for idx,x in enumerate(starts[1:]):
    ias.append(x - starts[idx])

  sdata = {}
  sdata["min"] = min(ias)
  sdata["max"] = max(ias)
  sdata["mean"] = statistics.mean(ias)
  sdata["median"] = statistics.median(ias)

  return sdata

def configure_continuous ():
  info = {"type" : "continuous"}
  info["start-offset"] = float(input("Start offset: "))
  info["restart-interval"] = float(input("Restart interval: "))
  return info

def configure_periodic ():
  info = {"type" : "periodic"}
  info["start-offset"] = float(input("Start offset: "))
  info["period"] = float(input("Period (seconds): "))
  info["period-variance"] = float(input("Period Variance (seconds): "))
  return info

def gather_input (opts, data):
  name = input("Model name [%s]: " % (opts.name))
  if not name:
    data["name"] = opts.name
  else:
    data["name"] = name

  data["author"] = input("Model author: ")

  data["description"] = input("Description: ")

  roles = {}
  ipmap = {}
  appmap = {}
  dnsmap = json.loads(open("%s/dnsmap.json" % (opts.model_dir), "r").read())

  print()
  print("CONNECTION INFO")
  for k,v in data["conns"].items():
    appmap[k] = {"data" : {"proto" : v["mdata"]["proto"], "port" : v["mdata"]["port"]},
                 "clients": set(), "servers" : set()}

    print("=== Connection: (%s, %d) ===" % (v["mdata"]["proto"], v["mdata"]["port"]))
    starts_data = analyze_starts(v["starts"])
    print("Instances: %d" % (len(v["starts"])))
    if len(v["starts"]) >= 2:
      print("Interarrivals - min: %(min)0.3f, max: %(max)0.3f, mean: %(mean)0.3f, median: %(median)0.3f" % (starts_data))
    del v["starts"]

    while True:
      ctype = input("Connection type ([p]eriodic, [c]ontinuous): ")
      if ctype:
        ctype = ctype[0]

      if ctype in ["c", "p"]:
        break
      print("You must choose a connection type")

    if ctype == "c":
      appmap[k]["conn_info"] = configure_continuous()
    elif ctype == "p":
      appmap[k]["conn_info"] = configure_periodic()

    # Clients
    cmap = {}
    unmapped_cips = []
    for cip in v["clients"]:
      if cip in dnsmap:
        cmap.setdefault(dnsmap[cip],[]).append(cip)
      else:
        unmapped_cips.append(cip)

    del v["clients"]

    print("== Map Roles ==")
    print("There are %d named clients, and %d clients without found names" % (len(cmap), len(unmapped_cips)))

    for name,iplist in cmap.items():
      print("%s [%s]" % (name, ", ".join(iplist)))
      rname = input("Role: ")
      roles[rname] = {"name" : rname, "type" : "client"}
      for cip in iplist:
        ipmap.setdefault(cip, set()).add(rname)
      appmap[k]["clients"].add(rname)
    for cip in unmapped_cips:
      print("%s" % (cip))
      rname = input("Role: ")
      roles[rname] = {"name" : rname, "type" : "client"}
      ipmap.setdefault(cip, set()).add(rname)
      appmap[k]["clients"].add(rname)

    # Servers
    smap = {}
    unmapped_sips = []
    for sip in v["servers"]:
      if sip in dnsmap:
        smap.setdefault(dnsmap[sip],[]).append(sip)
      else:
        unmapped_sips.append(sip)

    del v["servers"]

    print("There are %d named servers, and %d servers without found names" % (len(smap), len(unmapped_sips)))
    for name,iplist in smap.items():
      print("%s [%s]" % (name, ", ".join(iplist)))
      rname = input("Role: ")
      roles[rname] = {"name" : rname, "type" : "server", "listen_port" : v["mdata"]["port"]}
      appmap[k]["servers"].add(rname)
      for sip in iplist:
        ipmap.setdefault(sip, set()).add(rname)
    for sip in unmapped_sips:
      print("%s" % (sip))
      rname = input("Role: ")
      roles[rname] = {"name" : rname, "type" : "server", "listen_port" : v["mdata"]["port"]}
      ipmap.setdefault(sip, set()).add(rname)
      appmap[k]["servers"].add(rname)

  for k,v in appmap.items():
    appmap[k]["clients"] = list(v["clients"])
    appmap[k]["servers"] = list(v["servers"])

  data["roles"] = list(roles.values())
  data["appmap"] = appmap


def write (opts, data):
  bndl = tarfile.TarFile("%s.bndl" % (data["name"]), mode="w")
  bndl.add("%s/app_model.json" % (opts.model_dir), "app_model.json")
  bndl.add("%s/metadata.json" % (opts.model_dir), "metadata.json")
#  bndl.add("%s/net_xtic.json" % (opts.model_dir), "net_xtic.json")
  with open("model.mdata", "w") as mdf:
    mdf.write(json.dumps(data))
  bndl.add("model.mdata")
  bndl.close()
  os.unlink("model.mdata")


if __name__ == '__main__':
  opts = parse_args(sys.argv[1:])

  validate_dir(opts)

  out = parse(opts)
  gather_input(opts, out)
  write(opts, out)

# vim: syntax=python:
