#!/usr/bin/env python3

import subprocess
import re

import sys

HORIZONTAL_FIRST = (2, 3)
VERTICAL_FIRST = (3, 2)
display_coordinates_split = re.compile('[x+]')


def execute(command):
    return subprocess.run(command.split(), stdout=subprocess.PIPE).stdout.decode('utf8').strip().split('\n')


def get_displays(horizontal_first_sorting=False):
    sorting = HORIZONTAL_FIRST if horizontal_first_sorting else VERTICAL_FIRST
    displays = []
    for x in execute('xrandr'):
        split = x.split(' ')
        if 'connected' in split:
            location = [display_coordinates_split.split(t) for t in split if t != 'primary'][2]
            displays.append([int(x) for x in location])
    return sorted(displays, key=lambda x: [x[i] for i in sorting])


def get_active_window_id():
    return int(execute('xdotool getactivewindow')[0])


def get_window_state(window_id):
    window_state = execute('xprop -id %s _NET_WM_STATE' % window_id)[0]
    return dict(hmaximized='_NET_WM_STATE_MAXIMIZED_HORZ' in window_state,
                vmaximized='_NET_WM_STATE_MAXIMIZED_VERT' in window_state,
                focused='_NET_WM_STATE_FOCUSED' in window_state)


def set_window_state(window_id, state):
    if state['hmaximized'] or state['vmaximized']:
        props = ['add']
        if state['hmaximized']:
            props.append('maximized_horz')
        if state['vmaximized']:
            props.append('maximized_vert')
        execute('wmctrl -ir %s -b %s' % (window_id, ','.join(props)))


def unmaximize(window_id):
    execute('wmctrl -ir %s -b remove,maximized_vert,maximized_horz' % window_id)


def get_window_location(window_id):
    abs_x, abs_y, rel_x, rel_y, width, height = [0] * 6
    for row in execute('xwininfo -id %s' % window_id):
        if 'Absolute upper-left X:' in row:
            abs_x = int(row.split()[-1])
        elif 'Absolute upper-left Y:' in row:
            abs_y = int(row.split()[-1])
        elif 'Relative upper-left X:' in row:
            rel_x = int(row.split()[-1])
        elif 'Relative upper-left Y:' in row:
            rel_y = int(row.split()[-1])
        elif 'Width:' in row:
            width = int(row.split()[-1])
        elif 'Height:' in row:
            height = int(row.split()[-1])
    return abs_x - rel_x, abs_y - rel_y, width + rel_x, height + rel_y


def point_seq(window_id):
    x_, y_, x_size, y_size = get_window_location(window_id)
    for y in range(y_, y_ + y_size + 1, int(y_size / 2)):
        yield x_ + int(x_size / 2), y
        yield x_, y
        yield x_ + x_size, y


def get_display_index(point, displays):
    x, y = point[:2]
    for i, display in enumerate(displays):
        width, height, woffset, hoffset = display
        if woffset <= x < woffset + width and hoffset <= y < hoffset + height:
            return i


def get_display_index_of_window(window_id, displays):
    for point in point_seq(window_id):
        i = get_display_index(point, displays)
        if i is not None:
            return i


def next_display_index(index, displays, step=1):
    return (index + step) % len(displays)


def new_window_location(window_id, displays, step=1):
    window_pos = get_window_location(window_id)
    wx, wy = window_pos[:2]
    display_index = get_display_index_of_window(window_id, displays)
    owidth, oheight, owoffset, ohoffset = displays[display_index]
    xrel = (wx - owoffset) / owidth
    yrel = (wy - ohoffset) / oheight
    nwidth, nheight, nwoffset, nhoffset = displays[next_display_index(display_index, displays, step)]
    return xrel * nwidth + nwoffset, yrel * nheight + nhoffset


def move_to_next_monitor(window_id, displays, reverse=False):
    state = get_window_state(window_id)
    unmaximize(window_id)
    xnew, ynew = new_window_location(window_id, displays, step=-1 if reverse else 1)
    execute('xdotool windowmove %s %s %s' % (window_id, xnew, ynew))
    set_window_state(window_id, state)


if __name__ == '__main__':
    displays = get_displays('--horizontal-first' in sys.argv)
    if len(displays) > 1:
        move_to_next_monitor(get_active_window_id(), displays, '--reverse' in sys.argv)
