###################################################################
#   Part of the Rosetta Design Wizard for PyMOL
#
#   This file the ExtraOptions class
#
#   This is a module of rutines that interface with rosetta
#
# (c) Copyright Rosetta Commons Member Institutions.
# (c) This file is part of the Rosetta software suite and is made available under license.
# (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
# (c) For more information, see http://www.rosettacommons.org. Questions about this can be
# (c) addressed to University of Washington UW TechTransfer, email: license@u.washington.edu.
#
####################################################################

import os           # to get directory functions, like get_cwd()

import sys,re
import threading

sys.path.append( sys.path[-1] + 'rdwizard/' )



try:
    from rosetta_design_wizard import *
except:
    print "Rosetta Design Wizard Error: Unable to open Rosetta python Bindings."



#import minature1

import pymol_io, util


aa_codes = ['A','C','D','E','F','H','I','K','L','M','N','P','Q','R','S','T','V','W','Y']

task_names = ['NATRO','NATAA', 'ALLAAwc', 'ALLAAxc', 'POLAR',
              'APOLAR', 'PIKAA', 'NOTAA']
tasks_with_aas = [ 'PIKAA', 'NOTAA' ]

task_default_name = 'NATRO'
task_default_aa = ''


#used in pymol_io.color_by_task()
task_color_rgb = { 'NATRO'   : (0.1, 0.8, 0.1),  # a green
                   'NATAA'   : (1.0, 0.7, 0.2),  # brightorange
                   'ALLAAwc' : (0.1, 0.1, 1.0),  # br0
                   'ALLAAxc' : (0.3, 0.3, 0.8),  # tv_blue
                   'POLAR'   : (1.0, 0.0, 0.5),  # hotpink
                   'APOLAR'  : (1.0, 0.6, 0.6),  # deepteal
                   'PIKAA'   : (1.0, 0.5, 0.0),  # orange
                   'NOTAA'   : (1.0, 0.1, 0.1),  # red
                   'UNKOWN'  : (0.8, 0.8, 0.8) } # grey80




#used in Res_table.color_by_task()
task_color_hex = {
        'NATRO'   : util.hex_from_rgb( task_color_rgb[ 'NATRO'   ] ),
        'NATAA'   : util.hex_from_rgb( task_color_rgb[ 'NATAA'   ] ),
        'ALLAAwc' : util.hex_from_rgb( task_color_rgb[ 'ALLAAwc' ] ),
        'ALLAAxc' : util.hex_from_rgb( task_color_rgb[ 'ALLAAxc' ] ),
        'POLAR'   : util.hex_from_rgb( task_color_rgb[ 'POLAR'   ] ),
        'APOLAR'  : util.hex_from_rgb( task_color_rgb[ 'APOLAR'  ] ),
        'PIKAA'   : util.hex_from_rgb( task_color_rgb[ 'PIKAA'   ] ),
        'NOTAA'   : util.hex_from_rgb( task_color_rgb[ 'NOTAA'   ] ),
        'UNKOWN'  : util.hex_from_rgb( task_color_rgb[ 'UNKOWN'  ] ) }

def prefix_from_rgb( rgb ):
    r,g,b = rgb

    if r == 1.0: r = .9
    if g == 1.0: g = .9
    if b == 1.0: b = .9

    r = int( r * 10 )
    g = int( g * 10 )
    b = int( b * 10 )
    return "\\%s%s%s" % (r, g, b)

# used in Rosetta_design_wizard.setup_panel()
task_color_prefix = \
    { 'NATRO'   : prefix_from_rgb( task_color_rgb[ 'NATRO'   ] ),
      'NATAA'   : prefix_from_rgb( task_color_rgb[ 'NATAA'   ] ),
      'ALLAAwc' : prefix_from_rgb( task_color_rgb[ 'ALLAAwc' ] ),
      'ALLAAxc' : prefix_from_rgb( task_color_rgb[ 'ALLAAxc' ] ),
      'POLAR'   : prefix_from_rgb( task_color_rgb[ 'POLAR'   ] ),
      'APOLAR'  : prefix_from_rgb( task_color_rgb[ 'APOLAR'  ] ),
      'PIKAA'   : prefix_from_rgb( task_color_rgb[ 'PIKAA'   ] ),
      'NOTAA'   : prefix_from_rgb( task_color_rgb[ 'NOTAA'   ] ),
          }

