'''client server packet generator'''
import os
import sys
import socket
import time
import json
import math
import logging
import argparse
from multiprocessing import Process, Queue, Pipe, current_process, active_children
from .gen_netutils import get_ipv4_addresses, fake_sleep_until
from .model_predict import ModelPredictor as pred

logger = logging.getLogger(__name__)

TCPTIMEOUT = 300
MAXPDULEN = 65507
DATA65507 = "abcdefgh" * int(MAXPDULEN//8) + "abc" # data length 65507
DATA131072 = "abcdefgh" * int(131072//8) # data length 2^17
if sys.version_info[0] < 3: 
    MDATA65507 = memoryview(bytearray('a' * (MAXPDULEN) + 'abc'))
    MDATA131072 = memoryview(bytearray('a' * (131072))) # data_len = 2^17
else:
    MDATA65507 = memoryview(bytearray(bytes('a' * (MAXPDULEN) + 'abc', 'utf-8')))
    MDATA131072 = memoryview(bytearray(bytes('a' * (131072), 'utf-8'))) # data_len = 2^17
CONV_DELIM = 'y'#must not be any char sequence in DATA131072
PDU_DELIM = 'x'#must not be any char sequence in DATA131072
LENCONVPDUDELIM = len(CONV_DELIM) + len(PDU_DELIM)
LENPDUDELIM = len(PDU_DELIM)

def byteify(input_str): 
    if isinstance(input_str, dict):     
        return dict([(byteify(key), byteify(value)) for key,
         value in list(input_str.items())])
    elif isinstance(input_str, list):   
        return [byteify(element) for element in input_str]   
    elif isinstance(input_str, str):
        return input_str.encode('utf-8')
    else:
        return input_str    

class GeneratorGeneric1():
    def __init__(self, allappmodels, allusessions, client, outdir):
        self.allappmodels = allappmodels
        self.allusessions = allusessions
        self.client = client
        self.outdir = outdir
        self.global_start = None

    def generate(self):
        #TODO: Verify that opts.ip_addr is on this computer.
        self.global_start = time.time()
        if not self.client:
            #start all server procs
            logger.info("calling start servers")
            self.startSrvProcs()
        else:
            #start all client procs
            logger.info("calling start clients")
            self.startCliProcs()
        logger.debug('completed packet generation process')

    def startSrvProcs(self):
        try:
            logger.debug("inside startSrvProcs")
            local_ifs_ips = get_ipv4_addresses()
            local_ips = set([])
            for i, ips in local_ifs_ips.items():
                for ip in ips:
                    local_ips.add(ip)
            srv_ports = self.getallusess_srvports(local_ips)
            srvport_listeners = {}
            for skey, sval in srv_ports.iteritems():
                ip, port, proto, app = sval
                #app_name = str(app_name)
                #app_model = self.allappmodels[app_name]
                logger.debug('parent start srv proc: starting %s listener process for app %s on %s:%s',
                    proto, app, str(ip),str(port))
                childname = '-'.join(['proc', proto, str(port)]) #add appname
                if proto == 'udp':
                    srv_proc = Process(target=self.genUDPSrv, args=(ip, port, app), name=childname)
                elif proto == 'tcp':
                    srv_proc = Process(target=self.genTCPSrv, args=(ip, port, app), name=childname)
                srv_proc.daemon = False
                srv_proc.start()
                logger.debug('parent start srv proc: started %s listener process for app %s on %s:%s',
                    proto, app, str(ip),str(port))
                srvport_listeners[proto+'_'+ip+':'+str(port)+'_'+app] = srv_proc
        except KeyboardInterrupt:
            logger.debug('parent start srv proc: keyboard interrupt revd')
        logger.debug('parent start srv proc: exiting')

    def getallusess_srvports(self, local_ips):
        '''return dict{ip&port,tcp/udp} for this local machine'''
        logger.debug('getting list of all server ports from usessions')
        srvports = {}
        for usess in self.allusessions:
            srvip = usess['srv_ip']
            if srvip not in local_ips:
                continue
            app = usess['app']
            appmodel = self.allappmodels[app]
            for conn_type, cmodel in appmodel['conn_models'].items():
                port = cmodel['port_number']
                l4_proto = cmodel['l4_proto']
                #TODO now replace dict below with list/set
                srvports[str(srvip)+'_'+str(port)+'_'+str(l4_proto)] =  (srvip, port, l4_proto, app)
        return srvports

    def genUDPSrv(self, ip, port, app):
        #TODO replace queues with pipes? optional..
        ques = {}
        procs = {}
        proc_name = current_process().name
        logger.debug('%s: opening socket for UDP listener', proc_name)
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        try:
            sock.bind((ip, int(port)))
            logger.debug('%s: UDP listener socket opened at %s:%s', proc_name, ip, port)
            len2 = len(CONV_DELIM) + len(PDU_DELIM)
            i = 0
            num_conv = 0
            while True:
                i += 1
                logger.debug('%s: listening - waiting for new UDP data %d on %s:%s',
                                 proc_name, i, ip, port)
                data, address = sock.recvfrom(65536)
                logger.debug('%s: new data of len:%d received from %s',
                             proc_name ,len(data), str(address) )
                if address not in procs:
                    logger.debug('recognized new UDP conn from ' + str(address))
                    ques[address] = Queue()
                    ques[address].cancel_join_thread()
                    procs[address] = Process(target=self.startUDPTCPSrvConv, 
                                            args=( address, app, 'udp', ques[address], 
                                                sock), name=proc_name + '-conv-' + str(num_conv))
                    procs[address].daemon = False
                    procs[address].start()
                    num_conv += 1
                proc = procs[address]
                que = ques[address]
                logger.debug('%s: putting data received in queue for subprocess to handle', proc_name)
                que.put(data)
                if data[-len2:-len(PDU_DELIM)] == CONV_DELIM  or len(data) == 0:
                    del procs[address]
                    del ques[address]
            #while loop server is forever no joining, only interrupt
        except KeyboardInterrupt:
            logger.debug('%s: keyboard interrupt revd', proc_name)
        for addr, que in ques.items():
            while not que.empty():
                a = que.get()
        sock.close()
        logger.debug('%s: exiting', proc_name)

    def genTCPSrv(self, ip, port, app):
        proc_name = current_process().name
        logger.debug('%s: opening socket for TCP listener', proc_name)
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        try:
            sock.bind((ip, int(port)))
            sock.listen(128)
            logger.debug('%s: TCP listener socket opened at %s:%s' , proc_name, ip, str(port))
            i = 0
            while True:
                i += 1
                logger.debug('%s: listening - waiting for a new TCP connection %d on '\
                            '%s:%s', proc_name, i, ip, str(port))
                conn, address = sock.accept()
                logger.debug('%s: new connection received from %s' , proc_name, str(address))
                proc = Process(target=self.startUDPTCPSrvConv, 
                                args=(address, app, 'tcp', None, None, conn),
                                name=proc_name + '-conv-' + str(i))
                proc.daemon = False
                proc.start()
            #while loop server is forever no joining, only interrupt
        except KeyboardInterrupt:
            logger.debug('%s: keyboard interrupt revd', proc_name)
        sock.close()
        logger.debug('%s: exiting', proc_name)

    def startUDPTCPSrvConv(self, address, app, proto=None, que=None, sock=None, conn=None):
        assert (proto in ['tcp', 'udp']), 'proto must be "tcp" or "udp"'
        if proto == 'tcp':
            assert(conn != None), 'since proto is "tcp" conn must be supplied'
        else:
            assert(sock != None and que != None), 'since proto is "udp" sock and que must be supplied'
        proc_name = current_process().name
        try:
            pid = os.getpid()
            t_conv_start = time.time()
            logger.debug('%s: starting new TCP server conversation for %s', proc_name, str(address))
            i = 0 # num_req_recv 
            end_this_conv = False
            while True: 
                req = ''
                logger.debug('%s: waiting to get request - %d' , proc_name, i)
                while True:
                    if proto == 'tcp':
                        data = conn.recv(4096)
                    else: 
                        data = que.get()
                    #logger.debug('%s: received portion of request-%d, length '\
                    #            '%d' %(pid, i, len(data)) )
                    req += data
                    if req[0-len(PDU_DELIM):] == PDU_DELIM:
                        req = req[:0-len(PDU_DELIM)]
                        break
                logger.debug('%s: request %d received is %s', proc_name, i, req[:50])
                arr = req.split('*')
                logger.debug('%s: details of request %d: %s', proc_name, i, str(arr[:-1]))
                num_resp = int(arr[0])
                logger.debug('%s: %d response PDUs will be sent' , proc_name, num_resp)
                
                req_size = len(req)
                if req[0-len(CONV_DELIM):] == CONV_DELIM or len(req) == 0:
                    req = req[:0-len(CONV_DELIM)]
                    logger.debug('%s: converstsation closed by client' , proc_name)
                    end_this_conv = True
                t_pdu_start = time.time()
                res_arr = arr[1:-1]
                j = 0
                while j < num_resp:
                    pdu_size = int(res_arr[2*j])
                    pdu_time = float(res_arr[2*j + 1])

                    logger.debug('%s: size of response pdu %d is %d' , proc_name, j, pdu_size)
                    logger.debug('%s: fake sleeping for response time %f' , proc_name, pdu_time)
                    fake_sleep_until(t_pdu_start + pdu_time-0.2)
                    bytes_to_send = pdu_size
                    while  bytes_to_send > len(MDATA131072):
                        if proto == 'tcp':
                            conn.sendall(MDATA131072)
                        else:
                            sock.sendto(MDATA65507,address)
                        bytes_to_send -= len(MDATA131072)
                    last_d = DATA131072[:bytes_to_send]
                    last_d = last_d[:-len(PDU_DELIM)]+PDU_DELIM
                    if proto == 'tcp':
                        conn.sendall(last_d) #<<<<<<<<<<
                    else:
                        sock.sendto(last_d, address)
                    logger.debug('%s: sent response pdu %d of size %d to '\
                                'client.', proc_name, j, pdu_size)
                    j += 1
                if end_this_conv:
                    break
                i += 1
        # add keyboard interrupt except
        except KeyboardInterrupt:
            logger.debug('%s: keyboard interrupt revd', proc_name)
        if conn != None:
            conn.shutdown(socket.SHUT_RDWR)
            conn.close()
        logger.debug('%s: exiting', proc_name)
   
    def startCliProcs(self):
        '''creates a process to handle each user session'''
        try:
            usess_procs = []
            local_ifs_ips = get_ipv4_addresses()
            local_ips = set([])
            for i, ips in local_ifs_ips.items():
                for ip in ips:
                    local_ips.add(ip)
            #logger.debug('found %d user sessions, will start a process for each',
            #             len(self.allusessions))
            for i, usess in enumerate(self.allusessions):
                cliip = usess['cli_ip']
                if cliip not in local_ips:
                    continue
                logger.debug("parent start cli proc:starting client process for usersess %d app: %s", i, usess['app'])
                start = usess['start']
                logger.debug("parent start cli proc: sleeping for  %d secs before usess %d", self.global_start + start - time.time(), i)
                fake_sleep_until(self.global_start + start-0.2)
                childname = '-'.join(['proc','usess'+ str(i)])
                usess_proc = Process(target=self.start_cli_usess,args=(usess, i), name=childname)
                usess_proc.daemon = False
                usess_proc.start()
                usess_procs.append(usess_proc)
                #active sessions to join completed?
            for proci in usess_procs:
                proci.join()
        except KeyboardInterrupt:
            logger.debug('parent start cli proc: keyboard interrupt revd')
        logger.debug('parent start cli proc: exiting')

    def start_cli_usess(self, usess, usess_idx):
        '''calculates all the connections details for a user session then 
        creates calls a method that creates each connection and send packets'''
        proc_name = current_process().name
        logger.debug('%s: started usersession %d', proc_name, usess_idx)
        conn_type_procs = []
        try:
            app = usess['app']
            appmodel = self.allappmodels[app]
            usess_start_time = usess['start']
            usess_start_time_epoch = self.global_start + usess['start']
            for conntype, cmodel in appmodel['conn_models'].items():
                childname = proc_name + '-ctype' + conntype
                conn_type_proc = Process(target=self.start_cli_conn_type,args=(cmodel, usess, usess_start_time_epoch), name=childname)
                conn_type_proc.daemon = False
                conn_type_proc.start()
                conn_type_procs.append(conn_type_proc)
            for proci in conn_type_procs:
                proci.join()
        except KeyboardInterrupt:
            logger.debug('%s : keyboard interrupt revd', proc_name)
        logger.debug('%s: exiting', proc_name)

    def start_cli_conn_type(self, cmodel, usess, usess_start_time_epoch):
        proc_name = current_process().name
        try:
            numconnpools = int(round(pred().predictVal(cmodel['num_conn_pools'])))
            overlap = cmodel['has_overlapping_conn_pools']
            last_start = time.time()
            last_end = last_start
            logger.debug('%s: will have %d conn pools',
                        proc_name, numconnpools)
            conn_pool_procs = []
            for i in range(numconnpools):
                if i == 0:
                    connstart = pred().predictVal(cmodel['first_conn_start_time'])
                    starttime = usess_start_time_epoch + connstart
                elif overlap:
                    waitinterval = pred().predictVal(cmodel['start_start_intervals'])
                    starttime =  last_start + waitinterval
                else:
                    waitinterval = pred().predictVal(cmodel['end_start_intervals'])
                    starttime =  last_end + waitinterval
                last_start = starttime
                child_name = proc_name+ '-cpool' + str(i)
                logger.debug('%s: fake sleeping for %f before starting %s', 
                             proc_name, starttime - 0.01 - time.time(), child_name)
                fake_sleep_until(starttime-0.2)
                conn_pool_proc = Process(target=self.start_cli_conn_pool, 
                            args=(usess, cmodel,  i),
                            name=child_name)
                conn_pool_proc.daemon = False       
                conn_pool_proc.start()
                conn_pool_procs.append(conn_pool_proc)
                if not overlap:
                    conn_pool_proc.join()
                last_end = time.time()
            for proci in conn_pool_procs:
                proci.join()
        except KeyboardInterrupt:
            logger.debug('%s: keyboard interrupt revd', proc_name)
        logger.debug('%s: exiting', proc_name)

    def start_cli_conn_pool(self, usess, cmodel, cid):
        proc_name = current_process().name
        try:
            pid = os.getpid()
            logger.debug('%s: started process %d for new connection pool',
                        proc_name, pid)
            num_conns_in_pool = int(round(pred().predictVal(cmodel['size_conn_pools'])))
            logger.debug('%s: has pool of %d connections.. initiating',
                        proc_name, num_conns_in_pool)
            target_ip_port = (usess['srv_ip'], cmodel['port_number'])
            sockproc_list = self.make_sock_procs(num_conns_in_pool, cmodel, proc_name, target_ip_port)


            num_req_bursts = int(round(pred().predictVal(cmodel['num_req_burst_per_conn_pool'])))
            logger.debug('%s: connpool will have %d request bursts',
                        proc_name, num_req_bursts) #bursts
            nburst, nreq = [0, 0]
            while nburst < num_req_bursts:
                try:
                    b_start_time = time.time()
                    num_req_in_burst = int(round(pred().predictVal(cmodel['num_req_in_req_burst'])))
                    logger.debug('%s: request burst %d will have %d requests',
                            proc_name, nburst, num_req_in_burst) #bursts
                    for i in range(num_req_in_burst):
                        logger.debug('%s: sending req pdu %d-%d', proc_name, nburst, i)
                        req_size = int(round(pred().predictVal(cmodel['request_sizes'])))
                        num_resps = int(round(pred().predictVal(cmodel['num_response_per_req'])))
                        resp_details = self.make_resp_details(num_resps, cmodel, proc_name)
                        logger.debug('%s: calculated resp details for req pdu %d-%d is %s ', proc_name, nburst, i, resp_details)
                        
                        socki = nreq % len(sockproc_list)
                        socx, pipex, sockproc = sockproc_list[socki]
                        dat = ':'.join([str(req_size), str(num_resps), resp_details])
                        logger.debug('%s: dat is = %s', proc_name, dat)

                        logger.debug('%s: sending req pdu %d-%d to pipe of sock %d of %d', proc_name, nburst, i, socki, len(sockproc_list))
                        pipex.send([req_size, num_resps, resp_details])
                        logger.debug('%s: sent req pdu %d-%d to pipe of sock %d of %d', proc_name, nburst, i, socki, len(sockproc_list))
                        nreq += 1
                    inter_burst_time = pred().predictVal(cmodel['inter_req_burst_time'])
                    tx = inter_burst_time + b_start_time
                    
                    logger.debug('%s: fake sleeping for %d, before next req_burst ',
                                proc_name, tx - 0.01 - time.time())
                    fake_sleep_until(tx-0.2)
                    nburst += 1
                except Exception as e:
                    logger.error(e, exc_info=True)
                    raise
            close_wait = pred().predictVal(cmodel['close_wait'])
            time.sleep (min(5, close_wait))
            for sock, pipex, sockproc in sockproc_list:
                pipex.send([0, 0, 'end_conn'])
            time.sleep(2)
            for sock, pipex, sockproc in sockproc_list:
                sockproc.join()
        except KeyboardInterrupt:
            logger.debug('%s: keyboard interrupt revd', proc_name)
        logger.debug('%s: exiting', proc_name)

    def make_sock_procs(self, num_conns_in_pool, cmodel, proc_name, target_ip_port):
        sockproc_list = []
        for i in range(num_conns_in_pool):
            if cmodel['l4_proto'] == 'tcp':
                logger.debug('%s: starting tcp conn %d', proc_name, i)
                sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                sock.connect(target_ip_port)
            else:
                logger.debug('%s: starting udp %d', proc_name, i)
                sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            
            par_conn, child_conn = Pipe()
            child_name = proc_name+ '-sk' + str(i)
            sockproc = Process(target=self.sock_proc, 
                            args=(cmodel, sock, child_conn, target_ip_port),
                            name=child_name)
            sockproc.daemon = False       
            sockproc.start()
            sockproc_list.append((sock, par_conn, sockproc))
        logger.debug('%s: connections started', proc_name)
        return sockproc_list

    def make_resp_details(self, num_resps, cmodel, proc_name):
        resp_details = str(num_resps)
        for k in range(num_resps):
            resp_size = int(round(pred().predictVal(cmodel['response_sizes'])))
            srv_time = pred().predictVal(cmodel['server_processing_times'])
            srv_time = float("{0:.4f}".format(srv_time))
            resp_details = '*'.join([resp_details,str(resp_size),
                                    str(srv_time)])
            logger.debug('%s: resp_size = %d, resp_time = %f', 
                        proc_name, resp_size, srv_time)
        resp_details = ''.join([resp_details, '*'])
        logger.debug('%s: resp_details = %s', 
                    proc_name, resp_details)
        return resp_details

    def sock_proc(self, cmodel, socki, child_conn, target_ip_port):
        proc_name = current_process().name
        try:
            logger.debug('%s sock process started awaiting pkt details from pipe', proc_name)
            i = 0
            while True:
                #continue
                req_size, num_resps, resp_details = child_conn.recv()
                logger.debug('%s pdu details %d received from pipe', proc_name, i)
                self.send_request(req_size, resp_details, socki, cmodel['l4_proto'], target_ip_port, proc_name)
                self.recv_response(socki, num_resps, proc_name)
                if resp_details == 'end_conn':
                    self.send_close_conn(socki, cmodel['l4_proto'], target_ip_port, proc_name)
                    logger.debug('%s end conn %d received from pipe', proc_name, i)
                    break
        except EOFError:
            logger.error('%s: parent closed pipe')
        except KeyboardInterrupt:
            logger.debug('%s: keyboard interrupt revd', proc_name)
        except Exception as e:
            logger.error(e, exc_info=True)
            raise
        logger.debug('%s: exiting', proc_name)

    def send_request(self, req_size, resp_details, soc, l4_proto, target_ip_port, proc_name):
        num_chunks = int(math.ceil(float(req_size)/MAXPDULEN))
        bytes_to_send = req_size
        for i in range(num_chunks):
            chunk_size = min(bytes_to_send, MAXPDULEN)
            chunk = DATA65507[:bytes_to_send]
            if i == 0: #first
                if bytes_to_send <= len(resp_details):
                    chunk = resp_details
                else:
                    chunk = ''.join([resp_details,
                            DATA65507[len(resp_details):bytes_to_send]])
            if i == num_chunks - 1: #last
                    if chunk[-1] == '*':
                        chunk = ''.join([chunk, PDU_DELIM])
                    else:
                        chunk = ''.join([chunk[:-LENPDUDELIM], PDU_DELIM])
            if i != 0 and i != num_chunks - 1: #middle only
                chunk = MDATA65507
            logger.debug('%s: sending chunk: %s....', 
                        proc_name, chunk[:50])
            
            if l4_proto == 'tcp':
                soc.sendall(chunk)
            else:
                soc.sendto(chunk, (target_ip_port))
            bytes_to_send -= MAXPDULEN
        logger.debug('%s: request sent ', proc_name)
               
    def recv_response(self, soc, num_resps, proc_name ):
        #get and process reply - need for tcp not udp
        r_pdu = ''
        for q in range(min(1, num_resps)): 
            #had to replace num_resps with 1 bcos multi replies are received combined as 1
            logger.debug('%s: waiting for reply %d', proc_name, q)
            while True:
                data = soc.recv(4096)
                r_pdu = r_pdu + data
                if r_pdu[0-len(PDU_DELIM):] == PDU_DELIM:
                    r_pdu = r_pdu[:0-len(PDU_DELIM)]
                    break 
            logger.debug('%s: reply %d of length-%d received',
                        proc_name, q, len(r_pdu))

    def send_close_conn(self, soc, l4_proto, target_ip_port, proc_name):
    
        chunk = ''.join(['0*', CONV_DELIM, PDU_DELIM])            
        if l4_proto == 'tcp':
            soc.sendall(chunk)
        else:
            soc.sendto(chunk, (target_ip_port))
        soc.close()
