#!/usr/bin/python3
"""
LDP RAW Update Generator

Christian Giese, December 2022

Copyright (C) 2020-2026, RtBrick, Inc.
SPDX-License-Identifier: BSD-3-Clause
"""
from socket import inet_aton
import argparse
import ipaddress
import logging
import struct
import sys

try:
    from scapy.all import *
    log_runtime.setLevel(logging.ERROR)
    from scapy.contrib.ldp import *
    log_runtime.setLevel(logging.INFO)
except:
    print("Failed to load scapy!")
    exit(1)

# ==============================================================
# PATCHES
# ==============================================================

# Monkey patch scapy LDP FecTLVField

def _monkey_i2m_fec(self, pkt, x):
    if not x:
        return b""
    if isinstance(x, bytes):
        return x
    s = b"\x01\x00"
    tmp_len = 0
    fec = b""
    for o in x:
        a = ipaddress.ip_address(o[0])
        if a.version == 4:
            fec += b"\x02\x00\x01"
            # mask length
            fec += struct.pack("!B", o[1])
            # Prefix
            l = o[1]+7>>3
            p = ipaddress.v4_int_to_packed(int(a))
            fec += p[:l]
            tmp_len += 4+l
        else:
            fec += b"\x02\x00\x02"
            # mask length
            fec += struct.pack("!B", o[1])
            # Prefix
            l = o[1]+7>>3
            p = ipaddress.v6_int_to_packed(int(a))
            fec += p[:l]
            tmp_len += 4+l
    s += struct.pack("!H", tmp_len)
    s += fec
    return s

scapy.contrib.ldp.FecTLVField.i2m = _monkey_i2m_fec


def _monkey_i2m_address(self, pkt, x):
    if not x:
        return b""
    if isinstance(x, bytes):
        return x
    
    a = ipaddress.ip_address(x[0])
    if a.version == 4:
        tmp_len = 2 + len(x) * 4
        s = b"\x01\x01" + struct.pack("!H", tmp_len) + b"\x00\x01"
        for o in x:
            s += inet_aton(o)
    else:
        tmp_len = 2 + len(x) * 16
        s = b"\x01\x01" + struct.pack("!H", tmp_len) + b"\x00\x02"
        for o in x:
            a = ipaddress.ip_address(o)
            s += ipaddress.v6_int_to_packed(int(a))
    return s

scapy.contrib.ldp.AddressTLVField.i2m = _monkey_i2m_address


# ==============================================================
# DEFINITIONS
# ==============================================================

DESCRIPTION = """
The LDP RAW update generator is a simple 
tool to generate LDP RAW update streams 
for use with the BNG Blaster.
"""

LOG_LEVELS = {
    'warning': logging.WARNING,
    'info': logging.INFO,
    'debug': logging.DEBUG
}

MPLS_LABEL_MIN = 1
MPLS_LABEL_MAX = 1048575

# ==============================================================
# FUNCTIONS
# ==============================================================

def init_logging(log_level: int) -> logging.Logger:
    """Init logging."""
    level = LOG_LEVELS[log_level]
    log = logging.getLogger()
    log.setLevel(level)
    handler = logging.StreamHandler(sys.stdout)
    handler.setLevel(level)
    formatter = logging.Formatter('[%(asctime)s][%(levelname)-7s] %(message)s')
    formatter.datefmt = '%Y-%m-%d %H:%M:%S'
    handler.setFormatter(formatter)
    log.addHandler(handler)
    return log


def label_type(label: int) -> int:
    """Argument parser type for MPLS labels."""
    label = int(label)
    if label < MPLS_LABEL_MIN or label > MPLS_LABEL_MAX:
        raise argparse.ArgumentTypeError("MPLS label out of range %s - %s" % (MPLS_LABEL_MIN, MPLS_LABEL_MAX))
    return label


# ==============================================================
# MAIN
# ==============================================================

