# :noTabs=true:
# (c) Copyright Rosetta Commons Member Institutions.  # (c) This file is part of the Rosetta software suite and is made available under license.  # (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.  # (c) For more information, see http://www.rosettacommons.org. Questions about this can be
# (c) addressed to University of Washington UW TechTransfer, email: license@u.washington.edu.

import os, sys, platform, os.path

import warnings
warnings.filterwarnings("ignore", "to-Python converter for .+ already registered; second conversion method ignored.", RuntimeWarning, "^rosetta\\.")

import utility, core

import rosetta.utility
import rosetta.utility.file
#import rosetta.utility.pointer

import rosetta.numeric

import rosetta.core
#import rosetta.core.coarse
import rosetta.core.chemical
import rosetta.core.conformation
#import rosetta.core.grid
import rosetta.core.id
import rosetta.core.io
import rosetta.core.io.pdb
import rosetta.core.fragment
#import rosetta.core.graph
import rosetta.core.kinematics
#import rosetta.core.kinematics.tree
#import rosetta.core.mm
#import rosetta.core.optimization
import rosetta.core.pack
import rosetta.core.pack.task
import rosetta.core.pose
#import rosetta.core.pose.metrics
import rosetta.core.scoring
import rosetta.core.scoring.hbonds
#import rosetta.core.scoring.dna
#import rosetta.core.scoring.rna
#import rosetta.core.scoring.methods
#import rosetta.core.scoring.trie
#import rosetta.core.sequence
import rosetta.protocols.moves
import rosetta.protocols.abinitio
import rosetta.protocols.docking
import rosetta.protocols.loops
import rosetta.protocols.relax
import rosetta.core.pose.signals

from rosetta.core.id import *
from rosetta.core.conformation import *
from rosetta.core.chemical import *
from rosetta.core.pose import Pose
from rosetta.core.io.pdb import pose_from_pdb, dump_pdb
from rosetta.core.scoring import *
from rosetta.core.kinematics import *
from rosetta.core.fragment import *
from rosetta.core.pack.task import *
from rosetta.core.pack.task.operation import *
from rosetta.protocols.moves import *
from rosetta.protocols.abinitio import *
from rosetta.protocols.docking import *
from rosetta.protocols.loops import *
from rosetta.protocols.relax import *


# Create global 'Platform' that will hold info of current system
if sys.platform.startswith("linux"): Platform = "linux" # can be linux1, linux2, etc
elif sys.platform == "darwin" : Platform = "macos"
elif sys.platform == "cygwin" : Platform = "cygwin"
else: Platform = "_unknown_"
PlatformBits = platform.architecture()[0][:2]


# add iter property to Pose
def _Pose_residue_iterator(obj):
    def __pose_iter():
        for i in range(obj.total_residue()): yield obj.residue(i+1)
    return __pose_iter()

Pose.__iter__ = _Pose_residue_iterator

#    def __iter__(self):
#        def __pose_iter():
#            for i in range(self.total_residue()): yield self.residue(i+1)
#        return __pose_iter()



def add_extend(vectype):
  def extendfunc(vec,othervec):
    for i in othervec: vec.append(i)

  vectype.extend = extendfunc


for k,v in utility.__dict__.items():
  if k.startswith("vector1_"):
    add_extend(v)


def new_vector1_init(self,arg1=None,arg2=False):
  self.__old_init()
  if hasattr(arg1,"__iter__"): self.extend(arg1)
  elif type(arg1) is type(1):
    for i in xrange(arg1):
      self.append(arg2)


def replace_init(cls,init):
  cls.__old_init = cls.__init__
  cls.__init__ = init


def Vector1(l):
    ''' Create Vector1 object deducing type from the given list
    '''
    if   all( map(lambda x: type(x) == int, l) ): t = utility.vector1_int
    elif all( map(lambda x: type(x) == float or type(x) == int, l) ): t = utility.vector1_double
    elif all( map(lambda x: type(x) == str, l) ): t = utility.vector1_string
    elif all( map(lambda x: type(x) == bool, l) ): t = utility.vector1_bool
    else: raise Exception('Vector1: attemting to create vector of unknow type or mixed type vector init_list=' + str(l) )

    v = t()
    for i in l: v.append(i)
    return v


