#!/usr/bin/python
## decoystats

from Tkinter import *
from popen2 import popen2
import string
from sys import argv
from math import log10,exp,floor,log
from os import popen,getpid,system,remove
from whrandom import random
from sys import argv,exit
from glob import glob
import tkSimpleDialog
from os.path import exists
import cluster_trees ## for making clustering trees
from operator import add

def Help():
    print '\n'+'-'*75
    print 'run this script in the directory where you ran'
    print 'pFOLD.pgi <xx> <id> <chain> -decoystats -score -fa_input ... 2> <log_file>'
    print 'the four arguments are: <xx> <id> <chain> <log_file>'
    print
    print 'the script needs to be able to find your starting structures. It'
    print 'looks in ./paths.txt to get the starting structure directory.'
    print 'to give it a different paths.txt, use -paths <paths_file>'
    print '\n'+'-'*75
    exit()
    
##################### configuration #######################

RASMOL_COMMAND = '/net/local/bin/rasmol'
#RASMOL_COMMAND = '/net/local/bin/rasmol_16bit'
#RASMOL_COMMAND = '/net/local/bin/rasmol_32bit'
#RASMOL_COMMAND = '/home/phil/rasmol_32BIT'

TMP_DIR = '/tmp/'

#FA_CLUSTER_EXE = '/home/phil/C/fa_cluster.out'
FA_CLUSTER_EXE = '/users/pbradley/C/fa_cluster.out'
CALCULATE2D_EXE = '/users/pbradley/C/hb_bb.out'

DEFAULT_PATHS_FILE = './paths.txt' ## for starting structure directory

DATA_PLOTTER_CANVAS_WIDTH = 1000

##################### command line arguments ##############

args = argv[:]
if len(args) not in [5,7]:
    Help() ## print help message,die.

PATHS_FILE = DEFAULT_PATHS_FILE
if args.count('-paths'): ## get a different paths file
    pos = args.index('-paths')
    PATHS_FILE = args[pos+1]
    del args[pos]
    del args[pos]

xx = args[1]
id = args[2]
chain = args[3]
log_file = args[4]

pdb_file = '%s%s%satom_set1.pdb'%(xx,id,chain)

## get the path to the starting structures:
lines = map(string.split,popen('grep "starting.structure" %s'%PATHS_FILE).readlines())
assert len(lines) == 1 and len(lines[0])==3
START_FILE_PATH = lines[0][2]
print 'START_FILE_PATH:',START_FILE_PATH

## get the path to the scorefile
lines = map(string.split,popen('grep "^score" %s'%PATHS_FILE).readlines())
assert len(lines) == 1 and len(lines[0])==2
fasc_file = '%s%s%s.fasc'%(lines[0][1],xx,id)
print 'fasc_file:',fasc_file

## file for getting chi_required info
chi_required_file = '%s%s%s.chi_required.pdb'%(xx,id,chain)

if not exists(fasc_file):
    print 'cant open fasc_file:',fasc_file
    Help()
if not exists(pdb_file):
    print 'cant open ise pdb_file:',pdb_file
    Help()
if not exists(log_file):
    print 'cant open log_file:',log_file
    Help()
if not exists(chi_required_file):
    print '-'*75
    print 'WARNING!'*10
    print 'cant find the file for getting chi_required info from rosetta:',chi_required_file
    print 'this probably means that you have an old version of decoystats.f'
    print 'in your copy of ROSETTA. Consider checking out a more recent version'
    print 'and replacing at least that file.'
    print 'You wont be able to calculate the NAT-ROT data for user-defined'
    print 'decoy subsets.'
    print 'WARNING!'*10
    print '-'*75
    
################# some global variables ###########################################


MAX_ATOMS = 14 # heavyatoms, TRP

default_level = {'BBDEC-O':1,'BBDEC-E':1,
                 'BBDEC-N1':1,'BBDEC-T1':1,
                 'BBDEC-N2':2,'BBDEC-N3':2,'BBDEC-T2':2,'BBDEC-T3':2,
                 'BBFRG-O':1,'BBFRG-E':1,
                 'BBFRG-N1':1,'BBFRG-T1':1,
                 'BBFRG-N2':2,'BBFRG-N3':2,'BBFRG-T2':2,'BBFRG-T3':2,
                 'SC-RMS-1':2,'SC-RMS-2':2,'SC-RMS-3':2,
                 'MAXSUB-1':2,'MAXSUB-2':2,'MAXSUB-3':2}

chi_required = {} ## taken from chi_required_file: xx1xyz_.chi_required.pdb (if it exists)

dialog_entries = [] ## for dialogs

########################################################################
####################### FUNCTIONS         ##############################
########################################################################

def angle_range(a):
    while a>180.0:
        a=a-360.0
    while a<= -180.0:
        a = a +360.0
    return a

BB_LIST = ['A','B','G','E','O'] ## global variable
def ppo_to_bb(ppo):
    phi = angle_range(ppo[0])
    psi = angle_range(ppo[1])
    omega = angle_range(ppo[2])
    if abs(omega)<90:
        return 'O'
    elif phi>=0.0:
        if -100<psi<=100:
            return 'G'
        else:
            return 'E'
    else:
        if -125<psi<=50:
            return 'A'
        else:
            return 'B'

    
def intf(x):return int(floor(x))


## functions for rescaling data 0-1 to 0-1
EXPONENT = log10(2)
def Rescale_power(x):
    assert 0.0<=x<=1.0

    if x<0.00001:
        return x
    else:
        return exp(EXPONENT * log(x))
    
def Rescale_linear(x):
    assert 0.0<=x<=1.0
    return x

## functions for recoloring 0-1 -> #abcdef

def Uniform_rainbow_color(f):
    assert 0<=f<=1.0
    max_green = 200

    if f<0.25:
        x = 4*f ## runs from 0 to 1
        red = 0
        green = intf( min(max_green, x*max_green) )
        blue = 255
    elif f<0.5:
        x = 4*(f-0.25) ## runs from 0 to 1
        red = 0
        green = max_green
        blue = intf( 255 - x*255)
        red = 0
    elif f<0.75:
        x = 4*(f-0.5) ## runs from 0 to 1
        red = intf (x*255)
        green = max_green
        blue = 0
    else:
        x = 4*(f-0.75) ## runs from 0 to 1
        red = 255
        green = min(max_green,max_green-x*max_green)
        blue = 0
    return '#%02x%02x%02x'%(red,green,blue)
    
def Uniform_red(f):
    assert 0<=f<=1.0
    red = 255
    green = intf(255*(1-f))
    blue = intf(255*(1-f))
    return '#%02x%02x%02x'%(red,green,blue)


## global variables used by the plotters::
color_functions = {'rainbow':Uniform_rainbow_color,
                   'red':Uniform_red}

scale_functions = {'linear':Rescale_linear,
                   'power':Rescale_power}


########################################################################
###                        CLASSES:
########################################################################
class DATASET1D:
    def __init__(self,name,D,boxed,N,atom_names = {}):
        self.D = D
        self.boxed = boxed
        self.N = N
        self.name = name
        self.atom_names = atom_names
        self.type = 'frequency' ## could also be 'energy', set this manually afterword
        self.subset = '' ## could be reset afterword; used to control display in data_plotter
        self.level = 0 ## top level by default 
        
        

class DATASET2D:
    def __init__(self,name,D,boxed,N):
        self.D = D
        self.boxed = boxed
        self.N = N
        self.name = name
        self.subset = '' ## could be reset afterword; used to control display in data_plotter

class AXIS:
    def __init__(self,frame,label,scale):
        self.frame = frame
        self.label = label
        self.scale = scale

## uses:
##  ise.get_subset_names()  -- includes fasc_plotter, current_cluster
##  ise.get_subset_decoys()
        
class SUBSET_PICKER:
    def __init__(self,master,grid_row,grid_column):
        global ise
        subsets = ise.get_subset_names()
        l = map(lambda x,ise=ise: [len(ise.get_subset_decoys(x)),x],subsets)
        l.sort()
        l.reverse()
        self.listbox = Listbox(master,height=len(l)+1,
                               exportselection=0,selectmode=MULTIPLE)
        for ll in l:
            self.listbox.insert(END,"%s (%d)"%(ll[1],ll[0]))
        self.listbox.grid(row=grid_row,column=grid_column)

    def get(self):
        lb = self.listbox
        subset_list = []
        for i in map(int,lb.curselection()):
            subset_list.append( string.split(lb.get(i))[0] ) ## exclude (size)
        return subset_list

    def get_decoys(self):
        subset_list = self.get()
        decoy_list = ise.get_subset_intersection(subset_list)
        return decoy_list

################ Dialog windows #################################
        
class ENTRY_DIALOG(tkSimpleDialog.Dialog):
    def body(self,master):
        global dialog_entries
        self.e = {}
        for i in range(len(dialog_entries)):
            label = dialog_entries[i][0]+":"
            if len(dialog_entries[i])==2:
                default = dialog_entries[i][1]
            else:
                default = ''
            Label(master,text=label).grid(row=i,column=0)
            self.e[i] = Entry(master)
            if default: ## a default value is specified
                self.e[i].insert(0, default)
            self.e[i].grid(row=i,column=1)
        return self.e[0]

    def apply(self):
        global dialog_entries
        self.result = {}
        for i in range(len(dialog_entries)):
            label = dialog_entries[i][0]
            r = self.e[i].get()
            self.result[label] = r

class CLUSTER_DIALOG(tkSimpleDialog.Dialog):
    def body(self,master):
        labels = ["minTopClusterSize",
                  "tryTopClusterSize",
                  "maxTopClusterSize",
                  "minClusterSize",
                  "minClusterThreshold",
                  "maxClusterThreshold",
                  "prefix"]
        defaults = ['0.025','0.1','0.25','5','0','100','']
            
        self.e = {}
        for i in range(len(labels)):
            Label(master,text=labels[i]+":").grid(row=i)
            self.e[i] = Entry(master)
            if defaults[i]: self.e[i].insert(0,defaults[i])
            self.e[i].grid(row=i,column=1)
        self.labels = labels

        N = len(labels)
        Label(master,text="atoms:").grid(row=N,column=0)
        listbox = Listbox(master,height=4,exportselection=0,selectmode=MULTIPLE)
        listbox.insert(END,"data_plotter")
        listbox.insert(END,"rasmol")
        listbox.grid(row=N,column=1)
        self.atom_listbox = listbox

        Label(master,text="decoys:").grid(row=N+1,column=0)
        self.subset_picker = SUBSET_PICKER(master,N+1,1)
        return self.e[0]

    def apply(self):
        labels = self.labels
        self.result = {}
        for i in range(len(labels)):
            l = labels[i]
            r = self.e[i].get()
            self.result[l] = r

        lb = self.atom_listbox
        self.result['atoms'] = []
        for i in map(int,lb.curselection()):
            self.result['atoms'].append( lb.get(i))

        self.result['decoys'] = self.subset_picker.get_decoys()

