#!/usr/bin/env python3
# Compare SMB packets from 2 network capture files
#
# Copyright (C) 2017 Aurelien Aptel <aurelien.aptel@gmail.com>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import argparse
import subprocess
import re
import tempfile
import curses
import os

def strip_packet_no(arg):
    try:
        i = arg.rindex(":")
        return arg[:i]
    except Exception:
        return arg

def split_packet_arg(arg):
    try:
        m = re.match(r'^(.+):(\d+)$', arg)
        fn = m.group(1)
        no = int(m.group(2))
        return (fn, no)
    except Exception:
        raise Exception("invalid packet:no specified")


def smb_summaries(pcap):
    out = subprocess.check_output(['tshark', '-r', pcap, '-Y', '!browser && (smb||smb2)']).decode('utf-8')
    pkts = {}
    for line in out.split('\n'):
        m = re.match(r'''\s*(\d+).+?SMB2?\s*\d+\s*(.+)''', line)
        if m:
            pkts[int(m.group(1))] = m.group(2)
    return pkts

def smb_packet(pcap, no):
    out = subprocess.check_output(['tshark', '-r', pcap, '-Y', 'frame.number == %d'%no, '-V']).decode('utf-8')
    out = re.sub(r'^[\s\S]+\nSMB', 'SMB', out, re.S)
    return out

def smb_diff(a, b):
    with tempfile.NamedTemporaryFile(mode='w+') as fa, tempfile.NamedTemporaryFile(mode='w+') as fb:
        fa.write(smb_packet(a[0], a[1]))
        fa.flush()
        fb.write(smb_packet(b[0], b[1]))
        fb.flush()
        try:
            out = subprocess.check_output(['diff', '-u', fa.name, fb.name])
        except subprocess.CalledProcessError as e:
            out = e.output
        return out.decode('utf8')

def main():
    ap = argparse.ArgumentParser(description='compare smb packets')
    ap.add_argument('filea', metavar='CAPFILE1:NO', help='first file/packet number')
    ap.add_argument('fileb', metavar='CAPFILE2:NO', help='second file/packet number')
    args = ap.parse_args()

    smb_diff(split_packet_arg(args.filea), split_packet_arg(args.fileb))

class BufferViewer(object):
    def __init__(self, win, data, hl=True, diff=False):
        self.win = win
        self.data = data
        self.lines = self.data.split("\n")
        self.top = 0
        self.cursor = 0
        self.hl = hl
        self.diff = diff

    def set_data(self, data):
        self.data = data
        self.lines = self.data.split("\n")
        h = len(self.lines)-1
        self.top = max(0, min(h, self.top))
        self.cursor = max(0, max(self.top, min(h, self.cursor)))

    def height(self):
        return self.win.getmaxyx()[0]

    def width(self):
        return self.win.getmaxyx()[1]

    def is_cursor_at_bot(self):
        return self.cursor - self.top == self.height()-1

    def is_cursor_at_top(self):
        return self.cursor == self.top

    def move(self, direction):
        if direction > 0 and self.cursor < len(self.lines)-1:
            if self.is_cursor_at_bot():
                self.top += 1
            self.cursor += 1
        elif direction < 0 and self.cursor > 0:
            if self.is_cursor_at_top():
                self.top -= 1
            self.cursor -= 1

    def refresh(self):
        self.win.clear()
        for nr, line in enumerate(self.lines[self.top:self.top+self.height()]):
            line.replace('\t', '    ')
            line = line[:self.width()-1]
            line = line.ljust(self.width()-1)

            attr = curses.A_NORMAL

            if self.diff:
                if line[0] == '-':
                    attr = attr|curses.color_pair(1)
                elif line[0] == '+':
                    attr = attr|curses.color_pair(2)

            if self.hl and self.cursor - self.top == nr:
                attr = attr|curses.A_REVERSE

            self.win.addstr(nr, 0, line, attr)

        self.win.refresh()


class TraceViewer(object):
    def __init__(self, win, cap):
        self.buf = BufferViewer(win, '')
        self.set_cap(cap)

    def set_cap(self, cap):
        self.cap = cap
        self.pkts = smb_summaries(cap)
        self.nos = sorted(self.pkts.keys())
        lines = []
        for no in self.nos:
            lines.append(self.pkts[no])
        self.buf.set_data('\n'.join(lines))

    def get_packet(self):
        return (self.cap, self.nos[self.buf.cursor])

    def move(self, direction):
        self.buf.move(direction)

    def refresh(self):
        self.buf.refresh()

def curses_main(stdscr):
    ap = argparse.ArgumentParser(description='compare smb packets')
    ap.add_argument('filea', metavar='CAPFILE1:NO', help='first file/packet number')
    ap.add_argument('fileb', metavar='CAPFILE2:NO', help='second file/packet number')
    args = ap.parse_args()

    curses.curs_set(0)
    curses.use_default_colors()
    curses.init_pair(1, curses.COLOR_RED, -1)
    curses.init_pair(2, curses.COLOR_GREEN, -1)

    H = curses.LINES
    W = curses.COLS
    lw = (W-1)//2
    rw = W-1-lw
    th = (H-1)//2
    bh = H-1-th

    # h, w, y, x
    lwin = curses.newwin(th, lw, 0, 0)
    lbuf = TraceViewer(lwin, strip_packet_no(args.filea))

    rwin = curses.newwin(th, rw, 0, lw+1)
    rbuf = TraceViewer(rwin, strip_packet_no(args.fileb))

    bwin = curses.newwin(bh, W, th+1, 0)
    bbuf = BufferViewer(bwin, "", diff=True)


    while True:
        stdscr.refresh()
        rbuf.refresh()
        lbuf.refresh()
        bbuf.refresh()
        # y, x, ch, length
        stdscr.vline(0, lw, 0, th)
        stdscr.hline(th, 0, 0, W)

        last_lpacket = lbuf.get_packet()[1]
        last_rpacket = rbuf.get_packet()[1]

        k = stdscr.getkey()
        if k == 'q':
            return

        elif k == 'j':
            rbuf.move(+1)
        elif k == 'k':
            rbuf.move(-1)

        elif k == 'd':
            lbuf.move(+1)
        elif k == 'f':
            lbuf.move(-1)

        elif k == 'v':
            bbuf.move(+1)
        elif k == 'b':
            bbuf.move(-1)

        elif k == 'KEY_DOWN':
            lbuf.move(+1)
            rbuf.move(+1)
        elif k == 'KEY_UP':
            lbuf.move(-1)
            rbuf.move(-1)

        elif k == 'KEY_RESIZE':
            H, W = stdscr.getmaxyx()
            lw = (W-1)//2
            rw = W-1-lw
            th = (H-1)//2
            bh = H-1-th
            lwin.resize(th, lw)
            rwin.resize(th, rw)
            rwin.mvwin(0, lw+1)
            bwin.resize(bh, W)
            bwin.mvwin(th+1, 0)

        if lbuf.get_packet()[1] != last_lpacket or rbuf.get_packet()[1] != last_rpacket:
            bbuf.set_data(smb_diff(lbuf.get_packet(), rbuf.get_packet()))

if __name__ == '__main__':
    #main()
    curses.wrapper(curses_main)