task_description = \
    { 'NATRO':'Fix natural rotamer',
      'NATAA':'Fix natural amino acid',
      'ALLAAwc':'Allow all amino acids w/cystine',
      'ALLAAxc':'Allow all amino acids',
      'POLAR':'Allow only polar amino acids',
      'APOLAR':'Allow only non-polar amino acids',
      'PIKAA':'Pick amino acids to allow',
      'NOTAA':'Pick amion acids to disallow' }


ex_flags = [ "EX 1", "EX ARO 1", "EX 2", "EX ARO 2", "EX 3", "EX 4" ]

# this is based on ExtraRotSample enum in src/core/pack/task/RotamerSampleOptions.hh
ex_sample_levels =[ "No extra chi samples",
                    "1 full step std dev",
                    "1 1/2  step std dev",
                    "2 full step std dev",
                    "2 1/2  step std dev",
                    "4 1/2  step std dev",
                    "3 1/3  step std dev",
                    "6 1/4  step std dev"]




def make_task_name(obj, mode, aas):
    """this generates a unique key for a task, currently it is also
    the name of the task selection"""

    if mode =="PIKAA":
        aas = parse_task_aas(aas)
        return obj.name+"_@"+aas
    elif mode == "NOTAA":
        aas = parse_task_aas(aas)
        return obj.name+"_#"+aas
    else:
        return obj.name+"_"+mode


def parse_task_aas(aas):
    if not aas: return None
    aas = aas.upper()
    unique_aas = {}
    #aa_codes = ['A','C','D','E','F','H','I','K','L','M','N','P','Q','R','S','T','V','W','Y']
    for l in aas:
        if l in aa_codes:
            unique_aas[l]=1
    return "".join(sorted(unique_aas.keys()))

def parse_task_name( name ):
    if not name: return None
    name = str( name ).upper()
    if name == "ALLAA":   name = "ALLAAxc"
    if name == "ALLAAWC": name = "ALLAAwc"
    if name == "ALLAAXC": name = "ALLAAxc"

    if name not in task_names:
        return None

    return name



class Rosetta_io:

    def __init__(self):

        self.score_terms = []

        self.rosetta_engine = "minirosetta"
        self.database_path = None

        self.do_debug = False

    ##########################################

    def debug( self, message ):
        if self.do_debug:
            print message

    def init_mini( self, database_path ):

#        print "Rosetta Design Wizard: initializing mini..."

        self.database_path = database_path
        thread = threading.Thread(
            target=lambda db=database_path: init_mini( db ) )

        thread.start()
        thread.join()
#        print "Rosetta Design Wizard: done initializing mini"

        for term in get_score_terms():
            self.score_terms.append( term )
        self.score_terms.append( 'total' )

    def report_error( self, message):
        print "Rosetta Library Error:", message
        if get_init_state():
            tracer = get_tracer()
            if tracer:
                print tracer


    def set_weights( self, weights_tag, patch_tag ):
        if get_init_state():
            set_weights( weights_tag, patch_tag )
        else:
            self.report_error( "Rosetta is not yet initialized, cannot set the weights." )

    def get_weights( self ):
        if get_init_state():
            return get_weights()
        else:
            self.report_error( "Rosetta is not yet initialized, cannot get weights" )

    def nullfunc( self ):
        util.log()

    def add_object( self, obj ):
        """create pose for the object"""
        if get_init_state():
