#!/usr/bin/python

## trees from distances

from phil import *
import score_trees_devel
import time
from barcode_util import *

def Help():
    print '\nUsage: %s  < cluster threshold> -l <silent file or list of pdbs> -n native_file -graphics  <-freq_cutoff fraction of population in largest cluster above which feature is supressed> <-topN #decoys to use for clustering> \n\n'%(argv[0])
    exit()

print string.join(map(lambda x: argv[x],range(len(argv))))

if len(argv) < 2:
    Help()
##

args = argv[1:]
NDECOY=-1
if args.count('-topN'):
    pos = args.index('-topN')
    NDECOY = int(args[pos+1])
    del args[pos]
    del args[pos]

freq_cutoff=0.90   # threshold for eliminating features dominated by a single cluster
if args.count('-freq_cutoff'):
    pos = args.index('-freq_cutoff')
    freq_cutoff = float(args[pos+1])
    del args[pos]
    del args[pos]    

if args.count('-l'):
    pos = args.index('-l')
    list_file = args[pos+1]
    del args[pos]
    del args[pos]

NATIVE=0
if args.count('-n'):
    pos = args.index('-n')
    native_file = args[pos+1]
    del args[pos]
    del args[pos]
    NATIVE=1
    
graphics = 0
if args.count('-graphics'):
    pos = args.index('-graphics')
    fast = 1
    del args[pos]

start_time=time.time()
## read chisq file --> distances
ANGLE_THRESH_FINAL = float(args[0])
ANGLE_THRESH_FINE = 0.3 * ANGLE_THRESH_FINAL # threshold for initial fast clustering step
#ANGLE_THRESH_FINAL= 60.0 # threshold for subsequent hiearchical clustering
print 'CLUSTERING THRESHOLD: ', ANGLE_THRESH_FINAL

## read silent/list file
info,seq,SILENT_INPUT = Read_either_file( list_file )
barcode = {}
if NATIVE:
    info[NATIVE] = Read_file( native_file )

barcode_torsions = {}
## fill in torsion array for BB,CHI
for name in info.keys():
    bc_torsions = []
    name_info = info[name]
#    print name_info
    for i in range(len(seq)):
                bc_torsions.append( name_info[i+1]['BB'])
    if name == NATIVE:
        nat_torsions = bc_torsions
    else:
        barcode_torsions[name] = bc_torsions

if NATIVE: ## dont get confused later!
    del info[NATIVE]

      
def In_range(a): ## forces a to be in [-180,180)
    while a>=180.0:
        a = a - 360.0
    while a<-180.0:
        a = a + 360.0
    return a
        
def Angle_delta(a,b):
    a = In_range(a)
    b = In_range(b)
    assert -180 <= a <= 180
    assert -180 <= b <= 180
    x = max(a,b)
    y = min(a,b)
    delta = min( x-y, y+360.0 - x) 
    assert delta <= 180.0
    return delta

def Dist(c1,c2):
    return reduce(add, map(lambda x,c1=c1,c2=c2:(c1[x] != c2[x]), range( len(c1))))

def Torsion_distance(c1,c2): ## c1,c2 are lists of (phi,psi) pairs
    L = len(c1)
    assert len(c2) == L
    return 0.1 * sqrt( reduce(add, map(lambda x,c1=c1,c2=c2:
                                       (Angle_delta(c1[x][0],c2[x][0])**2 + ## phi
                                        Angle_delta(c1[x][1],c2[x][1])**2), ## psi
                                       range(L))) / float(L))
def compute_average_angle(l1) :
    import math
    deg_to_rad=2.*math.pi/360.
    avg_angle=[]
    L1 = len(l1) # number of decoys
    L2 = len(l1[0]) # number of residues
    for i in range(L2) :
        avg_angle.append([])
        for j in range(len(l1[0][i])):
            angle_list=map(lambda x,deg_to_rad=deg_to_rad,l1=l1, L1=L1, i=i,j=j : deg_to_rad*l1[x][i][j], range(L1))
            sin_angle=reduce(add,map(math.sin,angle_list))/L1
            cos_angle=reduce(add,map(math.cos,angle_list))/L1
            avg_angle[i].append( (1./deg_to_rad)*math.atan2(sin_angle,cos_angle))
    return avg_angle

