# (c) Copyright Rosetta Commons Member Institutions.
# (c) This file is part of the Rosetta software suite and is made available under license.
# (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
# (c) For more information, see http://www.rosettacommons.org. Questions about this can be
# (c) addressed to University of Washington UW TechTransfer, email: license@u.washington.edu.
import pymol
from pymol import cmd
import rosetta_io

#for make_gradient
import colorsys,sys

### This is just a collection of pymol functions

#WARNING calls to this should be replaced with proper multiobject
#handling
def main_object():
    """returns the most reseasonable object"""

    visible_objects = get_visible_objects()
    if visible_objects:
        return visible_objects[ -1 ]
    return None

def get_visible_objects():

    visible_objects = []

    vis_dict = cmd.get_vis()

    for obj_name in cmd.get_names( 'objects' ):
        if not obj_name.startswith("_") and vis_dict[ obj_name ][ 0 ] == 1:
            visible_objects.append( obj_name )

    return visible_objects


def get_active_selection():
    """PyMol allows only one selection to be enabled at a time.
    use cmd.enable or cmd.disable to enable or disable a selection.
    All I know about get_vis() is that cmd.scene() uses it! """
    vis_dict = cmd.get_vis()
    for selection in cmd.get_names('selections'):
        if not selection.startswith("_") and vis_dict[selection][0] == 1:
            return selection
    return None

def disable_active_selection():
    active_selection = get_active_selection()
    if active_selection:
        cmd.disable( active_selection )

def CAs(obj):
    """return the selection string for just the carbon atoms of the given object"""
    return obj.name + " and n. CA"

def model(obj):
    """return the pymol model for the given object"""
    return cmd.get_model(obj.name)

def pdb(obj):
    """return the pymol pdb string for the given object"""
    return cmd.get_pdbstr(obj.name)


def sele_by_obj(sele):
    """returns a dictionary obj_name->res_id of the residues in
    the selection"""
    ret = {}
    for obj_name in cmd.get_names('objects'):
        ret[obj_name] = all_residues(sele+' and '+obj_name)

    return ret

def all_residues(sele):
    """mo get all residues in a selection and return (obj_name,
    res_id) list"""

    res_sele = {}
    for chain in cmd.get_chains():
        # Need the quotes around the chain in case chain is " "; PyMOL reduces this to "".
        for at in cmd.get_model( '%s and chain "%s"' % (sele, chain.strip()) ).atom:

            res_sele[ ( chain, int( at.resi ) ) ] = 1

    return res_sele.keys()


def residues(obj, sele=None):
    """Return an unsorted list of res_ids in a given selection
    that are in a given object."""
    if not sele: sele = obj.name
    return sele_by_obj( sele )[ obj.name ]

def residue_sele_str(obj, res_id):
    """ returns a selection string for given object's residue"""
    chain, resi = res_id
    if type(resi)==tuple:
        b,e = resi
        str_resi = str(b)+"-"+str(e)
    else:
        str_resi = str(resi)
    return "/"+obj.name+"//"+chain+"/"+str_resi

def selection(obj, res_ids):
    """creates a selection string for the residues in the given
    object.  Note that in order to make the selection fast, it
    groups residues into ranges."""

    if res_ids == []: return ""

    compact = []
    res_ids.sort()
    #assert (res_ids == sorted(res_ids))


    cur_chain, beg = res_ids[0]
    cur = beg
    for chain, resi in res_ids[1:]:
        assert ( type(resi) == int )
        if chain == cur_chain and resi == cur + 1:
            cur += 1
        else:
            if beg == cur: compact.append((cur_chain,beg))
            else:          compact.append((cur_chain,(beg,cur)))
            cur_chain = chain
            beg = cur = resi

    if beg == cur: compact.append((cur_chain,beg))
    else:          compact.append((cur_chain,(beg,cur)))

    return " or ".join([residue_sele_str(obj, res_id)
                         for res_id in compact])