class PyJobDistributor:
	def __init__( self, pdb_name, nstruct, scorefxn ):
		self.pdb_name = pdb_name
		self.nstruct = nstruct
		self.current_num = 0		#current decoy number
		self.current_name = " "		#current decoy name
		self.job_complete = False	#job status
		self.scorefxn = scorefxn	#used for final score calculation
		self.native_pose = 0		#used for rmsd calculation
		self.additional_decoy_info = ' '  #used for any additional decoy information you want stored
		self.start_decoy()		#initializes the job distributor

	def start_decoy( self ):
		if self.job_complete == True:
			return
		i = 1
		file_exists = True
		while (file_exists == True and i <= self.nstruct):
			current_name = self.pdb_name + "_" + str(i) + ".pdb"
			if os.path.exists(current_name) == False:
				current_name_temp = current_name + ".in_progress"
				if os.path.exists(current_name_temp) == False:
					file_exists = False	#if such a file is not found, i is the current decoy #
					f = open(current_name_temp, 'w')
					f.write("this decoy is in progress")
					f.close()
					self.current_name = current_name
			i = i + 1
		self.current_num = i - 1
		if (file_exists == True):
			self.job_complete = True

	def output_decoy( self, pose):
		current_name_temp = self.current_name + ".in_progress"
		if os.path.exists(current_name_temp) == False:
			return

		dump_pdb(pose, self.current_name) #outputs pdb file
		os.remove(current_name_temp)

		score_tag = ".fasc"
		if (pose.is_fullatom() == False):
			score_tag = ".sc"

		scorefile = self.pdb_name + score_tag
		if os.path.exists(scorefile) == False:
			f = open(scorefile, 'w')
			f.write("pdb name: " + self.pdb_name + "     nstruct: " + str(self.nstruct) + '\n')
			f.close

		score = self.scorefxn(pose)	#calculates total score
		score_line = pose.energies().total_energies().weighted_string_of( self.scorefxn.weights())
		output_line = "filename: " + self.current_name + " total_score: " + str(round(score,2))
		if (self.native_pose != 0 ):	#calculates an rmsd if a native pose is defined
			rmsd = CA_rmsd(self.native_pose, pose )
			output_line = output_line + " rmsd: " + str(round(rmsd,2))
		f = open(scorefile, 'a')
		f.write(output_line + ' ' + score_line + self.additional_decoy_info + '\n') #outputs scorefile
		f.close

		self.start_decoy()

def generate_nonstandard_residue_set( params_list ):
	res_set = ChemicalManager.get_instance().nonconst_residue_type_set("fa_standard")
	res_set.read_files(params_list, ChemicalManager.get_instance().atom_type_set("fa_standard"), ChemicalManager.get_instance().mm_atom_type_set("fa_standard"))
    	return res_set

def generate_resfile_from_pose(pose,resfilename):
	f = open(resfilename, 'w')
	id = "NATRO"
	f.write(str(' start\n'))
	for i in range (1,pose.total_residue()+1):
		num = pose.pdb_info().number(i)
		chain = pose.pdb_info().chain(i)
		f.write(str(num).rjust(4) + str(chain).rjust(3) + str(id).rjust(7) + '  \n')
	f.close()

def generate_resfile_from_pdb(pdbfilename,resfilename):
	p = Pose(pdbfilename)
	generate_resfile_from_pose(p,resfilename)

def mutate_residue(pose, resid, new_res):

	if (pose.is_fullatom() == False):
		print "mutate_residue only works with fullatom poses"
		return

	scorefxn = create_score_function('standard')
	pack_task = TaskFactory.create_packer_task(pose)
	pack_task.initialize_from_command_line()

	v1 = rosetta.utility.vector1_bool()
	mut_res = aa_from_oneletter_code(new_res)

	for i in range(1,21):
		if (i == mut_res):
			v1.append(True)
		else:
			v1.append(False)

	for i in range(1,pose.total_residue()+1):
		if (i != resid):
			pack_task.nonconst_residue_task(i).prevent_repacking()

	pack_task.nonconst_residue_task(resid).restrict_absent_canonical_aas( v1 )

	packer = PackRotamersMover(scorefxn, pack_task)
	packer.apply(pose)
	return pose

def standard_task_factory():
	tf = TaskFactory()
	tf.push_back(InitializeFromCommandline())
	#tf.push_back(IncludeCurrent())
	tf.push_back(NoRepackDisulfides())
	return tf

def standard_packer_task(pose):
	tf = standard_task_factory()
	task = tf.create_task_and_apply_taskoperations(pose)
	return task

def add_extra_options():
    rosetta.protocols.abinitio.AbrelaxApplication.register_options()
    rosetta.protocols.abinitio.IterativeAbrelax.register_options()
    rosetta.protocols.abinitio.register_options_broker()


# PyMOL link code ----------------------------------------------------------------------------------
import uuid, socket, bz2
from array import array

