# :noTabs=true:
# (c) Copyright Rosetta Commons Member Institutions.
# (c) This file is part of the Rosetta software suite and is made available under license.
# (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
# (c) For more information, see http://www.rosettacommons.org. Questions about this can be
# (c) addressed to University of Washington UW TechTransfer, email: license@u.washington.edu.

import os, sys
import os.path

import warnings
warnings.filterwarnings("ignore", "to-Python converter for .+ already registered; second conversion method ignored.", RuntimeWarning, "^rosetta\\.")

import utility, core

import rosetta.utility
import rosetta.utility.file
import rosetta.utility.pointer

import rosetta.numeric

import rosetta.core
#import rosetta.core.coarse
import rosetta.core.chemical
import rosetta.core.conformation
#import rosetta.core.grid
import rosetta.core.id
import rosetta.core.io
import rosetta.core.io.pdb
import rosetta.core.fragment
#import rosetta.core.graph
import rosetta.core.kinematics
#import rosetta.core.kinematics.tree
#import rosetta.core.mm
#import rosetta.core.optimization
import rosetta.core.pack
import rosetta.core.pack.task
import rosetta.core.pose
#import rosetta.core.pose.metrics
import rosetta.core.scoring
import rosetta.core.scoring.hbonds
#import rosetta.core.scoring.dna
#import rosetta.core.scoring.rna
#import rosetta.core.scoring.methods
#import rosetta.core.scoring.trie
#import rosetta.core.sequence
import rosetta.protocols.moves
import rosetta.protocols.abinitio
import rosetta.protocols.docking
import rosetta.protocols.loops
import rosetta.protocols.relax

from rosetta.core.id import *
from rosetta.core.conformation import *
from rosetta.core.chemical import *
from rosetta.core.pose import Pose
from rosetta.core.io.pdb import pose_from_pdb, dump_pdb
from rosetta.core.scoring import *
from rosetta.core.kinematics import *
from rosetta.core.fragment import *
from rosetta.core.pack.task import *
from rosetta.core.pack.task.operation import *
from rosetta.protocols.moves import *
from rosetta.protocols.abinitio import *
from rosetta.protocols.docking import *
from rosetta.protocols.loops import *
from rosetta.protocols.relax import *

# add iter property to Pose
def _Pose_residue_iterator(obj):
    def __pose_iter():
        for i in range(obj.total_residue()): yield obj.residue(i+1)
    return __pose_iter()

Pose.__iter__ = _Pose_residue_iterator

#    def __iter__(self):
#        def __pose_iter():
#            for i in range(self.total_residue()): yield self.residue(i+1)
#        return __pose_iter()



def add_extend(vectype):
  def extendfunc(vec,othervec):
    for i in othervec: vec.append(i)

  vectype.extend = extendfunc


for k,v in utility.__dict__.items():
  if k.startswith("vector1_"):
    add_extend(v)


def new_vector1_init(self,arg1=None,arg2=False):
  self.__old_init()
  if hasattr(arg1,"__iter__"): self.extend(arg1)
  elif type(arg1) is type(1):
    for i in xrange(arg1):
      self.append(arg2)


def replace_init(cls,init):
  cls.__old_init = cls.__init__
  cls.__init__ = init


def Vector1(l):
    ''' Create Vector1 object deducing type from the given list
    '''
    if   all( map(lambda x: type(x) == int, l) ): t = utility.vector1_int
    elif all( map(lambda x: type(x) == float or type(x) == int, l) ): t = utility.vector1_double
    elif all( map(lambda x: type(x) == str, l) ): t = utility.vector1_string
    else: raise Exception('Vector1: attemting to create vector of unknow type or mixed type vector init_list=' + str(l) )

    v = t()
    for i in l: v.append(i)
    return v


class PyJobDistributor:
	def __init__( self, pdb_name, nstruct, scorefxn ):
		self.pdb_name = pdb_name
		self.nstruct = nstruct
		self.current_num = 0		#current decoy number
		self.current_name = " "		#current decoy name
		self.job_complete = False	#job status
		self.scorefxn = scorefxn	#used for final score calculation
		self.native_pose = 0		#used for rmsd calculation
		self.additional_decoy_info = ' '  #used for any additional decoy information you want stored
		self.start_decoy()		#initializes the job distributor

	def start_decoy( self ):
		if self.job_complete == True:
			return
		i = 1
		file_exists = True
		while (file_exists == True and i <= self.nstruct):
			current_name = self.pdb_name + "_" + str(i) + ".pdb"
			if os.path.exists(current_name) == False:
				current_name_temp = current_name + ".in_progress"
				if os.path.exists(current_name_temp) == False:
					file_exists = False	#if such a file is not found, i is the current decoy #
					f = open(current_name_temp, 'w')
					f.write("this decoy is in progress")
					f.close()
					self.current_name = current_name
			i = i + 1
		self.current_num = i - 1
		if (file_exists == True):
			self.job_complete = True

	def output_decoy( self, pose):
		current_name_temp = self.current_name + ".in_progress"
		if os.path.exists(current_name_temp) == False:
			return

		dump_pdb(pose, self.current_name) #outputs pdb file
		os.remove(current_name_temp)

		score_tag = ".fasc"
		if (pose.is_fullatom() == False):
			score_tag = ".sc"

		scorefile = self.pdb_name + score_tag
		if os.path.exists(scorefile) == False:
			f = open(scorefile, 'w')
			f.write("pdb name: " + self.pdb_name + "     nstruct: " + str(self.nstruct) + '\n')
			f.close

		score = self.scorefxn(pose)	#calculates total score
		score_line = pose.energies().total_energies().weighted_string_of( self.scorefxn.weights())
		output_line = "filename: " + self.current_name + " total_score: " + str(round(score,2))
		if (self.native_pose != 0 ):	#calculates an rmsd if a native pose is defined
			rmsd = CA_rmsd(self.native_pose, pose )
			output_line = output_line + " rmsd: " + str(round(rmsd,2))
		f = open(scorefile, 'a')
		f.write(output_line + ' ' + score_line + self.additional_decoy_info + '\n') #outputs scorefile
		f.close

		self.start_decoy()