def load_pdb_into( name, new_pdb ):
    cmd.view( 'prior', 'store' )
    cmd.load_raw(content=new_pdb,
                 format='pdb',
                 object=name)
    cmd.view( 'prior', 'recall' )


def color_by_task( obj ):
    """ color the object by the how it is specified in the packer"""

    data = {}

    for res_id in obj.res_ids:
        data[res_id] = obj.num_res_types.get(res_id, 0)

    color_by_data_simple( obj, data, palette = "yellow_magenta" )

    not_packable_res_ids = []
    for res_id in obj.res_ids:
        if not obj.packable.get(res_id, True):
            not_packable_res_ids.append(res_id)

    sele_str = selection( obj, not_packable_res_ids )
    if sele_str != "":
        cmd.color( "green", "(elem C and("+ sele_str+ "))", quiet=1 )

    cmd.color( "atomic", "all and not elem C", quiet=1)

def get_residue_colors( obj ):
    pymol.stored.colors = []
    cmd.iterate( "%s and n. ca" % obj.name, "stored.colors.append( (chain, resi, name, color))")
    res_colors = {}
    for chain, resi, name, color in pymol.stored.colors:
        if name == 'CA': # c-alpha atom -> this should always be true
            res_colors[(chain, resi)] = cmd.get_color_tuple(color)


    return res_colors



##def color_by_task( obj ):
##    """ color the object by the task names """
##    task_name = {}
##    for res_id in obj.res_ids:
##
##        if obj.task_name[ res_id ] in task_name:
##            task_name[ obj.task_name[ res_id ] ].append( res_id )
##        else:
##            task_name[ obj.task_name[ res_id ] ] = [ res_id ]
##
##    for name, res_ids in task_name.iteritems():
##        name = rosetta_io.parse_task_name( name )
##
##        if not name: name = 'UNKOWN'
##
##        sele_str = selection( obj, res_ids )
##        cmd.color( "atomic",
##                   "((" + sele_str + ") and not elem C)",
##                   quiet=1 )
##        cmd.set_color( name, rosetta_io.task_color_rgb[ name ] )
##        cmd.color( name, "(elem C and (" + sele_str + "))", quiet=1 )
##

#TODO mo make better
def default_display(obj):
#     cmd.hide("everything",obj.name)
#     cmd.show("lines"     ,obj.name)
#     #cmd.hide("(all and hydro)")

#     cmd.show("cartoon",obj.name)

#     cmd.set("cartoon_fancy_helices",1, obj.name)
#     cmd.set("cartoon_highlight_color", "grey50", obj.name)
#     cmd.set("cartoon_side_chain_helper", 1, obj.name)

#

#     cmd.util.performance(0) #maximum quality
#     #cmd.set("cartoon_transparency",.6,obj.name)
#     cmd.rebuild()

#     #cmd.util.cbag(obj.name)
#     #cmd.show("cartoon", obj.name)
#     cmd.hide("(hydro and "+obj.name+")")
    pass

def color_by_data_simple(obj, data, palette="blue_red"):
    """ given an object, (res_id, value) dictionary and a
    pymol_palette string, color the object.

    The palette can be any of the following strings:

    blue_green            green_white_magenta   red_cyan
    blue_magenta          green_white_red       red_green
    blue_red              green_white_yellow    red_white_blue
    blue_white_green      green_yellow          red_white_cyan
    blue_white_magenta    green_yellow_red      red_white_green
    blue_white_red        magenta_blue          red_white_yellow
    blue_white_yellow     magenta_cyan          red_yellow
    blue_yellow           magenta_green         red_yellow_green
    cbmr                  magenta_white_blue    rmbc
    cyan_magenta          magenta_white_cyan    yellow_blue
    cyan_red              magenta_white_green   yellow_cyan
    cyan_white_magenta    magenta_white_yellow  yellow_cyan_white
    cyan_white_red        magenta_yellow        yellow_green
    cyan_white_yellow     rainbow               yellow_magenta
    cyan_yellow           rainbow2              yellow_red
    gcbmry                rainbow2_rev          yellow_white_blue
    green_blue            rainbow_cycle         yellow_white_green
    green_magenta         rainbow_cycle_rev     yellow_white_magenta
    green_red             rainbow_rev           yellow_white_red
    green_white_blue      red_blue              yrmbcg

    """
    if len(data) != len(obj.res_ids):
        print "Rosetta Design Wizard: Warning: Trying to color by data but number of residues is inconsistent."
    pymol.stored.data = data
    cmd.alter( CAs(obj), 'b = stored.data[ (chain, int( resi ) ) ]' )

    cmd.spectrum(
        expression='b',
        palette = palette,
        selection = CAs(obj),
        byres = 1,
        quiet = 1)