#if Platform == "cygwin":
#        import gzip
#else:
#        import bz2

class PySocketClient:
    def __init__(self, udp_port=65000, udp_ip = '127.0.0.1'):
        self.udp_ip, self.udp_port = udp_ip, udp_port

        # of course this will not works... but it will be readjusted automatically
        self.last_accepted_packet_size = 1024*8  # ... well maybe next time...
        self.uuid = uuid.uuid4()
        self.sentCount = array('H', [0, 0, 0])  # packet_id, N, count
        self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)

    def sendMessage(self, msg):
        count = 1
        if len(msg) > self.last_accepted_packet_size:
            count = len(msg) /  self.last_accepted_packet_size + 1

        #print 'Sending messgae in %s packets...' % count

        for i in range(count):
            self.sentCount[1] = i
            self.sentCount[2] = count
            m = msg[i*self.last_accepted_packet_size:(i+1)*self.last_accepted_packet_size]
            self.sendRAWMessgae(m)

        self.sentCount[0] += 1


    def sendRAWMessgae(self, msg):
        buf = array('c', self.uuid.bytes)
        buf.extend( self.sentCount.tostring() )
        #buf.extend( array('H', [1,2]).tostring() )
        buf.extend( msg )

        self.socket.sendto(buf, (self.udp_ip, self.udp_port) )



class PyMOL_Mover(rosetta.protocols.moves.PyMover):
    def __init__(self, keep_history=False, update_energy=False, energy_type='total_energy'):
        rosetta.protocols.moves.PyMover.__init__(self)
        self.keep_history = keep_history
        self.update_energy = update_energy
        self.energy_type = energy_type
        self.link = PySocketClient()

    def getPoseName(self, pose):
        #if 'name' in pose.__dict__: return pose.name[:255]
        if not pose.pdb_info(): return 'pose'
        else:
            p1,p2,p3 = pose.pdb_info().name().rpartition('.pdb')
            name = p1 or p3
            return name[:255]


    def apply(self, pose):
        #print 'This PyMOL_Mover apply...', pose

        name = self.getPoseName(pose)
        #print name

        # Creating message...
        os = rosetta.utility.OStringStream();  pose.dump_pdb(os)
        #if Platform == "cygwin":
        #        message = 'PDB.gzip' + chr(len(name)) + name + gzip.compress(os.str())
        #else:
        message = 'PDB.bz2 ' + chr(self.keep_history) + chr(len(name)) + name + bz2.compress(os.str())
        self.link.sendMessage(message)

        if self.update_energy: self.send_specific_energy(pose , self.energy_type)



    def _send_RAW_Energies_old(self, pose, energyType, enegies):
        energyType = energyType[:255]
        name = self.getPoseName(pose)
        message = 'Ener.bz2' + chr(self.keep_history) + chr(len(name)) + name \
                             + chr(len(energyType)) + energyType \
                             + bz2.compress( array('f', enegies).tostring())
        self.link.sendMessage(message)


    def _send_RAW_Energies(self, pose, energyType, energies):
        energyType = energyType[:255]
        name = self.getPoseName(pose)

        energies = self._scale_energy(energies)

        e = ''
        info = pose.pdb_info()
        for i in range(len(energies)):
            if info:
                chain = info.chain(i+1)[0]
                res = info.number(i+1)
            else:
                chain = ' '
                res = i+1

            e += '%s%4d%02x' % (chain, res, energies[i])

        message = 'Ener.bz2' + chr(self.keep_history) + chr(len(name)) + name \
                             + chr(len(energyType)) + energyType \
                             + bz2.compress(e)
        self.link.sendMessage(message)


    def _scale_energy(self, energies):
        ''' scale give array from 0 to 255
        '''
        #for i in range(len(energies)):
        #    energies[i] = i*1./len(energies)

        r = [0]*len(energies)
        mi = min(energies)
        ma = max(energies)
        if ma - mi < 1e-100: ma+= 1e-100
        for i in range(len(energies)):
            r[i] = int( (energies[i] - mi)*255. / (ma-mi) )
        return r


    # energy output functions
    ################################################################################
    def send_energy(self, input_pose, energy_type='total_energy'):
        ''' Send cummulative energy to PyMOL
        '''
        if not input_pose.energies().energies_updated():
            print 'PyMOL_Mover::send_specific_energy: Energy is not updated, please score the pose first!'
            return

        #self.apply(input_pose)
        output = [0.]*input_pose.total_residue()
        if not energy_type == 'total_energy':
            score_type = score_type_from_name(energy_type)
        for i in range( 0 , len(output) ):
            if energy_type == 'total_energy':
                output[i] = input_pose.energies().residue_total_energy(i+1)
            else:
                output[i] = input_pose.energies().residue_total_energies(i+1)[score_type]
        self._send_RAW_Energies(input_pose,energy_type,output)


    # returns all variables in workspace matching target
    #    main_vars must be globals() or vars() when called
    #    i.e. pose=Pose(), target=pose will return matches=['pose']
    def _find_variable_name( target , main_vars ):
        # searches 'main_vars' for the object matching 'target', returns all matches
        matches = [0]
        matches = [index for index, value in main_vars.items() if value == target]
        if len(matches) > 1:
            print 'Consider reassignment, multiple variables have this object data'
        return matches



    '''
    def _output_energies( input_pose ):#, main_vars ):
        # could be implemented other ways, obtains a string of the pose name
        #pose_name = find_variable_name( input_pose , main_vars )[0]
        energies , weights = energies_to_list( input_pose )
        pymol = PyMOL_Mover()
        pymol.apply(input_pose)
        for i in range( 0 , len(energies) ):
            ##### KEY PORTION
            pymol.sendEnergies( input_pose , weights[i] , energies[i] )




    def energies_to_list( input_pose ):
        weights_obj = input_pose.energies().weights()
        weights = ['total_energy']
        for i in range( 1 , rosetta.core.scoring.end_of_score_type_enumeration + 1 ):
            score_name = name_from_score_type( ScoreType(i) )
            score_type = score_type_from_name(score_name)
            if not weights_obj[score_type] == 0.0:
                weights.append(score_name)
        output = []
        for j in range( 0 , len(weights)):
            if not weights[j] == 'total_energy':
                score_type = score_type_from_name(weights[j])
            temp = [0.]*input_pose.total_residue()
            output.append(temp)
            for k in range( 0 , input_pose.total_residue() ):
                if not weights[j] == 'total_energy':
                    output[j][k] = input_pose.energies().residue_total_energies(k+1)[score_type]
                else:
                    output[j][k] = input_pose.energies().residue_total_energy(k+1)
        return output , weights
    '''
