#!/usr/bin/python3.13

from imp import find_module,load_module
from sys import path as sys_path
from os import path as os_path
from os import makedirs
from argparse import ArgumentParser,ArgumentDefaultsHelpFormatter
from subprocess import check_output

sys_path.append('/usr/lib/python3.13/site-packages')

from ufo_interface import s_lorentz, write_model, write_run_card
from ufo_interface.templates import sconstruct_template
from ufo_interface.message import error, progress

def parse_args():
      arg_parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
      arg_parser.add_argument("ufo_path",
                              help = "Path to input UFO model directory")
      arg_parser.add_argument("--ncore",
                              default=1,
                              help = "Number of cores used for compilation")
      arg_parser.add_argument("--modelflags",
                              default = "-g -O0 -fno-var-tracking --std=c++11",
                              help = 'Compiler flags for model source code')
      arg_parser.add_argument("--lorentzflags",
                              default = "-g -O2 -ffast-math --std=c++11",
                              help = 'Compiler flags for lorentz source code')
      arg_parser.add_argument("--includedir",
                              default = '/usr/include/SHERPA-MC',
                              help = 'Path to Sherpa headers')
      arg_parser.add_argument("--libdir",
                              default = '/usr/lib64/SHERPA-MC',
                              help = 'Path to Sherpa libraries')
      arg_parser.add_argument("--installdir",
                              default = '/usr/lib64/SHERPA-MC',
                              help = 'Installation path for new shared Model library')
      arg_parser.add_argument("--noopt",
                              action="store_true",
                              help = "Disable optimization of Lorentz calculators")
      arg_parser.add_argument("--nocompile",
                              action="store_true",
                              help = "Do not compile generated UFO model")

      return arg_parser.parse_args()

def sort_orders(orders):
      """Sort the coupling orders such that the first item is the QCD order
      and the second item is the QED order. Raise an exception if
      either of them is not contained in the model

      """
      # Find index of QCD coupling in the 'orders' list
      i_qcd = None
      for i,ord in enumerate(orders):
            if ord.name == 'QCD':
                  i_qcd = i
                  break
      if i_qcd == None:
            raise ufo_exception('No QCD coupling found in model')
            
      # Swap first item in list with QCD coupling
      orders[0], orders[i_qcd] = orders[i_qcd], orders[0]
            
      # Find index of QED coupling in the 'orders' list
      i_qed = None
      for i,ord in enumerate(orders):
            if ord.name == 'QED':
                  i_qed = i
                  break
      if i_qed == None:
            raise ufo_exception('No QED coupling found in model')

      # Swap second item in list with QED coupling
      orders[1], orders[i_qed] = orders[i_qed], orders[1]
      
      assert(orders[0].name=='QCD' and orders[1].name=='QED')
      return orders

def check_model(model_name, model_path):
      # Check for conflicts with built-in models
      if model_name in ["SM", "HEFT", "TauPi"]:
            error("Model name {0} conflicts with built-in model. Please rename your UFO model directory.".format(model_name))
            exit(1)
            
      # Try to import the model to check if UFO path is ok
      try:
            f, pathn, desc = find_module(model_name, [model_path])
            model = load_module(model_name, f, pathn, desc)
                  
      except ImportError as err:
            error("Could not import UFO model from input path \"{0}\", make sure this is a path to a valid UFO model".format(args.ufo_path))
            exit(1)

      # For NLO models: need to filter out CT_Couplings
      model.non_ct_couplings = model.all_couplings
      if hasattr(model, "all_CTvertices"):
            non_ct_vertices  = [vtx for vtx in model.all_vertices if vtx not in model.all_CTvertices]
            model.non_ct_couplings = sum([vtx.couplings.values() for vtx in non_ct_vertices], [])

      # For NLO models: filter out lorentz structures 
      # that ONLY apper in CT vertices as they are somtimes ill-formatted
      needed_names = []
      for v in model.all_vertices:
            needed_names += [l.name for l in v.lorentz]
      model.all_lorentz = [l for l in model.all_lorentz if l.name in needed_names]

      # Sort coupling orders to ensure first order is QCD, second
      # order is QED
      model.all_orders = sort_orders(model.all_orders)

      return model

def make_output_dir(path):
      if not os_path.exists(path):
            makedirs(path)
      elif not os_path.isdir(path):
            error("Could not write to output path \"{0}\", file with the same name existing".format(path))
            exit(1)

if __name__ == "__main__":


      # Extract command line args
      args = parse_args()

      # Check the path to the UFO models
      arg_path = os_path.abspath(args.ufo_path)
      model_path, model_name = os_path.split((arg_path.rstrip('/')))
      model = check_model(model_name, model_path)

      # Output paths 
      out_dir = '{0}/.sherpa'.format(arg_path)
      sconstruct_file_path = os_path.join(out_dir, 'SConstruct')
      model_out_path = os_path.join(out_dir, 'Model.C')
      make_output_dir(out_dir)
      
      # Need this list in order to write model source code
      lorentzes = [s_lorentz(l) for l in model.all_lorentz]

      # Write model source code
      progress("Generating model source code")
      parameter_map = write_model(model, model_name, model_out_path)
      
      # Write lorentz calculator source code
      optimize = not args.noopt
      for lor in lorentzes:
            if not lor.has_ghosts():
                  progress("Generating source code for lorentz calculator '{0}'".format(lor.name()))
                  lor.write(out_dir,optimize,parameter_map)

      # Generate and write SConstruct file
      lib_name = 'Sherpa{0}'.format(model_name)
      with open(sconstruct_file_path, 'w') as sconstruct_file:
            sconstruct_file.write(sconstruct_template.substitute(libname = lib_name,
                                                                 includedir = args.includedir,
                                                                 libdir = args.libdir,
                                                                 installdir = args.installdir,
                                                                 modelflags = args.modelflags,
                                                                 lorentzflags = args.lorentzflags))

      # compile and install
      if not args.nocompile:
          progress("Compiling sources using scons")
          sconsargs = ['scons','-C',out_dir,'-j{0}'.format(args.ncore), 'install'] 
          if (args.ncore>1):
              sconsargs.append('-j{0}'.format(args.ncore))
          progress(check_output(sconsargs))

      # write example run card to working dir
      run_card_path="{0}_Example_Run.dat".format(model_name)
      while(os_path.exists(run_card_path)):
            run_card_path="_"+run_card_path
      progress("Writing example run card '{0}' to working directory".format(run_card_path))
      write_run_card(model, model_name, run_card_path)

      progress("Finished generating model '{0}'\nPlease cite Eur.Phys.J. C, 75 3 (2015) 137\nif you make use of Sherpa's BSM features".format(model_name))

      exit(0)