def generate_nonstandard_residue_set( params_list ):
	res_set = ChemicalManager.get_instance().nonconst_residue_type_set("fa_standard")
	res_set.read_files(params_list, ChemicalManager.get_instance().atom_type_set("fa_standard"), ChemicalManager.get_instance().mm_atom_type_set("fa_standard"))
    	return res_set

def generate_resfile_from_pose(pose,resfilename):
	f = open(resfilename, 'w')
	id = "NATRO"
	f.write(str(' start\n'))
	for i in range (1,pose.total_residue()+1):
		num = pose.pdb_info().number(i)
		chain = pose.pdb_info().chain(i)
		f.write(str(num).rjust(4) + str(chain).rjust(3) + str(id).rjust(7) + '  \n')
	f.close()

def generate_resfile_from_pdb(pdbfilename,resfilename):
	p = Pose(pdbfilename)
	generate_resfile_from_pose(p,resfilename)

def mutate_residue(pose, resid, new_res):

	if (pose.is_fullatom() == False):
		print "mutate_residue only works with fullatom poses"
		return

	scorefxn = create_score_function('standard')
	pack_task = TaskFactory.create_packer_task(pose)
	pack_task.initialize_from_command_line()

	v1 = rosetta.utility.vector1_bool()
	mut_res = aa_from_oneletter_code(new_res)

	for i in range(1,21):
		if (i == mut_res):
			v1.append(True)
		else:
			v1.append(False)

	for i in range(1,pose.total_residue()+1):
		if (i != resid):
			pack_task.nonconst_residue_task(i).prevent_repacking()

	pack_task.nonconst_residue_task(resid).restrict_absent_canonical_aas( v1 )

	packer = PackRotamersMover(scorefxn, pack_task)
	packer.apply(pose)
	return pose

def standard_task_factory():
	tf = TaskFactory()
	tf.push_back(InitializeFromCommandline())
	tf.push_back(IncludeCurrent())
	tf.push_back(NoRepackDisulfides())
	return tf

def standard_packer_task(pose):
	tf = standard_task_factory()
	task = tf.create_task_and_apply_taskoperations(pose)
	return task

def add_extra_options():
    rosetta.protocols.abinitio.AbrelaxApplication.register_options()
    rosetta.protocols.abinitio.IterativeAbrelax.register_options()
    rosetta.protocols.abinitio.register_options_broker()


def init(*args):
    utility.set_pyexit_callback()  # make sure that all mini sys exit just generate exceptions

    #if not args: args = ["app"
    #                  ,"-database"
    #                  ,os.path.join( os.path.expanduser("~"), "minirosetta_database")
    #                  ]

    # Figure out database dir...
    if os.path.isdir('minirosetta_database'):
        database = os.path.abspath('minirosetta_database')
        print 'Found minirosetta_database at %s, using it...' % database
    elif 'PYROSETTA_DATABASE' in os.environ:
        database = os.path.abspath( os.environ['PYROSETTA_DATABASE'] )
        print 'PYROSETTA_DATABASE environment variable was set to: %s... using it...' % database
    elif os.path.isdir(os.environ['HOME'] + '/minirosetta_database'):
        database = os.path.abspath(os.environ['HOME'] + '/minirosetta_database')
        print 'Found minirosetta_database at home folder, ie: %s, using it...' % database
    else:
        print 'Could not found minirosetta_database! Check your paths or set PyRosetta environment vars. Exiting...'
        sys.exit(1)





    if not args: args = ["app", "-database", database, "-ex1", "-ex2aro"]

    v = utility.vector1_string()
    v.extend(args)
    print version()
    core.init(v)


def version():
    return "PyRosetta-Release1.0-%s retrieved from: %s" % (rosetta.core.minirosetta_svn_version(), rosetta.core.minirosetta_svn_url())
