import sys
import usb
import time
import struct
import array
import math

class DeviceDescriptor(object) :
    def __init__(self, vendor_id, product_id, interface_id) :
        self.vendor_id = vendor_id
        self.product_id = product_id
        self.interface_id = interface_id

    def getDevice(self) :
        """
        Return the device corresponding to the device descriptor if it is
        available on a USB bus.  Otherwise, return None.  Note that the
        returned device has yet to be claimed or opened.
        """
        buses = usb.busses()
        for bus in buses :
            for device in bus.devices :
                if device.idVendor == self.vendor_id :
                    if device.idProduct == self.product_id :
                        return device
        return None

class PlugUSBDevice(object) :
    
    PLUG_VENDOR_ID = 0x03EB
    PLUG_PRODUCT_ID = 0x6124
    PLUG_INTERFACE_ID = 0
    PLUG_BULK_IN_EP = 2
    PLUG_BULK_OUT_EP = 1

    def __init__(self) :
        self.device_descriptor = DeviceDescriptor(PlugUSBDevice.PLUG_VENDOR_ID,
                                                  PlugUSBDevice.PLUG_PRODUCT_ID,
                                                  PlugUSBDevice.PLUG_INTERFACE_ID)
        self.device = self.device_descriptor.getDevice()
        self.handle = None

    def open(self) :
        self.device = self.device_descriptor.getDevice()
        self.handle = self.device.open()
        if sys.platform == 'darwin' :
            # XXX : For some reason, Mac OS X doesn't set the
            # configuration automatically like Linux does.
            self.handle.setConfiguration(1)
        self.handle.claimInterface(self.device_descriptor.interface_id)

    def close(self) :
        self.handle.releaseInterface()

    def getDataPacket(self, bytesToGet) :
        """
        Assume bytesToGet is two bytes wide.
        """
        self.handle.bulkWrite(PlugUSBDevice.PLUG_BULK_OUT_EP,
                              chr(0)+chr(bytesToGet & 0xFF)+chr(bytesToGet>>8),
                              200)
        # XXX : Gah! Returns a tuple of longs.  Why doesn't it return
        # a string?
        return self.handle.bulkRead(PlugUSBDevice.PLUG_BULK_IN_EP,
                                    bytesToGet,
                                    200)

