#!/usr/bin/python
# vim: tabstop=4 expandtab shiftwidth=4 softtabstop=4

from __future__ import print_function

# core modules
import argparse
try:
    from configparser import ConfigParser
except ImportError:
    from ConfigParser import ConfigParser
import logging
from   pprint import pprint
import signal
import subprocess
import sys
import time

# external modules
import datetime
import dateutil.parser
import dateutil.tz
import ldap
from   ldap.ldapobject import ReconnectLDAPObject
import ldap.modlist
from   ldap.syncrepl import SyncreplConsumer
from   ldapurl import LDAPUrl
import ldif



def getArguments():
    configfile = '/etc/dassldapsync.conf'
    parser = argparse.ArgumentParser(description='Synchronize the content of two LDAP servers.')
    parser.add_argument('-d', '--debug', action='store_true', help="enable debug output")
    parser.add_argument('-n', '--dry-run', action='store_true', dest='dryrun', help="dry run")
    parser.add_argument('configfile', default=configfile,
                        help="Configuration file [default: {}]".format(configfile))
    return parser.parse_args()


class Options(object):
    def __init__(self):
        self.create = False
        self.delete = False
        self.starttls = False
        self.updateonly = False
        self.filter = None
        self.attrlist = None
        self.exclude = None
        self.renameattr = None
        self.renamecommand = None
        self.pwd_max_days = 0

def readLDIFSource(path):
    logger = logging.getLogger()
    logger.info("reading LDAP objects from file {}".format(path))
    with open(path, 'r') as f:
        parser = ldif.LDIFRecordList(f)
        parser.parse()
        result = parser.all_records
    return result

def readLdapSource(server, binddn, bindpw, basedn, filterstr, attrlist=None, starttls=False):
    logger = logging.getLogger()
    logger.info("reading LDAP objects from server {}".format(server))
    ldapurl = LDAPUrl(hostport="{}:389".format(server))
    con = ldap.initialize(ldapurl. initializeUrl())
    if starttls:
        con.start_tls_s()
    con.simple_bind_s(binddn, bindpw)
    results = con.search_s(basedn, ldap.SCOPE_SUBTREE, filterstr, attrlist)
    return results

