#!/usr/bin/env python

"""
    Simulate latency and packet loss using Traffic Control.
    WARNING: All traffic, including SSH, is impacted.
"""

from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
import subprocess
import json
import logging
import os
import sys

log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)

tc_path = '/usr/sbin/tc'

sshuser = ""
sshkey = ""


def parse_args():
    ap = ArgumentParser(description=__doc__,
        formatter_class=ArgumentDefaultsHelpFormatter)
    ap.add_argument('action', choices=["setup", "show", "reset"], help="Action")
    ap.add_argument('target_role', choices=["admin", "masters", "workers", "all"], help="Target host[s]")
    ap.add_argument('--env-json-path', default='./environment.json',
            help="environment.json full path")
    ap.add_argument('--outdir', default='.', help='output directory')
    ap.add_argument('-l', '--logfile', help='logfile')
    ap.add_argument('--ifname', default="eth0", help="Interface name")
    ap.add_argument('--latency', default=50, type=float, help="Latency")
    ap.add_argument('--jitter', default=20, type=float, help="Jitter")
    ap.add_argument('--packet-loss', default=1, type=float,
            help="Packet loss percentage")
    ap.add_argument('--duplication', default=0.1, type=float,
            help="Packet duplication percentage")
    ap.add_argument('--corruption', default=0.1, type=float,
            help="Packet corruption percentage")
    ap.add_argument('--bandwidth', default=1000, type=int,
            help="Interface bandwidth capping in Mbps")
    args = ap.parse_args()
    args.outdir = os.path.abspath(args.outdir)
    return args


def run_ssh(ipaddr, cmd):
    basecmd = "/usr/bin/ssh -o LogLevel=error -oStrictHostKeyChecking=no -oUserKnownHostsFile=/dev/null -i {} {}@{} ".format(sshkey, sshuser, ipaddr)
    cmd = basecmd + cmd
    log.info("Running %r", cmd)
    subprocess.check_call(cmd, shell=True)


def run_tc(nodeaddr, cmd):
    return run_ssh(nodeaddr, "{} {}".format(tc_path, cmd))


def cleanup_tc(nodeaddr):
    """Clean up tc conf
    """
    try:
        run_tc(nodeaddr, 'qdisc del dev eth0 root')
    except:
        return
    out = run_tc(nodeaddr, 'qdisc show dev eth0')
    if out:
        log.debug("Cleaning up tc conf")
        run_tc(nodeaddr, 'qdisc del dev eth0 root')
    else:
        log.debug("No tc conf to be cleaned up")


def setup_netem(nodeaddr, parent_cid, args):
    """Setup tc netem
    """
    cmd = ("qdisc add dev {iface} parent {cid} netem"
           " corrupt {corruption}%"
           " delay {latency}ms {jitter}ms distribution normal"
           # "delay" reorders packets as well
           " duplicate {duplication}%"
           " loss random {packet_loss_perc}% 25%"
           )
    cmd = cmd.format(iface=args.ifname, cid=parent_cid, latency=args.latency,
            corruption=args.corruption, duplication=args.duplication,
            jitter=args.jitter, packet_loss_perc=args.packet_loss)
    run_tc(nodeaddr, cmd)


def setup_tc(nodeaddr, args):
    """Setup traffic control
    """
    cleanup_tc(nodeaddr)

    # Setup network emulation - simple example:
    # tc qdisc add dev eth0 handle 1: root htb
    # tc class add dev eth0 parent 1: classid 1:5 htb rate 1000Mbps
    # tc qdisc add dev eth0 parent 1:5  netem delay 100ms 10ms
    # tc filter add dev eth0  protocol ip prio 1 u32 match ip src 0.0.0.0/0 flowid 1:5

    # Create HTB at the root
    run_tc(nodeaddr, "qdisc add dev eth0 handle 1: root htb")

    # Create main rate limiting class
    cmd = "class add dev eth0 parent 1: classid 1:5 " \
            "htb rate {}Mbps".format(args.bandwidth)
    run_tc(nodeaddr, cmd)

    setup_netem(nodeaddr, "1:5", args)
    # Create filter to match all traffic
    cmd = "filter add dev eth0 protocol ip prio 1 u32 match ip " \
        "src 0.0.0.0/0 flowid 1:5"
    run_tc(nodeaddr, cmd)


def main():
    global sshuser, sshkey
    args = parse_args()
    if args.logfile:
        handler = logging.FileHandler(os.path.abspath(args.logfile))
    else:
        handler = logging.StreamHandler()
    log.addHandler(handler)

    with open(args.env_json_path) as f:
        env = json.load(f)
    sshuser = env['sshUser']
    sshkey = env['sshKey']

    role = args.target_role.rstrip('s')
    for target_block in env["minions"]:
        if role != "all" and target_block["role"] != role:
            continue
        ipaddr = target_block["addresses"]["privateIpv4"]
        if args.action == "show":
            run_tc(ipaddr, "qdisc show dev eth0")
        elif args.action == "setup":
            setup_tc(ipaddr, args)
        elif args.action == "reset":
            cleanup_tc(ipaddr)


if __name__ == '__main__':
    sys.exit(main())