"""
AUTHOR

	Robert L. Campbell, (adapted by Matthew O'Meara)

USAGE

	color_by_data(selection='sel',
                  gradient='bgr' or 'rgb' or 'bwr' or 'rwb'
		      or 'bmr' or 'rmb' or 'gray' or 'reversegray'
	              mode='hist' or 'ramp',
		      nbins=11, sat=1.0, value=1.0)

    This function allows coloring of a collection of residues with
    given data, following a gradient of colours.  The gradients can
    be:
	'bgr': blue -> green   -> red
	'rgb': red  -> green   -> blue
	'bwr': blue -> white   -> red
	'rwb': red  -> white   -> blue
	'bmr': blue -> magenta -> red
	'rmb': red  -> magenta -> blue

	('rainbow' and 'reverserainbow' can be used as
	synonyms for 'bgr' and 'rgb' respectively and 'grey'
	can be used as a synonym for 'gray').

    The division of the data ranges can in either of two modes: 'hist'
    or 'ramp'. 'hist' is like a histogram (equal-sized increments
    leading to unequal numbers of residues in each bin). 'ramp' as a
    ramp of data ranges with the ranges chosen to provide an equal
    number of residues in each group.

    You can also specify the saturation and value (i.e. the"s" and "v"
    in the "HSV" color scheme) to be used for the gradient.
    The defaults are 1.0 for both"sat" and "value".

        In the case of the gray scale gradients,"sat" sets
	the minimum intensity (normally black) and "value"
	sets the maximum (normally white)

"""
# function for creating the gradient
def make_gradient(sel,gradient,nbins,sat,value):
    coldesc = ['col'+str(j) for j in range(nbins)]

    if gradient == 'bgr' or gradient == 'rainbow':
        # create colors in a gradient blue -> green -> red
        for j in range(nbins):
            # create colors using hsv scale (fractional) starting
            # at blue(.6666667) through red(0.00000) in intervals
            # of .6666667/(nbins -1) (the "nbins-1" ensures that
            # the last color is, in fact, red (0).

            hsv = (colorsys.TWO_THIRD - colorsys.TWO_THIRD * float(j) / (nbins-1), sat, value)
            #hsv = (colorsys.TWO_THIRD*(1-float(j))/(nbins-1), sat, value)

            #convert to rgb and set color
            #print "in bgr", j, hsv, colorsys.hsv_to_rgb(*hsv)
            cmd.set_color(coldesc[j], list(colorsys.hsv_to_rgb(*hsv)))

    elif gradient == 'rgb' or gradient == 'reverserainbow':
        # create colors in a gradient red -> green -> blue
        for j in range(nbins):
            # create colors using hsv scale (fractional) starting
            # at red(.00000) through blue(0.66667) in intervals of
            # .6666667/(nbins -1) (the "nbins-1" ensures that the
            # last color is, in fact, red (0)
            hsv = (colorsys.TWO_THIRD * float(j) / (nbins-1), sat, value)

            #convert back to rgb and set color
            cmd.set_color(coldesc[j], list(colorsys.hsv_to_rgb(*hsv)))

    elif gradient == 'bmr':
        # create colors in a gradient from blue -> magenta -> red
        for j in range(nbins):
            rgb = [min(1.0, float(j)*2/(nbins-1)),
                   0.0,
                   min(1.0, float(nbins-j-1)*2/(nbins-1))]

            # convert rgb to hsv, modify saturation and value
            hsv = list(colorsys.rgb_to_hsv(*rgb))
            hsv[1] = hsv[1]*sat
            hsv[2] = hsv[2]*value

            #convert back to rgb and set color
            cmd.set_color(coldesc[j], list(colorsys.hsv_to_rgb(*hsv)))

    elif gradient == 'rmb':
        # create colors in a gradient from red -> magenta -> blue
        for j in range(nbins):
            rgb = [min(1.0, float(nbins-j-1)*2/(nbins-1)),
                   0.0,
                   min(1.0, float(j)*2/(nbins-1))]

            # convert rgb to hsv, modify saturation and value
            hsv = list(colorsys.rgb_to_hsv(*rgb))
            hsv[1] = hsv[1]*sat
            hsv[2] = hsv[2]*value

            #convert back to rgb and set color
            cmd.set_color(coldesc[j], list(colorsys.hsv_to_rgb(*hsv)))

    elif gradient == 'bwr':
        # create colors in a gradient from blue -> white
        for j in range(nbins/2):
            # create colors in a gradient from blue to white
            rgb = [min(1.0, float(j)*2/(nbins-1)),
                   min(1.0, float(j)*2/(nbins-1)),
                   min(1.0, float(nbins-j-1)*2/(nbins-1))]

            # convert rgb to hsv, modify saturation and value
            hsv = list(colorsys.rgb_to_hsv(*rgb))
            hsv[1] = hsv[1]*sat
            hsv[2] = hsv[2]*value

            #convert back to rgb and set color
            cmd.set_color(coldesc[j], list(colorsys.hsv_to_rgb(*hsv)))

        # create colors in a gradient from white -> red
        for j in range(nbins/2,nbins):
            rgb = [min(1.0, float(j)*2/(nbins-1)),
                   min(1.0, float(nbins-j-1)*2/(nbins-1)),
                   min(1.0, float(nbins-j-1)*2/(nbins-1))]

            # convert rgb to hsv, modify saturation and value
            hsv = list(colorsys.rgb_to_hsv(*rgb))
            hsv[1] = hsv[1]*sat
            hsv[2] = hsv[2]*value

            #convert back to rgb and set color
            cmd.set_color(coldesc[j], list(colorsys.hsv_to_rgb(*hsv)))

    elif gradient == 'rwb':
        # create colors in a gradient from red -> white
        for j in range(nbins/2):

            rgb = [min(1.0, float(nbins-j-1)*2/(nbins-1)),
                   min(1.0, float(j)*2/(nbins-1)),
                   min(1.0, float(j)*2/(nbins-1))]

            # convert rgb to hsv, modify saturation and value
            hsv = list(colorsys.rgb_to_hsv(*rgb))
            hsv[1] = hsv[1]*sat
            hsv[2] = hsv[2]*value

            #convert back to rgb and set color
            cmd.set_color(coldesc[j], list(colorsys.hsv_to_rgb(*hsv)))

        # create colors in a gradient from white -> blue
        for j in range(nbins/2,nbins):
            rgb = [min(1.0, float(nbins-j-1)*2/(nbins-1)),
                   min(1.0,float(nbins-j-1)*2/(nbins-1)),
                   min(1.0, float(j)*2/(nbins-1))]

            # convert rgb to hsv, modify saturation and value
            hsv = list(colorsys.rgb_to_hsv(rgb[0],rgb[1],rgb[2]))
            hsv[1] = hsv[1]*sat
            hsv[2] = hsv[2]*value

            #convert back to rgb and set color
            cmd.set_color(coldesc[j], list(colorsys.hsv_to_rgb(*hsv)))

    elif gradient == 'gray' or gradient == 'grey':
        # create colors in a gradient of grays 'sat' -> 'value'
        for j in range(nbins):
            hsv = [0, 0, sat + (value-sat)*float(j)/(nbins-1)]

            #convert back to rgb and set color
            cmd.set_color(coldesc[j], list(colorsys.hsv_to_rgb(*hsv)))

    elif gradient == 'reversegray' or gradient == 'reversegrey':
        # create colors in a gradient of grays 'value' -> 'sat'
        for j in range(nbins):
            hsv = [0, 0, value - (value-sat)*float(j)/(nbins-1)]

            #convert back to rgb and set color
            cmd.set_color(coldesc[j], list(colorsys.hsv_to_rgb(*hsv)))

    return coldesc

