#!/usr/bin/python

## trees from distances

from phil import *
import score_trees_devel


def Help():
    print '\nUsage: %s <codes-file> <chisq-file> <p-val cluster threshold> -s <silent file> \n\n'%(argv[0])
    exit()
    
if len(argv) < 4:
    Help()
##
args = argv[1:]
if args.count('-s'):
    pos = args.index('-s')
    silent_file = args[pos+1]
    del args[pos]
    del args[pos]
else:
    silent_file = ''

## read chisq file --> distances
codes_file = args[0]
chisq_file = args[1]
THRESHOLD = float(args[2])


lines = map(string.split,popen('grep "^CHISQ " '+chisq_file).readlines())
distance = {}
scores = {}
names = {}
N = 0
for line in lines:
    i1 = int(line[1])
    i2 = int(line[2])
    scores[i1] = [1.0]
    scores[i2] = [1.0]
    names[i1] = line[3]
    names[i2] = line[4]
    N = max(N, max(i1+1,i2+1))
    pval = float(line[9])
    distance[(i1,i2)] = pval
    distance[(i2,i1)] = pval
    distance[(i1,i1)] = 0.0
    distance[(i2,i2)] = 0.0

assert len(scores.keys()) == N

## call make_tree
P = 25
## score_tree = score_trees_devel.Make_tree(distance, N,
##                                          score_trees_devel.Update_distance_matrix_SL,
##                                          scores,P)

score_tree = score_trees_devel.Make_tree(distance, N,
                                           score_trees_devel.Update_distance_matrix_AL,
                                         scores,P)

#score_tree = score_trees_devel.Make_tree(distance, N,
#                                         score_trees_devel.Update_distance_matrix_AL_GEOM,
#                                         scores,P)


## call plot_tree
ps_file = chisq_file+'.cluster.ps'

name_list = []
size_list = []
for i in range(N):
    name_list.append( names[i] )
    size_list.append( 1 )

score_trees_devel.Plot_tree( score_tree, name_list, size_list, ps_file)

#print score_tree

## get clusters:

def Get_clusters(tree,names,threshold):
    if score_trees_devel.IsALeaf(tree):
        return []
    if tree[2] < threshold:
        members = score_trees_devel.Node_members( tree )
        print 'cluster:', threshold, tree[2], \
              string.join(map(lambda x,n=names: n[x],members))

        return [ members ]

    else:
        return Get_clusters(tree[0],names,threshold) + \
               Get_clusters(tree[1],names,threshold)
    
clusters = Get_clusters(score_tree,names,THRESHOLD)

#exit()

## parse the codes file:
barcode = {}
barcode_torsions = {}
lines = map(string.split,popen('grep "^CODE:" '+codes_file).readlines())
for line in lines:
    name = line[1]
    barcode[name] = map(int,line[2:])

lines = map(string.split,popen('grep "^CODE_TORSIONS:" '+codes_file).readlines())
for line in lines:
    name = line[1]
    barcode_torsions[name] = tuple(map(lambda x:map(float,string.split(x,',')),line[2:]) )
#    print name, barcode_torsions[name]
lines = map(string.split, popen('grep "^NAT_TORSIONS:" '+codes_file).readlines())
if lines:
    NATIVE = 1
    nat_torsions = tuple( map(lambda x:map(float,string.split(x,',')),lines[0][2:]) ) 
else:
    NATIVE = 0
    
def In_range(a): ## forces a to be in [-180,180)
    while a>=180.0:
        a = a - 360.0
    while a<-180.0:
        a = a + 360.0
    return a
        
def Angle_delta(a,b):
    a = In_range(a)
    b = In_range(b)
    assert -180 <= a <= 180
    assert -180 <= b <= 180
    x = max(a,b)
    y = min(a,b)
    delta = min( x-y, y+360.0 - x) 
    assert delta <= 180.0
    return delta

def Dist(c1,c2):
    return reduce(add, map(lambda x,c1=c1,c2=c2:(c1[x] != c2[x]), range( len(c1))))

def Torsion_distance(c1,c2): ## c1,c2 are lists of (phi,psi) pairs
    L = len(c1)
    assert len(c2) == L
    return 0.1 * sqrt( reduce(add, map(lambda x,c1=c1,c2=c2:
                                       (Angle_delta(c1[x][0],c2[x][0])**2 + ## phi
                                        Angle_delta(c1[x][1],c2[x][1])**2), ## psi
                                       range(L))) / float(L))

# the following list averaging functions are for determining the mean value of the torsion angles for each flavor
# the input is  a list (over all input decoys) of a list of pairs of torsion angles.  these will need to be modified
# for features for which averaging does not make sense (rotamer numbers, for example).

def add_composite_lists(l1,l2) :
#  l1 and l2 are lists of lists of possibly variable length
        new=[]
        L1=len(l1)
        L2=len(l2) 
        assert(L1==L2)
        for i in range(L1) :
                new.append([])
                A1=len(l1[i])
                A2=len(l2[i])
                assert(A1==A2)
                for j in range(A1) :
                        new[i].append(l1[i][j] + l2[i][j])
        return new


def add_list_of_composite_lists(l1) :
    L = len(l1)
    list_sum= reduce(add_composite_lists,map(lambda x, list=l1 : list[x] , range(L) ) )
    for i  in range(len(list_sum)) :
        for j in range(len(list_sum[i])) :
                list_sum[i][j]=list_sum[i][j] / float(L)
    return list_sum

decoys = barcode.keys()
N = len(decoys)
decoy_name = {}
for i in range(N):
    decoy_name[i] = decoys[i]

