'''
@author: Juan C. Espinoza
'''

import time
import json
import numpy
import paho.mqtt.client as mqtt
import zmq
import datetime
from zmq.utils.monitor import recv_monitor_message
from functools import wraps
from threading import Thread
from multiprocessing import Process

from schainpy.model.proc.jroproc_base import Operation, ProcessingUnit
from schainpy.model.data.jrodata import JROData
from schainpy.utils import log

MAXNUMX = 100
MAXNUMY = 100

class PrettyFloat(float):
    def __repr__(self):
        return '%.2f' % self

def roundFloats(obj):
    if isinstance(obj, list):
        return map(roundFloats, obj)
    elif isinstance(obj, float):
        return round(obj, 2)

def decimate(z, MAXNUMY):
    dy = int(len(z[0])/MAXNUMY) + 1

    return z[::, ::dy]

class throttle(object):
    '''
    Decorator that prevents a function from being called more than once every
    time period.
    To create a function that cannot be called more than once a minute, but
    will sleep until it can be called:
    @throttle(minutes=1)
    def foo():
      pass

    for i in range(10):
      foo()
      print "This function has run %s times." % i
    '''

    def __init__(self, seconds=0, minutes=0, hours=0):
        self.throttle_period = datetime.timedelta(
            seconds=seconds, minutes=minutes, hours=hours
        )

        self.time_of_last_call = datetime.datetime.min

    def __call__(self, fn):
        @wraps(fn)
        def wrapper(*args, **kwargs):
            coerce = kwargs.pop('coerce', None)
            if coerce:
                self.time_of_last_call = datetime.datetime.now()
                return fn(*args, **kwargs)
            else:
                now = datetime.datetime.now()
                time_since_last_call = now - self.time_of_last_call
                time_left = self.throttle_period - time_since_last_call

                if time_left > datetime.timedelta(seconds=0):
                    return

            self.time_of_last_call = datetime.datetime.now()
            return fn(*args, **kwargs)

        return wrapper

class Data(object):
    '''
    Object to hold data to be plotted
    '''

    def __init__(self, plottypes, throttle_value):
        self.plottypes = plottypes
        self.throttle = throttle_value
        self.ended = False
        self.localtime = False
        self.__times = []
        self.__heights = []

    def __str__(self):
        dum = ['{}{}'.format(key, self.shape(key)) for key in self.data]
        return 'Data[{}][{}]'.format(';'.join(dum), len(self.__times))

    def __len__(self):
        return len(self.__times)

    def __getitem__(self, key):
        if key not in self.data:
            raise KeyError(log.error('Missing key: {}'.format(key)))

        if 'spc' in key:
            ret = self.data[key]
        else:
            ret = numpy.array([self.data[key][x] for x in self.times])
            if ret.ndim > 1:
                ret = numpy.swapaxes(ret, 0, 1)
        return ret

    def __contains__(self, key):
        return key in self.data

    def setup(self):
        '''
        Configure object
        '''
        
        self.ended = False
        self.data = {}
        self.__times = []
        self.__heights = []
        self.__all_heights = set()
        for plot in self.plottypes:
            if 'snr' in plot:
                plot = 'snr'
            self.data[plot] = {}

    def shape(self, key):
        '''
        Get the shape of the one-element data for the given key
        '''
        
        if len(self.data[key]):
            if 'spc' in key:
                return self.data[key].shape
            return self.data[key][self.__times[0]].shape
        return (0,)

    def update(self, dataOut):
        '''
        Update data object with new dataOut
        '''

        tm = dataOut.utctime
        if tm in self.__times:
            return

        self.parameters = getattr(dataOut, 'parameters', [])
        self.pairs = dataOut.pairsList
        self.channels = dataOut.channelList
        self.interval = dataOut.getTimeInterval()
        self.localtime = dataOut.useLocalTime
        if 'spc' in self.plottypes or 'cspc' in self.plottypes:
            self.xrange = (dataOut.getFreqRange(1)/1000., dataOut.getAcfRange(1), dataOut.getVelRange(1))
        self.__heights.append(dataOut.heightList)
        self.__all_heights.update(dataOut.heightList)
        self.__times.append(tm)

        for plot in self.plottypes:
            if plot == 'spc':
                z = dataOut.data_spc/dataOut.normFactor
                self.data[plot] = 10*numpy.log10(z)
            if plot == 'cspc':
                self.data[plot] = dataOut.data_cspc
            if plot == 'noise':
                self.data[plot][tm] = 10*numpy.log10(dataOut.getNoise()/dataOut.normFactor)
            if plot == 'rti':
                self.data[plot][tm] = dataOut.getPower()
            if plot == 'snr_db':
                self.data['snr'][tm] = dataOut.data_SNR
            if plot == 'snr':
                self.data[plot][tm] = 10*numpy.log10(dataOut.data_SNR)
            if plot == 'dop':
                self.data[plot][tm] = 10*numpy.log10(dataOut.data_DOP)
            if plot == 'mean':
                self.data[plot][tm] = dataOut.data_MEAN
            if plot == 'std':
                self.data[plot][tm] = dataOut.data_STD
            if plot == 'coh':
                self.data[plot][tm] = dataOut.getCoherence()
            if plot == 'phase':
                self.data[plot][tm] = dataOut.getCoherence(phase=True)
            if plot == 'output':
                self.data[plot][tm] = dataOut.data_output
            if plot == 'param':
                self.data[plot][tm] = dataOut.data_param

    def normalize_heights(self):
        '''
        Ensure same-dimension of the data for different heighList
        '''

        H = numpy.array(list(self.__all_heights))
        H.sort()
        for key in self.data:            
            shape = self.shape(key)[:-1] + H.shape
            for tm, obj in self.data[key].items():
                h = self.__heights[self.__times.index(tm)]
                if H.size == h.size:
                    continue
                index = numpy.where(numpy.in1d(H, h))[0]
                dummy = numpy.zeros(shape) + numpy.nan                
                if len(shape) == 2:
                    dummy[:, index] = obj
                else:
                    dummy[index] = obj
                self.data[key][tm] = dummy
        
        self.__heights = [H for tm in self.__times]

    def jsonify(self, decimate=False):
        '''
        Convert data to json
        '''

        ret = {}
        tm = self.times[-1]

        for key, value in self.data:
            if key in ('spc', 'cspc'):
                ret[key] = roundFloats(self.data[key].to_list())
            else:
                ret[key] = roundFloats(self.data[key][tm].to_list())

        ret['timestamp'] = tm
        ret['interval'] = self.interval

    @property
    def times(self):
        '''
        Return the list of times of the current data
        '''

        ret = numpy.array(self.__times)
        ret.sort()
        return ret

    @property
    def heights(self):
        '''
        Return the list of heights of the current data
        '''

        return numpy.array(self.__heights[-1])

