'''client server packet generator'''
import os
import sys
import socket
import time
import datetime
import random
import json
import math
import logging
import argparse
import numpy as np
import pandas as pd
from multiprocessing import Process, Queue, Pipe, current_process, active_children
from .gen_netutils import get_ipv4_addresses, fake_sleep_until
from .model_predict import ModelPredictor as pred
import csv

logger = logging.getLogger(__name__)

TCPTIMEOUT = 300
MAXPDULEN = 65507
DATA65507 = "abcdefgh" * int(MAXPDULEN//8) + "abc" # data length 65507
DATA131072 = "abcdefgh" * int(131072//8) # data length 2^17
if sys.version_info[0] < 3: 
    MDATA65507 = memoryview(bytearray('a' * (MAXPDULEN) + 'abc'))
    MDATA131072 = memoryview(bytearray('a' * (131072))) # data_len = 2^17
else:
    MDATA65507 = memoryview(bytearray(bytes('a' * (MAXPDULEN) + 'abc', 'utf-8')))
    MDATA131072 = memoryview(bytearray(bytes('a' * (131072), 'utf-8'))) # data_len = 2^17
CONV_DELIM = 'y'#must not be any char sequence in DATA131072
PDU_DELIM = 'x'#must not be any char sequence in DATA131072
LENCONVPDUDELIM = len(CONV_DELIM) + len(PDU_DELIM)
LENPDUDELIM = len(PDU_DELIM)

def byteify(input_str): 
    if isinstance(input_str, dict):     
        return dict([(byteify(key), byteify(value)) for key,
         value in list(input_str.items())])
    elif isinstance(input_str, list):   
        return [byteify(element) for element in input_str]   
    elif isinstance(input_str, str):
        return input_str.encode('utf-8')
    else:
        return input_str    


def Makebyte(mstr):
    if mstr == None:
        mstr = ''
    if sys.version_info[0] < 3 and isinstance(mstr, bytes):
        return mstr
    elif sys.version_info[0] >= 3 and isinstance(mstr, str):
        return mstr.encode('utf-8')
    print('&&&&&&&&&', type(mstr))

class GeneratorType3():
    def __init__(self, allappmodels, allusessions, client, outdir):
        self.allappmodels = allappmodels
        self.allusessions = allusessions
        self.client = client
        self.outdir = outdir
        self.global_start = None

    def generate(self):
        self.global_start = time.time()
        if not self.client:
            #start all server procs
            logger.info("calling start servers")
            self.startSrvProcs()
        else:
            #start all client procs
            logger.info("calling start clients")
            self.startCliProcs()
        logger.debug('completed packet generation process')

    def startSrvProcs(self):
        try:
            logger.debug("inside startSrvProcs")
            local_ifs_ips = get_ipv4_addresses()
            local_ips = set([])
            for i, ips in local_ifs_ips.items():
                for ip in ips:
                    local_ips.add(ip)
            srv_ports = self.getallusess_srvports(local_ips)
            srvport_listeners = {}
            for skey, sval in srv_ports.items():
                ip, port, proto, app = sval
                #app_name = str(app_name)
                #app_model = self.allappmodels[app_name]
                logger.debug('parent start srv proc: starting %s listener process for app %s on %s:%s',
                    proto, app, str(ip),str(port))
                childname = '-'.join(['process', app, ip, proto, str(port)])
                if proto == 'udp':
                    srv_proc = Process(target=self.genUDPSrv, args=(ip, port, app), name=childname)
                elif proto == 'tcp':
                    srv_proc = Process(target=self.genTCPSrv, args=(ip, port, app), name=childname)
                srv_proc.daemon = False
                srv_proc.start()
                logger.debug('parent start srv proc: started %s listener process for app %s on %s:%s',
                    proto, app, str(ip),str(port))
                srvport_listeners[proto+'_'+ip+':'+str(port)+'_'+app] = srv_proc
            #this are server processes - no need to join
            #for procname, proci in srvport_listeners:
            #    proci.join()
            #    logger.debug('Main server proc: child %s joined ', , proci.name)
        except KeyboardInterrupt:
            logger.debug('parent start srv proc: keyboard interrupt revd')
        logger.debug('parent start srv proc: exiting')

    def getallusess_srvports(self, local_ips):
        '''return dict{ip&port,tcp/udp} for this local machine'''
        logger.debug('getting list of all server ports from usessions')
        srvports = {}
        for usess in self.allusessions:
            srvip = usess['srv_ip']
            if srvip not in local_ips:
                continue
            app = usess['app']
            appmodel = self.allappmodels[app]
            for conn_type, cmodel in appmodel['conn_models'].items():
                port = cmodel['port_number']
                l4_proto = cmodel['l4_proto']
                #TODO now replace dict below with list/set
                srvports[str(srvip)+'_'+str(port)+'_'+str(l4_proto)] =  (srvip, port, l4_proto, app)
        return srvports

    def genUDPSrv(self, ip, port, app):
        ques = {}
        procs = {}
        proc_name = current_process().name
        logger.debug('%s: opening socket for UDP listener', proc_name)
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        try:
            sock.bind((ip, int(port)))
            logger.debug('%s: UDP listener socket opened at %s:%s', proc_name, ip, port)
            len2 = len(CONV_DELIM) + len(PDU_DELIM)
            i = 0
            num_conv = 0
            while True:
                i += 1
                logger.debug('%s: listening - waiting for new UDP data %d on %s:%s',
                                 proc_name, i, ip, port)
                data, address = sock.recvfrom(65536)
                logger.debug('%s: new data of len:%d received from %s',
                             proc_name ,len(data), str(address) )
                if address not in procs:
                    logger.debug('recognized new UDP conn from ' + str(address))
                    ques[address] = Queue()
                    ques[address].cancel_join_thread()
                    procs[address] = Process(target=self.startUDPTCPSrvConv, 
                                            args=( address, app, 'udp', ques[address], 
                                                sock), name=proc_name + '-conv-' + str(num_conv))
                    procs[address].daemon = False
                    procs[address].start()
                    num_conv += 1
                proc = procs[address]
                que = ques[address]
                logger.debug('%s: putting data received in queue for subprocess to handle', proc_name)
                que.put(data)
                if data[-len2:-len(PDU_DELIM)] == CONV_DELIM  or len(data) == 0:
                    del procs[address]
                    del ques[address]
            #while loop server is forever no joining, only interrupt
        except KeyboardInterrupt:
            logger.debug('%s: keyboard interrupt revd', proc_name)
        for addr, que in ques.items():
            while not que.empty():
                a = que.get()
        sock.close()
        logger.debug('%s: exiting', proc_name)

    def genTCPSrv(self, ip, port, app):
        proc_name = current_process().name
        logger.debug('%s: opening socket for TCP listener', proc_name)
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        try:
            sock.bind((ip, int(port)))
            sock.listen(128)
            logger.debug('%s: TCP listener socket opened at %s:%s' , proc_name, ip, str(port))
            i = 0
            while True:
                i += 1
                logger.debug('%s: listening - waiting for a new TCP connection %d on '\
                            '%s:%s', proc_name, i, ip, str(port))
                conn, address = sock.accept()
                logger.debug('%s: new connection received from %s' , proc_name, str(address))
                proc = Process(target=self.startUDPTCPSrvConv, 
                                args=(address, app, 'tcp', None, None, conn),
                                name=proc_name + '-conv-' + str(i))
                proc.daemon = False
                proc.start()
            #while loop server is forever no joining, only interrupt
        except KeyboardInterrupt:
            logger.debug('%s: keyboard interrupt revd', proc_name)
        sock.close()
        logger.debug('%s: exiting', proc_name)

    def startUDPTCPSrvConv(self, address, app, proto=None, que=None, sock=None, conn=None):
        assert (proto in ['tcp', 'udp']), 'proto must be "tcp" or "udp"'
        if proto == 'tcp':
            assert(conn != None), 'since proto is "tcp" conn must be supplied'
        else:
            assert(sock != None and que != None), 'since proto is "tcp" conn must be supplied'
        proc_name = current_process().name
        try:
            pid = os.getpid()
            t_conv_start = time.time()
            logger.debug('%s: starting new TCP server conversation for %s', proc_name, str(address))
            i = 0 # num_req_recv 
            end_this_conv = False
            pdu_delim_bytes = Makebyte(PDU_DELIM)
            conv_delim_bytes = Makebyte(CONV_DELIM)
            while True: 
                req = Makebyte('')
                logger.debug('%s: waiting to get request - %d' , proc_name, i)
                while True:
                    if proto == 'tcp':
                        data = conn.recv(4096)
                    else: 
                        data = que.get()
                    #logger.debug('%s: received portion of request-%d, length %d' , proc_name, i, len(data))
                    req += data
                    #logger.debug('%s: last n bytes of data is %s and pdu_delim_bytes is %s ',
                    #             proc_name, data[-0-len(PDU_DELIM):], pdu_delim_bytes)
                    if data[0-len(PDU_DELIM):] == pdu_delim_bytes:
                        logger.debug('%s: request %d  is fully received', proc_name, i)
                        req = req[:0-len(PDU_DELIM)]
                        break

                    if len(data) == 0:
                        #logger.debug('%s: len data is %d, breaking', proc_name, len(data))
                        break

                t_pdu_start = time.time()
                #logger.debug('%s: request %d bytes received is %s', proc_name, i, req[:50])
                if sys.version_info[0] >= 3:
                    req = req.decode('utf-8')
                #logger.debug('%s: request %d str received is %s', proc_name, i, req[:50])
                arr = req.split('*')
                logger.debug('%s: details of request %d: %s', proc_name, i, str(arr[:-1]))
                
                #logger.debug('%s: len full req is %d', proc_name, len(req))
                #logger.debug('%s: last n bytes of full req is %s and pdu_delim_bytes is %s ', proc_name, req[-0-len(CONV_DELIM):], CONV_DELIM)
                if req[0-len(CONV_DELIM):] == CONV_DELIM or len(req) == 0:
                    req = req[:0-len(CONV_DELIM)]
                    logger.debug('%s: converstsation closed by client' , proc_name)
                    end_this_conv = True

                num_resp = 0 if len(req)==0 else int(arr[0]) 
                logger.debug('%s: %d response PDUs will be sent' , proc_name, num_resp)
                
                res_arr = arr[1:-1]
                j = 0
                while j < num_resp:
                    pdu_size = int(res_arr[2*j])
                    pdu_time = float(res_arr[2*j + 1])

                    logger.debug('%s: size of response pdu %d is %d' , proc_name, j, pdu_size)
                    logger.debug('%s: fake sleeping for response time %f' , proc_name, pdu_time)
                    fake_sleep_until(t_pdu_start + pdu_time-0.1)
                    bytes_to_send = pdu_size
                    while  bytes_to_send > len(MDATA131072):
                        logger.debug('%s: sending full MDATA131072 as response', proc_name)
                        if proto == 'tcp':
                            conn.sendall(MDATA131072)
                        else:
                            sock.sendto(MDATA65507,address)
                        bytes_to_send -= len(MDATA131072)
                    last_d = DATA131072[:bytes_to_send]
                    last_d = last_d[:-len(PDU_DELIM)]+PDU_DELIM

                    logger.debug('%s: sending %d portion MDATA131072 as response', proc_name, bytes_to_send)
                    if proto == 'tcp':
                        conn.sendall(Makebyte(last_d)) #<<<<<<<<<<
                    else:
                        sock.sendto(Makebyte(last_d), address)
                    logger.debug('%s: *** sent response pdu %d of size %d to '\
                                'client.', proc_name, j, pdu_size)
                    j += 1
                if end_this_conv:
                    break
                i += 1
        # add keyboard interrupt except
        except KeyboardInterrupt:
            logger.debug('%s: keyboard interrupt revd', proc_name)
        if conn != None:
            #conn.shutdown(socket.SHUT_RDWR)
            conn.close()
        logger.debug('%s: exiting', proc_name)

    def startCliProcs(self):
        '''creates a process to handle each user session'''
        try:
            usess_procs = []
            local_ifs_ips = get_ipv4_addresses()
            local_ips = set([])
            for i, ips in local_ifs_ips.items():
                for ip in ips:
                    local_ips.add(ip)
            #logger.debug('found %d user sessions, will start a process for each',
            #             len(self.allusessions))
            self.allusessions.sort(key=lambda x: x['start'])
            for i, usess in enumerate(self.allusessions):
                cliip = usess['cli_ip']
                if cliip not in local_ips:
                    continue
                logger.debug("parent start cli proc:starting client process for usersess %d app: %s", i, usess['app'])
                start = usess['start']
                logger.debug("parent start cli proc: sleeping for  %d secs before usess %d", self.global_start + start - time.time(), i)
                fake_sleep_until(self.global_start + start-0.1)
                childname = '-'.join(['process', usess['app'],'usess', str(i)])
                usess_proc = Process(target=self.start_cli_usess,args=(usess, i), name=childname)
                usess_proc.daemon = False
                usess_proc.start()
                usess_procs.append(usess_proc)
                #active sessions to join completed?
            for proci in usess_procs:
                proci.join()
        except KeyboardInterrupt:
            logger.debug('parent start cli proc: keyboard interrupt revd')
        logger.debug('parent start cli proc: exiting')

    def start_cli_usess(self, usess, usess_idx):
        '''calculates all the connections details for a user session then 
        creates calls a method that creates each connection and send packets'''
        proc_name = current_process().name
        logger.debug('%s: started usersession %d', proc_name, usess_idx)
        conn_procs = []
        try:
            app = usess['app']
            appmodel = self.allappmodels[app]
            umodel = appmodel['usess_model']
            start_time = self.global_start + usess['start']
            num_conns = int(round(pred().predictVal(umodel['num_conns'])))
            logger.debug('%s: usess will have %d conns', proc_name, num_conns)

            ctype_seq = random.choice(umodel['usess_conn_seqs'])
            ctype_seq2 = ctype_seq[1:]
            #n_1stctype = int(round(pred().predictVal(umodel['usess_first_conntype_num'])))

            for i in range(num_conns):
                if i >= len(ctype_seq) and len(ctype_seq2) > 0:
                    #logger.debug('%s: ctype_seq2 = %s, len=%d, i-1 = , ', proc_name, num_conns)
                    ctype = ctype_seq2[(i-1) % len(ctype_seq2)]
                else:
                    ctype = ctype_seq[i % len(ctype_seq)]

                #ctype = np.random.choice(umodel['conn_type_distr][ctypes], 
                #                           p=umodel['conn_type_distr][probs])
                #if i <= n_1stctype:
                    #ctype = umodel['usess_first_conntype']

                cmodel = appmodel['conn_models'][str(ctype)]
                overlap = umodel['has_overlap_conns']
                childname = proc_name + '-conn-' + str(i)
                conn_proc = Process(target=self.start_cli_conn, 
                    args=(usess, cmodel, i, start_time),name=childname)
                conn_proc.daemon = False #True       
                conn_proc.start()
                inter_conn_time = pred().predictVal(umodel['conn_inter_arrrival'])

                logger.debug('%s: inter conn time is  %f connpools', proc_name, inter_conn_time)
                inter_conn_time = inter_conn_time if inter_conn_time else 0
                start_time += inter_conn_time
                if not overlap:
                    conn_proc.join()
                else:
                    conn_procs.append(conn_proc)
                fake_sleep_until(start_time)
            for proci in conn_procs:
                proci.join()
        except KeyboardInterrupt:
            logger.debug('%s : keyboard interrupt revd', proc_name)
        logger.debug('%s: exiting', proc_name)

    def start_cli_conn(self, usess, cmodel, cid, starttime):
        proc_name = current_process().name
        curr_time = starttime
        try:
            pid = os.getpid()
            logger.debug('%s: started process "%s" for new connection',
                        proc_name, proc_name)

            app = usess['app']
            srv_ip = usess['srv_ip']
            cli_ip = usess['cli_ip']
            srv_port = cmodel['port_number']
            l4_proto = cmodel['l4_proto']
            target_ip_port = (srv_ip, srv_port)
            cli_port_sock = self.make_socks(1, cmodel, proc_name, target_ip_port)[0]

            num_reqs_burst = int(np.round(pred().predictVal(cmodel['num_req_burst_per_conn'])))
            logger.debug('%s: conn will have %d request bursts',
                        proc_name, num_reqs_burst)
            curr_req_burst = 0
            while curr_req_burst < num_reqs_burst:
                try:
                    #send request burst
                    num_req_in_burst = int(np.round(pred().predictVal(cmodel['num_req_in_req_burst'])))
                    logger.debug('%s: request burst %d will have %d requests',
                            proc_name, curr_req_burst, num_req_in_burst) #bursts

                    num_responses = []
                    
                    for i in range(num_req_in_burst):
                        logger.debug('%s: sending req pdu %d-%d', proc_name, curr_req_burst, i)
                        req_size = int(np.round(pred().predictVal(cmodel['request_sizes'])))
                        logger.debug("%s: request size of req pdu %d-%d is %d", proc_name, curr_req_burst, i, req_size)
                        num_resps = int(np.round(pred().predictVal(cmodel['num_response_per_req'])))
                        logger.debug('%s: req pdu %d-%d will have %d responses', proc_name, curr_req_burst, i, num_resps)
                        resp_details = self.make_resp_details(num_resps, cmodel, proc_name)
                        logger.debug('%s: calculated resp details for req pdu %d is %s ', proc_name, curr_req_burst, resp_details)
                        num_responses.append((i, num_resps))
                        self.send_request(req_size, resp_details, cli_port_sock, l4_proto, target_ip_port, proc_name)
                    
                    #TODO: check if responses are bundled - remove for loop
                    for i, num_resps in num_responses:
                        if num_resps > 0:
                            self.recv_response(cli_port_sock, num_resps, proc_name)
                            break

                    inter_req_time = pred().predictVal(cmodel['inter_req_burst_time_per_conn'])
                    logger.debug('%s: inter request time after req burst %d is %d, sleeping',
                                proc_name, curr_req_burst, inter_req_time)
                    curr_time += inter_req_time
                    fake_sleep_until(curr_time)
                    curr_req_burst += 1

                except Exception as e:
                    logger.error(e, exc_info=True)
                    raise
            self.send_close_conn(cli_port_sock, l4_proto, target_ip_port, proc_name)
        except KeyboardInterrupt:
            logger.debug('%s: keyboard interrupt revd', proc_name)
        logger.debug('%s: exiting', proc_name)


    def make_socks(self, num_sock, cmodel, proc_name, target_ip_port):
        sock_list = []
        for i in range(num_sock):
            if cmodel['l4_proto'] == 'tcp':
                logger.debug('%s: starting tcp conn %d', proc_name, i)
                sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                sock.connect(target_ip_port)
            else:
                logger.debug('%s: starting udp %d', proc_name, i)
                sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            sock.settimeout(2.0)

            sock_list.append(sock)
        logger.debug('%s: connections started', proc_name)
        return sock_list
    
    def make_resp_details(self, num_resps, cmodel, proc_name):
        resp_details = str(num_resps)
        for k in range(num_resps):
            resp_size = int(pred().predictVal(cmodel['response_sizes']))
            srv_time = pred().predictVal(cmodel['server_processing_times'])
            srv_time = float("{0:.4f}".format(srv_time))
            resp_details = '*'.join([resp_details,str(resp_size),
                                    str(srv_time)])
            logger.debug('%s: resp_size = %d, resp_time = %f', 
                        proc_name, resp_size, srv_time)
        resp_details = ''.join([resp_details, '*'])
        logger.debug('%s: resp_details = %s', 
                    proc_name, resp_details)
        return resp_details

    def send_request(self, req_size, resp_details, soc, l4_proto, target_ip_port, proc_name):
        num_chunks = int(math.ceil(float(req_size)/MAXPDULEN))
        bytes_to_send = req_size
        for i in range(num_chunks):
            chunk_size = min(bytes_to_send, MAXPDULEN)
            chunk = DATA65507[:bytes_to_send]
            if i == 0: #first
                if bytes_to_send <= len(resp_details):
                    chunk = resp_details
                else:
                    chunk = ''.join([resp_details,
                            DATA65507[len(resp_details):bytes_to_send]])
            if i == num_chunks - 1: #last
                    if chunk[-1] == '*':
                        chunk = ''.join([chunk, PDU_DELIM])
                    else:
                        chunk = ''.join([chunk[:-LENPDUDELIM], PDU_DELIM])
            if i != 0 and i != num_chunks - 1: #middle only
                chunk = DATA65507
            logger.debug('%s: sending chunk: %s....', 
                        proc_name, chunk[:50])
            
            if l4_proto == 'tcp':
                soc.sendall(Makebyte(chunk))
            else:
                soc.sendto(Makebyte(chunk), (target_ip_port))
            bytes_to_send -= MAXPDULEN
        logger.debug('%s: *** request sent, d_size: %d, resp_details: %s', proc_name, req_size, resp_details)
               
    def recv_response(self, soc, num_resps, proc_name):
        #get and process reply - need for tcp not udp
        r_pdu = Makebyte('')
        pdu_delim_bytes = Makebyte(PDU_DELIM)
        numreplies = min(1, num_resps)
        for q in range(1, 1+numreplies): 
            #had to replace num_resps with 1 bcos multi replies are received combined as 1
            logger.debug('%s: waiting for reply %d of %d total', proc_name, q, numreplies)
            try:
                while True:
                    data = soc.recv(4096)
                    logger.debug('%s: portion of reply %d of length-%d received',
                            proc_name, q, len(data))
                    r_pdu = r_pdu + data
                    if data[0-len(PDU_DELIM):] == pdu_delim_bytes:
                        break 
                logger.debug('%s: full reply %d of length-%d received',
                            proc_name, q, len(r_pdu))
            except socket.timeout:
                logger.warning('%s: reply  %d not received within timeout', proc_name, q)

    def send_close_conn(self, soc, l4_proto, target_ip_port, proc_name):
    
        chunk = ''.join(['0*', CONV_DELIM, PDU_DELIM])            
        if l4_proto == 'tcp':
            soc.sendall(Makebyte(chunk))
        else:
            soc.sendto(Makebyte(chunk), (target_ip_port))
        soc.close()
