'''client server packet generator'''
import os
import sys
import socket
import time
import json
import math
import logging
import argparse
from multiprocessing import Process, Queue
from .gen_netutils import get_ipv4_addresses, fake_sleep_until
from .model_predict import ModelPredictor as pred

logger = logging.getLogger(__name__)

TCPTIMEOUT = 300
MAXPDULEN = 65507
DATA65507 = "abcdefgh" * int(MAXPDULEN//8) + "abc" # data length 65507
DATA131072 = "abcdefgh" * int(131072//8) # data length 2^17
if sys.version_info[0] < 3: 
    MDATA65507 = memoryview(bytearray('a' * (MAXPDULEN) + 'abc'))
    MDATA131072 = memoryview(bytearray('a' * (131072))) # data_len = 2^17
else:
    MDATA65507 = memoryview(bytearray(bytes('a' * (MAXPDULEN) + 'abc', 'utf-8')))
    MDATA131072 = memoryview(bytearray(bytes('a' * (131072), 'utf-8'))) # data_len = 2^17
CONV_DELIM = 'y'#must not be any char sequence in DATA131072
PDU_DELIM = 'x'#must not be any char sequence in DATA131072
LENCONVPDUDELIM = len(CONV_DELIM) + len(PDU_DELIM)
LENPDUDELIM = len(PDU_DELIM)

def byteify(input_str): 
    if isinstance(input_str, dict):     
        return dict([(byteify(key), byteify(value)) for key,
         value in list(input_str.items())])
    elif isinstance(input_str, list):   
        return [byteify(element) for element in input_str]   
    elif isinstance(input_str, str):
        return input_str.encode('utf-8')
    else:
        return input_str    

class GeneratorGeneric():
    def __init__(self, allappmodels, allusessions, client, outdir):
        self.allappmodels = allappmodels
        self.allusessions = allusessions
        self.client = client
        self.outdir = outdir
        self.global_start = None

    def generate(self):
        #TODO: Verify that opts.ip_addr is on this computer.
        self.global_start = time.time()
        if not self.client:
            #start all server procs
            logger.info("calling start servers")
            self.startSrvProcs()
        else:
            #start all client procs
            logger.info("calling start clients")
            self.startCliProcs()
        logger.debug('completed packer generation process')

    def startSrvProcs(self):
        logger.debug("inside startSrvProcs")
        local_ifs_ips = get_ipv4_addresses()
        local_ips = set([])
        for i, ips in local_ifs_ips.items():
            for ip in ips:
                local_ips.add(ip)
        srvports = self.getallusess_srvports(local_ips)
        srvport_listeners = self.make_listeners(srvports)
        while True:
            #TODO: check if last user sess?
            try:
                pass
            except KeyboardInterrupt:
                for idx, proc in srvport_listeners:
                    #proc.join(Timeout = 0)
                    #if proc.is_alive():
                    logger.debug('killing listener process %s, pid %d', 
                                idx, proc.pid)
                    proc.terminate()
                break

    def getallusess_srvports(self, local_ips):
        '''return dict{ip&port,tcp/udp} for this local machine'''
        logger.debug('getting list of all server ports from usessions')
        srvports = {}

        for usess in self.allusessions:
            srvip = usess['srv_ip']
            if srvip not in local_ips:
                continue
            app = usess['app']
            appmodel = self.allappmodels[app]
            for conn_type, cmodel in appmodel['conn_models'].items():
                port = cmodel['port_number']
                l4_proto = cmodel['l4_proto']
                srvports[str(srvip)+'_'+str(port)+'_'+str(l4_proto)] =  (srvip, port, l4_proto, app)
        return srvports

    def make_listeners(self, srv_ports):
        '''creates listener processes for all port/l4proto pairs'''
        #TODO: make start listener in per usses?
        listener_procs = {}
        for skey, sval in srv_ports.iteritems():
            ip, port, proto, app = sval
            #app_name = str(app_name)
            #app_model = self.allappmodels[app_name]
            logger.debug('starting %s listener process for app %s on %s:%s',
                 proto, app, str(ip),str(port))
            if proto == 'udp':
                srv_proc = Process(target=self.genUDPSrv, args=(ip, port, app))
            elif proto == 'tcp':
                srv_proc = Process(target=self.genTCPSrv, args=(ip, port, app))
            srv_proc.daemon = False #True
            srv_proc.start()
            logger.debug('started %s listener process for app %s on %s:%s',
                 proto, app, str(ip),str(port))
            listener_procs[proto+'_'+ip+':'+str(port)+'_'+app] = srv_proc
        return listener_procs

    def startCliProcs(self):
        '''creates a process to handle each user session'''
        usess_procs = []
        local_ifs_ips = get_ipv4_addresses()
        local_ips = set([])
        for i, ips in local_ifs_ips.items():
            for ip in ips:
                local_ips.add(ip)
        #logger.debug('found %d user sessions, will start a process for each',
        #             len(self.allusessions))
        for i, usess in enumerate(self.allusessions):
            cliip = usess['cli_ip']
            if cliip not in local_ips:
                continue
            logger.debug("starting client process for usersess %d app: %s", i, usess['app'])
            start = usess['start']
            logger.debug("sleeping for  %d secs before usess %d", self.global_start + start - time.time(), i)
            fake_sleep_until(self.global_start + start)
            usess_proc = Process(target=self.start_cli_usess,args=(usess, i))
            usess_proc.daemon = False #True
            usess_proc.start()
            usess_procs.append(usess_proc)
        
        for i, proci in enumerate(usess_procs):
            p_id = proci.pid   
            proci.join()
            logger.debug('startCliProcs: usess process %d is joined', i)
        '''while True:
            try:
                pass
                
            except KeyboardInterrupt:
                for  proc in usess_procs:
                    #proc.join(Timeout = 0)
                    #if proc.is_alive():
                    logger.debug('killing user session proc pid %d', proc.pid)
                    proc.terminate()
                break '''    

    def start_cli_usess(self, usess, usess_idx):
        '''calculates all the connections details for a user session then 
        creates calls a method that creates each connection and send packets'''
        logger.debug('usess %d: started usersession %d', usess_idx, usess_idx)
        srvip = usess['srv_ip']
        cliip = usess['cli_ip']
        app = usess['app']
        appmodel = self.allappmodels[app]
        usess_start_time = usess['start']
        usess_gstart_time = self.global_start + usess['start']
        #fake_sleep_until(usess_start_time) 
        conns_details = [] #array of arrays of conns... 
        for conntype, cmodel in appmodel['conn_models'].items():
            numconnpools = int(round(pred().predictVal(cmodel['num_conn_pools'])))
            overlap = cmodel['has_overlapping_conn_pools']
            c_arr = []
            last_start = -1
            starttime = -1
            logger.debug('usess %d: will have %d conn pools of conn type %s',
                         usess_idx, numconnpools, str(conntype))
            for i in range(numconnpools):
                if i == 0:
                    c_arr = []
                    connstart = pred().predictVal(cmodel['first_conn_start_time'])
                    starttime = usess_gstart_time + connstart
                    waitinterval = 0
                elif overlap:
                    c_arr = []
                    startinterval = pred().predictVal(cmodel['start_start_intervals'])
                    starttime =  starttime + startinterval
                    waitinterval = 0
                else:
                    starttime = -1
                    waitinterval = pred().predictVal(cmodel['end_start_intervals'])
                portnumber = cmodel['port_number']
                l4proto = cmodel['l4_proto']
                poolsize = int(round(pred().predictVal(cmodel['size_conn_pools'])))
                connx = {'srv_ip':srvip, 'cli_ip':cliip, 'app':app, 
                        'isoverlap':overlap, 'starttime':starttime,
                        'waitinterval':waitinterval, #waitinterval(for_no_overlap array), 
                        'srvportnumber':portnumber, 'l4_proto':l4proto,
                        'cpoolsize':poolsize, 'conntype':conntype}
                c_arr.append(connx)
                if overlap or i == numconnpools - 1:
                    conns_details.append(c_arr) 
        #print conns_details
        conns_details.sort(key=lambda k:k[0]['starttime'])
        logger.debug('usess %d: %d sequences of connection pools found',
                     usess_idx, len(conns_details))
        for i, conn_seq in enumerate(conns_details):
            logger.debug('usess %d: conn seq %d : %s', usess_idx, i, str(conn_seq))
        cli_procs = []
        for i, seq_connpools in enumerate(conns_details):
            logger.debug('usess %d: conn-sequence %d has: %s',
                     usess_idx, i, str(seq_connpools))
            cli_proc = Process(target=self.create_seq_conpools, 
                               args=(seq_connpools, appmodel, usess_idx, i))
            #sleep till next conn start time??

            if seq_connpools[0]['starttime'] != -1:
                fake_sleep_until(seq_connpools[0]['starttime'])
            #fake_sleep_until(time.time() + seq_connpools[0]['waitinterval'])

            cli_proc.daemon = False #True       
            cli_proc.start()
            cli_procs.append(cli_proc)
        for proc in cli_procs:
            p_id = proc.pid   
            proc.join()
            logger.debug('usess %d: process %d is joined',
                     usess_idx, p_id)
        '''try:         
            pass     
        except KeyboardInterrupt:    
            for proc in cli_procs:   
                #proc.join(Timeout = 0)
                #if proc.is_alive(): 
                logger.info('killing process %d' %(proc.pid))        
                proc.terminate()     
            #break'''

    def create_seq_conpools(self, seq_connpools, appmodel, usess_idx, cid):
        try:
            pid = os.getpid()
            logger.debug('usess %d: started process %d for new connection seq',
                        usess_idx, pid)
            for connpool in seq_connpools:
                conntype = connpool['conntype']
                cmodel = appmodel['conn_models'][conntype]

                if connpool['starttime'] != -1:
                    fake_sleep_until(connpool['starttime'])
                fake_sleep_until(time.time() + connpool['waitinterval'])

                num_conns_in_pool = connpool['cpoolsize']
                sock_list = []
                logger.debug('usess %d: proc %d: starting new conn_pool type: %s with conn_pool size: %d',
                            usess_idx, pid, str(conntype), num_conns_in_pool)
                target_ip_port = (connpool['srv_ip'], connpool['srvportnumber'])
                for i in range(num_conns_in_pool):
                    logger.debug('usess %d: proc %d: initiating conn', usess_idx, pid)
                    if connpool['l4_proto'] == 'tcp':
                        logger.debug('usess %d: proc %d: starting tcp %d',
                                    usess_idx, pid, i)
                        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                        sock.connect(target_ip_port)
                    else:
                        logger.debug('usess %d: proc %d: starting udp %d', 
                                    usess_idx, pid, i)
                        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
                    sock_list.append(sock)
                logger.debug('usess %d: proc %d: %d connections started for pool',
                            usess_idx, pid, num_conns_in_pool)
                t_conv_start = connpool['starttime']
                #TODO make allowance for multi sections of conns here
                #each section, pre post etc will have its num_req, reqsize,
                #resp_size, srv_time...
                num_reqs = int(round(pred().predictVal(cmodel['num_requests'])))
                logger.debug('usess %d: proc %d: connections will have %d request pdus',
                            usess_idx, pid, num_reqs)
                npdu = 1
                while npdu <= num_reqs:
                    pdus_sent = npdu -1
                    t_pdu_start = time.time()
                    '''this sends multi resps per pdu'''
                    logger.debug('usess %d: proc %d: sending pdu %d' %(
                                usess_idx, pid, npdu))
                    ipt = pred().predictVal(cmodel['inter_request_times'])
                    req_size = int(round(pred().predictVal(cmodel['request_sizes'])))
                    num_resps = int(round(pred().predictVal(cmodel['num_response_per_req'])))
                    resp_details = str(num_resps)
                    for k in range(num_resps):
                        resp_size = abs(int(round(pred().predictVal(cmodel['response_sizes']))))
                        srv_time = pred().predictVal(cmodel['server_processing_times'])
                        srv_time = float("{0:.4f}".format(srv_time))
                        resp_details = '*'.join([resp_details,str(resp_size),
                                                str(srv_time)])
                        logger.debug('usess %d: proc %d: resp_size = %d, resp_time = %f', 
                                    usess_idx, pid, resp_size, srv_time)
                    resp_details = ''.join([resp_details, '*'])
                    logger.debug('usess %d: proc %d: resp_details = %s', 
                                usess_idx, pid, resp_details)
                    i = 0
                    #TODO NB problem if len(resp_details) > MAXPDULEN 
                    num_chunks = int(math.ceil(float(req_size)/MAXPDULEN))
                    bytes_to_send = req_size
                    for i in range(num_chunks):
                        chunk = DATA65507[:bytes_to_send]
                        if i == 0: #first
                            if bytes_to_send <= len(resp_details):
                                chunk = resp_details
                            else:
                                chunk = ''.join([resp_details,
                                        DATA65507[len(resp_details):bytes_to_send]])
                        if i == num_chunks - 1: #last
                            if npdu == num_reqs:
                                #logger.debug('process_%d: sending end of conversation delimeter '\
                                #            'in last pdu %d' %(pid, npdu))
                                if chunk[-1] == '*':
                                    chunk = ''.join([chunk, CONV_DELIM, PDU_DELIM])
                                else:
                                    chunk = ''.join([chunk[:-LENCONVPDUDELIM], CONV_DELIM, PDU_DELIM])
                            else:
                                if chunk[-1] == '*':    
                                    chunk = ''.join([chunk, PDU_DELIM])
                                else:
                                    chunk = ''.join([chunk[:-LENPDUDELIM], PDU_DELIM])
                        if i != 0 and i != num_chunks - 1: #middle only
                            chunk = MDATA131072
                        logger.debug('usess %d: proc %d: sending chunk: %s .... %s', 
                                    usess_idx, pid, chunk[:50], chunk[-20:])
                        for sock in sock_list:
                            if connpool['l4_proto'] == 'tcp':
                                sock.sendall(chunk)
                            else:
                                sock.sendto(chunk, (target_ip_port))
                        bytes_to_send -= MAXPDULEN
                    
                    #logger.info("process_%d: pdu #%d sent" %(pid, npdu))
                    npdu += 1
                    #get and process reply - need for tcp not udp
                    r_pdu = ''
                    #TODO: avoid sequential recv 4 multiconns without multiproc?
                    for sock in sock_list:
                        for q in range(min(1, num_resps)): #TODO urgent?  had to replace num_resp with 1 multi replies received as 1
                            logger.debug('usess %d: proc %d: waiting for reply %d',
                                        usess_idx, pid, q)
                            while True:
                                data = sock.recv(4096)
                                r_pdu = r_pdu + data
                                if r_pdu[0-len(PDU_DELIM):] == PDU_DELIM:
                                    r_pdu = r_pdu[:0-len(PDU_DELIM)]
                                    break 
                            
                            logger.debug('usess %d: proc %d: reply %d of length-%d received',
                                        usess_idx, pid, q, len(r_pdu))
                    logger.debug('usess %d: proc %d: fake sleeping until %f, (%f) before next pdu, ',
                                usess_idx, pid, t_pdu_start + ipt, ipt)
                    fake_sleep_until(t_pdu_start + ipt)
                #TODO: sleep until keep_alive expires
                close_wait = pred().predictVal(cmodel['close_wait'])
                logger.debug('usess %d: proc %d: conn pool completed close waiting %d secs, before next cpool in conn sequence',
                    usess_idx, pid, close_wait)
                time.sleep (min(5, close_wait))
                for sock in sock_list:
                    sock.close()
            
            logger.debug('usess %d: proc %d: completed all conn pools in seq',
                usess_idx, pid)
        except:
            logger.exception("usess %d: proc %d: something bad happened", usess_idx, pid)

    def genUDPSrv(self, ip, port, app):
        procs = {}
        ques = {}
        logger.debug('opening socket for UDP listener')
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        sock.bind((ip, int(port)))
        logger.debug('UDP listener socket opened at %s:%s', ip, port)
        len2 = len(CONV_DELIM) + len(PDU_DELIM)
        i = 0
        while True:
            i += 1
            logger.debug('listening - waiting for new UDP data %d on %s:%s' %(i,
                        ip, port))
            data, address = sock.recvfrom(65536)
            logger.debug('new data of len:%d received from ' %(len(data)) +
                        str(address) )
            if address not in procs:
                logger.debug('recognized new UDP conn from ' + str(address))
                ques[address] = Queue()
                procs[address] = Process(target=self.startUDPSrvConv, 
                                        args=(ques[address], address, sock,
                                            app, i-1))
                procs[address].daemon = True
                procs[address].start()
            proc = procs[address]
            que = ques[address]
            logger.debug('putting data received in queue for subprocess to handle')
            que.put(data)
            if data[-len2:-len(PDU_DELIM)] == CONV_DELIM  or len(data) == 0:
                del procs[address]
                del ques[address]
            try:
                pass
            except KeyboardInterrupt:
                for proc in procs.values():
                    logger.debug('killing process %d' %(proc.pid))
                    proc.terminate()
                break

    def startUDPSrvConv(self, que, address, sock, app, cli_sess_idx):
        try:
            pid = os.getpid()
            t_conv_start = time.time()
            logger.debug('process_%d: starting new UDP server conversation for' %(pid) +
                        str(address))
            end_conv_now = False
            req_size = 0
            i = 0
            while True:
                req = ''
                logger.debug('process_%d: waiting to get request - %d' %(pid, i))
                while True:
                    data = que.get()
                    logger.debug('process_%d: received portion of request-%d, length ' \
                    '%d' %(pid, i, len(data)) )
                    req += data
                    if req[0-len(PDU_DELIM):] == PDU_DELIM:
                        req_size = len(req)
                        req = req[:0-len(PDU_DELIM)]
                        break
                logger.debug('process_%d: total length of request-%d received is '\
                            '%d req[:20]: %s' %(pid, i, req_size, req[:20]))
                
                #TODO: det num reply Pdus, sizes, proc times from req ... use in below.
                arr = req.split('*')
                num_resp = arr[0]
                #print 'num_resp', num_resp
                #sys.stdout.flush()
                num_resp = int(float(num_resp))
                logger.debug('process_%d: %d response PDUs will be sent' %(pid, num_resp))
                if req[0-len(CONV_DELIM):] == CONV_DELIM:
                    end_conv_now = True
                t_pdu_start = time.time()

                res_arr = arr[1:-1]
                j = 0
                while i < num_resp:
                    pdu_size = res_arr[2*j]
                    pdu_time = res_arr[2j + 1]

                    logger.debug('process_%d: response pdu %d of size %d will be sent' %(
                                pid, j, pdu_size))
                    logger.debug('process_%d: fake sleeping for response time %d' %(pid,
                                pdu_time))
                    fake_sleep_until(t_pdu_start + pdu_time)
                    bytes_to_send = pdu_size
                    while bytes_to_send > len(MDATA65507):
                        sock.sendto(MDATA65507,address)
                        bytes_to_send -= len(MDATA65507)
                    last_d = DATA65507[:bytes_to_send]
                    last_d = last_d[ : 0 - len(PDU_DELIM)] + PDU_DELIM
                    t_pdu_start = time.time()
                    sock.sendto(last_d, address)
                    logger.debug('process_%d: sent response pdu %d of size %d to client.' %(
                                pid, j, pdu_size))
                    j += 1
                if end_conv_now:
                    logger.debug('process_%d: received end of conversation delimiter'\
                    '... ending conv.' %(pid))
                    break
                i += 1 
            logger.debug('process_%d: ending process'  %(pid))
        except:
            logger.exception("process %d: something bad happened", pid)

    def genTCPSrv(self, ip, port, app):
        logger.debug('opening socket for TCP listener')
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.bind((ip, int(port)))
        sock.listen(128)
        logger.debug('TCP listener socket opened at %s:%s' %(ip, port))
        i = 0
        while True:
            i += 1
            logger.debug('listening - waiting for a new TCP connection %d on '\
                        '%s:%s' %(i, ip, port))
            conn, address = sock.accept()
            logger.debug('new connection received from ' + str(address))
            proc = Process(target=self.startTCPSrvConv, args=(address, conn, app,
                                                            i-1))
            proc.daemon = True
            proc.start()

    def startTCPSrvConv(self, address, conn, app, cli_sess_idx):
        try:
            pid = os.getpid()
            t_conv_start = time.time()
            logger.debug('process_%d: starting new TCP server conversation for' %(pid) +
                        str(address))
            i = 0
            num_req_recv = 0
            end_this_conv = False
            while True: 
                req = ''
                logger.debug('process_%d: waiting to get request - %d' %(pid, i))
                while True:
                    data = conn.recv(4096)
                    #logger.debug('process_%d: received portion of request-%d, length '\
                    #            '%d' %(pid, i, len(data)) )
                    req += data
                    if req[0-len(PDU_DELIM):] == PDU_DELIM:
                        req = req[:0-len(PDU_DELIM)]
                        break
                logger.debug('process_%d: request %d received is %s',
                            pid, i, req[:50])
                arr = req.split('*')
                logger.debug('process_%d: details of request %d: %s',
                            pid, i, str(arr[:-1]))
                num_resp = int(float(arr[0]))
                logger.debug('process_%d: %d response PDUs will be sent' %(pid, num_resp))
                
                req_size = len(req)
                if req[0-len(CONV_DELIM):] == CONV_DELIM or len(req) == 0:
                    req = req[:0-len(CONV_DELIM)]
                    logger.debug('process_%d: converstsation closed by client' %(pid))
                    end_this_conv = True
                t_pdu_start = time.time()
                
                
                res_arr = arr[1:-1]
                j = 0
                while j < num_resp:
                    pdu_size = int(res_arr[2*j])
                    pdu_time = float(res_arr[2*j + 1])

                    logger.debug('process_%d: size of response pdu %d is %d' %(pid, j, pdu_size) )
                    logger.debug('process_%d: fake sleeping for response time %f' %(pid,
                                pdu_time))
                    fake_sleep_until(t_pdu_start + pdu_time)
                    bytes_to_send = pdu_size
                    while  bytes_to_send > len(MDATA131072):
                        conn.sendall(MDATA131072)
                        bytes_to_send -= len(MDATA131072)
                    last_d = DATA131072[:bytes_to_send]
                    last_d = last_d[:-len(PDU_DELIM)]+PDU_DELIM
                    conn.sendall(last_d)
                    logger.debug('process_%d: sent response pdu %d of size %d to '\
                                'client.' %(pid, j, pdu_size))
                    j += 1
                if end_this_conv:
                    break
                i += 1
        except:
            logger.exception("process %d: something bad happened", pid)