class PublishData(Operation):
    '''
    Operation to send data over zmq.
    '''

    __attrs__ = ['host', 'port', 'delay', 'zeromq', 'mqtt', 'verbose']

    def __init__(self, **kwargs):
        """Inicio."""
        Operation.__init__(self, **kwargs)
        self.isConfig = False
        self.client = None
        self.zeromq = None
        self.mqtt = None

    def on_disconnect(self, client, userdata, rc):
        if rc != 0:
            log.warning('Unexpected disconnection.')
        self.connect()

    def connect(self):
        log.warning('trying to connect')
        try:
            self.client.connect(
                host=self.host,
                port=self.port,
                keepalive=60*10,
                bind_address='')
            self.client.loop_start()
            # self.client.publish(
            #     self.topic + 'SETUP',
            #     json.dumps(setup),
            #     retain=True
            #     )
        except:
            log.error('MQTT Conection error.')
            self.client = False

    def setup(self, port=1883, username=None, password=None, clientId="user", zeromq=1, verbose=True, **kwargs):
        self.counter = 0
        self.topic = kwargs.get('topic', 'schain')
        self.delay = kwargs.get('delay', 0)
        self.plottype = kwargs.get('plottype', 'spectra')
        self.host = kwargs.get('host', "10.10.10.82")
        self.port = kwargs.get('port', 3000)
        self.clientId = clientId
        self.cnt = 0
        self.zeromq = zeromq
        self.mqtt = kwargs.get('plottype', 0)
        self.client = None
        self.verbose = verbose        
        setup = []
        if mqtt is 1:
            self.client = mqtt.Client(
                client_id=self.clientId + self.topic + 'SCHAIN',
                clean_session=True)
            self.client.on_disconnect = self.on_disconnect
            self.connect()
            for plot in self.plottype:
                setup.append({
                    'plot': plot,
                    'topic': self.topic + plot,
                    'title': getattr(self, plot + '_' + 'title', False),
                    'xlabel': getattr(self, plot + '_' + 'xlabel', False),
                    'ylabel': getattr(self, plot + '_' + 'ylabel', False),
                    'xrange': getattr(self, plot + '_' + 'xrange', False),
                    'yrange': getattr(self, plot + '_' + 'yrange', False),
                    'zrange': getattr(self, plot + '_' + 'zrange', False),
                })
        if zeromq is 1:
            context = zmq.Context()
            self.zmq_socket = context.socket(zmq.PUSH)
            server = kwargs.get('server', 'zmq.pipe')

            if 'tcp://' in server:
                address = server
            else:
                address = 'ipc:///tmp/%s' % server

            self.zmq_socket.connect(address)
            time.sleep(1)


    def publish_data(self):
        self.dataOut.finished = False
        if self.mqtt is 1:
            yData = self.dataOut.heightList[:2].tolist()
            if self.plottype == 'spectra':
                data = getattr(self.dataOut, 'data_spc')
                z = data/self.dataOut.normFactor
                zdB = 10*numpy.log10(z)
                xlen, ylen = zdB[0].shape
                dx = int(xlen/MAXNUMX) + 1
                dy = int(ylen/MAXNUMY) + 1
                Z = [0 for i in self.dataOut.channelList]
                for i in self.dataOut.channelList:
                    Z[i] = zdB[i][::dx, ::dy].tolist()
                payload = {
                    'timestamp': self.dataOut.utctime,
                    'data': roundFloats(Z),
                    'channels': ['Ch %s' % ch for ch in self.dataOut.channelList],
                    'interval': self.dataOut.getTimeInterval(),
                    'type': self.plottype,
                    'yData': yData
                }

            elif self.plottype in ('rti', 'power'):
                data = getattr(self.dataOut, 'data_spc')
                z = data/self.dataOut.normFactor
                avg = numpy.average(z, axis=1)
                avgdB = 10*numpy.log10(avg)
                xlen, ylen = z[0].shape
                dy = numpy.floor(ylen/self.__MAXNUMY) + 1
                AVG = [0 for i in self.dataOut.channelList]
                for i in self.dataOut.channelList:
                    AVG[i] = avgdB[i][::dy].tolist()
                payload = {
                    'timestamp': self.dataOut.utctime,
                    'data': roundFloats(AVG),
                    'channels': ['Ch %s' % ch for ch in self.dataOut.channelList],
                    'interval': self.dataOut.getTimeInterval(),
                    'type': self.plottype,
                    'yData': yData
                }
            elif self.plottype == 'noise':
                noise = self.dataOut.getNoise()/self.dataOut.normFactor
                noisedB = 10*numpy.log10(noise)
                payload = {
                    'timestamp': self.dataOut.utctime,
                    'data': roundFloats(noisedB.reshape(-1, 1).tolist()),
                    'channels': ['Ch %s' % ch for ch in self.dataOut.channelList],
                    'interval': self.dataOut.getTimeInterval(),
                    'type': self.plottype,
                    'yData': yData
                }
            elif self.plottype == 'snr':
                data = getattr(self.dataOut, 'data_SNR')
                avgdB = 10*numpy.log10(data)

                ylen = data[0].size
                dy = numpy.floor(ylen/self.__MAXNUMY) + 1
                AVG = [0 for i in self.dataOut.channelList]
                for i in self.dataOut.channelList:
                    AVG[i] = avgdB[i][::dy].tolist()
                payload = {
                    'timestamp': self.dataOut.utctime,
                    'data': roundFloats(AVG),
                    'channels': ['Ch %s' % ch for ch in self.dataOut.channelList],
                    'type': self.plottype,
                    'yData': yData
                }
            else:
                print "Tipo de grafico invalido"
                payload = {
                    'data': 'None',
                    'timestamp': 'None',
                    'type': None
                }

            self.client.publish(self.topic + self.plottype, json.dumps(payload), qos=0)

        if self.zeromq is 1:
            if self.verbose:
                log.log(
                    'Sending {} - {}'.format(self.dataOut.type, self.dataOut.datatime),
                    self.name
                )
            self.zmq_socket.send_pyobj(self.dataOut)

    def run(self, dataOut, **kwargs):
        self.dataOut = dataOut
        if not self.isConfig:
            self.setup(**kwargs)
            self.isConfig = True

        self.publish_data()
        time.sleep(self.delay)

    def close(self):
        if self.zeromq is 1:
            self.dataOut.finished = True
            self.zmq_socket.send_pyobj(self.dataOut)
            time.sleep(0.1)
            self.zmq_socket.close()
        if self.client:
            self.client.loop_stop()
            self.client.disconnect()