class LdapSync(object):
    def __init__(self, destserver,
                 destbinddn, destbindpw,
                 srcbasedn, destbasedn, options=Options()):
        self.logger = logging.getLogger()

        self.destserver = destserver
        self.destbasedn = destbasedn
        self.destbinddn = destbinddn
        self.destbindpw = destbindpw
        self.options = options

        self.srcbasedn = srcbasedn

        self.con = None

        self.attrmap = ldap.cidict.cidict({})
        self.classmap = {}

        #self.junk_objectclasses = [ b"sambaidmapentry" ]
        #"sambasid", 
        self.junk_objectclasses = []
        self.junk_attrs = ["authzto",
                           "creatorsname", "createtimestamp", "contextcsn", 
                           "entrycsn", "entryuuid",
                           "memberof", "modifiersname", "modifytimestamp", 
                           "pwdaccountlockedtime", "pwdchangedtime", "pwdfailuretime",
                           "structuralobjectclass"]

        self.reset_result()


    def reset_result(self):
        self.result = {
            'excluded': {'ok': [], 'failed': []},
            'unmodified': {'ok': [], 'failed': []},
            'add': {'ok': [], 'failed': []},
            'update': {'ok': [], 'failed': []},
            'delete': {'ok': [], 'failed': []},
        }


    def _dest_ldap_connect(self):
        if self.con is None:
            self.logger.info("connect to destination LDAP server {}".format(self.destserver))
            ldapurl = LDAPUrl(hostport="{}:389".format(self.destserver))
            self.con = ldap.initialize(ldapurl. initializeUrl())
            if self.options.starttls:
                self.con.start_tls_s()
            self.con.simple_bind_s(self.destbinddn, self.destbindpw)

    def __adapt_dn(self, dn):
        # move LDAP object to dest base
        if self.srcbasedn != self.destbasedn:
            dn_old = dn
            rpath = dn[:-len(self.srcbasedn)]
            dn = rpath+self.destbasedn
            self.logger.debug("moved {} to {}".format(dn_old, dn))
            # print "dn:",dn,"src:",srcbasedn,"rpath:",rpath,"dest:",destbasedn
        return dn

    def __is_dn_included(self, dn):
        if self.options.exclude is None:
            return True
        if dn.lower().endswith(self.options.exclude):
            return False
        return True

    def __adapt_source_ldap_objects(self, searchresult):
        """
        Do configured modification to the source LDAP objects.
        """
        self.logger.debug("modifying LDAP objects retrieved from source LDAP")

        update_objects = []

        for r in searchresult:
            dn = self.__adapt_dn(r[0])
            d = ldap.cidict.cidict(r[1])

            if not self.__is_dn_included(dn):
                self.notify_excluded(dn)
            else:
                objectclasses = d["objectclass"]

                newObjectclasses = []
                for o in objectclasses:
                    if o.lower() in self.classmap:
                        new_oc = self.classmap[o.lower()]
                        if new_oc not in newObjectclasses:
                            newObjectclasses.append(new_oc)
                    else:
                        if o not in newObjectclasses:
                            newObjectclasses.append(o)

                d["objectclass"] = newObjectclasses

                for a in d.keys():
                    attr = a
                    if self.attrmap.has_key(a.lower()):
                        attr = self.attrmap[attr].lower()
                        if attr.lower() != a.lower():
                            values = d[a]
                            del d[a]
                            d[attr] = values

                update_objects.append((dn, d))
        return update_objects


    def _get_dest_entry(self, dn, entry):
        """
        In the destination LDAP, the objects should be named
        according to options.renameattr.
        """
        attrlist = self.options.attrlist

        existingDestDn = None
        existingDestEntry = None
        if self.options.renameattr and entry.has_key(self.options.renameattr):
            searchresult = self.con.search_s(
                self.destbasedn,
                ldap.SCOPE_SUBTREE,
                '%s=%s' % (self.options.renameattr, entry[self.options.renameattr][0]), attrlist)
            if searchresult is not None and len(searchresult) > 0:
                existingDestDn, existingDestEntry = searchresult[0]
                if existingDestDn.lower() != dn.lower():
                    if not self.options.dryrun:
                        self.con.modrdn_s(existingDestDn, dn)
                    self.notify_renamed(existingDestDn, dn,
                                        existingDestEntry[self.options.renameattr][0],
                                        entry[self.options.renameattr][0],
                                        options)
        if existingDestDn is None:
            searchresult = self.con.search_s(dn, ldap.SCOPE_BASE, 'objectclass=*', attrlist)
            existingDestDn, existingDestEntry = searchresult[0]
        return (existingDestDn, existingDestEntry)


    def __handle_pwdAccountLockedTime(self, dn, entry, now, max_age):
        # hack for syncing accounts locked by password policy
        do_unlock = False
        if self.options.pwd_max_days > 0 and entry.has_key('pwdChangedTime'):
            # print "pwdChangedTime set for",dn
            pwdChange = entry['pwdChangedTime'][0]
            d = dateutil.parser.parse(pwdChange)
            if (now-d) > max_age:
                entry['pwdAccountLockedTime'] = ['000001010000Z']
                self.logger.info("locking {} {}".format(dn, pwdChange))
            else:
                # pwdAccountLockedTime is a operational attribute,
                # and therefore not part of entry.
                # Do extra search to retrieve attribute.
                searchresult = self.con.search_s(
                    dn, ldap.SCOPE_BASE,
                    "objectclass=*", attrlist=['pwdAccountLockedTime'])
                tmp_dn, tmp_entry = searchresult[0]
                if tmp_entry.has_key('pwdAccountLockedTime'):
                    do_unlock = True
        return do_unlock


    def _syncLdapObject(self, srcDn, srcAttributes):
        tzutc = dateutil.tz.gettz('UTC')
        now = datetime.datetime.now(tzutc)
        max_age = datetime.timedelta(days=self.options.pwd_max_days)

        objectClasses = srcAttributes['objectClass']
        srcAttributes['objectClass'] = [oc for oc in objectClasses if oc.lower() not in self.junk_objectclasses]

        try:
            destDn, destAttributes = self._get_dest_entry(srcDn, srcAttributes)

            # hack for syncing accounts locked by password policy
            do_unlock = self.__handle_pwdAccountLockedTime(srcDn, srcAttributes, now, max_age)

            mod_attrs = ldap.modlist.modifyModlist(destAttributes, srcAttributes)

            # hack for unlocking, see above
            if do_unlock:
                self.logger.info("unlocking {} {}".format(destDn, 'pwdAccountLockedTime'))
                mod_attrs.append((ldap.MOD_DELETE, 'pwdAccountLockedTime', None))

            if self.options.attrlist is not None:
                mod_attrs = [a for a in mod_attrs if a[1].lower() in self.options.attrlist]

            if self.junk_attrs is not None:
                mod_attrs = [a for a in mod_attrs if a[1].lower() not in self.junk_attrs]

            if mod_attrs:
                try:
                    self.logger.debug('mod_attrs: ' + str(mod_attrs))
                    if not self.options.dryrun:
                        self.con.modify_s(srcDn, mod_attrs)
                    self.notify_modified(srcDn)
                except:
                    self.logger.exception('modify failed')
                    self.notify_modified(srcDn, False)
            else:
                self.notify_unchanged(srcDn)

        except ldap.NO_SUCH_OBJECT:
            if self.options.create:
                try:
                    entry = ldap.modlist.addModlist(srcAttributes, self.junk_attrs)
                    if not self.options.dryrun:
                        self.con.add_s(srcDn, entry)
                    self.notify_created(srcDn)
                except (ldap.OBJECT_CLASS_VIOLATION,
                        ldap.NO_SUCH_OBJECT,
                        ldap.CONSTRAINT_VIOLATION) as e:
                    #print(e)
                    self.notify_created(srcDn, False)


    def __syncLdapDestination(self, update_objects):
        logger.debug("writing data to destination LDAP")
        for obj in update_objects:
            dn, entry = obj
            self._syncLdapObject(dn, entry)


    def __deleteDestLdapObjects(self, update_objects):
        """
        Remove all LDAP objects in destination LDAP server
        that did not come from the source LDAP objects
        and are not excluded.
        """

        searchresult = self.con.search_s(self.destbasedn, ldap.SCOPE_SUBTREE, self.options.filter)
        existing = [x[0].lower() for x in searchresult]

        morituri = existing

        if self.destbasedn.lower() in existing:
            morituri.remove(self.destbasedn.lower())

        for obj in update_objects:
            dn, entry = obj
            if dn.lower() in existing:
                morituri.remove(dn.lower())
        for dn in morituri:
            if self.__is_dn_included(dn):
                try:
                    if not self.options.dryrun:
                        self.con.delete_s(dn)
                    self.notify_deleted(dn)
                except:
                    self.notify_deleted(dn, False)


    def sync(self, searchresult):
        """
        Synchronize entries from searchresult to destination LDAP server.
        """
        if len(searchresult) == 0:
            self.logger.error("empty source, aborting")
            return

        self._dest_ldap_connect()

        update_objects = self.__adapt_source_ldap_objects(searchresult)
        self.__syncLdapDestination(update_objects)
        if self.options.delete:
            self.__deleteDestLdapObjects(update_objects)
        self.con.unbind()

        self.__log_summary(True)


    def __log_summary(self, show_failed=True, show_ok=False):
        result = self.result
        for action in result.keys():
            ok = len(result[action]['ok'])
            failed = len(result[action]['failed'])
            print("{} (ok: {}, failed: {})".format(action, ok, failed))

            if ok > 0 and (show_ok or ok <= 3):
                print("succeeded:")
                print("\n".join(result[action]['ok']))

            if failed > 0 and (show_failed or failed <= 3):
                print("failed:")
                print("\n".join(result[action]['failed']))
            print()

    def get_short_dn(self, dn):
        return dn.lower().replace(',' + self.srcbasedn.lower(), '')

    def notify_unchanged(self, dn):
        #logger.debug(u'{} unchanged'.format(self.get_short_dn(dn)))
        self.result['unmodified']['ok'].append(dn)

    def notify_excluded(self, dn):
        #logger.debug(u'{} unchanged'.format(self.get_short_dn(dn)))
        self.result['excluded']['ok'].append(dn)

    def notify_created(self, dn, ok=True):
        if ok:
            logger.debug(u'{} created'.format(self.get_short_dn(dn)))
            self.result['add']['ok'].append(dn)
        else:
            self.logger.warning(u"failed to add {}".format(dn))
            self.result['add']['failed'].append(dn)

    def notify_modified(self, dn, ok=True):
        if ok:
            logger.debug(u'{} modified'.format(self.get_short_dn(dn)))
            self.result['update']['ok'].append(dn)
        else:
            self.logger.error(u"failed to modify {}".format(dn))
            self.result['update']['failed'].append(dn)

    def notify_deleted(self, dn, ok=True):
        if ok:
            logger.debug(u'{} deleted'.format(self.get_short_dn(dn)))
            self.result['delete']['ok'].append(dn)
        else:
            self.logger.error(u"failed to delete {}".format(dn))
            self.result['delete']['failed'].append(dn)

    def notify_renamed(self, dn, newdn, uid, newuid, options):
        print(u"renamed {} -> {}".format(dn, newdn))
        subprocess.check_call(
            "%s %s %s %s %s" % (options.renamecommand, dn, newdn, uid, newuid),
            shell=True)