class ADD_TORSIONS_DIALOG(tkSimpleDialog.Dialog):
    def body(self,master):
        ## listbox for selecting the residues
        Label(master,text="residues:").grid(row=0,column=0)
        listbox = Listbox(master,height=4,exportselection=0,selectmode=MULTIPLE)
        listbox.insert(END,"data_plotter")
        listbox.insert(END,"rasmol")
        listbox.grid(row=0,column=1)
        self.residue_listbox = listbox

        ## listbox for selecting which torsion angles to use
        torsion_names = ['phi','psi','omega',
                         'chi1','chi2','chi3','chi4']

        Label(master,text="torsions:").grid(row=1,column=0)
        listbox = Listbox(master,height=7,exportselection=0,selectmode=MULTIPLE)
        for n in torsion_names:
            listbox.insert(END,n)
        listbox.grid(row=1,column=1)
        self.torsions_listbox = listbox
        
        return self.residue_listbox

    def apply(self):
        self.result = {}

        lb = self.residue_listbox
        self.result['residues'] = []
        for i in map(int,lb.curselection()):
            self.result['residues'].append( lb.get(i))
        
        lb = self.torsions_listbox
        self.result['use_torsion'] = {}
        for i in range(7): ## bb + chi
            self.result['use_torsion'][i] = 0
        for i in map(int,lb.curselection()):
            self.result['use_torsion'][i] = 1
            
class GET_NEW_SUBSET_DIALOG(tkSimpleDialog.Dialog):
    def body(self,master):
        ## query subset name
        Label(master,text="subset name:").grid(row=0,column=0)
        self.name_entry = Entry(master)
        self.name_entry.grid(row=0,column=1)
        
        ## listbox for selecting the subsets
        self.subset_pickers = {}
        for i in range(3): ## union of up to 3 intersections
            if i:
                Label(master,text="OR").grid(row=2*i,column=1)
            self.subset_pickers[i] = SUBSET_PICKER(master,2*i+1,1)

        ## pick by text in name
        Label(master,text="pick by text in decoy name:").grid(row=6,column=0)
        self.grep_entry = Entry(master)
        self.grep_entry.grid(row=6,column=1)
            
    def apply(self):
        ## get name
        self.result = {}
        self.result['name'] = self.name_entry.get()
        ## get decoys in union of intersections:
        big_decoy_list = {}
        for j in range(3):
            decoy_list = self.subset_pickers[j].get_decoys()
            for d in decoy_list:
                big_decoy_list[d] = 1

        ## pick subset by name
        grepper = self.grep_entry.get()
        if grepper:
            print 'choose subset by text in name:',grepper
            ks = big_decoy_list.keys()
            for name in ks:
                if not string.count(name,grepper):
                    print 'no match: (%s)'%grepper,name
                    del big_decoy_list[name]
                
        self.result['decoys'] = big_decoy_list.keys()
        
class SUBSET_SCORE_DIALOG(tkSimpleDialog.Dialog):
    def body(self,master):
        self.subset_picker = SUBSET_PICKER(master,0,1)
        
    def apply(self):
        self.result = {}
        self.result['subset_names'] = self.subset_picker.get()
        
