#!/usr/bin/python

from phil import *
from amino_acids import extra_longer_names
import string

def Help():
    print '\nUsage: %s list_file MIN_NB MAX_CHI BB_BIN_WIDTH {-n <native file>}\n\n'\
          %(argv[0])
    exit()

## parameters

if len(argv)<5:
    Help()

##
args = argv[1:]
if args.count('-n'):
    pos = args.index('-n')
    native_file = args[pos+1]
    del args[pos]
    del args[pos]
else:
    native_file = ''

if args.count('-o'):
    pos = args.index('-o')
    out_file = args[pos+1]
    del args[pos]
    del args[pos]
else:
    out_file = 'barcode.cst'
    
list_file = args[0]


MIN_DEV = 0.0 ## smallest value for sd(phi) + sd(psi) to include as features
MIN_NB = int(args[1]) ## definition of exposed, used for sc rotamers
MAX_CHI = int(args[2]) ## only use rot numbers up to and including this chi torsion
BB_BIN_WIDTH = int(args[3]) ## in degrees

## MIN_DEV = 20.0 ## smallest value for sd(phi) + sd(psi) to include as features
## MIN_NB = 12 ## definition of exposed, used for sc rotamers
## MAX_CHI = 2 ## only use rot numbers up to and including this chi torsion
## BB_BIN_WIDTH = 20 ## in degrees

## MIN_DEV = 20.0 ## smallest value for sd(phi) + sd(psi) to include as features
## MIN_NB = 12 ## definition of exposed, used for sc rotamers
## MAX_CHI = 4 ## only use rot numbers up to and including this chi torsion
## BB_BIN_WIDTH = 5 ## in degrees


EXPOSED = tuple( [-1]*MAX_CHI ) ## internal name for SC bin if nb<MIN_NB


########################################### functions:
def In_range(a):
    while a>180.0:
        a = a - 360.0
    while a<=-180.0:
        a = a + 360.0
    return a
        
def Angle_delta(a,b):
    a = In_range(a)
    b = In_range(b)
    assert -180 <= a <= 180
    assert -180 <= b <= 180
    x = max(a,b)
    y = min(a,b)
    delta = min( x-y, y+360.0 - x) 
    assert delta <= 180.0
    return delta


def Get_bb( tor ):
    assert len(tor) == 8
    return tuple( tor[1:3] )

def Get_bb_bin( bb ):
    phi_bin = int(floor(bb[0]/ BB_BIN_WIDTH ))
    psi_bin = int(floor(bb[1]/ BB_BIN_WIDTH ))

    return ( phi_bin, psi_bin )

def Chi_bin(a):
    ## modelled on Rosetta's rot_from_chi for chi1
    a = In_range(a)
    if 0 <= a <120:
        return 1
    elif -120 <= a < 0:
        return 3
    else:
        return 2

def Chi_range_from_bin(chi):
    ## modelled on Rosetta's rot_from_chi for chi1
    if chi == 1:
        return 0., 120.
    elif chi== 3:
        return -120., 0.
    elif chi == 2:
        return 120., -120.
    else:
        return -120., 120.  # no chi angle for this residue (too short)
def Get_rot( tor ):
    assert len(tor) == 8
    nb = tor[0]
    if nb >= MIN_NB:
        return tuple( tor[4:4+MAX_CHI] )
    else:
        return EXPOSED

def Read_file(filename):
    ## read: the neighbor counts, sc torsion numbers, bb torsions
    ## returns: info,seq

    info = {}
    seq = {}

    try:
        data = open(filename,'r')
    except:
        log('missing %s\n'%filename)
        return info,seq

    line = data.readline()
    nb = {}
    bb = {}
    rot = {}

    L = 0
    while line:
        if line[:4] == 'ATOM' and line[12:16] == ' CA ':
            L = L+1
        if line[:9] == 'res aa nb' or line[:9] == 'res  aa n':
            n = len(string.split(line))
            for i in range(L):
                line = data.readline()
                l = string.split(line)
                if len(l) == n and int(l[0]) == i+1:
                    nb[int(l[0])] = int(l[2])
                else:
                    break
        elif line[:8] == 'absolute':
            line = data.readline() ## header line
            for i in range(L):
                line = data.readline()
                l = string.split(line)
                if len(l) == 11 and l[-1] == 'chi_absolute' and int(l[0]) == i+1:
                    rot[int(l[0])] = map(int, l[6:10] )
                    seq[int(l[0])] = extra_longer_names[ l[1] ]
                else:
                    break
        elif line[:8] == 'complete':
            for i in range(L):
                line = data.readline()
                l = string.split(line)
                if l[1] in ['E','H','L'] and int(l[0]) == i+1:
                    bb[int(l[0])] = map(float, l[2:5] )
                    #ss[int(l[0])] = l[1]
                else:
                    break
            
        line = data.readline()

    data.close()
    L = len( seq.keys())
    
    if L and L == len( nb.keys()) == len( bb.keys()) == len( rot.keys()):
        for pos in nb.keys():
            if bb.has_key( pos) and rot.has_key( pos):
                info[pos] = [nb[pos] ] + bb[pos] + rot[pos]
            else:
                info = {} ## signal failure
                break

    return info,seq ## info will be {} if failed

            