#            thread = threading.Thread( target=init_pose,
#                                       args=( obj.name, pymol_io.pdb( obj ) ) )
#            thread.start()
#            thread.join()

            init_pose( obj.name, pymol_io.pdb( obj ) )
        else:
            self.report_error( "Rosetta is not yet intialized, cannot intialize pose %s " % (obj.name) )

    def set_rosetta_res_ids(self, obj):
        """Se the res ids that rosetta thinks the object should have"""
        if get_init_state():
            obj.rosetta_res_ids = []
            for res_id in get_residues(obj.name):
                obj.rosetta_res_ids.append(res_id)

            #    self.report_error("Unable to set the rosetta res_ids for pose %s" % obj.name)
        else:
            self.report_error("Rosetta is not yet initialized, cannot get rosetta res_ids for pose %s" % obj.name)


    def set_task( self, obj, task_string ):
        """set task for the object """
        self.debug( "setting task for object %s " % obj.name)
        successful = False
        if get_init_state():
            try:
                set_task( obj.name, str( task_string ) )
                successful = True
            except:
                self.report_error( "There was an error with the resfile reader:")
        else:
            self.report_error( "Rosetta is not yet initialized, cannot set task for pose %s" % obj.name )

        return successful

    def clean_residue_task( self, obj, res_id ):
        if get_init_state():
            try:
                if res_id in obj.rosetta_res_ids:
                    clean_residue_task( obj.name, res_id[0], res_id[1] )

            except:
                self.report_error("There was an error trying to clean residue %s in object %s" %(res_id, obj.name) )
        else:
            self.report_error( "Rosetta is not yet initialized, cannot clean residue task" )


    def set_residue_task( self, obj, res_id, command):
        self.debug("set_residue_task, %s, %s, %s" % ( obj, res_id, command ))
        if get_init_state():
            try:
                if res_id in obj.rosetta_res_ids:
                    set_residue_task( obj.name, res_id[0], res_id[1], command)

            except:
                self.report_error( "There was an error while trying to set res_id %s with task command %s." % (res_id, command) )
            #self.debug( "here is the task string for from rosetta_io" )
            #self.debug( str(self.get_task( obj) ))
        else:
            self.report_error( "Rosetta is not yet initialized, cannot set the residue task for the object %s." % obj.name )


    def get_task( self, obj ):
        """ set the packable_variable for each residue in obj """
        self.debug("rosetta_io.get_task( %s )" % obj.name )
        if get_init_state():
            task = get_task( obj.name )
        else:
            self.report_error( "Rosetta is not yet initialized, cannot get packable information for pose %s." % obj.name )
        self.debug( "rosetta_io.get_task -> here is the task:" )
        self.debug( str( task ) )
        return task

    def get_packable( self, obj ):
        """ set the packable_variable for each residue in obj """
        self.debug( "in get_packable %s." % obj.name )
        if get_init_state():
            obj.packable = {}
            get_packable_vector = get_packable( obj.name )

            if len(get_packable_vector) != len(obj.rosetta_res_ids):
                self.set_rosetta_res_ids(obj)
            assert len(obj.rosetta_res_ids) == len(get_packable_vector)
            i = 0
            for is_packable in get_packable_vector:
                rres_id = obj.rosetta_res_ids[i]
                obj.packable[rres_id] = is_packable
                i +=1
        else:
            self.report_error( "Rosetta is not yet initialized, cannot get packable information for pose %s" % obj.name )

    def get_num_res_types( self, obj ):
        self.debug("in get_num_res_types %s." % obj.name )
        if get_init_state():
            obj.num_res_types = {}
            num_res_types_vector = num_res_types(obj.name)
            self.debug( "num_res_types(obj.name) %s %s" % (num_res_types_vector, len( num_res_types_vector ) ) )

            if len(num_res_types_vector) != len(obj.rosetta_res_ids):
                self.set_rosetta_res_ids(obj)
            assert len(obj.rosetta_res_ids) == len(num_res_types_vector)
            i = 0
            for num_res in num_res_types_vector:
                rres_id = obj.rosetta_res_ids[i]
                obj.num_res_types[rres_id] = num_res
                i +=1

        else:
            self.report_error("Rosetta is not yet initialized, cannt get the number of rotamers at each residue for pose %s." % obj.name )

        self.debug("done with get_num_res_types")


    def run_task( self, obj ):
        if get_init_state():
            run_task( obj.name )
        else:
            self.report_error( "Rosetta is not yet initialized, cannot run task for pose %s." % obj.name )


    def copy_task( self, src_obj, dest_obj ):
        if get_init_state():
            task = self.get_task( src_obj )
            self.set_task( dest_obj, task )
        else:
            self.report_error( "Rosetta is not yet initialized, cannot copy tast from pose %s to pose %s." % (src_obj.name, dest_obj.name) )

    def get_pdb( self, obj ):