class GNUPLOTTER_2D:
    def __init__(self,parent,canvas_width,canvas_border):
        self.parent = parent

        self.window = Toplevel()
        self.window.title('fasc plotter')
        self.window.geometry("%dx%d"%(canvas_width+170,canvas_width+100))

        self.native_name = self.parent.pdb_id+'.pdb'
        
        ## mode
        self.click_mode = StringVar()
        self.click_mode.set('zoom')

        ## menubar
        menu = Menu(self.window,tearoff=0)
        self.window.config(menu=menu)
        menu.add_radiobutton(label='Zoom',
                             variable = self.click_mode,
                             value='zoom')
        menu.add_radiobutton(label='Load',
                             variable = self.click_mode,
                             value='load')
        menu.add_command(label="add torsions",command = self.add_torsions)
        menu.add_command(label="configure subset score",command = self.subset_score)

        ## status bar
        self.mouse_location  = StringVar()
        self.mouse_location.set('hey')
        f = Frame(self.window)
        f.pack(side=TOP)
        Label(f,textvariable=self.mouse_location).pack(side=LEFT)
        
        self.tags = map(str,range(10)) ## dummy initialization

        self.axis = {}
        self.index = {0:0, 1:1, 2:2} ## starting values
        self.bounds = [[0,0,0],[1,1,1]] ## dummy starting values
        
        for i in range(3):

            if i==0:
                pack_side= BOTTOM
                scale_orient= HORIZONTAL
                scale_width = 35
            elif i==1:
                pack_side= LEFT
                scale_orient= VERTICAL
                scale_width = 75
            else:
                pack_side= RIGHT
                scale_orient= VERTICAL
                scale_width = 75
                
            f = Frame(self.window)
            f.pack(side=pack_side)
            label = {}
            for j in range(3):## score,min,max
                label[j] = StringVar()
                Label(f,textvariable=label[j]).pack()
            scale = Scale(f,from_=0,to=9,orient=scale_orient,
                          width=scale_width,length=(2*canvas_width)/3,
                          command = (lambda x,s=self,i=i:s.set_index(i,x)))
            scale.set(i)
            scale.pack()
            self.axis[i] = AXIS(f,label,scale)
            
        
        self.canvas = Canvas(self.window)
        self.canvas.pack(expand=YES,fill=BOTH)

        self.canvas_width  = canvas_width
        self.canvas_border = canvas_border

        self.canvas.bind("<Button-1>",self.button1_press)
        self.canvas.bind("<B1-Motion>",self.button1_motion)
        self.canvas.bind("<Motion>",self.motion)
        self.canvas.bind("<ButtonRelease-1>",self.button1_release)

        self.nativex = DoubleVar()
        self.nativey = DoubleVar()
        self.nativex.trace("w",(lambda x,y,z,s=self:s.mark_native()))
        self.nativey.trace("w",(lambda x,y,z,s=self:s.mark_native()))
        #### END INIT ###############################

    def get_score_tags(self):
        return self.tags

    def mark_native(self):
        self.canvas.delete("native_mark")
        self.canvas.create_text(self.nativex.get(),self.nativey.get(),
                                text="N",tags="native_mark")
        

    def set_current_subset(self,subset_name):
        self.update_data_points()
        self.data_to_xy()
        self.replot()

    def subset_score(self):
        dialog = SUBSET_SCORE_DIALOG(self.window,"choose subsets")
        subset_names = dialog.result['subset_names'] 
        index = self.tags.index( 'subset' )
        for p in self.fasc_data_points:
            p[1][index] = 0 ## not in any subsets

        for subset_name in subset_names:
            names = self.parent.get_subset_decoys(subset_name)
            print subset_name,len(names)
            score = subset_names.index(subset_name) + 1
            for p in self.fasc_data_points:
                if not self.fullnames.has_key(p[0]):continue
                name = self.fullnames[p[0]] 
                if name in names:
                    p[1][index] = score
        
        self.data_to_xy()
        self.replot()
            
        
    ## add bb/sc torsion angles to list of plottable tags 
    def add_torsions(self):
        log_file = self.parent.log_file
        N = self.parent.data_plotter.N ## number of residues
        dialog = ADD_TORSIONS_DIALOG(self.window,"add torsions")
        result = dialog.result
        use_torsion = result['use_torsion']
        
        ## determine which residues to use
        l = result['residues']
        use_residue = {}
        for i in range(1,N+1):
            use_residue[i] = 0
            
        if 'rasmol' in l:
            atoms = self.parent.rasmol.get_current_selection()
            for r in self.parent.atoms_to_residues(atoms):
                use_residue[r] = 1
        if 'data_plotter' in l:
            for r in self.parent.data_plotter.get_current_selection():
                use_residue[r] = 1
                    
        if reduce(add,use_residue.values()) == 0:
            print 'empty residue list: selecting all residues'
            for r in range(1,N+1):
                use_residue[r] = 1

        print 'use_residue:',use_residue
        print 'use_torsion:',use_torsion
        ## read torsions from log_file:
        torsion_names = ['phi','psi','omega',
                         'chi1','chi2','chi3','chi4']
        command = 'grep "^DS START_FILE" %s -A%d'%(log_file,N)
        print command
        data = popen(command)
        line = data.readline()
        torsions = {}
        residue3 = {}
        while line:
            if line[:4] == 'DS S':
                id = string.split(line)[2]
                name = string.split(id,'/')[-1]+'.pdb'
                if id == 'native':
                    name = self.native_name
                torsions[name] = {}
            elif line[:4] == 'DS_T':
                if 1:
                    l = string.split(line)
                    r = int(l[1])
                    residue3[r] = l[2]
                    if use_residue[r]:
                        torsions[name][r] = map(float,l[3:6])+\
                                            [float(l[6]),float(l[8]),
                                             float(l[10]),float(l[12])]
            else:
                if line[:2] != '--':
                    print 'whoah:',line[:-1]
            line = data.readline()
        data.close()

        ## think about adding new tags
        new_tags = []
        for r in range(1,N+1):
            if use_residue[r]:
                for i in range(7):
                    if use_torsion[i]:
                        ## also used below
                        tag = 'torsion %3d %s %s'%(r,residue3[r],torsion_names[i])
                        if tag not in self.tags:
                            
                            mn = 1000
                            mx = -1000
                            for name in torsions.keys():
                                mn= min(mn,torsions[name][r][i])
                                mx= max(mx,torsions[name][r][i])
                            if mn==mx:
                                print 'no variation:',tag
                            else:
                                new_tags.append(tag)

        self.tags = self.tags + new_tags
        
        for ii in range(len(self.fasc_data_points)):
            p = self.fasc_data_points[ii]
            name = p[0]
            if torsions.has_key(name):
                for r in range(1,N+1):
                    if use_residue[r]:
                        for i in range(7):
                            if use_torsion[i]:
                                tag = 'torsion %3d %s %s'%(r,residue3[r],torsion_names[i])
                                if tag in new_tags:
                                    p[1].append( torsions[name][r][i] ) 
            else:
                print 'no torsion info:',name
                p[1] = p[1] + [0.0]*(len(new_tags))
        
        self.update_axes()
                                
    def get_decoy_scores(self,score_name):
        if not self.tags.count(score_name):
            print 'bad score tag:',score_name
            return {}
        decoy_scores = {} 
        score_index = self.tags.index(score_name)
        fullnames = self.fullnames
        for p in self.fasc_data_points:
            name = p[0]
            score = p[1][score_index]
            if fullnames.has_key(name):
                decoy_scores[ fullnames[name]] = score
        return decoy_scores
    
            
    def get_current_decoys(self):
        files = []
        for b in self.bins.keys():
            for l in self.bins[b]:
                name = l[2]
                if self.fullnames.has_key(name):
                    files.append( self.fullnames[name] )
        return files

    def get_all_decoys(self):
        return self.fullnames.values()

    def c2x(self,i,j):## from canvas to x,y
        x_min = self.bounds[0][0]
        y_min = self.bounds[0][1]
        x_max = self.bounds[1][0]
        y_max = self.bounds[1][1]
        if (x_min == x_max or y_min == y_max): return 0,0
        x_scale = (self.canvas_width - 2*self.canvas_border) / (x_max-x_min)
        y_scale = (self.canvas_width - 2*self.canvas_border) / (y_max-y_min)

        x = x_min + (i-self.canvas_border)/x_scale
        y = y_max - (j-self.canvas_border)/y_scale
        return x,y
    
    def button1_press(self,event):
        mode = self.click_mode.get()
        if mode == 'zoom':
            x = event.x
            y = event.y
            self.press = [event.x,event.y]
            self.motion = 0
            self.box = self.canvas.create_rectangle(x,y,x,y)
        elif mode == 'load':
            x = self.canvas.canvasx(event.x)
            y = self.canvas.canvasy(event.y)
            name = self.find_closest_datapoint(x,y)
            print name
            #print self.fullnames
            if self.fullnames.has_key(name):
                filename = self.fullnames[name]
                self.parent.load_decoy(filename)

    def motion(self,event):
        data_x,data_y = self.c2x(event.x,event.y)## call canvasx also??
        self.mouse_location.set('x= %9.3f  y= %9.3f'%(data_x,data_y))
        
                
    def button1_motion(self,event):
        if self.click_mode.get() == 'zoom':
            self.motion = 1
            self.canvas.coords(self.box,self.press[0],self.press[1],event.x,event.y)

    def button1_release(self,event):
        if self.click_mode.get() == 'zoom':
            if self.motion:
                #print 'from: %f,%f to %f,%f'%(self.press[0],self.press[1],
                #                              event.x,event.y)
                x0,y0 = self.c2x(self.press[0],self.press[1])
                x1,y1 = self.c2x(event.x,event.y)
                z_min = self.bounds[0][2]
                z_max = self.bounds[1][2]
                self.bounds = [[min(x0,x1),min(y0,y1),z_min],
                               [max(x0,x1),max(y0,y1),z_max]]
                self.replot()

            self.canvas.delete(self.box)


    def set_index(self,i,index): ## i=0 -> x, i=1 -> y, i=2 -> z
        index = int(index)
        self.axis[i].label[0].set("%d: %s"%(index+2,self.tags[index]))
        self.index[i] = index
        self.data_to_xy()
        self.replot()

    def plot_fasc(self,fasc_file,fullnames):
        #dictionary mapping the stripped filenames to fullnames:
        self.fullnames = fullnames
        
        data = open(fasc_file,'r')

        
        ## read the score tags:
        line = string.split(data.readline())
        if line[0] != 'filename':
            print 'format problem in fasc_file:',line
            return
        L = len(line)
        self.tags = line[1:-1] + ['subset'] ## fake extra tag: subset
        self.update_axes()
        

        ## load the datapoints
        self.fasc_data_points = []
        line = data.readline()
        print 'reading the fasc_file:',fasc_file
        while line:
            l = string.split(line)
            if len(l) == L:
                name = l[0]
                values = map(float,l[1:-1]) + [0.0] ## fake extra value
                self.fasc_data_points.append([name,values])
            line = data.readline()
        data.close()
        self.update_data_points()

        print 'done reading the fasc file'

        self.set_index(0,0)## update the axis labels
        self.set_index(1,1)
        self.set_index(2,2)
        
        self.data_to_xy()
        self.replot()

    ## update self.data_points using self.fasc_data_points and the current_subset
    ## could also use a mode switch to choose torsion_data_points rather than
    ## fasc_data_points ??
    def update_data_points(self):
        print 'update_data_points: starting'
        decoys = self.parent.subsets[self.parent.current_subset.get()]
        self.data_points = []
        native_name = self.native_name
        for p in self.fasc_data_points:
            name = p[0]
            if name == native_name or \
               (self.fullnames.has_key(name) and \
                self.fullnames[name] in decoys):
                self.data_points.append( [p[0],p[1]] )
        print 'update_data_points: done'
        
    def update_axes(self):
        for i in range(3):
            self.axis[i].scale.config(to=len(self.tags)-1)

    def data_to_xy(self): ## call when plotting indices are updated:: resets self.bounds
        i0 = self.index[0]
        i1 = self.index[1]
        i2 = self.index[2]
        points = []
        x = self.data_points[0][1][i0]
        y = self.data_points[0][1][i1]
        z = self.data_points[0][1][i2]
        x_min = x
        y_min = y
        z_min = z
        x_max = x
        y_max = y
        z_max = z

        ## -180,180 for torsions
        if string.count( self.tags[i0], 'torsion'):
            x_min = -180
            x_max = 180
        if string.count( self.tags[i1], 'torsion'):
            y_min = -180
            y_max = 180
        if string.count( self.tags[i2], 'torsion'):
            z_min = -180
            z_max = 180
        
        for p in self.data_points:
            name = p[0]
            x = p[1][i0]
            y = p[1][i1]
            z = p[1][i2]

            x_min = min(x_min,x)
            y_min = min(y_min,y)
            z_min = min(z_min,z)
            x_max = max(x_max,x)
            y_max = max(y_max,y)
            z_max = max(z_max,z)
            points.append( [x,y,z,name] )

        if string.count( self.tags[i0], 'torsion'):
            x_pad = 1 ## 1 degree
        else:
            x_pad = (x_max-x_min)/20
            
        if string.count( self.tags[i1], 'torsion'):
            y_pad = 1 ## 1 degree
        else:
            y_pad = (y_max-y_min)/20
            
        ## z_pad is small to begin with:
        z_pad = 0.0001

        self.bounds = [ [x_min-x_pad, y_min-y_pad, z_min-z_pad] ,
                        [x_max+x_pad, y_max+y_pad, z_max+z_pad]]
        self.xy_points = points

    def show_bounds(self):
        for i in range(3): ## x,y
            for j in range(2): ## min,max
                self.axis[i].label[j+1].set("%f"%(self.bounds[j][i]))
            

    def find_closest_datapoint(self,x,y):
        
        xbin = int(x/self.bin_width)
        ybin = int(y/self.bin_width)
        nbins = int( self.canvas_width / self.bin_width) + 1
        bins = self.bins
        
        closest = [1000000.0,'hey']
        
        for xb in range(xbin-1,xbin+2):
            for yb in range(ybin-1,ybin+2):
                b = (xb,yb)
                if bins.has_key(b):
                    for p in bins[b]:
                        name = p[2]
                        d = (p[0] - x)**2 + (p[1] -y)**2
                        if d<closest[0]:
                            closest = [d,name]
        return closest[1]
                
    def replot(self):
        points = self.xy_points

        self.canvas.delete('all')
        self.canvas.create_rectangle(self.canvas_border,self.canvas_border,
                                     self.canvas_width-self.canvas_border,
                                     self.canvas_width-self.canvas_border)

        self.show_bounds()
        
        x_min = self.bounds[0][0]
        y_min = self.bounds[0][1]
        z_min = self.bounds[0][2]
        x_max = self.bounds[1][0]
        y_max = self.bounds[1][1]
        z_max = self.bounds[1][2]

        if (x_max == x_min or y_max == y_min): return
        
        x_scale = (self.canvas_width - 2*self.canvas_border) / (x_max-x_min)
        y_scale = (self.canvas_width - 2*self.canvas_border) / (y_max-y_min)
        z_scale = 1.0 / (z_max-z_min)

        ## for point look-up by x,y
        bin_width = 50.0
        self.bin_width = bin_width
        nbins = int( self.canvas_width / bin_width) + 1
        self.bins = {}
        for i in range(nbins):
            for j in range(nbins):
                self.bins[(i,j)] = []
                
        for p in points:
            if x_min<=p[0]<=x_max and \
               y_min<=p[1]<=y_max:
                name = p[3]
                x = self.canvas_border + (p[0]-x_min) * x_scale
                y = self.canvas_border + (y_max-p[1]) * y_scale

                if name == self.native_name:
                    self.nativex.set(x)
                    self.nativey.set(y)
                
                color = Uniform_rainbow_color( (p[2]-z_min) * z_scale)

                xbin = int(x/bin_width)
                ybin = int(y/bin_width)
                self.bins[(xbin,ybin)].append ([x,y,name])
                
                self.canvas.create_rectangle(x-2,y-2,x+2,y+2,fill=color,outline='')
        self.mark_native() ## put mark on top of points

