## trees from distances

## this is a good example of bad programming
## if I wrote it again it would be a lot nicer

import string
from os import popen,system,getcwd
import sys
from whrandom import random
from math import floor
from popen2 import popen2
from operator import add
from os.path import exists


def IsALeaf(node): return (node[0] == node[1])


def Average_score (leaf_list, leaf_scores, percentile):
    ls = []
    for leaf in leaf_list:
        ls = ls + leaf_scores[leaf]
    ls.sort()
    
    pos = (percentile * len(ls) ) / 100
    if pos==len(ls): pos = len(ls)-1
    
    return ls[pos]


def Make_tree(distance,num_leaves,Update_distance_matrix,leaf_scores,percentile):
    N = num_leaves
    
    nodes = []
    for i in range(N):
        nodes.append( (i,i,0.0,Average_score([i],leaf_scores,percentile)) )

    for i in range(N): ## initialize distance matrix
        for j in range(N):
            distance[(nodes[i],nodes[j])] = distance[(i,j)]
    

    while N>1:

        ## find two closest nodes and join them
        min_d = 100000

        for i in nodes:
            for j in nodes:
                if i<=j:continue
                if distance[(i,j)] < min_d:
                    min_d = distance[(i,j)]
                    n1 = i
                    n2 = j
        
##         print "num_nodes: %d   Joining: %s and %s   Distance: %7.3f\n"\
##               %(N,Show_small(n1),Show_small(n2),min_d)
        
        new_node = (n1,n2,min_d,Average_score( Node_members(n1)+Node_members(n2),
                                               leaf_scores,percentile))
        

        ## update the distances
        Update_distance_matrix (new_node,nodes,distance)

        ## update the node_list
        nodes.append(new_node)
        del nodes[ nodes.index(n1)]
        del nodes[ nodes.index(n2)]

        N = N-1
                    
    return nodes[0]

def Show_tree(tree,names):
    if IsALeaf(tree):
        return names [tree[0]]
    else:
        return '('+Show_tree(tree[0],names)+':'+str(float(tree[2])/2)+','+\
               Show_tree(tree[1],names)+':'+str(float(tree[2])/2)+')'

def Show_small(tree):
    if IsALeaf(tree):
        return `tree[0]`
    else:
        return '('+Show_small(tree[0])+','+Show_small(tree[1])+')'

def Node_members(node):
    if IsALeaf(node):
        return [node[0]]
    else:
        l1 = Node_members( node[0] )
        l2 = Node_members( node[1] )
        if min(l1)<min(l2):
            return l1+l2
        else:
            return l2+l1


def Update_distance_matrix_AL(new_node,old_nodes,distances): ## single linkage
    n1 = new_node[0]
    n2 = new_node[1]

    l1 = Node_members(new_node)
    distances [ (new_node,new_node)] = 0.0
    
    for n in old_nodes:
        if n==n1 or n==n2:continue
        l2 = Node_members(n)

        avg = 0.0
        count = 0
        for i in l1:
            for j in l2:
                assert i!=j
                avg = avg+ distances[(i,j)] 
                count = count + 1
        
        distances[(n,new_node)] = avg/count
        distances[(new_node,n)] = avg/count
        
    return

def Update_distance_matrix_SL(new_node,old_nodes,distances): ## single linkage
    n1 = new_node[0]
    n2 = new_node[1]

    l1 = Node_members(new_node)
    distances [ (new_node,new_node)] = 0.0
    
    for n in old_nodes:
        if n==n1 or n==n2:continue
        l2 = Node_members(n)

        min_d = 1000
        count = 0
        for i in l1:
            for j in l2:
                assert i!=j
                min_d = min(min_d, distances[(i,j)])
        
        distances[(n,new_node)] = min_d
        distances[(new_node,n)] = min_d
        
    return

def Center(tree,node_position):
    l = Node_members(tree)
    pos = 0.0
    for i in l:
        pos = pos + node_position[i]
    pos = pos/len(l)
    return pos

def Size(tree,sizes):
    if IsALeaf(tree):
        return sizes[tree[0]]
    else:
        return Size(tree[0],sizes)+Size(tree[1],sizes)

def Fig_tree(tree,node_position,sizes): ## edge = [ [x0,y0], [x1,y1], score, size]
    if IsALeaf(tree):
        return []
    else:

        rmsd = tree[2]
        center = Center(tree,node_position)
        
        c0 = Center(tree[0],node_position)
        r0 = tree[0][2]
        score0 = tree[0][3]
        size0 = Size(tree[0],sizes)
        if IsALeaf(tree[0]):
            cluster0 = tree[0][0]
        else:
            cluster0 = -1
        e0_horizontal = [ [rmsd, c0], [r0,c0], score0, size0, cluster0]
        e0_vertical   = [ [rmsd, c0], [rmsd,center], score0, 1, cluster0]
        
        c1 = Center(tree[1],node_position)
        r1 = tree[1][2]
        score1 = tree[1][3]
        size1 = Size(tree[1],sizes)
        if IsALeaf(tree[1]):
            cluster1 = tree[1][0]
        else:
            cluster1 = -1
        e1_horizontal = [ [rmsd, c1], [r1, c1], score1, size1 , cluster1]
        e1_vertical   = [ [rmsd, c1], [rmsd,center], score1, 1, cluster1]
        
        return [ e0_vertical,e0_horizontal,e1_vertical,e1_horizontal] + \
               Fig_tree(tree[0],node_position,sizes) + \
               Fig_tree(tree[1],node_position,sizes)
    