class SyncReplConsumer(ReconnectLDAPObject, SyncreplConsumer):
    """
    Syncrepl Consumer interface
    """

    def __init__(self, dest, syncrepl_entry_callback, *args, **kwargs):
        self.logger = logging.getLogger()
        # Initialise the LDAP Connection first
        ldap.ldapobject.ReconnectLDAPObject.__init__(self, *args, **kwargs)
        # We need this for later internal use
        self.__presentUUIDs = dict()
        self.cookie = None
        self.dest_ldap = dest
        self.syncrepl_entry_callback = syncrepl_entry_callback

    def syncrepl_get_cookie(self):
        return self.cookie

    def syncrepl_set_cookie(self, cookie):
        self.cookie = cookie

    def syncrepl_entry(self, dn, attributes, uuid):
        # First we determine the type of change we have here
        # (and store away the previous data for later if needed)
        if uuid in self.__presentUUIDs:
            change_type = 'modify'
        else:
            change_type = 'add'
        # Now we store our knowledge of the existence of this entry
        self.__presentUUIDs[uuid] = dn
        # Debugging
        logger.debug('{}: {} ({})'.format(dn, change_type, ",".join(attributes.keys())))
        # If we have a cookie then this is not our first time being run,
        # so it must be a change
        if self.cookie is not None:
            self.syncrepl_entry_callback(dn, attributes)


    def syncrepl_delete(self, uuids):
        """ syncrepl_delete """
        # Make sure we know about the UUID being deleted, just in case...
        uuids = [uuid for uuid in uuids if uuid in self.__presentUUIDs]
        # Delete all the UUID values we know of
        for uuid in uuids:
            logger.debug('detected deletion of entry {} ({})', uuid, self.__presentUUIDs[uuid])
            del self.__presentUUIDs[uuid]

    def syncrepl_present(self, uuids, refreshDeletes=False):
        """ called on initial sync """
        if uuids is not None:
            self.logger.debug('uuids: {}'.format(','.join(uuids)))
        # If we have not been given any UUID values,
        # then we have recieved all the present controls...
        if uuids is None:
            # We only do things if refreshDeletes is false as the syncrepl
            # extension will call syncrepl_delete instead when it detects a
            # delete notice
            if not refreshDeletes:
                deletedEntries = [
                    uuid for uuid in self.__presentUUIDs
                ]
                self.syncrepl_delete(deletedEntries)
            # Phase is now completed, reset the list
            self.__presentUUIDs = {}
        else:
            # Note down all the UUIDs we have been sent
            for uuid in uuids:
                self.__presentUUIDs[uuid] = True


    def syncrepl_refreshdone(self):
        self.logger.info('Initial synchronization is now done, persist phase begins')
        #self.logger.debug('UUIDs:\n' + '\n'.join(self.__presentUUIDs))