class DATA_PLOTTER: ############################################################
    def __init__(self,parent,window):
        global color_functions,scale_functions
        self.parent = parent
        
        self.window = window

        menu = parent.menu
        #menu = Menu(window,tearoff=0)
        #window.config(menu=menu)

        ## setup the datamenus -- we'll add new datasets here
        self.data2D_menu = Menu(menu,tearoff=0) ## 2D data
        menu.add_cascade(label="Data2D",menu=self.data2D_menu)
 
        self.data1D_menu = Menu(menu,tearoff=0) ## 1D data
        menu.add_cascade(label="Data1D",menu=self.data1D_menu)

            
        ## setup the color menu: control of color and rescaling
        ##  via these string variables::
        self.color_function_name = StringVar()
        self.color_function_name.set('rainbow')
        self.color_function_name.trace("w",(lambda x,y,z,s=self:s.replot())) ##
        self.scale_function_name = StringVar()
        self.scale_function_name.set('linear')
        self.scale_function_name.trace("w",(lambda x,y,z,s=self:s.replot())) ##
        self.zero_color=StringVar()
        self.zero_color.set('0') ## use color_function(0)
        self.zero_color.trace("w",(lambda x,y,z,s=self:s.replot())) ##
        self.energy_min = DoubleVar()
        self.energy_min.set(-3.0)
        self.energy_min.trace("w",(lambda x,y,z,s=self:s.replot())) ##
        self.energy_max = DoubleVar()
        self.energy_max.set(3.0)
        self.energy_max.trace("w",(lambda x,y,z,s=self:s.replot())) ##
        
        color_menu = Menu(menu,tearoff=0)
        menu.add_cascade(label="Color",menu=color_menu)
        for k in color_functions.keys():
            color_menu.add_radiobutton(label='Color: '+k,
                                       variable = self.color_function_name,
                                       value=k)
        color_menu.add_separator()
        for k in scale_functions.keys():
            color_menu.add_radiobutton(label='Scale: '+k,
                                       variable = self.scale_function_name,
                                       value=k)
        color_menu.add_separator()
        color_menu.add_radiobutton(label='zero-color: default',
                                   variable= self.zero_color,
                                   value='0')
        color_menu.add_radiobutton(label='zero-color: white',
                                   variable= self.zero_color,
                                   value='#ffffff')
        color_menu.add_separator()
        color_menu.add_command(label="config per-residue energy colors",
                               command = self.rescale_energy)

        ## the rasmol coloring interface
        self.rasmol_color_name = StringVar()
        self.rasmol_color_name.set("dummy")
        self.rasmol_color_menu = Menu(menu,tearoff=0)
        menu.add_cascade(label="Rasmol Color",menu=self.rasmol_color_menu)

        ## current selection
        menu.add_command(label="unselect",command=self.reset_current_selection)

        ## for tracking the current data ########################
        self.data2D = {} ## 2D data
        self.data2D_name_list = []
        self.current_data2D = {0:StringVar(),1:StringVar(),2:StringVar()}
        self.N = 0
        for i in range(2):
            self.current_data2D[i].set('dummy')
            self.current_data2D[i].trace("w",(lambda x,y,z,s=self:s.replot())) ##

        self.data1D = {} ## 1D data
        self.data1D_name_list = []
        self.data1D_show = {}
        
        ## make the status bar::
        self.mouse_location = StringVar()
        self.mouse_location.set('hey')
        f = Frame(self.window) ## make a frame to hold the plot titles
        f.pack(side=TOP)
        Label(f,textvariable=self.current_data2D[1]).pack(side=LEFT) ## upper
        Label(f,textvariable=self.current_data2D[0]).pack(side=LEFT) ## lower
        Label(f,textvariable=self.mouse_location).pack(side=LEFT) ## lower


        ## make the canvas
        self.canvas = Canvas(window)
        self.canvas.pack(expand=YES,fill=BOTH)
        self.canvas.bind('<Motion>',self.motion)
        self.canvas_width=DATA_PLOTTER_CANVAS_WIDTH ## initial value
        self.box_width=10 ##dummy initial value, we dont know total_residue yet
        self.SHOW_DIAG = 0 ## dont draw the diagonal elements

        self.follow_motion = 1 ## do call motion_callback_function
        self.canvas.bind("<Button-1>",self.button1_click) ## turns off follow_motion
        self.canvas.bind("<Double-Button-1>",self.double_button1_click) ## turns off follow_motion
        self.canvas.bind("<B1-Motion>",self.button1_motion)
        self.canvas.bind("<ButtonRelease-1>",self.button1_release)

        self.update_data1D_menu()
        ##### END INIT #####

    def calculate2D(self):
        
        subset_name = self.parent.current_subset.get()
        subset_decoys = self.parent.subsets[subset_name]

        ## make a list file
        prefix = self.parent.tmp_file_prefix
        list_file = prefix+'.calc2D.list'
        out = open(list_file,'w')
        for name in subset_decoys:
            out.write(name+'\n')
        out.close()
        command = '%s %s %s'\
                  %(CALCULATE2D_EXE,
                    list_file,
                    self.parent.native_pdb_file)
        print command
        data = popen(command)
        line = data.readline()

        D = {'HB':{}, 'CA':{}}
        boxed = {'HB':{}, 'CA':{}}
        
        while line:
            l = string.split(line)
            tag = l[0]
            if tag in ['HB','CA']:
                i = int(l[1])-1 ## convert to python numbering
                j = int(l[2])-1
                D[tag][(i,j)] = float(l[3])
            elif tag in ['NAT_HB','NAT_CA']:
                tag = tag[4:]
                i = int(l[1])-1 ## convert to python numbering
                j = int(l[2])-1
                boxed[tag][(i,j)] = 1
            line = data.readline()
        data.close()

        for tag in ['HB','CA']:
            if tag == 'HB':
                name = 'HB-BB-BB for %s'%subset_name
            else:
                name = 'CA-contact (9A) for %s'%subset_name
                
            dataset = DATASET2D(name,D[tag],boxed[tag],self.N)
            dataset.subset = subset_name
            self.add_dataset2D(dataset)
            
        
    def rescale_energy(self):
        global dialog_entries
        dialog_entries = [['energy_min',str( self.energy_min.get())],
                          ['energy_max',str( self.energy_max.get())] ]
        dialog = ENTRY_DIALOG(self.window,'enter per-residue energy scale')
        emin = float(dialog.result['energy_min'])
        emax = float(dialog.result['energy_max'])
        self.energy_min.set(emin)
        self.energy_max.set(emax)
        
    def update_data2D_menu(self):
        menu = self.data2D_menu
        menu.delete(0,1000)
        
        menu.add_command(label="calculate for current subset",command=self.calculate2D)
        menu.add_separator()
        
        subdata_menu = {}
        for i in [1,0]:## switch order
            if i==0:
                s = 'Lower'
            else:
                s = 'Upper'
            subdata_menu[i] = Menu(menu,tearoff=0)
            menu.add_cascade(label=s,menu=subdata_menu[i])

        current_subset = self.parent.current_subset.get()
        for name in self.data2D_name_list:
            dataset = self.data2D[name]
            if dataset.subset and dataset.subset != current_subset:continue
            menu.add_radiobutton(label = name,
                                 variable = self.current_data2D[2],
                                 command=(lambda s=self,n=name:\
                                          s.set_dataset(n)),
                                 value=name)
            for i in range(2):
                subdata_menu[i].add_radiobutton(label=name,
                                                variable=self.current_data2D[i],
                                                value=name)
                
            
    def update_data1D_menu(self):

        for menu in [self.data1D_menu, self.rasmol_color_menu]:
            menu.delete(0,1000)

            if menu ==self.data1D_menu:
                menu.add_command(label="calculate",command=self.calculate1D)
                menu.add_separator()

            more_menu = Menu(menu,tearoff=0)
            menu.add_cascade(label="more choices",menu=more_menu)
            even_more_menu = Menu(menu,tearoff=0)
            menu.add_cascade(label="even more choices",menu=even_more_menu)
            menu.add_separator()

            current_subset = self.parent.current_subset.get()
            for name in self.data1D_name_list:
                dataset = self.data1D[name]
                show = self.data1D_show[name]
                level = dataset.level
                subset = dataset.subset
                if subset and subset!= current_subset:
                    if show.get():
                        show.set(0) ## triggers replot
                    continue ## dont put this one in the menus

                if level==0:
                    m = menu
                elif level==1:
                    m = more_menu
                else:
                    m = even_more_menu

                if menu == self.data1D_menu:
                    m.add_checkbutton(label = name,variable=show)
                else:
                    m.add_radiobutton(label=name,
                                      variable=self.rasmol_color_name,
                                      value=name)
        

    def calculate1D(self):
        #### calculate 1D quantities for the current subset ####
        global chi_required
        
        subset_name = self.parent.current_subset.get()
        subset_decoys = self.parent.subsets[subset_name]

        ## initialize arrays:
        enumber = 0 ## flag for later
        total_decoys = 0 ## count number of decoys found in log_file and in subset_decoys
        residues = range(1,self.N+1)
        native_bb = {} 
        native_rot = {}
        bb_count = {}
        chi_correct = {}
        nat_rot_count = {} ## only calculated if we successfully loaded chi_required
        for pos in residues: ## 1->N
            chi_correct[pos] = {0:0, 1:0, 2:0, 3:0}
            bb_count[pos] ={}
            for bb in BB_LIST:
                bb_count[pos][bb] = 0
            if chi_required:
                nat_rot_count[pos] = {}
                for atom_name in chi_required[pos]['atom_names']:
                    nat_rot_count[pos][atom_name] = 0
                    
        ## read the log_file
        print 'reading energies and torsions from log_file'
        command = 'grep "^DS START_FILE" %s -A%d'\
                  %(self.parent.log_file,
                    2*self.N+1)
        data = popen(command)
        line = data.readline()
        while line:
            l = string.split(line)
            if line[:4] == 'DS S':
                id = l[2]
                decoy = START_FILE_PATH+'/'+id+'.pdb' ## fullname

                if id != 'native':
                    if decoy not in subset_decoys: ## skip to next decoy
                        line = data.readline()
                        while line and line[:4] != 'DS S':
                            line = data.readline()
                        continue
                    else:
                        total_decoys = total_decoys + 1
                        
            elif line[:4] == 'DS_T': ## torsions
                bb = ppo_to_bb( map(float,l[3:6]) )
                pos = int(l[1]) #1...N
                rot = [int(l[7]), int(l[9]), int(l[11]), int(l[13]) ]
                
                if id == 'native': ## native is the first entry we hit
                    native_bb[pos] = bb
                    native_rot[pos] = rot
                else:
                    bb_count[pos][bb] = bb_count[pos][bb] + 1
                    nat_rot = native_rot[pos]
                    for i in range(4):
                        if rot[i] == nat_rot[i]:
                            chi_correct[pos][i] = chi_correct[pos][i]+1
                        else:
                            break
                    if chi_required:
                        for atom_name in chi_required[pos]['atom_names']:
                            chi_req = chi_required[pos][atom_name]
                            if chi_req == 0 or \
                               (rot[:chi_req] == nat_rot[:chi_req]):
                                nat_rot_count[pos][atom_name] = nat_rot_count[pos][atom_name] + 1
                                
                        
                        
            elif line[:5] == 'DS_E_': ## list of names
                if not enumber:
                    ## initialize the arrays
                    enames = l[1:]
                    enumber = len(enames)
                    native_E = {} ## store native energies
                    avgE = {} ## store delta to average_E
                    natE = {} ## store delta to native_E
                    for pos in residues:
                        native_E[pos] = {}
                        avgE[pos] = {}
                        natE[pos] = {}
                        for i in range(enumber):
                            avgE[pos][i] = 0.0
                            natE[pos][i] = 0.0
                    
            elif line[:5] == 'DS_E ':
                pos = int(l[1])
                if id == 'native':
                    for i in range(enumber):
                        native_E[pos][i] = float(l[3+2*i])
                else:
                    for i in range(enumber):
                        avgE[pos][i] = avgE[pos][i] + float(l[3+2*i]) - float(l[4+2*i])
                        natE[pos][i] = natE[pos][i] + float(l[3+2*i]) - native_E[pos][i]
                
            line = data.readline()
        data.close()
        
        ## calculate new datasets:

        ## bb
        for bb in BB_LIST:
            name = 'BBDEC-%s for %s'%(bb,subset_name)
            D = {}
            boxed = {}
            N = self.N
            for pos in residues:
                D[pos-1] = float(bb_count[pos][bb]) / total_decoys
                if bb == native_bb[pos]:
                    boxed [pos-1] = 1

            dataset = DATASET1D(name,D,boxed,N)
            dataset.subset = subset_name
            if bb in ['A','B','G']:
                dataset.level = 0
            else:
                dataset.level = 1
            self.add_dataset1D(dataset)

        ## chiN correct
        for i in range(4):
            name = 'CHI%s for %s'\
                   %(string.join(map(lambda x:str(x+1),range(i+1)),''),subset_name)
            D = {}
            boxed = {}
            N = self.N
            for pos in residues:
                D[pos-1] = float(chi_correct[pos][i]) / total_decoys
            dataset = DATASET1D(name,D,boxed,N)
            dataset.subset = subset_name
            if i<2:
                dataset.level = 0
            else:
                dataset.level = 1
            self.add_dataset1D(dataset)

        ## NAT-ROT
        if chi_required: ## were we able to load the info?
            name = 'NAT-ROT for %s'%subset_name
            D = {}
            boxed = {}
            N = self.N
            atom_names = {}
            for pos in range(N):
                rosetta_pos = pos + 1 ## stupid numbering systems
                atom_names[pos] = chi_required[rosetta_pos]['atom_names'][:] 
                D[pos] = []
                for i in range(len(atom_names[pos])):
                    atom_name = atom_names[pos][i]
                    D[pos].append( float ( nat_rot_count[rosetta_pos][atom_name] )/total_decoys )
            dataset = DATASET1D(name,D,boxed,N,atom_names)
            dataset.level = 0
            dataset.subset = subset_name
            self.add_dataset1D(dataset)
            
        ## energies
        for i in range(enumber):
            for tag in ['Avg','Nat']:
                name = 'del%sE %s for %s'%(tag,enames[i],subset_name)
                D = {}
                boxed = {}
                N = self.N
                for pos in residues:
                    if tag == 'Avg':
                        D[pos-1] = avgE[pos][i] / total_decoys
                    else:
                        D[pos-1] = natE[pos][i] / total_decoys
                dataset = DATASET1D(name,D,boxed,N)
                dataset.subset = subset_name
                dataset.type = 'energy'
                if enames[i] in ['atr','res']:
                    dataset.level = 0
                elif enames[i] in ['sol','aa','dun','hbnd']:
                    dataset.level = 1
                else: ## rep,pair
                    dataset.level = 2
                self.add_dataset1D(dataset)
                    
                
    def set_current_subset(self,subset):
        self.update_data2D_menu()
        self.update_data1D_menu()
        return


    def get_cluster_bounds(self): ## allocate part of canvas to cluster tree
        x0 = self.box_width * self.N + 50
        y0 = 25
        x1 = x0 + 200
        y1 = y0 + 600
        return x0,y0,x1,y1



    def get_current_selection(self):
        residues = []
        D = self.current_selection.D
        for i in D.keys():
            if D[i]==1.0:
                residues.append(i+1)#################### rosetta nbrs!!!
        return residues
                
    def reset_current_selection(self):
        if self.N:
            for i in range(self.N):
                self.current_selection.D[i] = 0.0
            self.replot1D()

    def get_rasmol_bfactor(self): ## returns a bfactor type thingy, numbered 1->N!!!
        dataset = self.data1D [ self.rasmol_color_name.get() ]
        bfactor = {}
        atom_names = dataset.atom_names
        D = dataset.D
        N = dataset.N
        if dataset.type == 'frequency':
            rescale = scale_functions[ self.scale_function_name.get() ]
        elif dataset.type == 'energy':
            emin = self.energy_min.get()
            emax = self.energy_max.get()
            def rescale(x,emin=emin,emax=emax):
                e = max(emin, min(emax, x))
                return float(e-emin)/(emax-emin)
        else:
            print 'bad dataset type'
            return
        
        for i in range(N): ## convert to a bfactor rep
            if D.has_key(i):
                if atom_names.has_key(i):
                    assert type(D[i]) == type([])
                    bfactor[i+1] = {} 
                    for j in range(len(atom_names[i])):
                        name = atom_names[i][j]
                        bfactor[i+1][name] = min(99.99, 100.0*rescale( D[i][j] ))
                else:
                    assert type(D[i]) == type(0.0)
                    bfactor[i+1] = min(99.99, 100.0 * rescale( D[i] ))
        return bfactor
    
    def setup_color_function(self):
        global color_functions,scale_functions

        cf = color_functions[ self.color_function_name.get() ]
        sf = scale_functions[ self.scale_function_name.get() ]
        zc = self.zero_color.get()
        emin = self.energy_min.get()
        emax = self.energy_max.get()
        
        def color_function(x,type,cf=cf,sf=sf,zc=zc,emin=emin,emax=emax):
            if type == 'frequency':
                assert 0.0<=x<=1.0
                if x==0.0 and zc!='0':
                    return zc
                else:
                    return cf(sf(x))
            elif type == 'energy':
                e = max(emin, min(emax, x))
                f = float(e-emin)/(emax-emin)
                return cf(f)
            else:
                print 'bad dataset1D type!!!'
                
                
        return color_function
    
    def set_dataset(self,name):
        for i in range(2):
            self.current_data2D[i].set(name)

    ## create a pseudo-dataset1D that displays the current residue selection
    def setup_selection(self):
        D= {}
        boxed = {}
        for i in range(self.N):
            D[i] = 0.0
        self.current_selection = DATASET1D('current_selection',D,boxed,self.N)
        self.current_selection.level = 0 ## top level
        self.add_dataset1D(self.current_selection,1)
        

    def add_dataset2D(self,dataset): ## dataset is type DATASET2D
        name = dataset.name
        if self.N:
            if dataset.N != self.N:
                print 'length mismatch in new dataset:',self.N, dataset.N
                return ## dont add it
        else:
            self.N = dataset.N
            self.setup_selection()

            
        self.data2D[name] = dataset
        if name not in self.data2D_name_list:
            self.data2D_name_list.append(name)

        self.update_data2D_menu()
        
    def add_dataset1D(self,dataset,show_by_default=0): ## dataset is type DATASET1D
        name = dataset.name
        if name in self.data1D.keys():
            print 'redundant dataset!!'
            return
        
        if self.N and dataset.N != self.N:
            print 'length mismatch in new dataset:',self.N, dataset.N
            return ## dont add it
        else:
            self.N = dataset.N

        self.data1D[name] = dataset
        show = IntVar()
        show.set(show_by_default) ##default is not to show
        show.trace("w",(lambda x,y,z,s=self:s.replot1D())) ##replot1D on changes

        self.data1D_show[name] = show
        self.data1D_name_list.append( name )
        self.update_data1D_menu()
        
    def button1_click(self,event):
        x = event.x
        y = event.y
        
        x0,y0,x1,y1 = self.get_cluster_bounds()
        canvas = self.canvas

        self.boxed = 0
        if x<x0:
            self.follow_motion = (self.follow_motion+1)%2
            self.press = [x,y]
            self.motion = 0
            self.box = canvas.create_rectangle(x,y,x,y)
            self.boxed = 1
        elif x0 <= x <= x1 and y0 <= y <= y1:
            cx = canvas.canvasx(x)
            cy = canvas.canvasy(y)
            item = canvas.find_closest(cx,cy)
            self.parent.clusterer.cluster_tree_click( item, canvas )
            #canvas.itemconfigure(item,capstyle="round")
            #self.parent.clusterer.cluster_tree_click( canvas.gettags(item) )

    def double_button1_click(self,event):
        x = event.x
        y = event.y
        
        x0,y0,x1,y1 = self.get_cluster_bounds()
        canvas = self.canvas

        self.boxed = 0
        if x<x0:
            pass
        elif x0 <= x <= x1 and y0 <= y <= y1:
            cx = canvas.canvasx(x)
            cy = canvas.canvasy(y)
            item = canvas.find_closest(cx,cy)
            self.parent.clusterer.cluster_tree_double_click( item, canvas )
            #canvas.itemconfigure(item,capstyle="round")
            #self.parent.clusterer.cluster_tree_click( canvas.gettags(item) )

            
    def button1_motion(self,event):
        self.motion=1
        if self.boxed:
            self.canvas.coords(self.box,self.press[0],self.press[1],event.x,event.y)

    def button1_release(self,event):
        if self.boxed:
            self.canvas.delete(self.box)
            if self.motion:
                x1 = self.press[0] 
                y1 = self.press[1]
                x2 = event.x
                y2 = event.y

                i1 = intf( x1/self.box_width) 
                j1 = self.N-1 - intf(y1/self.box_width) 
                i2 = intf( x2/self.box_width) 
                j2 = self.N-1 - intf(y2/self.box_width) 


                for i in range(min(i1,i2),max(i1,i2)+1):
                    self.current_selection.D[i] = 1.0
                for i in range(min(j1,j2),max(j1,j2)+1):
                    self.current_selection.D[i] = 1.0

                self.replot1D()
            
    def motion(self,event):

        x = self.canvas.canvasx(event.x)
        y = self.canvas.canvasy(event.y)
        i = intf( x/self.box_width) 
        j = self.N-1 - intf(y/self.box_width) 

        if i>=j: ## lower
            name = self.current_data2D[0].get()
        else:
            name = self.current_data2D[1].get()

        if name=='dummy':return

        D = self.data2D[ name ].D
        if D.has_key((i,j)):
            f=D[(i,j)]
        else:
            f = 0.0
            
        self.mouse_location.set('%d,%d: %f'%(i+1,j+1,f)) ## output:: 1->N
        
        if not self.follow_motion:
            return


        if 0<=i<self.N and 0<=j<self.N:
            self.parent.data_plotter_callback(i,j)
            

    def replot(self):
        self.replot2D()
        self.replot1D()

    def replot2D(self):
        #if self.no_plot:return
        
        lower_set = self.current_data2D[0].get()
        upper_set = self.current_data2D[1].get()
        if not self.data2D.has_key(upper_set) or\
           not self.data2D.has_key(lower_set):
            print 'bad setting for current_data:',upper_set,lower_set
            return

        lower_set = self.data2D[lower_set]
        upper_set = self.data2D[upper_set]
        D0 = lower_set.D
        boxed0 = lower_set.boxed
        D1 = upper_set.D
        boxed1 = upper_set.boxed
        N = self.N
        
        self.box_width = int(floor(self.canvas_width/self.N))
        #print 'replot box_width:',self.box_width

        self.canvas.delete('2D') ## delete the items tagged with "2D"

        color_function = self.setup_color_function()
        
        for i in range(N):
            for j in range(N):
                if i==j and not self.SHOW_DIAG:continue
                if i>=j: ## below the diagonal
                    D = D0
                    boxed = boxed0 
                else: ## above it
                    D = D1
                    boxed = boxed1
                    
                x1 = i * self.box_width
                y1 = (self.N-1)*self.box_width - j * self.box_width
                x2 = x1+self.box_width
                y2 = y1+self.box_width

                if D.has_key((i,j)):
                    color = color_function( D[(i,j)], 'frequency' )
                else:
                    color = color_function( 0.0 ,'frequency')

                if boxed.has_key((i,j)):
                    outline_color = '#000000'
                    self.canvas.create_rectangle(x1+1,y1+1,x2-1,y2-1,
                                                 tags="2D",fill=color,
                                                 outline=outline_color,width=1) 
                else:
                    self.canvas.create_rectangle(x1,y1,x2,y2,
                                                 tags="2D",fill=color,outline='')

            
    def replot1D(self):
        #if self.no_plot:return

        if not self.data1D_name_list: return ## nothing to plot
        
        N = self.N
        
        self.box_width = int(floor(self.canvas_width/self.N))
        #print 'replot1D box_width:',self.box_width

        self.canvas.delete('1D') ## delete the items tagged with "1D"

        color_function = self.setup_color_function()

        counter = 0
        height_per_atom = 1
        height = MAX_ATOMS * height_per_atom ## 14
        y_offset = (N+2) * self.box_width

        
        for name in self.data1D_name_list:
            show = self.data1D_show [ name ].get()
            if show:
                counter = counter + 1
                dataset = self.data1D[name]
                D = dataset.D
                boxed = dataset.boxed
                atom_names = dataset.atom_names ## empty if per_rsd info only

                dataset_type = dataset.type ## 'frequency' or 'energy'

                self.canvas.create_text(N*self.box_width,y_offset+counter*height,
                                        text=name,anchor=NW,tags="1D")
                for i in range(N):
                    if not atom_names:
                        x1 = i * self.box_width
                        y1 = y_offset + counter * height
                        x2 = x1+self.box_width
                        y2 = y1+height
                    
                        color = color_function( D[i] , dataset_type)
                        if boxed.has_key(i):
                            outline_color = '#000000'
                            self.canvas.create_rectangle(x1+1,y1+1,x2-1,y2-1,
                                                         tags="1D",fill=color,
                                                         outline=outline_color,width=1) 
                        else:
                            self.canvas.create_rectangle(x1,y1,x2,y2,
                                                         tags="1D",fill=color,outline='')
                            
                    else: ## boxes for per-atom stuff right now
                        for j in range(len(atom_names[i])):
                            x1 = i * self.box_width
                            y1 = y_offset + counter * height + j*height_per_atom
                            x2 = x1 + self.box_width
                            y2 = y1 + height_per_atom
                            
                            color = color_function( D[i][j], dataset_type )
                            self.canvas.create_rectangle(x1,y1,x2,y2,
                                                         tags="1D",fill=color,outline='')
                

