#!/usr/bin/python
import sys, itertools, optparse

optParser = optparse.OptionParser( 
   
   usage = "python %prog [options] <flattened_gff_file> <s/bam_file> <output_file>",
   
   description=
      "This script counts how many reads in <s/bam_file> fall onto each exonic " +
      "part given in <flattened_gff_file> and outputs a list of counts in " +
      "<output_file>, for further analysis with the DEXSeq Bioconductor package. " +
      "(Notes: The <flattened_gff_file> should be produced with the script " +
      "prepare_exon_annotation_ensembl/refseq.py). <s/bam_file> may be '-' to indicate standard input.",
      
   epilog = 
      "Written by Simon Anders (sanders@fs.tum.de), European Molecular Biology " +
      "Laboratory (EMBL). (c) 2010. Released under the terms of the GNU General " +
      "Public License v3. Part of the 'DEXSeq' package. " + 
      "Modified by Xi Wang (xi.wang@newcastle.edu.au), UoN, Australia" )
      
optParser.add_option( "-p", "--paired", type="choice", dest="paired",
   choices = ( "no", "yes" ), default = "no",
   help = "'yes' or 'no'. Indicates whether the data is paired-end (default: no)" )

optParser.add_option( "-s", "--stranded", type="choice", dest="stranded",
   choices = ( "yes", "no", "reverse" ), default = "yes",
   help = "'yes', 'no', or 'reverse'. Indicates whether the data is " +
      "from a strand-specific assay (default: yes ). " +
      "Be sure to switch to 'no' if you use a non strand-specific RNA-Seq library " +
      "preparation protocol. 'reverse' inverts strands and is neede for certain " +
      "protocols, e.g. paired-end with circularization."  )
   
optParser.add_option( "-a", "--minaqual", type="int", dest="minaqual",
   default = 10,
   help = "skip all reads with alignment quality lower than the given " +
      "minimum value (default: 10)" )

optParser.add_option( "-b", "--bam", type="choice", dest="bam",
   choices = ( "no", "yes" ), default = "no",
   help = "'yes' or 'no'. Indicates whether the data is BAM instead of SAM (default: no)" )
   
if len( sys.argv ) == 1:
   optParser.print_help()
   sys.exit(1)

(opts, args) = optParser.parse_args()

if len( args ) != 3:
   sys.stderr.write( sys.argv[0] + ": Error: Please provide three arguments.\n" )
   sys.stderr.write( "  Call with '-h' to get usage information.\n" )
   sys.exit( 1 )

try:
   import HTSeq
except ImportError:
   sys.stderr.write( "Could not import HTSeq. Please install the HTSeq Python framework\n" )   
   sys.stderr.write( "available from http://www-huber.embl.de/users/anders/HTSeq\n" )   
   sys.exit(1)

gff_file = args[0]
sam_file = args[1]
out_file = args[2]
stranded = opts.stranded == "yes" or opts.stranded == "reverse"
reverse = opts.stranded == "reverse"
is_PE = opts.paired == "yes"
is_BAM = opts.bam == "yes"
minaqual = opts.minaqual

if sam_file == "-":
   sam_file = sys.stdin


# Step 1: Read in the GFF file as generated by aggregate_genes.py
# and put everything into a GenomicArrayOfSets

features = HTSeq.GenomicArrayOfSets( "auto", stranded=stranded )     
for f in  HTSeq.GFF_Reader( gff_file ):
   if f.type == "exonic_part":
      f.name = f.attr['gene_id'] + ":" + f.attr['exonic_part_number']
      features[f.iv] += f

# initialise counters
num_reads = 0
counts = {}
counts[ '_empty' ] = 0
counts[ '_ambiguous' ] = 0
counts[ '_lowaqual' ] = 0
counts[ '_notaligned' ] = 0

# put a zero for each feature ID
for iv, s in features.steps():
   for f in s:
      counts[ f.name ] = 0

#We need this little helper below:
def reverse_strand( s ):
   if s == "+":
      return "-"
   elif s == "-":
      return "+"
   else:
      raise SystemError, "illegal strand"

# Now go through the aligned reads

if not is_BAM: 
  tmp_obj = HTSeq.SAM_Reader( sam_file )
else: 
  tmp_obj = HTSeq.BAM_Reader( sam_file )

if not is_PE:

   num_reads = 0
#   for a in HTSeq.SAM_Reader( sam_file ):
   for a in tmp_obj:
      if not a.aligned:
         counts[ '_notaligned' ] += 1
         continue
      if a.aQual < minaqual:
         counts[ '_lowaqual' ] += 1
         continue
      rs = set()
      for cigop in a.cigar:
         if cigop.type != "M":
            continue
         if reverse:
            cigop.ref_iv.strand = reverse_strand( cigop.ref_iv.strand )
         for iv, s in features[cigop.ref_iv].steps( ):
            rs = rs.union( s )
      set_of_gene_names = set( [ f.name.split(":")[0] for f in rs ] )
      if len( set_of_gene_names ) == 0:
         counts[ '_empty' ] += 1
      elif len( set_of_gene_names ) > 1:
         counts[ '_ambiguous' ] +=1
      else:
         for f in rs:
            counts[ f.name ] += 1
      num_reads += 1
      if num_reads % 100000 == 0:
         sys.stderr.write( "%d reads processed.\n" % num_reads )

else: # paired-end

   num_reads = 0
#   for af, ar in HTSeq.pair_SAM_alignments( HTSeq.SAM_Reader( sam_file ) ):
   for af, ar in HTSeq.pair_SAM_alignments( tmp_obj ):
      rs = set()
      if af and ar and not af.aligned and not ar.aligned:
         counts[ '_notaligned' ] += 1
         continue
      if af and ar and not af.aQual < minaqual and ar.aQual < minaqual:
         counts[ '_lowaqual' ] += 1
         continue
      if af and af.aligned and af.aQual >= minaqual and af.iv.chrom in features.chrom_vectors.keys():
         for cigop in af.cigar:
            if cigop.type != "M":
               continue
            if reverse:
               cigop.ref_iv.strand = reverse_strand( cigop.ref_iv.strand )
            for iv, s in features[cigop.ref_iv].steps():
               rs = rs.union( s )
      if ar and ar.aligned and ar.aQual >= minaqual and ar.iv.chrom in features.chrom_vectors.keys():
         for cigop in ar.cigar:
            if cigop.type != "M":
               continue
            if not reverse:
               cigop.ref_iv.strand = reverse_strand( cigop.ref_iv.strand )
            for iv, s in features[cigop.ref_iv].steps():
                  rs = rs.union( s )
      set_of_gene_names = set( [ f.name.split(":")[0] for f in rs ] )
      if len( set_of_gene_names ) == 0:
         counts[ '_empty' ] += 1
      elif len( set_of_gene_names ) > 1:
         counts[ '_ambiguous' ] = 0
      else:
         for f in rs:
            counts[ f.name ] += 1
      num_reads += 1
      if num_reads % 100000 == 0:
         sys.stderr.write( "%d reads processed.\n" % num_reads )

 
# Step 3: Write out the results

fout = open( out_file, "w" )
for fn in sorted( counts.keys() ):
   fout.write( "%s\t%d\n" % ( fn, counts[fn] ) )
fout.close()
