# (c) Copyright Rosetta Commons Member Institutions.
# (c) This file is part of the Rosetta software suite and is made available under license.
# (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
# (c) For more information, see http://www.rosettacommons.org. Questions about this can be
# (c) addressed to University of Washington UW TechTransfer, email: license@u.washington.edu.

import os, math

from pymol.wizard import Wizard

from pymol import cmd


import prompt_manager


import rosetta_io
import pymol_io
import molComplex
import hbond

class Rosetta_design_wizard(Wizard):
    def __init__(self,  ros_io, controller, _self=cmd):
        try:
            _, version = cmd.get_version()
        except:
            version = cmd.get_version()

        if version < 1:
            Wizard.__init__( self )
        elif _self.__name__ == 'pymol.cmd':
            Wizard.__init__( self, _self )
        else:
            Wizard.__init__( self, _self.cmd )

        self.do_debug = False
        self.name = "rosetta design wizard"

        self.rosetta_io = ros_io
        self.controller = controller


        self.advanced_features = False

        self.show_weights_menu = False



        #mo quite api calls
        #move to pymol_io
        cmd.feedback("disable","api","everything")
        cmd.feedback("disable","cmd", "everything")


        self.prompt_manager = prompt_manager.Prompt_manager()


        self.setup_panel()
        #self.setup_pymol_extend()

        self.key_mode_fun = None

        self.objects = {}

        self.possible_scenes = ['before','design','score','hbond']
        self.scene = 'design'
        self.current_score_term   = "total"
        self.current_weights       = "standard"
        self.current_weights_patch = "score12"

        # so you can use the keyboard to do design

        cmd.extend( 'natro',   lambda sele, aas='': self.set_task( 'NATRO',   aas, sele ) )
        cmd.extend( 'nataa',   lambda sele, aas='': self.set_task( 'NATAA',   aas, sele ) )
        cmd.extend( 'allaawc', lambda sele, aas='': self.set_task( 'ALLAAwc', aas, sele ) )
        cmd.extend( 'allaaxc', lambda sele, aas='': self.set_task( 'ALLAAxc', aas, sele ) )
        cmd.extend( 'polar',   lambda sele, aas='': self.set_task( 'POLAR',   aas, sele ) )
        cmd.extend( 'apolar',  lambda sele, aas='': self.set_task( 'APOLAR',  aas, sele ) )
        cmd.extend( 'pikaa',   lambda sele, aas='': self.set_task( 'PIKAA',   aas, sele ) )
        cmd.extend( 'notaa',   lambda sele, aas='': self.set_task( 'NOTAA',   aas, sele ) )

        self.do_scene()
        #cmd.refresh_wizard()


    def debug( self, message ):
        if self.do_debug:
            print message

    def onError( self, message ):
        print message

    def setup_panel( self ):
        """Wizards have a menu panel built in next to the main view panel."""

        #smm = []
        #smm.append( [ 1, 'Design','cmd.get_wizard().set_scene("'+'design'+'")'])
#         smm.append( [ 1, 'Score','cmd.get_wizard().set_scene("'+'score'+'")'])
#         smm.append( [ 1, 'H-Bonds','cmd.get_wizard().set_scene("'+'hbond'+'")'])
#         self.menu['scene']=smm

        smm = []
        smm.append([2,'Allowed Rotamers',''])
        for tname in rosetta_io.task_names:
            smm.append([ 1, rosetta_io.task_color_prefix[tname]+\
                             rosetta_io.task_description[tname],
                         'cmd.get_wizard().set_task(name="'+tname+'")'])
        self.menu['task']=smm

        smm = []
        ex_sample_level_menu = {}
        smm.append([2, 'Chi Angle Freedom', ''])
        for i in range( len( rosetta_io.ex_flags) ):
            #ex_sample_level_menu[i] = []
            #ex_sample_level_menu[i].append([2, "Extra Chi Sample Level", '' ])
            #
            #for j in range( len( rosetta_io.ex_sample_levels ) ):
            #    ex_sample_level_menu[i].append([1, rosetta_io.ex_sample_levels[ j ],
            #                                 "cmd.get_wizard().set_ex_flag(ex_flag='"+str(i)+"', level='"+str(j)+"')"])
            #self.menu[ "ex_sample_level_menu"+str(i)] = ex_sample_level_menu[ i ]
            smm.append([1, rosetta_io.ex_flags[ i ], "cmd.get_wizard().set_ex_flag(ex_flag='"+rosetta_io.ex_flags[ i ]+"', level='1')"])
        self.menu['ex_flags'] = smm