class RASMOL_HANDLER: ##############################################################
    def __init__(self,parent):
        ## make an entry in parent's window,packed at top
        ## keep a list of all rasmol sessions
        ## keep a list of connected session(s)?
        ## handle sending commands to sessions, highlighting residues, etc
        self.parent = parent

        window = self.parent.window
        parent_menu = self.parent.menu
        self.menu = Menu(parent_menu,tearoff=0)
        parent_menu.add_cascade(label="Rasmol",menu=self.menu)

        ## sessions:
        self.sessions = {}
        self.current_session = StringVar()
        self.current_session.set('dummy')
        
        ## frame for entry and buttons
        frame = Frame(window)
        frame.pack(side=TOP,fill=X)
        
        ## setup buttons
        self.follow_pair = IntVar() ## should we follow the mouse in the 2dplotter?
        self.follow_pair.set(1)
        
        Button(frame,text="Recolor",command=self.recolor).pack(side=LEFT)
        Checkbutton(frame,text="Follow pair",variable=self.follow_pair).pack(side=LEFT)
        
        raswin = Entry(frame)
        raswin.pack(side=LEFT,fill=X)
        self.command_state = [[],-1] #for command-line editting
        raswin.config(font=('helvetica',14,'normal'))
        raswin.insert(0,'RasMol> ')
        raswin.pack(side=TOP,fill=X)
        raswin.focus()
        raswin.bind('<Return>', (lambda event,s=self:s.raswin_fetch()))
        raswin.bind('<Up>', (lambda event,s=self:s.raswin_last_command(-1)))
        raswin.bind('<Down>', (lambda event,s=self:s.raswin_last_command(1)))
        self.raswin = raswin
        
    def get_current_session(self):
        session = self.sessions[self.current_session.get()]
        return session

    def get_current_selection(self):
        atoms = []
        session = self.get_current_session()
        session.rin.write('show selected atom\nselect selected\n')
        session.rin.flush()
        line = session.rout.readline()
        while line and not string.count(line,'select selected'):
            #print line[:-1]
            if string.count(line,'Group:'):
                atom = int(string.split(string.split(line)[-1],'/')[0])
                if atom not in atoms:
                    atoms.append(atom)
            line = session.rout.readline()
        print 'current session: %s total atoms: %d'%(session.id,len(atoms))
        return atoms

    def add_session(self,new_session):
        id = new_session.id
        self.sessions[id] = new_session
        self.menu.add_radiobutton(label=id,variable=self.current_session,value=id)
        self.current_session.set(id)
        
    def highlight(self,i,j):
        if self.follow_pair.get():
            self.get_current_session().highlight(i,j)

    def recolor(self):
        self.get_current_session().recolor()
        
    def raswin_fetch(self):
        command = self.raswin.get()[8:]
        if command:
            self.command_state[0].append(command)
            self.command_state[1] = len(self.command_state[0])

            session = self.get_current_session() 
            session.rin.write(command+'\n')
            session.rin.flush()

        self.raswin.delete(0,END)
        self.raswin.insert(0,'RasMol> ')

    def raswin_last_command(self,increment):
        new_state = self.command_state[1]+increment
        if new_state <0 or new_state >= len(self.command_state[0]):return
        self.command_state[1] = new_state

        self.raswin.delete(0,END)
        self.raswin.insert(0,'RasMol> %s'%(self.command_state[0][ self.command_state[1]] ))