# --------------------------------------------------------------------------------------------------
class PyMOL_Observer(rosetta.core.pose.PosePyObserver):
    ''' Responds to general events (changes of geometry and energies) to pose and sends updates to
        pymol.
    '''
    def __init__(self):
        rosetta.core.pose.PosePyObserver.__init__(self)
        self.pymol = rosetta.PyMOL_Mover()

    def generalEvent(self, event):
        #print 'PyMOL_Observer...'
        #print 'PyMOL_Observer:generalEvent', event.pose
        self.pymol.apply(event.pose)

# --------------------------------------------------------------------------------------------------



def init(*args):
    utility.set_pyexit_callback()  # make sure that all mini sys exit just generate exceptions

    #if not args: args = ["app"
    #                  ,"-database"
    #                  ,os.path.join( os.path.expanduser("~"), "minirosetta_database")
    #                  ]
    #if not args: args = ["app", "-database", "minirosetta_database", "-ex1", "-ex2aro"]

    # Figure out database dir...
    if os.path.isdir('minirosetta_database'):
        database = os.path.abspath('minirosetta_database')
        print 'Found minirosetta_database at %s, using it...' % database

    elif 'PYROSETTA_DATABASE' in os.environ:
        database = os.path.abspath( os.environ['PYROSETTA_DATABASE'] )
        print 'PYROSETTA_DATABASE environment variable was set to: %s... using it...' % database

    elif os.path.isdir(os.environ['HOME'] + '/minirosetta_database'):
        database = os.path.abspath(os.environ['HOME'] + '/minirosetta_database')
        print 'Found minirosetta_database at home folder, ie: %s, using it...' % database

    elif sys.platform == "cygwin" and os.path.isdir('/minirosetta_database'):
        database = os.path.abspath('/minirosetta_database')
        print 'Found minirosetta_database at root folder, ie: %s, using it...' % database

    else:
        print 'Could not found minirosetta_database! Check your paths or set PyRosetta environment vars. Exiting...'
        sys.exit(1)


    if not args: args = ["app", "-database", database, "-ex1", "-ex2aro"]

    v = utility.vector1_string()
    #v = utility.vector1_string('--database %s' % database)
    #v = Vector1('--database %s' % database)
    v.extend(args)
    #v.append('--database')
    #v.append(database)

    print version()
    core.init(v)


def version():
    return "PyRosetta-%s retrieved from: %s" % (rosetta.core.minirosetta_svn_version(), rosetta.core.minirosetta_svn_url())