def Node_labels(tree,sizes,node_position):
    if IsALeaf(tree):return []
    else:
        pos = [tree[2],Center(tree,node_position)]
        size = 0
        for leaf in Node_members(tree):
            size = size+sizes[leaf]
        return [ [ `size`, pos] ] + \
               Node_labels(tree[0],sizes,node_position) + \
               Node_labels(tree[1],sizes,node_position)



def Canvas_tree(tree, names, sizes, plotter, plot_width, plot_height, selected_cluster):
    ## plot_width and plot_height in pixels

    ## plotter has methods:
    ## .make_line ( [x0,y0], [x1,y1], line_width, normalized_score, extra_tag, selected)
    ## .make_text (text,  [x,y], font)

    branch_width_pixels = min(100,plot_height/5)
    
    ## allocate widths for branches; widths measure cluster sizes
    total = reduce(add,sizes)
    w_factor = float( branch_width_pixels) / total 
    total = 0
    for s in sizes:
        width = max(1,int(floor(0.5+ s*w_factor))) ## in pixels
        total = total+width
    remainder = plot_height - total
    cluster_width = float(remainder)/len(names) ## padding alotted to each cluster
    
    ## position nodes vertically on tree
    nodes = Node_members(tree)
    node_position = {}
    mark = plot_height
    for i in range(len(nodes)):
        node_position[nodes[i]] = mark
        width = max(1,int(floor(0.5+ s*w_factor)))
        mark = mark - cluster_width - width 

    edges = Fig_tree(tree,node_position,sizes) ## each edge = [[x0,y0],[x1,y1],score,size,cluster]


    ## set fontsize: is this still right??
    
    font = min(18, max (5, int(floor( 0.5 + (cluster_width+7.5)/10))))

    ## rescale the x-positions
    max_rmsd = tree[2]
    min_rmsd = tree[2]
    for e in edges:
        if e[0][0]>0: min_rmsd = min(min_rmsd,e[0][0])
        if e[1][0]>0: min_rmsd = min(min_rmsd,e[1][0])
    min_rmsd = max(0,min_rmsd-0.5)

    def Transform(rmsd,min_rmsd = min_rmsd, max_rmsd = max_rmsd,plot_width = plot_width):
        return int (floor ( 0.5 + plot_width * (rmsd - min_rmsd) / (max_rmsd - min_rmsd)))
    

    ## rescale colors
    scores = []
    for e in edges:
        scores.append(e[2])
    min_score = min(scores)
    max_score = max(scores)
    if max_score == min_score:
        max_score = max_score + 1

    ## write the edges
    for e in edges:
        start = [ Transform (max(e[0][0],min_rmsd)), e[0][1]] ## rescale x-position
        stop = [ Transform (max(e[1][0],min_rmsd)), e[1][1]]

        normalized_score = float( e[2] - min_score)/(max_score-min_score)
        line_width = max(1,int(floor(0.5+ e[3]*w_factor)))
        selected = 0
        if e[4]>=0: ## it's a real cluster edge
            cluster = e[4]
            extra_tag = 'cluster%02d.%03d'%(cluster,sizes[cluster])
            if extra_tag == selected_cluster:
                selected = 1
        else:
            extra_tag = 'dummy'

        plotter.make_line(start,stop,line_width,normalized_score,extra_tag,selected)


    ## show scale
    plotter.make_line([Transform(min_rmsd),5], [Transform(max_rmsd),5],3,1.0,'dummy',0)

    for i in range(int(floor(min_rmsd+1)),1+int(floor(tree[2]))):
        plotter.make_text( str(i), [Transform(i),0], 18)
        

    plotter.make_text( 'Colors: from blue (%7.2f) to red (%7.2f)'%(min_score,max_score),
                       [0,25],10)


    ## label leaves
    for i in range(len(names)):
        extra_tag = 'cluster%02d.%03d'%(i,sizes[i])
        plotter.make_text(names[i],
                          [Transform(min_rmsd),node_position[i]],
                          font,extra_tag)
        
    ## label internal vertices with sizes
    for l in Node_labels (tree,sizes,node_position):
        plotter.make_text(l[0], [Transform(l[1][0]),l[1][1]], font)
    
    return



 