decoys = info.keys()
N = len(decoys)
decoy_name = {}
for i in range(N):
    decoy_name[i] = decoys[i]

name_list = []
size_list = []
for i in range(N):
    name_list.append( decoy_name[i] )
    size_list.append( 1 )
    
nat_flavor_list = {}
nat_flavor_score = {}
flavor_avg_list={}
flavor_list={}
decoy_flavor={}
vuse_subset=0
num_decoy=N


if NDECOY > 0:
    use_subset=1
    num_decoy=NDECOY
    
for name in info.keys():
    decoy_flavor[name]={}

print seq
for ires in range(len(seq) - 1):
    if ires == 0: continue
    c=[]
    c.append(ires)
    cluster_name = 'BB:%s'%(ires+1)
    subcode = {}
    subcode_torsions = {}
  
    for name in info.keys():
#        subcode[name] = map(lambda x:barcode[name][x], c)
        subcode_torsions[name] = map(lambda x:barcode_torsions[name][x],c)

    nat_tor = map(lambda x:nat_torsions[x],c)
    residue_torsions=map(lambda x, n=subcode_torsions: n[x],name_list)
    avg_torsions=compute_average_angle(residue_torsions)
    decoy_dev=0.
    nat_dev=0.
    for i in range(num_decoy):
        decoy_dev=decoy_dev+Torsion_distance(subcode_torsions[name_list[i]],avg_torsions)
        nat_dev=nat_dev+Torsion_distance(subcode_torsions[name_list[i]],nat_tor)
    if decoy_dev/num_decoy  < ANGLE_THRESH_FINAL/10. :
        continue
    print 'NAT DEV: %0.2f  Decoy_dev:  %0.2f  %s' %(nat_dev/num_decoy,decoy_dev/num_decoy,cluster_name)

    ## make distance matrix
    distance = {}
    scores = {}
    print 'TIME before distance calc',time.time()-start_time
    for i in range(num_decoy):
        if NATIVE:   ##  should not ever be the case
	    d2 = Torsion_distance( subcode_torsions[decoy_name[i]], nat_tor) 
            scores[i] = [ d2 ]
        else:
            scoreqs[i] = [1.0]

        distance[(i,i)] = 0.0

#  for larger datasets, don't have to store distance matrix with
# local maximum clustering method as is done currently
        for j in range(i+1,num_decoy):
            d2 = Torsion_distance( subcode_torsions[decoy_name[i]],
                                   subcode_torsions[decoy_name[j]] )
            distance[(i,j)] = d2
            distance[(j,i)] = d2
    print 'TIME after distance  calc',time.time()-start_time

    
    P=25
    flist=[]
    new_flav_clust=score_trees_devel.cluster_max(distance,num_decoy,ANGLE_THRESH_FINE/10.)
    for new_clust in new_flav_clust:
      size = len(new_clust)
      if size  > 0 :
	flist.append(map(lambda x, n=decoy_name: n[x],new_clust))	
	first=new_clust[0]
        av_dev_from_first=reduce(add,map(lambda x, d=distance : d[(first,x)],new_clust))/size
        max_dev_from_first=reduce(max,map(lambda x: distance[(first,x)],new_clust))
	av_dist=0
	counter=0
	for i in new_clust:
	    for j in new_clust:
		av_dist=av_dist+distance[(i,j)]
		counter=counter+1
	av_dist=av_dist/counter
        print 'INIT CLUS SIZE: %s Av dev: %.3f Max dev:  %.3f Ave dist:  %.3f'%(size,av_dev_from_first,max_dev_from_first,av_dist) 
    print 'TIME after INIT_cluster',time.time()-start_time
###### condense clusters with quick hiearchical clustering. optional  #######
    condensed_flist=[]
    condensed_flav_clust=[]
    flavor_value=[]
    flavor_value_list=[]
    for flavor in flist:
            avg_torsions=map(lambda x, n=subcode_torsions: n[x],flavor)
	    flavor_value=compute_average_angle(avg_torsions)
	    flavor_value_list.append(flavor_value)
    flav_dist={}
    for flav1 in flavor_value_list:
	dist_string=string.join(map(lambda x : '%.1f  '%x,map(lambda x: Torsion_distance(x,flav1),flavor_value_list)))
	for flav2 in flavor_value_list:
	    dist=Torsion_distance(flav1,flav2)
	    if len(flav2) < 3 or len(flav1) < 3: dist = dist - 1.0  # disfavor singletons
	    flav_dist[(flavor_value_list.index(flav1),flavor_value_list.index(flav2))]=dist		
    score_tree = score_trees_devel.Make_tree_threshold(flav_dist, len(flavor_value_list),score_trees_devel.Update_distance_matrix_AL,scores,P,ANGLE_THRESH_FINAL/10.)