########################################### functions:


## read info from the log-file:

## info[name] = { pos: [ nb, phi, psi, omega, rot1, rot2, rot3, rot4]}



files = map(lambda x:string.split(x)[0], open(list_file,'r').readlines())
n_files = len(files)
counter = 0

outfile = open (out_file, 'w')
seq = {}
info = {}
for file in files:
    counter = counter+1
    if not counter%25:
        log('%d %d\n'%(counter, n_files))
        
    f,s = Read_file( file )
    if not f:
        log('bad file: '+file)
        continue
    
    if seq:
        if s != seq:
            log('bad seq: '+file)
            continue
    else:
        seq = s
            
    info[file] = f

num_decoy=counter
NATIVE = 0
if native_file:
    f,s = Read_file( native_file )
    if f and s == seq:
        NATIVE = 1
        info[NATIVE] = f


## sort positions by deviation of bb-torsions
L = len(seq.keys())

bb_list = {}
chi_list = {}

for pos in seq.keys():
    bb_list[pos] = []
    chi_list[pos] = []

hydrophobic=['A','F','I','L','M','P','V','W','Y']
for name in info.keys():
    if 1:
##     if string.count(name,'2aas'): ## the NMR guys ### DANGER !!!
        for pos in info[name].keys():
            tor = info[name][pos]
            
            bb = Get_bb( tor )
            rot = Get_rot( tor )

            bb_list[pos].append( bb )
#            print pos,rot
            if seq[pos] in hydrophobic:
#            if rot != EXPOSED:
                chi_list[pos].append( rot )

consensus_rot={}
chi_map = {}
rot_distr={}
for pos in chi_list.keys():
    count={}
    chi_map[pos] = {}
    counter = 0
    if NATIVE:
        nat_rot = Get_rot( info[NATIVE][pos] )
    else:
        nat_rot = (0) ## shouldnt match anybody

    nat_count = 0
    for rot in chi_list[pos]:
##         if rot == Get_rot( info[NATIVE][pos]):
##             nat_count = nat_count + 1
##             state='NATIVE'
##         else:
##             state='NOTNAT'
        if not count.has_key( rot):
            count[rot]=1
        else:
            count[rot]=count[rot]+1
        if not chi_map[pos].has_key( rot ):
            chi_map[pos][rot] = counter
##             print 'MAP  SC%d SEQ %s %s counter= %d rep= %s'\
##                   %(pos,seq[pos],state,counter,string.join(map(str,list(rot)),','))
            counter = counter + 1

    if not counter in rot_distr.keys():
        rot_distr[counter]=1
    else:
        rot_distr[counter]=rot_distr[counter]+1
        
    log('CHI '+`pos,seq[pos],counter,nat_count, nat_rot`)
    for rot in chi_map[pos].keys():
        if rot == Get_rot( info[NATIVE][pos]):
            nat_count = nat_count + 1
            state='NATIVE'
        else:
            state='NOTNAT'
        print 'MAP  SC%d SEQ %s %s counter= %d rep= %s, COUNT= %s'\
            %(pos,seq[pos],state,chi_map[pos][rot],string.join(map(str,list(rot)),','),count[rot])

        if rot != EXPOSED:
          weight=-10000.0
          chi1_lower, chi1_upper = Chi_range_from_bin(rot[0])
          chi2_lower, chi2_upper = Chi_range_from_bin(rot[1])
          outfile.write ( 'SC_% s%-4d %3.2f SC_BIN %4d %8.2f %8.2f %8.2f %8.2f %8.2f\n'%(seq[pos],pos,count[rot]/float(num_decoy+1),pos,\
                weight,chi1_lower,chi1_upper,chi2_lower,chi2_upper))

    rmax=-1
    rot_max=-1
    for rot in chi_map[pos].keys():
        if count[rot] > rmax:
            rot_max=rot
            rmax=count[rot]
    consensus_rot[pos]=rot_max

for rot_count in rot_distr.keys():
    print 'ROT_DISTR: ',rot_count,rot_distr[rot_count]

## now rank the bb positions by deviations
dev_list = []

bb_feature_rsd_list = []

