#!/usr/bin/python2

# Copyright 2014 Jared Boone <jared@sharebrained.com>
#
# This file is part of gr-tpms.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; see the file COPYING.  If not, write to
# the Free Software Foundation, Inc., 51 Franklin Street,
# Boston, MA 02110-1301, USA.
#

import time

from math import floor, ceil, sqrt

from argparse import ArgumentParser

from gnuradio import analog
from gnuradio import blocks
from gnuradio import digital
from gnuradio import gr
from gnuradio import filter
from gnuradio.filter import firdes
import pmt

import tpms

from tpms.packet_check import packet_decode

from tpms.source import source_hackrf, source_rtlsdr, source_file
from tpms.ask import ask_channel_filter
from tpms.fsk import fsk_center_tracking, fsk_demodulator
from tpms.decode import clock_recovery

class packet_sink(gr.basic_block):
	def __init__(self, output_raw=False):
		super(packet_sink, self).__init__(
			"packet_handler",
			in_sig=None,
			out_sig=None,
		)
		self.output_raw = output_raw

		port = pmt.intern("packet_sink")
		self.message_port_register_in(port)

		self.set_msg_handler(port, self._handle)

	def _handle(self, message):
		d = pmt.to_python(message)
		data_ref = pmt.dict_ref(message, pmt.intern("data"), pmt.PMT_NIL)
		data = pmt.u8vector_elements(data_ref)
		data_s = ''.join([str(c) for c in data])
		d['data'] = data_s
		packet_decode(d, output_raw=self.output_raw)

def queue_handler(queue):
	def queue_handler_fn():
		while True:
			message = queue.delete_head()
			print('message')
	return queue_handler_fn

class demodulate_ask(gr.hier_block2):
	def __init__(self, if_sampling_rate, packet_sink):
		super(demodulate_ask, self).__init__(
			"demodulate_ask",
			gr.io_signature(1, 1, gr.sizeof_gr_complex*1),
			gr.io_signature(0, 0, 0),
		)

		self.complex_to_mag = blocks.complex_to_mag(1)
		self.connect((self, 0), (self.complex_to_mag, 0))

		filter_output = False
		bits_output = False
		cor_output = False

		configuration = (
			(
				4667, 40, (
					(4040, ('110011001100110011001011',)),
					(4667, ('10101010101010101010101010101001',)),
				),
			), (
				8400, 20, (
					(8157, ('111101010101010101010101010101011110',)),
					(8400, ('1010101010101010101010101010110',)),
				),
			), (
				10000, 16, (
					(10000, ('11001111101010101001',)),
				),
			),
		)

		for filter_rate, decimation, symbol_rates in configuration:
			input_rate = float(if_sampling_rate) / decimation

			channel_filter = ask_channel_filter(if_sampling_rate, decimation=decimation, symbol_rate=filter_rate)
			self.connect((self.complex_to_mag, 0), (channel_filter, 0))

			if filter_output:
				input_rate_s = ('%.3f' % (input_rate / 1e3)).replace('.', 'k')
				filename = 'lpf_ask_%s.i16' % (input_rate_s,)
				float_to_short = blocks.float_to_short(1, 32767 * sqrt(2) / 2 / 2)
				self.connect((channel_filter, 0), (float_to_short, 0))
				file_sink_filter = blocks.file_sink(gr.sizeof_short*1, filename, False)
				self.connect((float_to_short, 0), (file_sink_filter, 0))

			for symbol_rate, access_codes in symbol_rates:
				if not isinstance(access_codes, tuple):
					raise RuntimeError('Access code formatting error')

				reclock = clock_recovery(input_rate, symbol_rate)
				self.connect((channel_filter, 0), (reclock, 0))

				if bits_output:
					input_rate_s = ('%.3f' % (input_rate / 1e3)).replace('.', 'k')
					symbol_rate_s = ('%.3f' % (symbol_rate / 1e3)).replace('.', 'k')
					filename = 'bit_ask_%s_%s.i8' % (input_rate_s, symbol_rate_s)
					file_sink_bits = blocks.file_sink(gr.sizeof_char*1, filename, False)
					self.connect((reclock, 0), (file_sink_bits, 0))

				for access_code in access_codes:
					correlator = digital.correlate_access_code_bb(access_code, 0)
					self.connect((reclock, 0), (correlator, 0))

					attributes = pmt.make_dict()
					attributes = pmt.dict_add(attributes, pmt.intern("modulation"), pmt.intern("ask"))
					attributes = pmt.dict_add(attributes, pmt.intern("access_code"), pmt.intern(access_code))
					attributes = pmt.dict_add(attributes, pmt.intern("symbol_rate"), pmt.from_long(int(round(symbol_rate))))

					frame_sink = tpms.fixed_length_frame_sink(256, attributes)
					self.connect((correlator, 0), (frame_sink, 0))

					self.msg_connect(frame_sink, "packet_source", packet_sink, "packet_sink")

					if cor_output:
						input_rate_s = ('%.3f' % (input_rate / 1e3)).replace('.', 'k')
						symbol_rate_s = ('%.3f' % (symbol_rate / 1e3)).replace('.', 'k')
						filename = 'cor_ask_%s_%s_%s.i8' % (input_rate_s, symbol_rate_s, access_code)
						file_sink_correlator = blocks.file_sink(gr.sizeof_char*1, filename, False)
						self.connect((correlator, 0), (file_sink_correlator, 0))