# list is a list, for each flavor, of all members
# for now include all singletons as flavors so that each
#	decoy has a complete bar code                                        
def Get_flavor_clusters(tree,names,threshold,list):
    if score_trees_devel.IsALeaf(tree):
        members = score_trees_devel.Node_members( tree )
        list.append(map(lambda x, n=names: n[x],members))
        return [ members ]
    if tree[2] < threshold:
        members = score_trees_devel.Node_members( tree )
        list.append(map(lambda x, n=names: n[x],members))
#        print 'cluster:', threshold, tree[2], \
#              string.join(map(lambda x,n=names: n[x],members))

        return [ members ]
    else:
        return Get_flavor_clusters(tree[0],names,threshold,list) + \
               Get_flavor_clusters(tree[1],names,threshold,list)
nat_flavor_list = {}
nat_flavor_score = {}
flavor_list={}

for c in clusters:
    cluster_name = string.join(map(lambda x:names[x], c),',')
#    print c,string.join(map(lambda x:names[x], c))
    subcode = {}
    subcode_torsions = {}
  
    for name in barcode.keys():
        subcode[name] = map(lambda x:barcode[name][x], c)
        subcode_torsions[name] = map(lambda x:barcode_torsions[name][x],c)

    nat_tor = map(lambda x:nat_torsions[x],c)
    
    ## make distance matrix
    distance = {}
    scores = {}
    for i in range(N):
#        print 'calc distances:',i
        if NATIVE:
	    d2 = Torsion_distance( subcode_torsions[decoy_name[i]], nat_tor) 
            scores[i] = [ d2 ]
        else:
            scores[i] = [1.0]

        distance[(i,i)] = 0.0
        for j in range(i+1,N):
            d2 = Torsion_distance( subcode_torsions[decoy_name[i]],
                                   subcode_torsions[decoy_name[j]] )
            distance[(i,j)] = d2
            distance[(j,i)] = d2

    #break
    ## call make_tree
    P = 25
    cluster_tree = score_trees_devel.Make_tree(distance, N,
                                               score_trees_devel.Update_distance_matrix_AL,
                                               scores,P)
#    print cluster_tree
    ## call plot_tree
    ps_file = chisq_file+'.cluster%s.ps'%cluster_name

    name_list = []
    size_list = []
    for i in range(N):
        name_list.append( decoy_name[i] )
        size_list.append( 1 )

    score_trees_devel.Plot_tree( cluster_tree, name_list, size_list, ps_file)

    flist=[]                    
    clusters = Get_flavor_clusters(cluster_tree,decoy_name,10.,flist)                    
    flavor_list[cluster_name]=flist
    
#    print flist
    counter=0
    nat_flavor=1000
    min_nat_score=1000.0
    for flavor in flist:
	for name in flavor:
		nat_dist=Torsion_distance(subcode_torsions[name] , nat_tor)
		print 'FEATURE: ',cluster_name,'FLAVOR: ', counter,'DECOY: ',name,' NAT DIST: ', nat_dist, 'TORSIONS: ',subcode_torsions[name]
	avg_torsions=map(lambda x, n=subcode_torsions: n[x],flavor)
#	print 'avg', avg_torsions
	flavor_value=add_list_of_composite_lists(avg_torsions)
	nat_score=Torsion_distance(flavor_value , nat_tor)
	print 'FLAVOR VALUES: ',counter,flavor_value,'NAT DIST',nat_score
	if nat_score < min_nat_score :
		min_nat_score=nat_score
		nat_flavor=counter
	counter=counter+1
    nat_flavor_list[cluster_name]=nat_flavor 
    nat_flavor_score[cluster_name]=min_nat_score

#print flavor_list

rms_list = {}
silent_exist = 0
if silent_file != '' :
	silent_exist = 1
	lines = map(string.split,popen('grep "^SCORE" '+ silent_file).readlines())
#	print lines[0]
	rms_index=lines[0].index('rms')
	description_index=lines[0].index('description')
	for line in lines :
#		print line[description_index],line[rms_index]
		if barcode_torsions.has_key(line[description_index]) :
			rms_list[line[description_index]]=line[rms_index]

key_list=flavor_list.keys()
key_list.sort()
print 'features',key_list
nat_code=string.join( map(str, map (lambda x, l=nat_flavor_list : l[x], key_list) ) )
print 'NATIVE FLAVOR CODE',nat_code, map(lambda x, l = nat_flavor_score : l[x], key_list )
flavor_code={}
code_count={}
#print 'nat_torsion',nat_torsions[20:22]
#print 'nat_barcode',barcode_torsions['NATIVE'][20:22]   
for name in name_list:
	flavor_code[name]=[]
	for feature in key_list :     # problem is that the order of features may not be the order in which they were subclustered
		for flavor in range(len(flavor_list[feature])):
#			print 'flavor',flavor,flavor_list[feature][flavor]
			if name in flavor_list[feature][flavor] :
				flavor_code[name].append(flavor)
        code=string.join( map(str,flavor_code[name]))
	if code == nat_code:
		print '!!!!! nat code !!!!'
	nat_dist=Torsion_distance(barcode_torsions[name],nat_torsions)

	if silent_exist :
		if rms_list.has_key(name) :
			rms=rms_list[name]
		else:
			rms=999
        print 'FLAVOR CODE: ',name,code,'   NAT_TORS_DIST:  ',nat_dist,' RMS: ',rms
        if not code_count.has_key( code ):
              code_count[code] = 0             
        
        code_count[code] = code_count[code] + 1

for code in code_count.keys():
	print 'FLAVOR CODE COUNT: ',code_count[code],' CODE ',code