##       if graphics :
## ############  do complete hiearchical clustering for Phil's cluster graphics##################### 
##         cluster_tree = score_trees_devel.Make_tree(flav_dist, len(flavor_value_list),score_trees_devel.Update_distance_matrix_AL,scores,P)
##         ps_file = chisq_file+'.cluster%s.ps'%cluster_name
##         score_trees_devel.Plot_tree( cluster_tree, name_list, size_list, ps_file)
##         flist=[]                    
##         clusters = Get_flavor_clusters(cluster_tree,decoy_name,10.,flist)
## def Get_flavor_clusters(tree,names,threshold,list):
##     if score_trees_devel.IsALeaf(tree):
##         members = score_trees_devel.Node_members( tree )
##         list.append(map(lambda x, n=names: n[x],members))
##         return [ members ]
##     if tree[2] < threshold:
##         members = score_trees_devel.Node_members( tree )
##         list.append(map(lambda x, n=names: n[x],members))
## #        print 'cluster:', threshold, tree[2], \
## #              string.join(map(lambda x,n=names: n[x],members))
##         return [ members ]
    if len(score_tree) < 2 :
        print 'ONLY ONE FLAVOR!  OMIT FEATURE: ',cluster_name
        continue
    
    for node in score_tree:
	members = score_trees_devel.Node_members(node)
        temp=reduce(add,map(lambda x: flist[x],members))
        if len(temp) > 3 :
           condensed_flist.append(reduce(add,map(lambda x: flist[x],members)))
           condensed_flav_clust.append(reduce(add,map(lambda x: new_flav_clust[x],members)))
    print 'CONDENSE CLUSTERS'
    for new_clust in condensed_flav_clust:
      size = len(new_clust)
      if size  > 0 :	
	first=new_clust[0]
        av_dev_from_first=reduce(add,map(lambda x, d=distance : d[(first,x)],new_clust))/size
        max_dev_from_first=reduce(max,map(lambda x: distance[(first,x)],new_clust))
	av_dist=0
	counter=0
	for i in new_clust:
	    for j in new_clust:
		av_dist=av_dist+distance[(i,j)]
		counter=counter+1
	av_dist=av_dist/counter
        print 'FINAL CLUS SIZE: %s Av dev: %.3f Max dev:  %.3f Ave dist:  %.3f'%(size,av_dev_from_first,max_dev_from_first,av_dist) 

    size_list=[] 
    counter=0
    for clus in condensed_flist:
	size_list.append( (len(clus),counter) )
	counter=counter+1

    size_list.sort()
    size_list.reverse()

    if size_list[0][0] > freq_cutoff*float(num_decoy):
        print 'eliminate feature dominated by single flavor',size_list[0][0],num_decoy,freq_cutoff*float(num_decoy)
        continue    #omit features dominated by a single flavor

    order=[]
    for i in range(len(size_list)):
	order.append(int(size_list[i][1]))

    flist = map(lambda x : condensed_flist[x],order)

    print 'TIME after final cluster',time.time()-start_time
    flavor_list[cluster_name]=flist

    counter=0
    nat_flavor=1000
    min_nat_score=1000.0
    flavor_value_list=[]
    decoy_flavor_dist={}
    min_decoy_flavor_dist={}
    for name in info.keys():
        min_decoy_flavor_dist[name]=1000.
    for flavor in flist:
	for name in flavor:
		nat_dist=Torsion_distance(subcode_torsions[name] , nat_tor)
		torsion_str=[]
		for tor in subcode_torsions[name] :
		    torsion_str.append(string.join(map(lambda x : '%.1f '%x,tor)))
		print 'FEATURE: %s FLAVOR:  %s  DECOY: %s  NAT DIST:  %.2f TORSIONS: %s'%(cluster_name,counter,name,nat_dist,torsion_str)
 	avg_torsions=map(lambda x, n=subcode_torsions: n[x],flavor)
        flavor_value=compute_average_angle(avg_torsions)
 	flavor_value_list.append(flavor_value)

        for name in info.keys() :
            decoy_flavor_dist[name]=Torsion_distance(subcode_torsions[name], flavor_value)
            if decoy_flavor_dist[name] < min_decoy_flavor_dist[name] :
                min_decoy_flavor_dist[name]=decoy_flavor_dist[name]
                decoy_flavor[name][cluster_name]=counter