for pos in bb_list.keys():

    total_sd = 0
    ssd = {}
    for i in range(2): ## 0= phi, 1=psi
        ll = map(lambda x:x[i], bb_list[pos] )

        sd = {}
        for s in range(-6,7):
            shift = s*30

            l = map(lambda x:In_range(x+shift), ll)

            m = float( reduce(add,l) ) / len(l)
        
            sd[s] = sqrt( reduce(add, map(lambda x:( Angle_delta(x,m) )**2,l)) / (len(l) - 1))
            #print shift, sd[s]
            

        ssd[i] = min(sd.values())
        total_sd = total_sd + ssd[i]

    if pos >1 and pos<L and \
       total_sd >= MIN_DEV: ## add this to the list of interesting positions
#        log(`[total_sd, pos, ssd[0], ssd[1]]`)
        bb_feature_rsd_list.append(pos)



## setup a mapping from bb angles to features
bin_map = {}

if 0:
    gpout,gpin = popen2('gnuplot')
    gpin.write('set nokey\n')
    angle = -180
    while angle<180:
        gpin.write('set arrow from %f, graph 0 to %f, graph 1 nohead\n'\
                   %(angle,angle))
        gpin.flush()
        angle = angle + BB_BIN_WIDTH

for pos in bb_feature_rsd_list:
    bin_map[pos] = {}
    counter = 0

    if NATIVE:
        nat_bb = Get_bb( info[NATIVE][pos] ) 
        nat_bin = Get_bb_bin ( nat_bb )
    else:
        nat_bb = (0) ## shoudn match anybody
        nat_bin = (0)

    nat_counter = 0
    
    for bb in bb_list[pos]:
        bin = Get_bb_bin( bb )

        if not bin_map[pos].has_key( bin ):
            bin_map[pos][bin] = counter
#            print 'MAP BB%d counter= %d rep= %.1f,%.1f'\
#                  %(pos,counter,bb[0],bb[1])
            counter = counter + 1

        if bin == nat_bin:
            nat_counter = nat_counter + 1
            
#    log('BB %4d num_bins: %4d nat_count: %4d\n'\
#        %(pos, counter, nat_counter))


    if 0:
        out = open('junk.plot','w')
        out.write('N %f %f\n'%(nat_bb[0],nat_bb[1]))

        for bb in bb_list[pos]:
            out.write('D %f %f\n'%(bb[0],bb[1]))

        out.close()

        command = 'plot [-180:180] [-180:180] "< grep D junk.plot" u 2:3, "< grep N junk.plot" u 2:3'
        angle = -180
        while angle<=180:
            command = command + ', %f'%angle
            angle = angle + BB_BIN_WIDTH
        command = command + '\n'
        #print command
        gpin.write(command)
        gpin.flush()
        raw_input()

code_count = {}

print 'CODE_NAMES:',
## for pos in bb_feature_rsd_list:
##     print 'BB%d'%pos,
for pos in range(1,L+1):
    rot = Get_rot( info[name][pos] )
    if seq[pos] in hydrophobic: print 'SC%d'%pos,
#    if rot != EXPOSED: print 'SC%d'%pos,
print
    
bar_code=[]
nat_count=0
for pos in range(1,L+1):
        rot = consensus_rot[pos] 
        if seq[pos] in hydrophobic:
         if rot == Get_rot( info[NATIVE][pos]):
          bar_code.append('N')
          nat_count=nat_count+1
         else:
          if rot != EXPOSED:
            bar_code.append( chi_map[pos][rot] )
          else:
            bar_code.append(9)
code = string.join( map(str,bar_code))
print 'CODE:', code, ' CONSENSUS ',' NATCHAR: ' ,nat_count

natchar_count={}
for name in info.keys():

    #if not string.count( name, '2aas'):continue

    bar_code = []

##     for pos in bb_feature_rsd_list:
##         bin = Get_bb_bin( Get_bb( info[name][pos] ) )
##         bar_code.append( bin_map[pos][bin] )
    nat_count=0
    for pos in range(1,L+1):
        rot = Get_rot( info[name][pos] )
        if seq[pos] in hydrophobic:
         if rot == Get_rot( info[NATIVE][pos]):
          nat_count=nat_count+1
          bar_code.append('N')
         else:
          if rot != EXPOSED:
            bar_code.append( chi_map[pos][rot] )
          else:
            bar_code.append(9)
    code = string.join( map(str,bar_code))
    print 'CODE:', code, name, ' NATCHAR: ', nat_count
    if not natchar_count.has_key(nat_count):
        natchar_count[nat_count]=1
    else:
        natchar_count[nat_count]=natchar_count[nat_count]+1
    if not code_count.has_key( code ):
        code_count[code] = 0

    code_count[code] = code_count[code] + 1


for code in code_count.keys():
    print 'CODE_COUNT:',code, 'COUNT: ',code_count[code]
    
print 'CONSENSUS NAT COUNT: ',nat_count
key_list=natchar_count.keys()
key_list.sort()
for natchar in key_list():
    print 'NATCHAR_COUNT: ',natchar, natchar_count[natchar]