class RASMOL_SESSION: ##############################################################
    def __init__(self,parent,id):
        self.parent = parent
        self.id = id

        ## start a rasmol process
        rout,rin = popen2(RASMOL_COMMAND)
        self.rout = rout
        self.rin = rin
        
        ## setup pdb and script files
        prefix = self.parent.tmp_file_prefix
        self.pdb_file    = '%s.rasmol_pdb.%s.pdb'%(prefix,id)
        self.script_file = '%s.rasmol_script.%s.pdb'%(prefix,id)
        self.SHOW_H = 0 #dont show hydrogens
        self.LOADED = 0 #no structure currently loaded

        ## start listening for clicks
        self.listen_rasmol() 

        ## track residue-pair highlighting
        self.highlighted = []

    def write(self,command):
        self.rin.write(command)
        if command[-1]!='\n':self.rin.write('\n')
        self.rin.flush()

    def open(self,file,bfactor,name=''):
        ## file shouldnt be our temporary pdb-file!!
        ## bfactor = {1:{'N':0.5, 'CA':0.2, ...},...}
        ##      or = {1:0.5, 2:0.75, ..}

        self.name = name
        self.original_pdb_file = file

        if self.LOADED:        ## recover the current view
            self.write('write rasmol %s\n'%self.script_file)
            
        data = open(file,'r')
        out = open(self.pdb_file,'w')
        out.write('HEADER    %40s            %10s\n'\
                  %(string.split(file,'/')[-1],name))
        out.write('COMPND     %s\n'%file)
        line=  data.readline()

        first_resnum = -999
        while line:
            if line[:4] == 'ATOM':
                atom_name = line[12:16]
                if atom_name[1] == 'H' and not self.SHOW_H:
                    line = data.readline()
                    continue
                atom_name = string.split(atom_name)[0] ## strip the white-space
                resnum = int(line[22:26])
                if first_resnum == -999:
                    first_resnum = resnum
                    
                bf = 0.0
                if bfactor.has_key(resnum):
                    if type(bfactor[resnum]) == type(0.0):
                        bf = bfactor[resnum]
                    elif type(bfactor[resnum]) == type({}) and \
                         bfactor[resnum].has_key(atom_name):
                        bf = bfactor[resnum][atom_name]

                if resnum==first_resnum:
                    if atom_name=='N':
                        bf=0.0
                    elif atom_name=='O':
                        bf=99.99
                
                out.write('%s%6.2f\n'%(line[:60],bf))
            elif line[:6] in ['MODEL ','ENDMDL']:
                out.write(line)

            line = data.readline()
        data.close()

        #self.LOADED = 1
        if self.LOADED:
            self.write('zap\nscript %s\n'%(self.script_file))
        else:
            self.LOADED=1
            self.write('zap\nload %s\n'%(self.pdb_file))

    def recolor(self):
        bfactor = self.parent.data_plotter.get_rasmol_bfactor()
                
        tmp_file = self.pdb_file+'.tmp'
        system('cp %s %s'%(self.pdb_file,tmp_file))
        self.open(tmp_file,bfactor,self.name)
        self.write('color temperature\n')
                

    def clean_pipe(self):
        self.write('GARBAGE GARBAGE GARBAGE\n')
        line = self.rout.readline()
        while not string.count(line,'GARBAGE'):
            if self.Good_line(line):
                print line[:-1]
            line = self.rout.readline()

    def Good_line(self,line):
        return not string.count(line,'GARBAGE') and \
               not string.count(line,'^') and \
               not string.count(line,'Unrecognised') and \
               not string.count(line,'RasMol>') and \
               not string.count(line,'atoms selected')


    def listen_rasmol(self):
        global root #TK root
        self.write('GARBAGE GARBAGE GARBAGE\n')
        line = self.rout.readline()
        while not string.count(line,'GARBAGE'):
            if self.Good_line(line):
                print line[:-1]
            line = self.rout.readline()
        root.after(500,self.listen_rasmol) ## every 1/2 of a second
        
        

    def highlight(self,i,j):
        if self.highlighted == [i,j]:return
        self.write('define current selected\n')
        if self.highlighted:##un-highlight the previous pair
            self.write('select (%d,%d)\nwireframe\n'\
                       %(self.highlighted[0],self.highlighted[1]))
            
        self.highlighted = [i,j]
        self.write('select (%d,%d)\nwireframe 50\n'\
                   %(self.highlighted[0],self.highlighted[1]))
        self.write('select current\n')