def color_by_data(obj, res_ids, data, custom_bins = None,
                  mode="hist", gradient="rgb",
                  nbins=11, sat=1., value=1.):

    nbins=int(nbins)
    sat=float(sat)
    value=float(value)

    # make sure lowercase
    gradient.lower()
    mode.lower()

    # Sanity checking
    if len(res_ids) != len(data):
        print "\n     WARNING: Inequal number of x-values and y-values\n"

    if nbins <= 1:
        print "\n     WARNING: You specified nbins<=1, which doesn't make sense...resetting nbins=11\n"
        nbins=11

    if mode not in ('hist','ramp'):
        print "\n     WARNING: Unknown mode ",mode, "    ----->   Nothing done.\n"
        return

    if custom_bins and mode == 'hist':
        print "\n     WARNING: Custom bins can only be used with histogram binning"
    elif gradient not in ('bgr','rgb','rainbow','reverserainbow','bwr',
                          'rwb','bmr','rmb','gray','grey',
                          'reversegray','reversegrey'):
        print "\n     WARNING: Unknown gradient: ",gradient, "    ----->   Nothing done.\n"
        return

    if len(res_ids) == 0:
        print "\n     WARNING: No residues specified\n"
        return

    #list of (res_id list)
    binned = []

    #sort the res_ids and data together
    pairs = zip(res_ids,data)
    pairs.sort(lambda (x1,y1), (x2,y2): cmp(y1,y2))
    res_ids, data = zip(*pairs) #python's unzip

    if mode == 'ramp':
        #Put the x-points into the bins evenly.

        #If nbins does not evenly divide len(data), add an
        #extra x-point to the first remainder number of bins
        bin_num = len(data)/nbins
        remainder = len(data) % nbins
        i = 0
        for j in range(remainder):
            bins.append(res_ids[i:i+bin_num+1])
            i += bin_num+1
        for j in range(remainder, nbins):
            bins.append(res_ids[i:i+bin_num])
            i += bin_num

    elif mode == 'hist':
        if custom_bins:
            custom_bins_check = copy.deepcopy(custom_bins)
            custom_bins_check.sort()
            assert custom_bins == custom_bins_check
            bins = custom_bins
            if max(data)>= max(bins):
                print "\n     WARNING: data_outside max custom bin. Clipping data to ", max(custom_bins)
                m = max(custom_bins)
                data = [min(d,m) for d in data]

        else:
            max_bound = max(data)
            min_bound = min(data)
            nbins = int(nbins)
            if nbins <= 0:
                print "\n     WARNING: nbins <= 0"

            bin_width = (max_bound-min_bound)/nbins
            bins = [min_bound +i*bin_width for i in range(nbins)]

        i = 0
        for b in bins:
            bin = []
            while data[i] <= b:
                bin.append(res_ids[i])
                i +=1

            binned.append(bin[:]) #pythonism [:] -> deep copy

#note as of PyMol 1.1, PyMol's selection algebra is SLOW.  Its
#worth one's while to compress a collection of residues into a
#minimal number of ranges before making a selection.  This is
#done in selection(*args)

    sel = []
    for res_ids in binned:
        if res_ids:
            sel.append(selection(obj,res_ids))
        else:
            sel.append('')

    colors = make_gradient(sel, gradient, nbins, sat, value)
    for j in range(nbins):
        if sel[j]:
            cmd.color(colors[j],sel[j])
