#!/usr/bin/env python
#
# (c) Copyright 2015 Hewlett Packard Enterprise Development LP
# (c) Copyright 2017 SUSE LLC
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

DOCUMENTATION = '''
---
module: matchcidr
author: Joe Fegan
short_description: Find a nic matching the given search creteria
description:
    - Searches the output of "ip a" for nics matching the given search criteria
options:
    cidr:
        required: true
        description:
            - Find nics that have addresses on this cidr.
    ips:
        required: true
        description:
            - List of candidate interfaces and ips, newline separated.
'''

EXAMPLES = '''
- matchcidr: cidr="10.1.11.0/24" ips="eth0 10.1.11.56/24\neth1 192.168.24.54/16\n"
'''

from socket import inet_aton

class MatchCidr(object):
    name = 'matchcidr'

    def __init__(self, module):
        self.module = module
        self.ips = module.params['ips']
        self.cidr = module.params['cidr']

    def fail(self, **kwargs):
        return self.module.fail_json(**kwargs)

    def succeed(self, **kwargs):
        return self.module.exit_json(**kwargs)

    def mkint(self, ipstr):
        binstr = inet_aton(ipstr)
        value = ((ord(binstr[0]) & 0xff) << 24) | \
                ((ord(binstr[1]) & 0xff) << 16) | \
                ((ord(binstr[2]) & 0xff) << 8) | \
                (ord(binstr[3]) & 0xff)
        return value

    def cidrmatch(self, cidr):
        ip, width = cidr.split('/')
        if int(width) < self.width:
            return False
        if self.mkint(ip) & self.mask != self.subnet:
            return False
        return True

    def execute(self):
        try:
            # convert mask etc to integers.
            subnet, width = self.cidr.split('/')
            self.width = int(width)
            shift = 32 - self.width
            self.mask = 0xffffffff >> shift << shift
            self.subnet = self.mkint(subnet) & self.mask
            # look for nics matching this.
            result = None
            for candidate in self.ips.split('\n'):
                nic, thiscidr = candidate.split(' ')
                if self.cidrmatch(thiscidr):
                    if result:
                        if result['stdout'] == nic:
                            continue
                        self.fail(msg='%s: multiple nics matching %s\n%s' % (self.name, self.cidr, self.ips))
                    ip, width = thiscidr.split('/')
                    result = dict(stdout=nic, ip=ip, cidr=self.cidr,
                                  subnet=hex(self.subnet), mask=hex(self.mask))
            if result:
                self.succeed(**result)
            else:
                self.fail(msg='%s: no nic matching %s\n%s' % (self.name, self.cidr, self.ips))
        except Exception as e:
            self.fail(msg='%s: %s %s' % (self.name, str(type(e)), str(e)))


def main():
    module = AnsibleModule(
        argument_spec=dict(
            cidr=dict(required=True),
            ips=dict(required=True)
        )
    )

    mod = MatchCidr(module)
    return mod.execute()


from ansible.module_utils.basic import *
main()