class ReceiverData(ProcessingUnit):

    __attrs__ = ['server']

    def __init__(self, **kwargs):

        ProcessingUnit.__init__(self, **kwargs)

        self.isConfig = False
        server = kwargs.get('server', 'zmq.pipe')
        if 'tcp://' in server:
            address = server
        else:
            address = 'ipc:///tmp/%s' % server

        self.address = address
        self.dataOut = JROData()

    def setup(self):

        self.context = zmq.Context()
        self.receiver = self.context.socket(zmq.PULL)
        self.receiver.bind(self.address)
        time.sleep(0.5)
        log.success('ReceiverData from {}'.format(self.address))


    def run(self):

        if not self.isConfig:
            self.setup()
            self.isConfig = True

        self.dataOut = self.receiver.recv_pyobj()
        log.log('{} - {}'.format(self.dataOut.type,
                                 self.dataOut.datatime.ctime(),),
                'Receiving')


class PlotterReceiver(ProcessingUnit, Process):

    throttle_value = 5
    __attrs__ = ['server', 'plottypes', 'realtime', 'localtime', 'throttle']

    def __init__(self, **kwargs):

        ProcessingUnit.__init__(self, **kwargs)
        Process.__init__(self)
        self.mp = False
        self.isConfig = False
        self.isWebConfig = False
        self.connections = 0
        server = kwargs.get('server', 'zmq.pipe')
        plot_server = kwargs.get('plot_server', 'zmq.web')
        if 'tcp://' in server:
            address = server
        else:
            address = 'ipc:///tmp/%s' % server

        if 'tcp://' in plot_server:
            plot_address = plot_server
        else:
            plot_address = 'ipc:///tmp/%s' % plot_server

        self.address = address
        self.plot_address = plot_address
        self.plottypes = [s.strip() for s in kwargs.get('plottypes', 'rti').split(',')]
        self.realtime = kwargs.get('realtime', False)
        self.localtime = kwargs.get('localtime', True)
        self.throttle_value = kwargs.get('throttle', 5)
        self.sendData = self.initThrottle(self.throttle_value)
        self.dates = []
        self.setup()

    def setup(self):

        self.data = Data(self.plottypes, self.throttle_value)
        self.isConfig = True        

    def event_monitor(self, monitor):

        events = {}

        for name in dir(zmq):
            if name.startswith('EVENT_'):
                value = getattr(zmq, name)
                events[value] = name

        while monitor.poll():
            evt = recv_monitor_message(monitor)
            if evt['event'] == 32:
                self.connections += 1
            if evt['event'] == 512:
                pass

            evt.update({'description': events[evt['event']]})

            if evt['event'] == zmq.EVENT_MONITOR_STOPPED:
                break
        monitor.close()
        print('event monitor thread done!')

    def initThrottle(self, throttle_value):

        @throttle(seconds=throttle_value)
        def sendDataThrottled(fn_sender, data):
            fn_sender(data)

        return sendDataThrottled

    def send(self, data):
        log.success('Sending {}'.format(data), self.name)
        self.sender.send_pyobj(data)

    def run(self):

        log.success(
            'Starting from {}'.format(self.address),
            self.name
        )

        self.context = zmq.Context()
        self.receiver = self.context.socket(zmq.PULL)
        self.receiver.bind(self.address)
        monitor = self.receiver.get_monitor_socket()
        self.sender = self.context.socket(zmq.PUB)
        if self.realtime:
            self.sender_web = self.context.socket(zmq.PUB)
            self.sender_web.connect(self.plot_address)
            time.sleep(1)

        if 'server' in self.kwargs:
            self.sender.bind("ipc:///tmp/{}.plots".format(self.kwargs['server']))
        else:
            self.sender.bind("ipc:///tmp/zmq.plots")

        time.sleep(2)

        t = Thread(target=self.event_monitor, args=(monitor,))
        t.start()

        while True:
            dataOut = self.receiver.recv_pyobj()
            if not dataOut.flagNoData:                
                if dataOut.type == 'Parameters':                    
                    tm = dataOut.utctimeInit
                else:
                    tm = dataOut.utctime
                if dataOut.useLocalTime:
                    if not self.localtime:
                        tm += time.timezone
                    dt = datetime.datetime.fromtimestamp(tm).date()
                else:
                    if self.localtime:
                        tm -= time.timezone
                    dt = datetime.datetime.utcfromtimestamp(tm).date()
                coerce = False
                if dt not in self.dates:
                    if self.data:
                        self.data.ended = True
                        self.send(self.data)
                        coerce = True
                    self.data.setup()
                    self.dates.append(dt)

                self.data.update(dataOut)
            
            if dataOut.finished is True:
                self.connections -= 1
                if self.connections == 0 and dt in self.dates:
                    self.data.ended = True                    
                    self.send(self.data)
                    self.data.setup()
            else:
                if self.realtime:
                    self.send(self.data)
                    # self.sender_web.send_string(self.data.jsonify())
                else:                    
                    self.sendData(self.send, self.data, coerce=coerce)
                    coerce = False

        return

    def sendToWeb(self):

        if not self.isWebConfig:
            context = zmq.Context()
            sender_web_config = context.socket(zmq.PUB)
            if 'tcp://' in self.plot_address:
                dum, address, port = self.plot_address.split(':')
                conf_address = '{}:{}:{}'.format(dum, address, int(port)+1)
            else:
                conf_address = self.plot_address + '.config'
            sender_web_config.bind(conf_address)
            time.sleep(1)
            for kwargs in self.operationKwargs.values():
                if 'plot' in kwargs:
                    log.success('[Sending] Config data to web for {}'.format(kwargs['code'].upper()))
                    sender_web_config.send_string(json.dumps(kwargs))
            self.isWebConfig = True