#         smm = []
#         smm.append( [ 2, 'File', ''])
#         smm.append( [ 1, 'Extra Options', 'cmd.get_wizard().setup()'])
#         self.menu['file'] = smm

    def get_panel(self):
        """This function is the wizard panel on the side."""
        # [ n, title, command ]
        # where n==1 -> dark button
        #       n==2 -> light button
        #       n==3 -> another menu

        panel = [[ 1, 'Rosetta Design', '' ]]
        if self.objects:
            obj_names = self.objects.keys()
            obj_names.sort()


            smm = []
            smm.append( [ 2, "Select View Mode", '' ] )
            smm.append( [ 1, "Design", 'cmd.get_wizard().set_scene("' + 'design' + '")' ] )
            smm.append( [ 1, 'Score',  'cmd.get_wizard().set_scene("' + 'score'  + '")' ] )
            self.menu[ 'scene' ] = smm

            panel.append([ 3, 'View (%s)' % self.scene, 'scene' ] )
            if self.scene == 'design':
                panel.append([ 3, 'Residue Level Task', 'task' ] )
                if self.advanced_features:
                    panel.append( [3, "Adjust Chi Angle Freedom", 'ex_flags'] )

            elif self.scene == 'score':
                smm = []
                smm.append( [ 2, "Select Score Term", '' ] )
                for term in self.rosetta_io.score_terms:
                    smm.append( [ 1, term, 'cmd.get_wizard().color_by_score_term( score_term="' + term + '" )' ] )
                self.menu[ 'view_score' ] = smm
                panel.append([ 3, "Color By Score Term (%s)" % self.current_score_term, 'view_score' ] )

            smm = []
            smm.append( [ 2, "Select Object", '' ] )
            for obj_name in obj_names:
                smm.append( [ 1, obj_name, 'cmd.get_wizard().print_detailed_score( obj_name="' + obj_name + '" )' ] )
            self.menu[ 'detailed_score' ] = smm
            panel.append([ 3, 'Detailed Score', 'detailed_score' ] )

            if self.advanced_features:
                #allow the weights to be adjusted

                weights, weights_patches = self.rosetta_io.get_all_weights_files()
                smm = []
                smm.append( [ 2, "Select Score Function Weights", '' ] )
                for weights_tag in weights:
                    smm.append( [ 1, weights_tag, 'cmd.get_wizard().set_weights( weights="' + weights_tag + '" ) ' ] )
                self.menu[ 'weights' ] = smm

                smm = []
                smm.append( [ 2, "Select Weights Patch", '' ] )
                for wp in weights_patches:
                    smm.append( [ 1, wp, 'cmd.get_wizard().set_weights( patch="' + wp + '" ) ' ] )
                self.menu[ 'weights_patches' ] = smm

                smm = []

                panel.append( [ 3, 'Weights (%s)' % self.current_weights,       'weights' ] )
                panel.append( [ 3, 'Patch (%s)'   % self.current_weights_patch, 'weights_patches' ] )



            smm = []
            smm.append( [ 2, "Select Object", '' ] )
            for obj_name in obj_names:
                smm.append( [ 1, obj_name, 'cmd.get_wizard().run( obj_name="%s" )' % obj_name ] )
            self.menu[ 'run_task' ] = smm
            panel.append([ 3, 'Run Task', 'run_task' ] )

        panel.append([ 2, 'Done', 'cmd.set_wizard()' ])
        return panel

    def do_key(self, k,x,y,m):
        """return 1 to signal that the key event handling has been taken care
of, ie pushing the key does not do anything else"""
        if self.key_mode_fun: return self.key_mode_fun(k)

    ##########################
    #    mo MENU HANDLING    #
    ##########################

    def set_scene(self, scene):
        """ called by the menu: specify how to color the objects """
        if scene not in self.possible_scenes:
            print "Rosetta Design Wizard: WARNING:  no such view ", scene

        for _, obj in self.objects.iteritems():
            self.debug( "%s has %s residues." % ( obj.name, len( obj.res_ids ) ) )


        if scene != self.scene:
            self.scene = scene
            self.update_view()
            cmd.refresh_wizard()

    def update_view(self):
        """ refresh the coloring of the objects """
        self.debug("In update_view")
        for _, obj in self.objects.iteritems():
            self.debug( "%s has %s residues." % ( obj.name, len( obj.res_ids ) ) )

        if self.scene not in self.possible_scenes:
            print "Rosetta Design WARNING:  no such view ", self.scene

        if self.scene == 'design':
            for _,obj in self.objects.iteritems():
                self.rosetta_io.get_packable( obj )
                self.rosetta_io.get_num_res_types( obj )

                pymol_io.color_by_task( obj )


        elif self.scene == 'score':
            self.color_by_score_term()

        elif self.scene == 'hbond':
            for _,obj in self.objects.iteritems():
                hbond.do_hbond(obj)
                self.debug( "done with hbond rutine")
            #try to color residues by score:
            #if there is scoring information already defined, use it
            #elif rosetta is working, score molecular complex
            #elif ask user to locate scoring information
            #view in design mode.

    def color_by_score_term( self, score_term=None ):
        """ called from update_view """

        if score_term:
            self.current_score_term = score_term
        else:
            score_term = self.current_score_term

        print "Rosetta Design Wizard: color by score term %s" % score_term


        for obj_name, obj in self.objects.iteritems():
            print "coloring object %s by score term %s..." % ( obj_name, score_term )
            data = {}
            min_data =  10000
            max_data = -10000
            sum_data = 0.0
            print "score_term -> %s" % score_term

            for res_id in obj.res_ids:
                data[res_id] = -100000
            for res_id in obj.rosetta_res_ids:

                score = self.rosetta_io.get_weighted_score( obj, res_id, score_term )
                if score is not None:
                    data[ res_id ] = score

                if data[ res_id ] < min_data: min_data = data[ res_id ]
                if data[ res_id ] > max_data: max_data = data[ res_id ]
                sum_data += data[ res_id ]
            average = sum_data / len( obj.res_ids )
            print "min = %s, max = %s, average = %s" % (min_data, max_data, average )

            # this is to make clashes--when they occur--not dominate the coloring
            for key, value in data.iteritems(): data[ key ] = math.log( 1 +  data[key] - min_data )

            pymol_io.color_by_data_simple( obj, data )


    def set_task( self, name, aas=None, sele=None ):
        """ Called by the menu: make the selected residues have the given task"""
        self.do_rosetta_design_update()
        if not sele:
            sele = pymol_io.get_active_selection()
            if not sele:
                print "Rosetta Design Wizard: you must make a selection first!"
                return None

        name = rosetta_io.parse_task_name( name )

        #this is to get the aas when name == 'PIKAA' or 'NOTAA'
        if name == 'PIKAA' or name == 'NOTAA':
            if not aas:
                self.debug( "in PIKAA or NOTAA in set_task, name = %s, aas = %s" % ( name, aas ) )
                if 'PIKAA'==name:

                    self.prompt_manager.set_prefix('ask',"Enter amino acids to allow:")
                elif 'NOTAA'==name:
                    self.prompt_manager.set_prefix('ask',"Enter amino acids to disallow:")

                self.prompt_manager.set_text( 'ask', "" )
                self.prompt_manager.set_postfix( 'ask', "_" )
                self.prompt_manager.show( 'ask' )


                on_return = lambda : \
                    self.set_task( name, self.prompt_manager.get_text( 'ask' ) )
                self.key_mode_fun = lambda k: \
                    self.prompt_manager.key_prompt_fun('ask', k, on_return)

                # wait for call back
                return 1
            else:
                # assume that the aas were gotten by prompt manager asking for them, so hide the prompt manager
                self.prompt_manager.hide('ask')
                self.key_mode_fun = None

                command = name + " " + aas
        else:
            command = name

        for obj_name, obj in self.objects.iteritems():
            for res_id in pymol_io.residues( obj, sele ):
                self.debug( "setting for obj %s res_id %s command %s" %(obj.name, res_id, command) )
                self.rosetta_io.clean_residue_task( obj, res_id )
                self.rosetta_io.set_residue_task( obj, res_id, command)

        for obj_name, _ in pymol_io.sele_by_obj( sele ).iteritems():
            self.controller.object_modified( obj_name )


