#!/usr/bin/python

## trees from distances

import string
from os import popen,system
from os.path import exists
import sys
from whrandom import random
from math import floor
from popen2 import popen2
import score_trees_devel
from sys import stderr


def Print_help_message():
    print '\n'+'-'*75
    print 'Usage: %s {-gs} {-cartoons <N>} <prefix> <score-index> <P>'%sys.argv[0]
    print '-gs gives grey-scale colors, default is rainbow'
    print '-cartoons gives little cartoon plots for N of the clusters (thanks jens meiler)'
    print '\nLoads cluster distances from <prefix>.info'
    print 'Loads decoy scores from the <score-index> column of the '
    print '  DECOY_SCORE lines in <prefix>.info\n'
    print 'Score assigned to a cluster is the Pth percentile score. (0<=P<=100)'
    print 'Makes two trees:\n 1) average linkage: "<prefix>.average_linkage_cluster_tree.ps"'
    print ' 2) single linkage: "<prefix>.single_linkage_cluster_tree.ps"\n'+'-'*75+'\n'



if len(sys.argv) <4:
    Print_help_message()
    assert len(sys.argv)>=4

args = sys.argv[1:]
if '-gs' in args:
    GREY_SCALE = 1
    del args[args.index('-gs')]
else:
    GREY_SCALE = 0

num_cartoons = 0
if '-cartoons' in args:
    pos = args.index('-cartoons')
    num_cartoons = int(args[pos+1])
    del args[pos]
    del args[pos]

min_cluster_size = 0
if '-m' in args:
    pos = args.index('-m')
    min_cluster_size = int(args[pos+1])
    del args[pos]
    del args[pos]

USE_NATIVE = 0
if '-n' in args:
    USE_NATIVE = 1
    pos = args.index('-n')
    native_score = float( args[pos+1] )
    del args[pos]
    del args[pos]

prefix = args[0]
score_index = int(args[1])
P = int(args[2])

if not exists(prefix+'.info'):
    sys.stderr.write('WARNING: couldnt open info_file: %s\n'%prefix+'.info')
    Print_help_message()
    assert (exists(prefix+'.info'))

if 1:

    ## parse .info file
    file = prefix+'.info'
    data = open(file,'r')
    line = string.split(data.readline())

    sizes = {}
    cluster_members = {}
    decoy_score = {}
    distance = {}
    rmsd_to_native = {}

    sys.stderr.write('parsing .info file: %s\n'%file)
    while line:
        if line[0] == 'CLUSTER_RMSDS' and int(line[2]) >= min_cluster_size:
            cluster = int(line[1])
            rmsd_to_native[cluster] = float(line[4])
            assert cluster == len(sizes.keys()) and len(line) == 10+cluster
            sizes[cluster] = int(line[2])
            for i in range(cluster+1):
                distance[(i,cluster)] = float(line[9+i])
                distance[(cluster,i)] = float(line[9+i])
        elif line[0] == 'CLUSTER_INFO:':
            cluster = int(line[1])
            cluster_members[cluster] = map(lambda x:string.split(x,',')[1],
                                           line[9:])
        elif line[0] == 'DECOY_SCORE:':
            decoy_score[line[-1]] = float(line[score_index])
        line = string.split(data.readline())
    data.close()

    N = len(sizes.keys())
    print 'N=',N

    if USE_NATIVE:
        N = N+1

        native_index = N-1
        sizes[native_index] = 1

        ## add distances to native
        distance[(native_index,native_index)] = 0.0
        for i in range(N-1):
            distance[(i,native_index)] = rmsd_to_native[i]
            distance[(native_index,i)] = rmsd_to_native[i]

    cluster_scores = {}
    for cluster in range(N):
##         print sizes[cluster],len(cluster_members[cluster])
        if USE_NATIVE and cluster == native_index: continue
        assert sizes[cluster] == len(cluster_members[cluster])
        scores = []
        for decoy in cluster_members[cluster]:
            scores.append(decoy_score[decoy])
        cluster_scores[cluster] = scores


    if USE_NATIVE:
        cluster_scores[native_index] = [ native_score ]

    names = {}
    for i in range(N):
        if USE_NATIVE and i == native_index: continue
        names[i] = `i`+'_'+`sizes[i]`

    ## add native
    if USE_NATIVE:
        names[native_index] = score_trees_devel.NATIVE_NAME

## def Make_tree(distance,num_leaves,Update_distance_matrix,leaf_scores,percentile):

    if num_cartoons:
        ps_file = prefix+'.average_linkage_cluster_tree.I%02d.P%03d.C%d.ps'\
                  %(score_index,P,num_cartoons)
    else:
        ps_file = prefix+'.average_linkage_cluster_tree.I%02d.P%03d.ps'%(score_index,P)

    score_tree = score_trees_devel.Make_tree(distance, N,
                                             score_trees_devel.Update_distance_matrix_AL,
                                             cluster_scores,P)

    ##########
    ## testing

    

    if 0:
        trees = score_trees_devel.Cut_tree( score_tree, 6 )

        max_size =0
        for tree in trees:
            size = score_trees_devel.Size( tree, sizes)
            print size,string.join(map(lambda x:names[x],
                                       score_trees_devel.Node_members(tree)),',')
            if size > max_size:
                max_size = size
                score_tree = tree


    if 0:
        score_tree = score_trees_devel.Trim_tree( score_tree, names, sizes, 2.75 )


    stderr.write('making %s\n'%ps_file)
    score_trees_devel.Plot_tree( score_tree, names, sizes, ps_file, GREY_SCALE,
                                 num_cartoons, prefix)


    if num_cartoons:
        ps_file = prefix+'.single_linkage_cluster_tree.I%02d.P%03d.C%d.ps'\
                  %(score_index,P,num_cartoons)
    else:
        ps_file = prefix+'.single_linkage_cluster_tree.I%02d.P%03d.ps'%(score_index,P)

    score_tree = score_trees_devel.Make_tree(distance, N,
                                             score_trees_devel.Update_distance_matrix_SL,
                                       cluster_scores,P)

    stderr.write('making %s\n'%ps_file)
    score_trees_devel.Plot_tree( score_tree, names, sizes, ps_file, GREY_SCALE,
                                 num_cartoons,prefix)