def main():
    # parse arguments
    parser = argparse.ArgumentParser(description=DESCRIPTION)
    parser.add_argument('-l', '--lsr-id', metavar='ADDRESS', type=ipaddress.ip_address, required=True, help='LSR identifier')
    parser.add_argument('-i', '--message-id-base', metavar='N', type=int, default=1000, help='message identifier base')
    parser.add_argument('-w', '--withdraw', action="store_true", help='withdraw')
    parser.add_argument('-a', '--address-base', metavar='ADDRESS', type=ipaddress.ip_address, help='address message base')
    parser.add_argument('-A', '--address-num', metavar='N', type=int, default=0, help='address message count')
    parser.add_argument('-p', '--prefix-base', metavar='PREFIX', type=ipaddress.ip_network, help='label mapping base prefix')
    parser.add_argument('-P', '--prefix-num', metavar='N', type=int, default=0, help='label mapping prefix count')
    parser.add_argument('-m', '--label-base', metavar='LABEL', type=label_type, default=10000, help='label base')
    parser.add_argument('-M', '--label-num', metavar='N', type=int, default=1, help='label count')
    parser.add_argument('-f', '--file', type=str, default="out.ldp", help='output file')
    parser.add_argument('--append', action="store_true", help="append to file if exist")
    parser.add_argument('--pcap', metavar='FILE', type=str, help="write LDP updates to PCAP file")
    parser.add_argument('--log-level', type=str, default='info', choices=LOG_LEVELS.keys(), help='logging Level')
    args = parser.parse_args()

    # init logging
    log = init_logging(args.log_level)

    # Here we will store packets for optional PCAP output
    pcap_packets = []
    def pcap(message):
        if args.pcap:
            pcap_packets.append(Ether()/IP()/TCP(sport=len(pcap_packets)+10000, dport=646, seq=1, flags='PA')/message)

    addresses = []
    if args.address_base and args.address_num > 0:
        log.info("init %s addresses" % (args.address_num))

        address = args.address_base
        for _ in range(args.address_num):
            log.debug("add address %s" % (address))
            addresses.append(address)
            address += 1

    prefixes = []
    if args.prefix_base and args.prefix_num > 0:
        log.info("init %s labeled prefixes" % (args.prefix_num))

        prefix = args.prefix_base
        label_index = 0
        label = args.label_base
        for _ in range(args.prefix_num):
            log.debug("add prefix %s via label %s" % (prefix, label))
            prefixes.append((prefix, label))

            label_index += 1
            if label_index < args.label_num:
                label = args.label_base + label_index
                if label > MPLS_LABEL_MAX:
                    label_index = 0
                    label = args.label_base
            else:
                label_index = 0
                label = args.label_base

            prefix = ipaddress.ip_network("%s/%s" % (prefix.broadcast_address+1, prefix.prefixlen))

    if len(addresses) + len(prefixes) == 0:
        exit(0)

    if args.append:
        log.info("open file %s (append)" % args.file)
        file_flags = "ab"
    else:
        log.info("open file %s (replace)" % args.file)
        file_flags = "wb"

    with open(args.file, file_flags) as f:

        mid = args.message_id_base

        for address in addresses:
            if args.withdraw:
                pdu = LDP(id=args.lsr_id)/LDPAddressWM(id=mid, address=[str(address)])
            else:
                pdu = LDP(id=args.lsr_id)/LDPAddress(id=mid, address=[str(address)])
            mid+=1

            pdu_bin = bytearray(pdu.build())
            pcap(pdu)
            f.write(pdu_bin)

        for prefix, label in prefixes:
            if args.withdraw:
                pdu = LDP(id=args.lsr_id)/LDPLabelWM(id=mid, fec=[(str(prefix.network_address),int(prefix.prefixlen))],label=label)
            else:
                pdu = LDP(id=args.lsr_id)/LDPLabelMM(id=mid, fec=[(str(prefix.network_address),int(prefix.prefixlen))],label=label)
            mid+=1

            pdu_bin = bytearray(pdu.build())
            pcap(pdu)
            f.write(pdu_bin)

    if args.pcap:
        log.info("create PCAP file %s" % args.pcap)
        try:
            wrpcap(args.pcap, pcap_packets)
        except Exception as e:
            log.error("failed to create PCAP file")
            log.debug(e)
    
    log.info("finished")


if __name__ == "__main__":
    main()