#!/usr/bin/python

## fisher discriminant
# for angles, need to use compute_average_angle and torsion_dist
# include chi angles and CA's

from random import shuffle
from phil import *
import time
from barcode_util_CA import *
from Numeric import *
from LinearAlgebra import *

print argv
args = argv[1:]

if args.count('-n'):
    pos = args.index('-n')
    native_file = args[pos+1]
    del args[pos]
    del args[pos]
else:
    native_file = ''

list1=args[0]
list2=args[1]

info,seq,SILENT_INPUT,coords = Read_either_file( list1 )

NATIVE = 0
if native_file:
    nat_info,nat_coord = Read_file( native_file )
    s=Get_seq(nat_info)
    if nat_info and s == seq:
        NATIVE = 1
#        info[NATIVE] = f

nat_contacts=[]
#use_list=(61,62,63,10,11,12)
#use_list=(62,11,12)
#use_list=(64,12,13)
for res1 in nat_coord.keys():
    for res2 in nat_coord.keys():
        if  (int(res1) - int(res2)) > 4:
#            if res1 not in use_list or res2 not in use_list: continue
            v1=array(nat_coord[res1])
            v2=array(nat_coord[res2])
#            diff=(dot(v1-v2,v1-v2))
            diff=sqrt(dot(v1-v2,v1-v2))
            if diff < 9.0 :
                print 'contact',res1,res2,diff
                nat_contacts.append( (res1,res2,diff) )

decoy_contact_vect1=[]
for name in info.keys():
    contact_vect=[]
    for contact in nat_contacts:
        v1=array(coords[name][contact[0]])
        v2=array(coords[name][contact[1]])
        contact_vect.append( sqrt(dot( (v1-v2),(v1-v2) ) ) )
    decoy_contact_vect1.append(contact_vect)

info,seq,SILENT_INPUT,coords = Read_either_file( list2 )
decoy_contact_vect2=[]
for name in info.keys():
    contact_vect=[]
    for contact in nat_contacts:
        v1=array(coords[name][contact[0]])
        v2=array(coords[name][contact[1]])
        contact_vect.append( sqrt(dot( (v1-v2),(v1-v2) ) ) )
    decoy_contact_vect2.append(contact_vect)
    
    
def fisher(set1,set2): 
    num1=len(set1)
    num2=len(set2)
    dim1=len(set1[0])
    dim2=len(set2[0])
    print 'dim: ',num1,num2,dim1,dim2
    assert dim1==dim2
    array1=array(set1)
    array2=array(set2)
    mean1=add.reduce(array1)
    mean2=add.reduce(array2)
    mean=(mean1+mean2)/(num1+num2)
    mean1=mean1/num1
    mean2=mean2/num2
    for i in range(len(mean1)):
        print 'means',i,mean1[i],mean2[i]

# rescale to unit variance

    var=0
    for v in array1:
        var+=(v-mean)*(v-mean)

    for v in array2:
        var+=(v-mean)*(v-mean)
##     var1=var1/num1
##     var1=sqrt(var1)

##     array1=(array1-mean1)/var1
##     array1=array1+mean1

##     for v2 in array2:
##         var+=(v2-mean2)*(v2-mean2)
    var=var/(num2+num1)
    var=sqrt(var)

    array2=array2/var
    array1=array1/var

    S1=0
    for v1 in array1:
       S1=S1+outerproduct(v1-mean1/var,v1-mean1/var) 

    for v2 in array2:
       S1=S1+outerproduct(v2-mean2/var,v2-mean2/var)

##     for i in range(6):
##         for j in range(6):
##             print 'S1',i,j,S1[i][j]
            
    S1_inverse=inverse(S1)

    w=matrixmultiply(S1_inverse,(mean1-mean2)/var)

    print w
           
    return mean1, mean2, w, S1, var

ntest1=len(decoy_contact_vect1)/10
ntest2=len(decoy_contact_vect2)/10

shuffle(decoy_contact_vect1)
shuffle(decoy_contact_vect2)
decoy1_train=decoy_contact_vect1[:-ntest1]
decoy1_test=decoy_contact_vect1[-ntest1:]
decoy2_train=decoy_contact_vect2[:-ntest2]
decoy2_test=decoy_contact_vect2[-ntest2:]
print 'SET SIZE: train1 %d test1 %d train2 %d test2 %d'%( len(decoy1_train),len(decoy1_test),len(decoy2_train),len(decoy2_test) ) 
mean1, mean2, w,  S1, var = fisher(decoy1_train,decoy2_train)

mean_diff=mean2 - mean1
print 'SUMMARY res1 res2 nat_dist pop1_dist pop2_dist pop2-pop1 fisher '
for i in range(len(nat_contacts)):
    print 'SUMMARY %d %d %2.4f %2.4f %2.4f %2.4f %3.4f  '%(nat_contacts[i][0],nat_contacts[i][1],nat_contacts[i][2],mean1[i],mean2[i],mean_diff[i],w[i])

for x in decoy1_test:
    print 'L1: ', x[0],x[1]
    v = array(x)
    print 'POP1 projection: ',dot( v/var,w)

for x in decoy2_test:
    print 'L2: ', x[0], x[1]
    v= array(x)
    print  'POP2 projection: ',dot( v/var,w)