#        for obj_name, obj in self.objects.iteritems():
#            update_obj = False
#            for res_id in pymol_io.residues( obj, sele ):
#
#                if obj.task_name[ res_id ] != name:
#                    obj.task_name[ res_id ] = name
#                    update_obj = True
#
#                if ( obj.task_name[ res_id ] in rosetta_io.tasks_with_aas and \
#                         obj.task_aas[ res_id ] != rosetta_io.parse_task_aas( aas ) ):
#                    update_obj = True
#                    obj.task_aas[ res_id ] = rosetta_io.parse_task_aas( aas )
#
#            if update_obj:
#                self.controller.object_modified( obj_name )
#
        #TODO mo does this really need to be refreshed?
        pymol_io.disable_active_selection()

        self.do_rosetta_design_update()


    def set_ex_flag( self, ex_flag, level, sele=None):
        self.debug( "set_ex_flag( ex_flag=%s, level=%s, sele=%s)" % (ex_flag, level, sele) )
        self.do_rosetta_design_update()
        if not sele:
            sele = pymol_io.get_active_selection()
            if not sele:
                print "Rosetta Design Wizard: you must make a selection first!"
                return None

        command = "%s LEVEL %s" % (ex_flag, level )
        for obj_name, res_ids in pymol_io.sele_by_obj( sele ).iteritems():
            obj = self.objects[ obj_name ]
            for res_id in res_ids:
                self.rosetta_io.set_residue_task( obj, res_id, command )
            self.controller.object_modified( obj_name )



    def print_detailed_score( self, obj_name ):
        print "Score for %s" % obj_name
        print self.rosetta_io.get_detailed_score( self.objects[ obj_name ] )
        print ""

    def set_weights( self, weights=None, patch=None ):
        if weights: self.current_weights       = weights
        if patch:   self.current_weights_patch = patch

        self.rosetta_io.set_weights( self.current_weights, self.current_weights_patch )

        self.show_simple_score()
        self.update_view()

    def get_view( self, obj_name ):
        return self.rosetta_io.get_task( self.objects[ obj_name ] )

    def parse_resfile( self, obj_name, resfile ):
        self.objects[ obj_name ].parse_resfile( resfile )
        self.controller.object_modified( obj_name )
        self.do_rosetta_design_update()

    def get_object_names( self ):
        return self.objects.keys()

    def get_native_object_names( self ):
        native_object_names = {}
        for _, obj in self.objects.iteritems():
            native_object_names[ obj.native_name ] = 1

        return native_object_names.keys()

    def get_decoy_object_names( self, native_name ):
        #note this does not include the native_name!

        decoy_names = []
        for obj_name, obj in self.objects.iteritems():
            if obj.native_name == native_name:
                decoy_names.append( obj_name )

        return decoy_names

    def get_main_object_name( self ):
        return pymol_io.main_object()


    def get_next_name( self, obj ):
        """get derivitive names like LP17, LP17-001, LP17-002, etc..."""
        base = obj.name.split("-")[0]
        count = 0
        for obj_name in self.objects:
            if base == obj_name.split( "-" )[0] and len( obj_name.split( "-" ) ) > 1:
                count = max( int( obj_name.split( "-" )[1] ), count )
        count+=1

        # Format count to a 3 or 5 digit string, as appropriate
        num = str( count )
        if count < 1000:
            digits_desired = 3
        else:
            digits_desired = 5
        num = (digits_desired - len(num)) * '0' + num

        return base + "-" + num


    def run( self, obj_name ):

        if obj_name not in self.objects:
            print "Rosetta Design Wizard: Error: unknown object %s " % obj_name
        obj = self.objects[ obj_name ]

        new_name = self.get_next_name( obj )
        print "Running task on %s -> %s" % ( obj_name, new_name )

        self.rosetta_io.run_task( obj )
        new_pdb = self.rosetta_io.get_pdb( obj )

        pymol_io.load_pdb_into( new_name, new_pdb )
        self.add_object( obj_name=new_name, native_name=obj.native_name )

        #cmd.disable( obj.name )


        new_obj = self.objects[ new_name ]

        if obj.native_name:  new_obj.native_name = obj.native_name
        else:                new_obj.native_name = obj.name

        if obj.native_score: new_obj.native_score = obj.native_score
        else:                new_obj.native_score = obj.score

        self.rosetta_io.copy_task( obj, new_obj )

        self.controller.object_new( new_obj.native_name )

        self.show_simple_score( new_obj )
        self.do_rosetta_design_update()

    ##########################
    #mo PYMOL EVENT HANDLING #
    ##########################
    def do_rosetta_design_update( self ):