class LdapSyncRepl(LdapSync):
    def __init__(self, destsrv,
                 destadmindn, destadminpw,
                 basedn, destbasedn,
                 options=Options(), source_ldap_url_obj=None):
        # Install our signal handlers
        signal.signal(signal.SIGTERM, self.shutdown)
        self.watcher_running = False
        self.source_ldap_url_obj = source_ldap_url_obj
        self.ldap_credentials = False
        self.source_ldap_connection = None
        super(LdapSyncRepl, self).__init__(destsrv,
                                           destadmindn, destadminpw,
                                           basedn, destbasedn, options)


    def sync(self):
        self._dest_ldap_connect()
        self.watcher_running = True
        while self.watcher_running:
            self.logger.info('Connecting to source LDAP server')
            # Prepare the LDAP server connection (triggers the connection as well)
            self.source_ldap_connection = SyncReplConsumer(self.con,
                                                           self.perform_application_sync_callback,
                                                           self.source_ldap_url_obj.initializeUrl())

            if self.source_ldap_url_obj.who and self.source_ldap_url_obj.cred:
                self.ldap_credentials = True
                # Now we login to the LDAP server
                try:
                    self.source_ldap_connection.simple_bind_s(
                        self.source_ldap_url_obj.who, self.source_ldap_url_obj.cred)
                except ldap.INVALID_CREDENTIALS as e:
                    print('Login to LDAP server failed: ', str(e))
                    sys.exit(1)
                except ldap.SERVER_DOWN:
                    print('LDAP server is down, going to retry.')
                    time.sleep(5)
                    continue

            # Commence the syncing
            self.logger.info('Staring sync process')
            ldap_search = self.source_ldap_connection.syncrepl_search(
                self.source_ldap_url_obj.dn or '',
                self.source_ldap_url_obj.scope or ldap.SCOPE_SUBTREE,
                mode='refreshAndPersist',
                attrlist=self.source_ldap_url_obj.attrs,
                filterstr=self.source_ldap_url_obj.filterstr or '(objectClass=*)'
            )

            try:
                while self.source_ldap_connection.syncrepl_poll(all=1, msgid=ldap_search):
                    print(".", end="")
            except KeyboardInterrupt:
                # User asked to exit
                print("aborted\n")
                self.shutdown(None, None)
            except Exception as e:
                # Handle any exception
                if self.watcher_running:
                    self.logger.exception('Encountered a problem, going to retry.')
                    time.sleep(5)

    def perform_application_sync_callback(self, dn, attributes):
        logger.debug('{}: src: {}'.format(dn, str(attributes)))
        try:
            self._syncLdapObject(dn, attributes)
        except ldap.NO_SUCH_OBJECT:
            self.logger.info("SKIPPED: {} object does not exist on target".format(dn))
            return False
        return True

    def shutdown(self, signum, stack):
        # Declare the needed global variables
        self.logger.info('Shutting down!')

        # We are no longer running
        self.watcher_running = False