class CLUSTER_CANVAS_INTERFACE: ## hack to use cluster_trees.py code w/o many changes
    def __init__(self,parent):
        self.parent = parent
        self.canvas = parent.data_plotter.canvas ## use the data_plotter's canvas
        self.update_bounds()
        
    def delete(self):
        self.canvas.delete('tree')
    def make_line(self,xy0,xy1,line_width,score,extra_tag,selected):
        color = Uniform_rainbow_color(score)
        if selected and xy0[1] == xy1[1]:
            min_x = min(xy0[0],xy1[0])
            max_x = max(xy0[0],xy1[0])
            self.canvas.create_line(self.x0 + min_x-30,
                                    self.y0 + xy0[1],
                                    self.x0 + max_x,
                                    self.y0 + xy1[1],
                                    width = line_width,
                                    fill = color,
                                    arrow = BOTH,
                                    #stipple = "gray50",
                                    tags=("tree",extra_tag))
        else:
            self.canvas.create_line(self.x0 + xy0[0],
                                    self.y0 + xy0[1],
                                    self.x0 + xy1[0],
                                    self.y0 + xy1[1],
                                    width = line_width,
                                    fill = color,
                                    tags=("tree",extra_tag))
            
    def make_text(self,text_string, xy, font_size,extra_tag="dummy"):
        self.canvas.create_text(self.x0 + xy[0],self.y0 + xy[1],text=text_string,
                                font=('helvetica',font_size,'normal'),
                                tags=("tree",extra_tag))
    def update_bounds(self):
        ## x0,y0 are the offsets used in plotting commands make_line and make_text
        self.x0,self.y0,x1,y1 = self.parent.data_plotter.get_cluster_bounds()
        
class CLUSTERER:
    def __init__(self,parent,dir):

        self.parent = parent
        self.window = parent.window

        self.canvas_interface = CLUSTER_CANVAS_INTERFACE(parent)

        ## clustering directory
        self.dir = dir
        if not exists(dir):system('mkdir '+dir)

        ## attach clustering menu to parent
        self.menu = Menu(self.parent.menu,tearoff=0)
        self.parent.menu.add_cascade(label="Cluster",menu=self.menu)
        self.menu.add_command(label="cluster",command=self.cluster)
        self.menu.add_command(label="rasmol selected cluster",command=self.load_current_cluster)
        

        self.menu.add_separator ## session list comes after this
        
        ## manage clustering sessions
        self.sessions = []
        self.current_session = StringVar()
        self.current_session.set('dummy')
        self.current_session.trace("w",(lambda x,y,z,s=self:s.plot_tree()))

        ## for cluster-tree coloring
        self.P = IntVar()
        self.P.set(25)
        self.P.trace("w",(lambda x,y,z,s=self:s.plot_tree()))
        self.color_score = StringVar()
        self.color_score.set("score")
        self.color_score.trace("w",(lambda x,y,z,s=self:s.plot_tree()))

        ## selected cluster
        self.current_cluster = StringVar()
        self.current_cluster.set('') ## will look like: "cluster00.075"

        ## sub-menu for recoloring the tree
        m = Menu(self.menu,tearoff=0)
        self.menu.add_cascade(label="color by score",menu=m)
        tags = self.parent.fasc_plotter.get_score_tags()
        m.add_radiobutton(label="clustered-atom rmsd",
                          variable = self.color_score,
                          value = "atom_rmsd")
        for t in tags:
            m.add_radiobutton(label=t,
                              variable = self.color_score,
                              value=t)

        ## old sessions
        self.load_old_sessions()
        ## END INIT ##

    ## returns the list of decoys in current_cluster
    def get_current_decoys(self):
        file = '%s.%s.pdb' %(self.current_session.get(),self.current_cluster.get())
        decoys = map(lambda x:string.split(x)[2],
                     popen('grep "^REMARK" %s | grep NATIVE -v'%file).readlines())
        print 'selected cluster contains %d decoys'%(len(decoys))
        return decoys
        
    def cluster_tree_click(self,item,canvas):
        tag_list = canvas.gettags(item)

        for t in tag_list:
            if t[:7] == 'cluster':
                print 'cluster click:',t
                self.current_cluster.set(t)
                self.plot_tree()
                break
            
    def cluster_tree_double_click(self,item,canvas):
        tag_list = canvas.gettags(item)

        for t in tag_list:
            if t[:7] == 'cluster':
                print 'cluster click:',t
                self.current_cluster.set(t)
                self.plot_tree()
                self.load_current_cluster()
                break

            
    def load_current_cluster(self):
        file = '%s.%s.pdb' %(self.current_session.get(),self.current_cluster.get())
        if not exists(file):
            print 'missing:',file
            return
        rasmol = RASMOL_SESSION(self.parent,string.split(file,'/')[-1])
        rasmol.open(file,{})
        self.parent.rasmol.add_session(rasmol)
        

    def cluster(self):
        dialog = CLUSTER_DIALOG(self.window,"cluster settings")
        result = dialog.result
        prefix = self.dir+'/'+result['prefix']
        
        ## get the atom set:
        l = result['atoms']
        use_atom = {}
        use_residue = {}
        if 'rasmol' in l:
            for atom in self.parent.rasmol.get_current_selection():
                use_atom [atom] = 1
            print 'total atoms in rasmol selection:',reduce(add,use_atom.values())
        if 'data_plotter' in l:
            for rsd in self.parent.data_plotter.get_current_selection():
                use_residue [rsd] = 1
            print 'total residues in data_plotter selection:',\
                  reduce(add,use_residue.values())

        

        ## get the decoy set:
        decoy_list = result['decoys']
        
        ## make the template file:
        template_file = prefix+'.template'
        out = open(template_file,'w')
        data = open(self.parent.native_pdb_file,'r')
        line = data.readline()
        while line:
            if line[:4] == 'ATOM':
                atomno = int(line[6:11])
                rsd = int(line[22:26])

                if (use_atom.has_key(atomno) or not use_atom) and \
                   (use_residue.has_key(rsd) or not use_residue):
                    out.write(line)
            line = data.readline()
        data.close()
        out.close()
            
        ## make the list file
        list_file = prefix+'.list'
        out = open(list_file,'w')
        out.write(string.join(decoy_list,'\n')+'\n')
        out.close()

        N = len(decoy_list)
        minClusterSize = min(int(result['minClusterSize']),
                             max(1,intf(N*float(result['minTopClusterSize']))))
        
        command = '%s -n %s -l %s -t %s -p %s %d,%d,%d,%d %f,%f'\
                  %(FA_CLUSTER_EXE,
                    self.parent.native_pdb_file,
                    list_file,
                    template_file,
                    prefix,
                    minClusterSize,
                    max(1,intf(N*float(result['minTopClusterSize']))),
                    max(1,intf(N*float(result['tryTopClusterSize']))),
                    max(1,intf(N*float(result['maxTopClusterSize']))),
                    float(result['minClusterThreshold']),
                    float(result['maxClusterThreshold']))
                    
        print command
        system(command)

        self.add_session(prefix)
        self.current_session.set(prefix) ## triggers a plot_tree via the trace function

    def add_session(self,prefix):
        self.sessions.append(prefix)
        self.menu.add_radiobutton(label=prefix,
                                  variable=self.current_session,
                                  value=prefix)

    def plot_tree(self):
        self.canvas_interface.delete()## clear existing plot
        self.canvas_interface.update_bounds()## get allocated part of canvas
        
        prefix = self.current_session.get()
        score_name = self.color_score.get()
        P = self.P.get() ## score percentile

        
        ## load data from .info file
        info_file = prefix+'.info'
        N,names,distance,sizes,cluster_members,decoy_rmsd = \
                                                          self.parse_info_file(info_file)
        
        ## retrieve the per-decoy scores in dictionary indexed by fullname
        if score_name == 'atom_rmsd':
            decoy_scores = decoy_rmsd
        else:
            decoy_scores = self.parent.fasc_plotter.get_decoy_scores(score_name)

        ## assign score list to each cluster
        cluster_scores = {}
        for cluster in range(N):
            assert len(cluster_members[cluster]) == sizes[cluster]
            cluster_scores[cluster] = []
            for decoy in cluster_members[cluster]:
                cluster_scores[cluster].append( decoy_scores[decoy] )

        ## make average linkage score tree
        score_tree = cluster_trees.Make_tree(distance, N,
                                             cluster_trees.Update_distance_matrix_AL,
                                             cluster_scores,P)

        x0,y0,x1,y1 = self.parent.data_plotter.get_cluster_bounds()
        plot_width = x1-x0
        y_buffer = 50
        plot_height = y1-y0- y_buffer
        cluster_trees.Canvas_tree( score_tree,names,sizes,self.canvas_interface,
                                   plot_width,plot_height,self.current_cluster.get())

    def parse_info_file(self,info_file):
##         N,names,distance,sizes,cluster_members,decoy_rmsd = self.parse_info_file(info_file)
        N = 0
        names = []
        distance = {}
        sizes = []
        cluster_members = {}
        decoy_rmsd = {}
        
        data = open(info_file,'r')
        line = string.split(data.readline())

        while line:
            if line[0] == 'CLUSTER_RMSDS':
                cluster = int(line[1])
                assert cluster == len(sizes) and len(line) == 10+cluster
                sizes.append(int(line[2]))
                names.append( '%d_%d'%(cluster,int(line[2])))
                
                for i in range(cluster+1):
                    distance[(i,cluster)] = float(line[9+i])
                    distance[(cluster,i)] = float(line[9+i])
            elif line[0] == 'CLUSTER_INFO:':
                cluster = int(line[1])
                cluster_members[cluster] = map(lambda x:string.split(x,',')[1],
                                               line[9:])
            elif line[0] == 'DECOY_RMSD':
                decoy_rmsd[line[1]] = float(line[2])
            line = string.split(data.readline())
        data.close()
        N = len(sizes)
        return N,names,distance,sizes,cluster_members,decoy_rmsd
        
        
    def load_old_sessions(self):
        files = glob(self.dir+'/*.info')
        for file in files:
            self.add_session(file[:-5])


