#!/usr/bin/python

## trees from distances

from phil import *
import score_trees_devel


def Help():
    print '\nUsage: %s <codes-file> <chisq-file> <p-val cluster threshold>\n\n'%(argv[0])
    exit()
    
if len(argv) < 4:
    Help()

## read chisq file --> distances
codes_file = argv[1]
chisq_file = argv[2]
THRESHOLD = float(argv[3])


lines = map(string.split,popen('grep "^CHISQ " '+chisq_file).readlines())
distance = {}
scores = {}
names = {}
N = 0
for line in lines:
    i1 = int(line[1])
    i2 = int(line[2])
    scores[i1] = [1.0]
    scores[i2] = [1.0]
    names[i1] = line[3]
    names[i2] = line[4]
    N = max(N, max(i1+1,i2+1))
    pval = float(line[9])
    distance[(i1,i2)] = pval
    distance[(i2,i1)] = pval
    distance[(i1,i1)] = 0.0
    distance[(i2,i2)] = 0.0

assert len(scores.keys()) == N

## call make_tree
P = 25
## score_tree = score_trees_devel.Make_tree(distance, N,
##                                          score_trees_devel.Update_distance_matrix_SL,
##                                          scores,P)

#score_tree = score_trees_devel.Make_tree(distance, N,
#                                         score_trees_devel.Update_distance_matrix_AL,
#                                         scores,P)

score_tree = score_trees_devel.Make_tree(distance, N,
                                         score_trees_devel.Update_distance_matrix_AL_GEOM,
                                         scores,P)


## call plot_tree
ps_file = chisq_file+'.cluster.ps'

name_list = []
size_list = []
for i in range(N):
    name_list.append( names[i] )
    size_list.append( 1 )

score_trees_devel.Plot_tree( score_tree, name_list, size_list, ps_file)

#print score_tree

## get clusters:

def Get_clusters(tree,names,threshold):
    if score_trees_devel.IsALeaf(tree):
        return []
    if tree[2] < threshold:
        members = score_trees_devel.Node_members( tree )
        print 'cluster:', threshold, tree[2], \
              string.join(map(lambda x,n=names: n[x],members))

        return [ members ]

    else:
        return Get_clusters(tree[0],names,threshold) + \
               Get_clusters(tree[1],names,threshold)
    
clusters = Get_clusters(score_tree,names,THRESHOLD)

exit()

## parse the codes file:
barcode = {}
barcode_torsions = {}
lines = map(string.split,popen('grep "^CODE:" '+codes_file).readlines())
for line in lines:
    name = line[1]
    barcode[name] = map(int,line[2:])

lines = map(string.split,popen('grep "^CODE_TORSIONS:" '+codes_file).readlines())
for line in lines:
    name = line[1]
    barcode_torsions[name] = map(lambda x:map(float,string.split(x,',')),line[2:])

lines = map(string.split, popen('grep "^NAT_TORSIONS:" '+codes_file).readlines())
if lines:
    NATIVE = 1
    nat_torsions = map(lambda x:map(float,string.split(x,',')),lines[0][2:])  
else:
    NATIVE = 0
    
def In_range(a): ## forces a to be in [-180,180)
    while a>=180.0:
        a = a - 360.0
    while a<-180.0:
        a = a + 360.0
    return a
        
def Angle_delta(a,b):
    a = In_range(a)
    b = In_range(b)
    assert -180 <= a <= 180
    assert -180 <= b <= 180
    x = max(a,b)
    y = min(a,b)
    delta = min( x-y, y+360.0 - x) 
    assert delta <= 180.0
    return delta

def Dist(c1,c2):
    return reduce(add, map(lambda x,c1=c1,c2=c2:(c1[x] != c2[x]), range( len(c1))))

def Torsion_distance(c1,c2): ## c1,c2 are lists of (phi,psi) pairs
    L = len(c1)
    assert len(c2) == L
    return 0.1 * sqrt( reduce(add, map(lambda x,c1=c1,c2=c2:
                                       (Angle_delta(c1[x][0],c2[x][0])**2 + ## phi
                                        Angle_delta(c1[x][1],c2[x][1])**2), ## psi
                                       range(L))) / float(L))


decoys = barcode.keys()
N = len(decoys)
decoy_name = {}
for i in range(N):
    decoy_name[i] = decoys[i]

for c in clusters:
    cluster_name = string.join(map(lambda x:names[x], c),',')
    print c,string.join(map(lambda x:names[x], c))

    subcode = {}
    subcode_torsions = {}
    for name in barcode.keys():
        subcode[name] = map(lambda x:barcode[name][x], c)
        subcode_torsions[name] = map(lambda x:barcode_torsions[name][x],c)

    nat_tor = map(lambda x:nat_torsions[x],c)
    
    ## make distance matrix
    distance = {}
    scores = {}
    for i in range(N):
        print 'calc distances:',i
        if NATIVE:
            scores[i] = [ Torsion_distance( subcode_torsions[decoy_name[i]], nat_tor) ]
        else:
            scores[i] = [1.0]

        distance[(i,i)] = 0.0
        for j in range(i+1,N):
            #d1 = Dist( subcode[decoy_name[i]],
            #           subcode[decoy_name[j]] )
            d2 = Torsion_distance( subcode_torsions[decoy_name[i]],
                                   subcode_torsions[decoy_name[j]] )
            distance[(i,j)] = d2
            distance[(j,i)] = d2

    #break
    ## call make_tree
    P = 25
    cluster_tree = score_trees_devel.Make_tree(distance, N,
                                               score_trees_devel.Update_distance_matrix_AL,
                                               scores,P)


    ## call plot_tree
    ps_file = chisq_file+'.cluster%s.ps'%cluster_name

    name_list = []
    size_list = []
    for i in range(N):
        name_list.append( decoy_name[i] )
        size_list.append( 1 )
        
    score_trees_devel.Plot_tree( cluster_tree, name_list, size_list, ps_file)