class demodulate_fsk(gr.hier_block2):
	def __init__(self, if_sampling_rate, packet_sink):
		super(demodulate_fsk, self).__init__(
			"demodulate_fsk",
			gr.io_signature(1, 1, gr.sizeof_gr_complex*1),
			gr.io_signature(0, 0, 0),
		)

		self.center_tracking = fsk_center_tracking(if_sampling_rate)
		self.connect((self, 0), (self.center_tracking, 0))

		configuration = (
			(38400, 10, 19200, (
				(19220, ('01010101010101010101010101010110',)),
			)),
			(35000, 10, 20000, (
				(19440, ('01010101010101010101010101010110',)),
			)),
			(30000, 10, 20000, (
				(19250, ('00110011001100110011001100110011010',)),
				(19950, ('110110101110001',)),
				(20480, ('110110101110001',)),
			)),
			(20000, 10, 20000, (
				(19224, ('1010101010101010101010101010100110',)),
			)),
			(40000, 10, 10000, (
				(9910, ('00111111001',)),
			)),
			(25000, 8, 20000, (
				(20040, ('010101010011110',)),
			)),
		)

		for deviation, decimation, channel_rate, symbol_rates in configuration:
			input_rate = float(if_sampling_rate) / decimation

			demodulator = fsk_demodulator(if_sampling_rate, 0, deviation, decimation, channel_rate)
			self.connect((self.center_tracking, 0), (demodulator, 0))

			for symbol_rate, access_codes in symbol_rates:
				if not isinstance(access_codes, tuple):
					raise RuntimeError('Access code formatting error')

				reclock = clock_recovery(input_rate, symbol_rate)
				self.connect((demodulator, 0), (reclock, 0))

				for access_code in access_codes:
					correlator = digital.correlate_access_code_bb(access_code, 0)
					self.connect((reclock, 0), (correlator, 0))

					attributes = pmt.make_dict()
					attributes = pmt.dict_add(attributes, pmt.intern("modulation"), pmt.intern("fsk"))
					attributes = pmt.dict_add(attributes, pmt.intern("deviation"), pmt.from_long(int(round(deviation))))
					attributes = pmt.dict_add(attributes, pmt.intern("access_code"), pmt.intern(access_code))
					attributes = pmt.dict_add(attributes, pmt.intern("symbol_rate"), pmt.from_long(int(round(symbol_rate))))

					frame_sink = tpms.fixed_length_frame_sink(256, attributes)
					self.connect((correlator, 0), (frame_sink, 0))

					self.msg_connect(frame_sink, "packet_source", packet_sink, "packet_sink")

class top_block(gr.top_block):
	def __init__(self, source, args):
		super(top_block, self).__init__(
			"top_block"
		)

		if source == 'hackrf':
			self.source = source_hackrf(args.tuned_frequency, args.if_rate)
		elif source == 'rtlsdr':
			self.source = source_rtlsdr(args.tuned_frequency, args.if_rate)
		elif source == 'file':
			self.source = source_file(args.file)
		else:
			raise RuntimeError('No source specified')

		if args.bursts:
			self.burst_detector = tpms.burst_detector()

			self.burst_file_sink = blocks.tagged_file_sink(gr.sizeof_gr_complex*1, args.if_rate)
			self.connect((self.source, 0), (self.burst_detector, 0))
			self.connect((self.burst_detector, 0), (self.burst_file_sink, 0))

		self.packet_sink = packet_sink(output_raw=args.raw)

		self.demodulate_ask = demodulate_ask(args.if_rate, self.packet_sink)
		self.connect((self.source, 0), (self.demodulate_ask, 0))

		self.demodulate_fsk = demodulate_fsk(args.if_rate, self.packet_sink)
		self.connect((self.source, 0), (self.demodulate_fsk, 0))

def main():
	parser = ArgumentParser()
	parser.add_argument('-f', '--file', type=str, default=None, help="Input file path for offline processing")
	parser.add_argument('-s', '--source', type=str, default=None, help="Source for live baseband data (hackrf, rtlsdr)")
	parser.add_argument('-b', '--bursts', action="store_true", help="Save bursts of significant energy to separate files")
	parser.add_argument('-r', '--raw', action="store_true", help="Include raw (undecoded) bits in output")
	parser.add_argument('-i', '--if-rate', type=float, default=None, help="IF sampling rate")
	parser.add_argument('-t', '--tuned-frequency', type=float, default=None, help="Receiver source center frequency")
	args = parser.parse_args()

	if args.file:
		source = 'file'
	else:
		source = args.source

	tb = top_block(source, args)
	tb.start()

	try:
		tb.wait()
	except KeyboardInterrupt:
		tb.stop()

if __name__ == '__main__':
	main()
