#!/usr/bin/env python
# :noTabs=true:


# (c) Copyright Rosetta Commons Member Institutions.
# (c) This file is part of the Rosetta software suite and is made available under license.
# (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
# (c) For more information, see http://www.rosettacommons.org. Questions about this can be
# (c) addressed to University of Washington UW TechTransfer, email: license@u.washington.edu.

## @file   PyMOLPyRosettaServer.py
## @brief
## @author Sergey Lyskov, Johns Hopkins University


import time, socket, gzip, bz2, threading
from cStringIO import StringIO
from array import array


import pymol

#from NetLink import PR_UDPServer
# ^^^ this does not work on CygWin PyMOL so we just add our code here...

class PR_UDPServer:
    def __init__(self, udp_ip = '127.0.0.1', udp_port=65000):
        self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.socket.bind( (udp_ip, udp_port) )
        self.buf = {}
        self.last_cleanup_time = time.time()



    def listen(self):
        data, addr = self.socket.recvfrom(1024*64)  # 64k buffer
        #print 'Got Message from %s' % str(addr), len(data)

        packet_id = data[:16+2]
        counts = array('H', data[18:22])  # should be 2 short integer

        #print 'Packet count info:', counts
        if counts[1] == 1:  # only one messgage in pack...
            return array('c', data[22:]) #bz2.decompress(data[22:])
        else:
            if packet_id not in self.buf: self.buf[packet_id] = [0., {}]

            c, d = self.buf[packet_id]
            d[counts[0]] = data[22:]
            self.buf[packet_id][0] = time.time()

            # now, lets check if we can find all the pieces of this message...
            if len(d) == counts[1]:  # yes, they all here...
                #print 'Asseblinbg messge from %s pieces...' % counts[1]
                m = array('c')
                for i in range(counts[1]):
                    m.extend( d[i] )
                del self.buf[packet_id]

                #print 'Messge is:', len(m), m, d
                #print 'leftover buffer len:', len(self.buf)
                return m #bz2.decompress(m)

            else:
                # there is no ready-to-return packets... however lets check if buffer can be cleaned up...
                # anything older then 10second should be discarded...
                current_time = time.time()
                if current_time - self.last_cleanup_time > 2.: # cleaning up every 2s

                    for k in self.buf.keys():
                        if current_time - self.buf[k][0] > 10.0 :
                            print 'Buffer clean up: %s' % repr(k)
                            del self.buf[k]

                return None



class PR_PyMOLServer:

    def processPacket(self, msg):
        ''' Format description:
            bytes 0:8   - packet type
            byte  8     - flags
            byte  9     - len(name)
            bytes 10:... - name
            bytes ...:... - mesage its self
        '''
        ptype = msg[:8].tostring()
        flags = ord(msg[8])
        name_len = ord(msg[9])
        #print 'name_len:', name_len
        name = msg[10: 10+name_len].tostring()
        data = msg[10+name_len:]  #.tostring()
        #print 'Decoy type: %s, name: %s' % (ptype, name)

        if ptype == 'PDB     ':   # string is just a pdb file, no compression
            #print 'Getting PDB packet "%s"...' % name
            #print 'Processing pdb...'
            #pymol.cmd.delete(name)
            pymol.cmd.read_pdbstr(data, name, 1)
            #pymol.cmd.show("cartoon", name)
            #pymol.cmd.forward()
            #pymol.cmd.refresh()

        elif ptype == 'PDB.gzip':   # string is just a pdb file, gzip compression
            #print 'Getting PDB.gzip packet "%s"...' % name
            pymol.cmd.read_pdbstr(gzip.GzipFile('', 'r', 0, StringIO(data)).read(), name, flags ^ 1)
            if flags: pymol.cmd.frame( pymol.cmd.count_frames() )


        elif ptype == 'PDB.bz2 ':   # string is just a pdb file, bz2 compression
            pymol.cmd.read_pdbstr(bz2.decompress(data), name, flags ^ 1)
            if flags: pymol.cmd.frame( pymol.cmd.count_frames() )


        elif ptype == 'Ener.bz2':   # energies info, bz2 compression
            #print 'Getting Ene2.bz2 packet...'
            e_type_len = ord(data[0])
            e_type = data[1: 1 + e_type_len].tostring()
            s = bz2.decompress( data[1+e_type_len:] )
            #print 'Compression stats: %s-->%s' % (len(data[1+e_type_len:]), len(s) )
            try:
                for i in range(0, len(s), 7):
                    pymol.cmd.color('R%s' % s[i+5:i+7], '%s and chain %s and resi %s' % (name, s[i], s[i+1:i+5]))

            except pymol.parsing.QuietException:
                print 'Coloring failed... did you forget to send geometry first?'


        elif ptype == 'Ene.gzip':   # energies info, gzip compression
            #print 'Getting Ene.gzip packet "%s"...' % name
            e_type_len = ord(data[0])
            e_type = data[1: 1 + e_type_len].tostring()
            s = gzip.GzipFile('', 'r', 0, StringIO(data[1+e_type_len:])).read()
            print 'etype=%s  msg=%s' % (e_type, s)
            #print 'Compression stats: %s-->%s' % (len(data[1+e_type_len:]), len(s) )
            try:
                for i in range(0, len(s), 7):
                    pymol.cmd.color('R%s' % s[i+5:i+7], '%s and chain %s and resi %s' % (name, s[i], s[i+1:i+5]))

            except pymol.parsing.QuietException:
                print 'Coloring failed... did you forget to send geometry first?'


        else:
            print 'Unknow packet type: %s, - ignoring...' % ptype


# Creating our own color spectrum
for i in range(256): pymol.cmd.set_color('R%02x' % i, [1., 1.-i/255., 1.-i/255])

def main(ip, port):
    print 'PyMOL <---> PyRosetta link started!'

    udp_serv = PR_UDPServer(ip, port)
    PS = PR_PyMOLServer()
    while True:
        r = udp_serv.listen()
        if r:
            #print len(r)
            PS.processPacket(r)

    s.close()


def start_rosetta_server(ip='', port=65000):
   if not ip:
       ip = socket.gethostbyname(socket.gethostname())
       if ip == '127.0.0.1':
           print "Unable to automatically determine your IP address.  Please specify it manually. e.g. start_rosetta_server 192.168.0.1"
           return

   thread = threading.Thread(target=main, args=[ip,port])
   thread.setDaemon(1)
   thread.start()

pymol.cmd.extend('start_rosetta_server', start_rosetta_server)

start_rosetta_server('127.0.0.1', 65000)

#### To use PyMOLPyRosettaServer over a network, uncomment the line below and set the first argument to your IP address
#start_rosetta_server('192.168.0.1', 65000)