class ISE_DECOY:
    def __init__(self,window,log_file,pdb_file,fasc_file):
        self.window = window
        self.window.title('data_plotter')

        self.native_pdb_file = pdb_file
        xx = string.split(pdb_file,'/')[-1][:2]
        self.xx = xx
        self.pdb_id = string.split(pdb_file,'/')[-1][2:6]
        self.pdb_chain = string.split(pdb_file,'/')[-1][6]
        
        self.log_file = log_file

        ## the main ise_decoy menu
        self.menu = Menu(self.window,tearoff=0)
        self.window.config(menu=self.menu)
        self.subset_menu = Menu(self.menu,tearoff=0)
        self.menu.add_cascade(label="decoy subsets",menu=self.subset_menu)

        ## handling decoy subsets
        #self.subset_list = ['ALL']
        self.subsets = {}
        self.current_subset = StringVar()
        self.current_subset.set('ALL')
        self.subset_menu.add_command(label="new subset",command=self.get_new_subset)
        self.subset_menu.add_separator()
        self.subset_menu.add_radiobutton(label="ALL",variable=self.current_subset,
                                         value="ALL")
        # move this to the end (calls to various children):
        #self.current_subset.trace("w",(lambda x,y,z,s=self:s.trace_current_subset()))

        ## tmp files:
        pid = getpid()
        self.tmp_file_prefix = '%s/ise_tmp.%d'%(TMP_DIR,pid)

        ## rasmol ################
        self.rasmol = RASMOL_HANDLER(self)
        self.rasmol_native = RASMOL_SESSION(self,'NATIVE')
        self.rasmol_native.open(pdb_file,{},'NATIVE')
        
        self.rasmol_decoy = RASMOL_SESSION(self,'DECOY')
        self.rasmol.add_session(self.rasmol_decoy)
        self.rasmol.add_session(self.rasmol_native)
        
        ## 1d,2d plotter
        self.data_plotter = DATA_PLOTTER( self,window)
        datasets1D,datasets2D,fullnames = Read_datasets(log_file)# load data
        for d in datasets2D:# add to plotter
            self.data_plotter.add_dataset2D(d)
        for d in datasets1D:
            if default_level.has_key(d.name):
                d.level = default_level[d.name]
            self.data_plotter.add_dataset1D(d)

        ## define ALL subset decoy list
        self.subsets['ALL'] = fullnames.values()

        ## gnuplotter
        self.fasc_plotter = GNUPLOTTER_2D(self,600,25)
        self.fasc_plotter.plot_fasc(fasc_file,fullnames)

        ## clustering
        self.clusterer = CLUSTERER(self,'./%s%s_ise_cluster/'%(xx,self.pdb_id))

        ## get rid of old tmp files
        self.clean_tmp_files()

        ## trace the current_subset variable
        self.current_subset.trace("w",(lambda x,y,z,s=self:s.trace_current_subset()))
        self.load_old_subsets()
        ## END INIT ##################

    def get_subset_names(self):
        return self.subsets.keys() + ['fasc_plotter','current_cluster']

    def get_subset_decoys(self,subset_name):
        if self.subsets.has_key( subset_name ):
            return self.subsets[subset_name]
        elif subset_name == 'fasc_plotter':
            return self.fasc_plotter.get_current_decoys()
        elif subset_name == 'current_cluster':
            return self.clusterer.get_current_decoys()
        else:
            print 'bad subset name:',subset_name
            return []
        
    def load_old_subsets(self):
        subset_file = '%s%s%s_ise_subsets.txt'\
                      %(self.xx,self.pdb_id,self.pdb_chain)
        if not exists(subset_file):
            return
        for line in map(string.split,open(subset_file,'r').readlines()):
            self.add_subset( line[0],line[1:])
        


    ## what should we do when the current subset changes?
    def trace_current_subset(self):
        self.fasc_plotter.set_current_subset(self.current_subset.get())
        self.data_plotter.set_current_subset(self.current_subset.get())
        
        
    ## returns the list of decoys in an intersection
    ## decoy names are as in the log-file lines: DS START_FILE ...
    ## plus the start_path; not the trimmed names in the .fasc file
    ## ie -- complete path to the file; same as REMARK lines in cluster files
    ##
    ## subset_list consists of names of defined decoy subsets PLUS
    ## the two special names:
    ## "fasc_plotter" -- the decoys currently visible in the fasc_plotter
    ## "current_cluster" -- the decoys in the currently selected cluster
    def get_subset_intersection(self,subset_list):
        if not subset_list:
            return []
        print 'finding intersection:',subset_list
        big_decoy_list = self.subsets['ALL'][:]
        for subset in subset_list:
            if subset == 'fasc_plotter':
                decoys = self.fasc_plotter.get_current_decoys()
            elif subset == 'current_cluster':
                decoys = self.clusterer.get_current_decoys()
            else:
                if not self.subsets.has_key(subset):
                    print 'bad subset name!!!',subset
                    continue
                decoys = self.subsets[subset]
            l = range(len(big_decoy_list))
            l.reverse()
            for i in l:
                d = big_decoy_list[i]
                if d not in decoys:
                    del big_decoy_list[i]
        print 'done finding intersection; len=',len(big_decoy_list)
        return big_decoy_list
            

    ## function for defining a new subset
    def get_new_subset(self):
        dialog = GET_NEW_SUBSET_DIALOG(self.window,"define a new subset")
        result = dialog.result
        subset_name = result['name']
        if not subset_name:
            subset_name = 'subset%d'%(1+len(self.subsets.keys()))

        self.add_subset( subset_name, result['decoys'],1)

        ## add it to the subset file
        subset_file = '%s%s%s_ise_subsets.txt'\
                      %(self.xx,self.pdb_id,self.pdb_chain)
        out = open(subset_file,'a')
        out.write('%s %s\n'\
                  %(string.join(string.split(subset_name),'_'),
                    string.join(result['decoys'])))
        out.close()

    def add_subset(self,subset_name,decoys,update_current=0):
        self.subsets[subset_name] = decoys[:]

        self.subset_menu.add_radiobutton(label=subset_name,
                                         variable=self.current_subset,
                                         value=subset_name)
        if update_current:
            self.current_subset.set(subset_name)
        
    def clean_tmp_files(self):
        print 'cleaning tmp files from previous sessions'
        pid_list = map(lambda x:string.split(x)[1],popen('/bin/ps ux').readlines())
        print pid_list
        me = string.split(popen('whoami').readline())[0]
        print 'I am:',me
        command = "/bin/ls -l %s/ise_tmp* | awk '($3==\"%s\")'"\
                  %(TMP_DIR,me)
        print command
        files = map(lambda x:string.split(x)[-1],
                    popen(command).readlines())
        #files = glob('%s/ise_tmp*'%TMP_DIR)
        for file in files:
            id = string.split(string.split(file,'/')[-1],'.')[1]
            if id not in pid_list:
                print 'delete:',file
                remove(file)

    def data_plotter_callback(self,i,j):
        self.rasmol.highlight(i+1,j+1)
        

    def load_decoy(self,name):
        print 'load decoy:',name
        self.rasmol_decoy.open(name,{},'DECOY')

    def atoms_to_residues(self,atom_list):
        file = self.native_pdb_file
        lines = popen('grep "^ATOM" '+file).readlines()
        rsd_list = []
        for line in lines:
            atom = int(line[6:11])
            rsd = int(line[22:26])
            if atom in atom_list and rsd not in rsd_list:
                rsd_list.append(rsd)
        return rsd_list

############################################################################
############################################################################
############################################################################

def Read_datasets(file):
    
    data = open(file,'r')
    line = data.readline()
    N = 0
    D = {}
    boxed = {}
    atom_names = {}
    fullnames = {}

    print 'reading decoystats log file:',file
    counter = 0
    while line:
        if line[:2] == 'DS':
            l = string.split(line)
            if l[1] == 'START_FILE':
                counter = counter +1
                if not counter%100:
                    print counter
                filename = START_FILE_PATH+'/'+l[2]+'.pdb'
                name = string.split(filename,'/')[-1]
                if fullnames.has_key(name) and \
                   filename != fullnames[name]:
                    print 'repeat name:\n%s\n%s\n%s'\
                          %(name,filename,fullnames[name])
                    #del fullnames[name]
                else:
                    fullnames[name] = filename
            elif l[0]=='DS_PAIR':
                assert len(l) == 5
                tag= (l[0],l[1])
                if not D.has_key(tag):
                    D[tag] ={}
                i = int(l[2])-1 ## convert to python numbering
                j = int(l[3])-1
                N= max(max(i+1,j+1),N)
                D[tag][(i,j)] = float(l[4])
            if l[0]=='DS_PAIR_BOX':
                assert len(l) == 4
                tag= (l[0][:-4],l[1])
                if not boxed.has_key(tag):
                    boxed[tag] = {}
                i = int(l[2])-1 ## convert to python numbering
                j = int(l[3])-1
                N= max(max(i+1,j+1),N)
                boxed[tag][(i,j)] = 1
            elif l[0] == 'DS_RSD':
                assert len(l) == 4
                tag = (l[0],l[1])
                if not D.has_key(tag):
                    D[tag] ={}
                i = int(l[2])-1 ## convert to python numbering
                N= max(i+1,N)
                D[tag][i] = float(l[3])
            elif l[0] == 'DS_RSD_BOX':
                assert len(l) == 3
                tag = (l[0][:-4],l[1])
                if not boxed.has_key(tag):
                    boxed[tag] ={} 
                i = int(l[2])-1 ## convert to python numbering
                N= max(i+1,N)
                boxed[tag][i] = 1
            elif l[0] == 'DS_ATOM':
                assert len(l) == 6
                tag = (l[0],l[1])
                if not D.has_key(tag):
                    D[tag] ={}
                    atom_names[tag] = {}
                i = int(l[2])-1 ## convert to python numbering
                j = int(l[3])##atom number
                N= max(i+1,N)
                if j==1:
                    D[tag][i] = []
                    atom_names[tag][i] = []
                D[tag][i].append( float(l[5]) )
                atom_names[tag][i].append(l[4])
                assert len(atom_names[tag][i]) == j
        line = data.readline()
    data.close()
    
    datasets1D = []
    datasets2D = []
    ks = D.keys()
    ks.sort()
    for tag in ks:
        name = tag[1]
        if not boxed.has_key(tag):boxed[tag] = {}
        if tag[0] == 'DS_PAIR':
            datasets2D.append(DATASET2D(name,D[tag],boxed[tag],N))
        elif tag[0] == 'DS_RSD':
            datasets1D.append(DATASET1D(name,D[tag],boxed[tag],N))
        elif tag[0] == 'DS_ATOM':
            datasets1D.append(DATASET1D(name,D[tag],boxed[tag],N,atom_names[tag]))

    return datasets1D,datasets2D,fullnames

def Load_chi_required(filename):
    global chi_required ## global var!!

    if not exists(filename):
        print 'aargh: cant open:',filename
        chi_required = {}
        return
    data = open(filename,'r')
    line = data.readline()
    while line:
        if line[:4] == 'ATOM':
            atom_name = line[12:16]
            if atom_name[1] == 'H': ## skip hydrogens
                line = data.readline()
                continue
            
            atom_name = string.split(atom_name)[0] ## strip the white-space

            resnum = int(line[22:26])

            if not chi_required.has_key( resnum): ## ROSETTA numbering-- like bfactor
                chi_required[resnum] = {}
                chi_required[resnum]['atom_names'] = []

            chi_required[resnum][atom_name] = int(float(line[60:66]))
            chi_required[resnum]['atom_names'].append( atom_name )
            
        line = data.readline()
    data.close()
    return

##########################################################################
##########################  MAIN  ########################################
##########################################################################

## setup the info: which atoms depend on which chi angles
Load_chi_required(chi_required_file)

root = Tk()
root.geometry("620x620")

## currently, this global variable (ise) is accessed by some of the dialogs to get
## subset lists, etc. bad form but hard to avoid
ise = ISE_DECOY(root,log_file,pdb_file,fasc_file)

root.mainloop()