#  assign native to flavor        
	nat_score=Torsion_distance(flavor_value , nat_tor)
	torsion_str=[]
	for tor in flavor_value :
	    torsion_str.append(string.join(map(lambda x : '%.1f '%x,tor)))
        freq = float(len(flavor))/float(num_decoy)
	print 'FLAVOR VALUES: ',cluster_name,counter,torsion_str,'NAT DIST',nat_score,'SIZE: ',freq
        if nat_score < min_nat_score :
		min_nat_score=nat_score
		nat_flavor=counter
	counter=counter+1
    nat_torsion_str=[]
    for tor in nat_tor :
        nat_torsion_str.append(string.join(map(lambda x : '%.1f '%x,tor)))
    print 'NATIVE TORSIONS: ',cluster_name,nat_torsion_str
    nat_flavor_list[cluster_name]=nat_flavor 
    nat_flavor_score[cluster_name]=min_nat_score
    flav_dist={}
    for flav1 in flavor_value_list:
	dist_string=string.join(map(lambda x : '%.1f  '%x,map(lambda x: Torsion_distance(x,flav1),flavor_value_list)))
#	print dist_string
	
	for flav2 in flavor_value_list:
	    flav_dist[(flavor_value_list.index(flav1),flavor_value_list.index(flav2))]=Torsion_distance(flav1,flav2)
# optional--cluster the cluster centers to make sure well spaced #
##     score_tree = score_trees_devel.Make_tree_threshold(flav_dist, len(flavor_value_list),score_trees_devel.Update_distance_matrix_AL,scores,P,8.0)
##     for node in score_tree:
## 	members = score_trees_devel.Node_members(node)
## 	print members

    flavor_avg_list[cluster_name]=flavor_value_list
print 'TIME before print flavor_list',time.time()-start_time

key_list=flavor_list.keys()
key_list.sort()
print 'features',key_list
nat_code=string.join( map(str, map (lambda x, l=nat_flavor_list : l[x], key_list) ) )
print 'NATIVE FLAVOR CODE',nat_code, map(lambda x, l = nat_flavor_score : l[x], key_list )
flavor_code={}
code_count={}
score_code_count={}
rms_code_count={}
   
for name in decoys:
	flavor_code[name]=[]
	for feature in key_list :     # problem is that the order of features may not be the order in which they were subclustered
		for flavor in range(len(flavor_list[feature])):
#			print 'flavor',flavor,flavor_list[feature][flavor]
                        found=0
			if name in flavor_list[feature][flavor] :
				flavor_code[name].append(flavor)
                                found=1
#                                print flavor,decoy_flavor[name][feature]
                                break
                if not found : # decoy was in small cluster that was cut from flavor list
                    flavor_code[name].append(decoy_flavor[name][feature])
# a bit overkill-flavor_codes found two different ways
        code=string.join( map(str,flavor_code[name]))
	if code == nat_code:       
		print '!!!!! nat code !!!!'
	nat_dist=Torsion_distance(barcode_torsions[name],nat_torsions)

        if info[name].has_key('rms'):
            rms=info[name]['rms']
        else:
            rms=999
            
        if info[name].has_key('score'):
            score=info[name]['score']
        else:
            score=999

        print 'FLAVOR CODE: ',name,code,'   NAT_TORS_DIST:  ',nat_dist,' RMS: ',rms,' SCORE: ',score
        if not code_count.has_key( code ):
              code_count[code] = 0             
              rms_code_count[code]=0
              score_code_count[code]=0
        code_count[code] = code_count[code] + 1
        rms_code_count[code]=rms_code_count[code]+rms
        score_code_count[code]=score_code_count[code]+score

for code in code_count.keys():
	print 'FLAVOR CODE COUNT: ',code_count[code],' CODE ',code,' score: ',score_code_count[code]/code_count[code],' rms: ',rms_code_count[code]/code_count[code]

print 'TIME:  end',time.time()-start_time  