class PlugSensors(object) :

    def __init__(self,
                 bytesPerDataPacket=64,
                 bitsPerSample=10,
                 channelsPerScan=8,
                 scansPerDataPacket=6) :
        
        # Number of bytes the Plug returns in a sensors data packet.
        self.bytesPerDataPacket = bytesPerDataPacket

        # Resolution at which ADC samples inputs.
        self.bitsPerSample = bitsPerSample

        # Number of ADC channels sampled in a single pass.
        self.channelsPerScan = channelsPerScan

        # Number of times all ADC channels are sampled per packet.
        self.scansPerDataPacket = scansPerDataPacket

        # Needed to convert from signed longs to string.
        self.__unpack_format__ = 'B'*self.bytesPerDataPacket

        # Needed to convert from string to unsigned bytes.
        self.__pack_format__ = 'b'*self.bytesPerDataPacket

        # Information not generated by ADC.
        self.numADCBytes = self.bitsPerSample*self.channelsPerScan*self.scansPerDataPacket/8
        self.skippedSamplesIndex = self.bitsPerSample*self.channelsPerScan*self.scansPerDataPacket/8
        self.bytesUsedIndex = self.skippedSamplesIndex + 1
        self.vibrationIndex = self.skippedSamplesIndex + 2

        assert self.bytesPerDataPacket*8 >= self.bitsPerSample*self.channelsPerScan*self.scansPerDataPacket

        self.plug = PlugUSBDevice()
        self.plug.open()

    def logSamplesToFile(self, filename) :
        f = file(filename, 'w')
        print "To stop data collection, type <CTRL> + c."
        startTime = time.time()
        f.write("# Plug data log.  Data format is:\n#\n")
        f.write("# current time in seconds\n")
        f.write("# scans recorded between the last time and the current time\n")
        f.write("# scans skipped between the last time and the current time\n")
        f.write("# light samples\n")
        f.write("# sound samples\n")
        f.write("# vibration samples\n")
        f.write("# voltage samples\n")
        f.write("# current1 samples\n")
        f.write("# current2 samples\n")
        f.write("# current3 samples\n")
        f.write("# current4 samples\n")
        f.write("# expansion samples\n")
        while True :
            try :
                data = self.getSamples()
                format_string = ("%d\t"*(data['scans_recorded']-1) + "%d\n")
                f.write("\n%f\n" % data['time'])
                f.write("%d\n" % data['scans_recorded'])
                f.write("%d\n" % data['scans_skipped'])
                f.write(format_string % tuple(data['light']))
                f.write(format_string % tuple(data['sound']))
                f.write(format_string % tuple(data['vibration']))
                f.write(format_string % tuple(data['voltage']))
                f.write(format_string % tuple(data['current1']))
                f.write(format_string % tuple(data['current2']))
                f.write(format_string % tuple(data['current3']))
                f.write(format_string % tuple(data['current4']))
                f.write(format_string % tuple(data['expansion']))
            except KeyboardInterrupt :
                print "You have successfully logged data."
                f.close()
                return

    def parseSamplesFromFile(self, filename) :
        # Store sensor data in arrays of unsigned shorts (minimum of
        # two bytes).
        sensors = [array.array('H'),
                   array.array('H'),
                   array.array('H'),
                   array.array('H'),
                   array.array('H'),
                   array.array('H'),
                   array.array('H'),
                   array.array('H'),
                   array.array('H')]
        # Store time data in an array of floats (minimum of 8 bytes).
        seconds = array.array('d')

        # Open the log file.
        f = file(filename, 'r')

        # Skip over the initial comments.
        line = f.readline()
        while line != '' and line[0] == "#" :
            line = f.readline()
            
        line = f.readline() # Skip blank line.
        
        while line != '' :
            time_recorded = float(line)
            scans_recorded = int(f.readline())
            scans_skipped = int(f.readline())
            for i in range(scans_recorded) :
                seconds.append(time_recorded)
            for i in range(len(sensors)) :
                for x in f.readline().split() :
                    sensors[i].append(int(x))
            f.readline() # Skip blank line.
            line = f.readline()

        return {"seconds" : seconds,
                "light" : sensors[0],
                "sound" : sensors[1],
                "vibration" : sensors[2],
                "voltage" : sensors[3],
                "current1" : sensors[4],
                "current2" : sensors[5],
                "current3" : sensors[6],
                "current4" : sensors[7],
                "expansion" : sensors[8]}
        

    def getSamples(self) :
        samples = {}
        # Request and wait for a packet.
        packet = self.plug.getDataPacket(self.bytesPerDataPacket)
        # Convert data from signed to unsigned.
        data = struct.unpack(self.__unpack_format__, struct.pack(self.__pack_format__, *packet))
        # Get all metadata.
        samples['time'] = time.time()
        samples['scans_skipped'] = data[self.skippedSamplesIndex]
        samples['scans_recorded'] = data[self.bytesUsedIndex]*8/(self.bitsPerSample*self.channelsPerScan)
        # Unpack the two bytes of vibratab data.
        samples['vibration'] = self.unpackBits(data[self.vibrationIndex:self.vibrationIndex+2], 1)[:samples['scans_recorded']]
        # Unpack ADC data.
        data = self.unpackBits(data[:self.numADCBytes], 10)
        # XXX : This next portion is hard coded.
        samples['light'] = [data[i] for i in range(0,samples['scans_recorded']*self.channelsPerScan,self.channelsPerScan)]
        samples['sound'] = [data[i] for i in range(1,samples['scans_recorded']*self.channelsPerScan,self.channelsPerScan)]
        samples['voltage'] = [data[i] for i in range(2,samples['scans_recorded']*self.channelsPerScan,self.channelsPerScan)]
        samples['expansion'] = [data[i] for i in range(3,samples['scans_recorded']*self.channelsPerScan,self.channelsPerScan)]
        samples['current4'] = [data[i] for i in range(4,samples['scans_recorded']*self.channelsPerScan,self.channelsPerScan)]
        samples['current2'] = [data[i] for i in range(5,samples['scans_recorded']*self.channelsPerScan,self.channelsPerScan)]
        samples['current3'] = [data[i] for i in range(6,samples['scans_recorded']*self.channelsPerScan,self.channelsPerScan)]
        samples['current1'] = [data[i] for i in range(7,samples['scans_recorded']*self.channelsPerScan,self.channelsPerScan)]
        return samples

    def unpackBits(self, s, bits) :
        """
        Unpack a sequence s of bytes into a list of numbers each of bits length.
        Assumes numbers are stored in little endian format and represent
        unsigned integers.
        """
        if (len(s)*8 < bits) :
            return []
        bitMask = int('1'*bits, 2)
        numberOfValues = int(8*len(s)/bits)
        currentByte = 0
        currentBit = 0
        values = []
        while len(values) != numberOfValues :
            bitsToGet = bits
            if currentBit + bitsToGet < 8 :
                value = (s[currentByte] >> currentBit) & bitMask
                currentBit += bitsToGet
                bitsToGet = 0
            else :
                value = (s[currentByte] >> currentBit)
                bitsToGet -= (8 - currentBit)
                currentBit = 0
                currentByte += 1
            for i in range(int(bitsToGet/8)) :
                value |= (s[currentByte] << (bits - bitsToGet))
                currentByte += 1
                bitsToGet -= 8
            if bitsToGet :
                value |= ((s[currentByte] & int('1'*bitsToGet, 2)) << (bits - bitsToGet))
                currentBit = bitsToGet
            values.append(value)
        return values

def main(argv=None) :
    if argv is None :
        script_name = sys.argv[0]
        argv = sys.argv[1:]
    if len(argv) == 1 :
        option = argv[0]
        filename = "PlugSensors.dat"
    elif len(argv) == 2 :
        option = argv[0]
        filename = argv[1]
    else :
        option = None
        filename = None

    if option == 'log' :
        s = PlugSensors()
        s.logSamplesToFile(filename)
        s.plug.close()
    elif option == 'parse' :
        s = PlugSensors()
        return s.parseSamplesFromFile(filename)
    elif option == 'sensors' :
        s = PlugSensors()
        for i in range(10) :
            print s.getSamples()
    else :
        print "Usage: python -i %s OPTION [FILENAME]" % script_name
        print "  where OPTION can be 'parse' or 'log'"

if __name__ == "__main__" :
    data = main()
    if type(data) == dict :
        print "Parsed data is now availabe in the 'data' dictionary."
        print "You can access the arrays of data looking at \"data['light']\", for example."
        print "Use 'data.keys()' to list other options."