def get_ldap_url_obj(self, configsection):
    baseurl = 'ldap://{server}:389/{basedn}'.format(server=configsection.get('server'), basedn=configsection.get('basedn'))
    attrs = None
    if configsection.get('attributes') is not None:
        attrs = configsection.get('attributes').split(',')
    return LDAPUrl(
        baseurl,
        dn=configsection.get('baseDn'),
        who=configsection.get('bindDn'),
        cred=configsection.get('basePassword'),
        filterstr=configsection.get('filter'),
        attrs=attrs
    )


if __name__ == "__main__":
    logging.basicConfig(format='%(levelname)s %(module)s.%(funcName)s: %(message)s', level=logging.INFO)
    logger = logging.getLogger()

    args = getArguments()
    if args.debug:
        logger.setLevel(logging.DEBUG)
    conffile = args.configfile

    config = ConfigParser()
    config.read(conffile)

    srcfile = None
    try:
        srcfile = config.get("source", "file")
    except:
        pass

    basedn = config.get("source", "baseDn")
    filterstr = config.get("source", "filter", fallback=None)

    if srcfile is None:
        srv = config.get("source", "server")
        admindn = config.get("source", "bindDn")
        adminpw = config.get("source", "bindPassword")
        starttls = config.getboolean("source", "starttls")

    destsrv = config.get("destination", "server")
    destadmindn = config.get("destination", "bindDn")
    destadminpw = config.get("destination", "bindPassword")
    destbasedn = config.get("destination", "baseDn")
    try:
        rdn = config.get("destination", "rdn")
        logger.warning("setting rdn is currently ignored")
    except:
        pass

    options = Options()
    try:
        options.exclude = config.get("destination", "excludesubtree").lower()
    except:
        pass

    options.dryrun = args.dryrun
    options.create = config.getboolean("destination", "create", fallback=False)
    options.delete = config.getboolean("destination", "delete", fallback=False)    
    options.starttls = config.getboolean("destination", "starttls", fallback=False)
    options.renameattr = config.get("destination", "detectRename", fallback=None)
    options.renamecommand = config.get("destination", "detectRename", fallback=None)
    options.pwd_max_days = int(config.get("source", "pwd_max_days", fallback=0))
    options.filter = filterstr

    # Set source.attrlist as global option.
    # If source would use less attributes than dest,
    # all attributes not retrieved from source would be deleted from dest
    try:
        options.attrlist = config.get("source", "attributes").split(",")
    except:
        options.attrlist = None

    if config.get('source', 'mode', fallback=None) == 'syncrepl':
        ldapsync = LdapSyncRepl(
            destsrv, destadmindn, destadminpw, basedn, destbasedn,
            options,
            source_ldap_url_obj=get_ldap_url_obj(config['source']))
        ldapsync.sync()
    else:
        if srcfile:
            objects = readLDIFSource(srcfile)
        else:
            objects = readLdapSource(srv, admindn, adminpw,
                                     basedn, filterstr, options.attrlist, starttls)

        ldapsync = LdapSync(destsrv, destadmindn, destadminpw, basedn, destbasedn, options)
        ldapsync.sync(objects)