#        self.controller.do_global_update()

        self.debug( "do_rosetta_design_update" )
        self.update_view()
        self.do_scene()
        #cmd.refresh_wizard()

    def get_prompt( self ):
        """PyMol call back function for the message in the upper right of the
main viewer window"""
        return self.prompt_manager.prompt

    def get_event_mask(self):
        """Specifies which events to listen for"""
        return   Wizard.event_mask_pick\
               + Wizard.event_mask_select\
               + Wizard.event_mask_key\
               + Wizard.event_mask_special\
               + Wizard.event_mask_scene\
               + Wizard.event_mask_state\
               + Wizard.event_mask_frame\
               + Wizard.event_mask_dirty

    def do_select( self, sele ):
        """This function is triggered when a selection is made."""
        self.active_sele = sele
        cmd.enable( sele )
        return None

    def add_object( self, obj_name, native_name=None ):
        obj = molComplex.MolComplex( obj_name, native_name )

        self.rosetta_io.add_object(  obj )
        self.rosetta_io.set_rosetta_res_ids(obj)

        self.objects[ obj_name ] = obj


        obj.score = self.rosetta_io.get_score( obj )
        self.set_task( 'NATRO', sele=obj.name )

        self.controller.object_modified( obj_name )

    def delete_object( self, obj_name ):
        del( self.objects[ obj_name ] )
        self.controller.object_deleted( obj_name )
        #self.rosetta_io.delete_object( self.objects[ obj_name ] ) # not yet implemented!


    def do_scene( self ):
        """This is triggered when a scene changes, like: adding a new object
        enabling/disabling a selection.
        """

        objs = cmd.get_names('objects')

        #mo syncronize pymol and rosetta
        for obj_name in objs:
            if obj_name not in self.objects.keys():
                self.add_object( obj_name )

        for obj_name in self.objects.keys():
            if obj_name not in objs:
                self.delete_object( obj_name )

        self.is_selection = (cmd.get_names("selections") != [])

        main_obj = pymol_io.main_object()
        if main_obj:
            self.show_simple_score( self.objects[ main_obj ] )

        return None

    def show_simple_score( self, obj=None ):
        if not obj:
            for _, obj in self.objects.iteritems(): self.show_simple_score( obj )
            return

        obj.score = self.rosetta_io.get_score( obj )
        if obj.native_score:
            self.debug( "for obj %s obj.score -> %s, obj.native_score->%s" %( obj.name, obj.score, obj.native_score ) )
            if  obj.score - obj.native_score > 0: sign = "+"
            else:                                sign = ""
            self.prompt_manager.set_text(
                "score", "Score for %s: %s change: %s%s" %\
                    ( obj.name,
                      str( round( obj.score, 2 ) ),
                      sign,
                      str( round( obj.score - obj.native_score, 2 ) ) ) )
        else:
            self.prompt_manager.set_text(
                "score", "Score for "+obj.name+" "+str( round(obj.score,2) ) )

        self.prompt_manager.show("score")

    def cleanup(self):
        """This function is triggered with cmd.set_wizard()."""
        #mo TODO delete selections
        pass

def do_rdw( app ):
    import rdwizard
    from rdwizard import util, controller, rosetta_io, res_table
    import tkFileDialog

    #print "do_rdw reading in database information..."
    settings_data = rdwizard.util.read_settingsfile()

    #print "do_rdw initializing Rosetta_io()..."
    rosetta_io = rosetta_io.Rosetta_io()

    #print "do_rdw initializing mini..."
    #rosetta_io.init_mini( settings_data[ 'database' ] )
    rosetta_io.init_mini( os.path.dirname( __file__ ) + '/database' )

    #print "do_rdw initializing controller()..."
    controller = controller.Controller()

    #print "do_rdw inializing Rosetta_design_wizard..."
    rdw = Rosetta_design_wizard( rosetta_io, controller)


    #print "do_rdw setting rdwizard..."
    cmd.set_wizard( rdw )

    #print "do_rdw intializing res_table.."
    res_table = res_table.Res_table( app=app,
                                     controller=controller )