#        print "getting pdb for object", obj.name
        if get_init_state():
            return get_pdb( obj.name )
        else:
            self.report_error( "Rosetta is not yet initialized, cannot get pdb for pose %s" % obj.name )
            return ""

    def get_score( self, obj):
        self.debug( "get_score from pose %s" % obj.name )
        if get_init_state():
            return get_score(obj.name)
        else:
            self.report_error( "Rosetta is not yet initialized, cannot get score for pose %s" % obj.name )
            return 0

    def get_detailed_score( self, obj ):
        self.debug( "get_detailed_score from pose %s" % obj.name )
        if get_init_state():
            return get_detailed_score( obj.name )
        else:
            self.report_error( "Rosetta is not yet initialized, cannot get detailed score for pose %s." % obj.name )

    def get_weighted_score( self, obj, res_id, score_term ):
        if not res_id in obj.rosetta_res_ids: return None
        assert score_term in self.score_terms
        chain, resi = res_id
        if score_term == 'total':
            total = 0
            for term in self.score_terms:
                if term != 'total':
                    total += get_weighted_score( obj.name, chain, resi, term )
            return total
        else:
            return get_weighted_score( obj.name, chain, resi, score_term )

    def resfile_readin( self, obj ):
        if self.rosetta_engine == "minirosetta":
            self.minirosetta_resfile_readin(obj)
        else:
            self.rosettapp_refile_readin(obj.name, obj.resfilename)

    def resfile_writeout( self, obj ):
        if self.rosetta_engine == "minirosetta":
            self.minirosetta_resfile_writeout(obj)
        else:
            self.report_error( "Rosetta Design Wizard: Warning: rosetta++ resfile write out nonfunctional!" )
            self.rosettapp_resfile_writeout(obj.residues)

    def minirosetta_resfile_readin(self, obj, filename=None ):

        #prepare for writting reading in mini resfile...
         mini_template.resfile_readin( self, obj, filename=None )
        # deal with what mini gives back...

    def get_all_weights_files( self ):
        filenames = os.listdir( self.database_path + "/scoring/weights" )
        weights = []
        weight_patches = []
        for filename in filenames:

            if filename.split(".")[ 1 ] == "wts":
                weights.append( filename.split(".")[ 0 ] )
            elif filename.split(".")[ 1 ] == "wts_patch":
                weight_patches.append( filename.split(".")[ 0 ] )

        return weights, weight_patches

    def rosettapp_resfile_parse(self, obj, text):
        """parses a rosetta++ resfile and returns a task -> res_id
        dictionary"""

        #warn if a residue has two different specifications
        residues_specified = {}

        #the dictionary to be returned
        residues_by_task = {}

        #begin only if we've seen the 'start' line
        begin = False


        lines = text.split("\n")
        for i in range(len(lines)-1):
            line = lines[i].split()

            #skip all lines above and including the line 'start'
            if begin == False:
                if line[0] == 'start':
                    begin = True
                continue

            #mo skip over blank and commented lines
            if line == []: i+=1;   continue

            chain = line[0].upper()
            #mo rosetta_residue index is in the seocnd column
            resi = int(line[2])

            #mo check to see if residue already has a specification
            #and warn if it does
            if (chain,resi) in residues_specified:
                print "res_table WARNING: residue",
                print str((chain, resi)),
                print "has duplicate specification: lines", i,
                print "and", residues_specified[(chain,resi)]
            else:
                residues_specified[(chain,resi)]=i

            mode = line[3]
            if mode not in Task.modes:
                print "res_table ERROR: residue mode ",
                print mode, " not recognized"
            elif  mode in Task.mode_aas:
                aas = line[4]
            else:
                aas = None

            task_name = make_task_name(obj, mode, aas)
            task= obj.get_task_by_name(task_name)
            if not task: task = Task(obj, mode, aas)
            task_id = task.id
            if task_id in residues_by_task:
                residues_by_task[task_id].append((chain,resi))
            else:
                residues_by_task[task_id] = [(chain,resi)]

        if begin == False:
            print "Rosetta Design Wizard: Rosetta++ resfile parser: missing start line"

        return residues_by_task

    ##TODO mo maybe replace with the above funciton...
    def rosettapp_resfile_readin(self, obj, filename=None):

        if filename:
            self.filelocation = filename

        self.resfile = open(self.filelocation)
        res_modes = []
        begin = False
        for line in self.resfile:

            #skip all lines above and including the line 'start'
            if begin == False:
                if line.strip() == 'start':
                    begin = True
                continue

            entry = line.split()
            if len(entry) == 0:
                continue
            elif len(entry)==5:
                chain, _, resi, mode, pikaa = entry
            elif len(entry) == 4:
                chain, _, resi, mode  = entry
                pikaa = ''
            else:
                print "Rosetta Design Wizard: Load Resfile Error: bad line", entry
            resi = int(resi)
            for i in range(len(self.res_modes)):
                if mode == self.res_modes[i]:
                    res_modes.append(((obj,chain,resi),(i,pikaa)))

        self.resfile.close()
        if begin == False:
            print "Rosetta Design Wizard: Load Resfile Error: missing start line"

            #        for res_id, (mode,pikaa) in res_modes:
            #            residue = self.get_residue(res_id)
            #            if not residue:
            #                print "Rosetta Design Wizard Load Error: residue ", res_id, "not in current object!"
            #            else:
            #                residue.mark(mode,pikaa)


    def rosettapp_resfile(self, obj):
        """return a string that is a rosettapp resfile"""
        text = 'start\n'
        i = 1
        for chain, resi in obj.rosetta_res_ids:
            task = obj.get_task_by_residue[(chain,resi)]
            if task.mode == 'NOTAA':
                mode = 'PIKAA'
                aas = ''.join([aa for aa in aa_codes if aa not in task.aas])
            else:
                mode = task.mode
                aas = task.aas

            text = text              +\
                   chain.rjust(2)    +\
                   str(i).rjust(5)   +\
                   str(resi).rjust(5)+\
                   mode.rjust(6)
            if task.mode in Task.mode_aas:
                text = text+'  '+aas

            text = text+'\n'
            i +=1
        return text

    ##TODO replace this with the above fuction...
    def rosettapp_resfile_writeout(self, residues):
        outlines = ['start']
        f = open(filelocation, 'w')
        for i in range(len(residues)):
            line = residues[i].chain.rjust(2) +\
                   i.rjust(5) +\
                   residues[i].resi.rjust(5) +\
                   residues[i].mode.rjust(6) +\
                   '  ' +\
                   residues[i].mode_pikaa

            outlines.append(line)

        f.writelines(outlines)
        f.close()


    def minirosetta_resfile_writeout(self,obj):
        mini_template.resfile_writeout(poseid)


    #mo used for building rosetta++ command strings
    def make_arg(self, var):
        if self.opt_vars[var]:
            if self.opt_var_seltype == 'counter':
                return '-'+var+' '+self.opt_vars[var]+' '
            else:
                return '-'+var+' '

    def rosettapp_build_command_string(self):
    ## def build_rosetta_plus_plus_command_string(self):
        c = '-s ' + self.object + '.pdb ' + '-design -fixbb '
        for var in opt_vars.keys():
            if opt_var_group[var] <> 'main':
                c = c + make_arg(var)
        return c

    def parse_score_terms(self, obj):
        """returns a (dictionary score_id -> (dictionary res_id -> float))"""

        pdb = obj.pdb()


