#!/usr/bin/python3.13
import sys, argparse, struct, os.path, datetime, uuid, re, pam
from time import sleep
from configparser import ConfigParser
import netifaces
from getpass import getpass, getuser
from subprocess import Popen, PIPE, DEVNULL
from base64 import b64decode, b64encode
import pwd, grp
import samba
from samba import Ldb
from samba.param import LoadParm, default_path
from dns import resolver, reversename
from samba.net import Net
from samba.credentials import Credentials
from samba.dcerpc import nbt
from shutil import which
from samba import NTSTATUSError
from samba.netcmd.domain import cmd_domain_join, cmd_domain_provision, cmd_domain_demote
from samba.netcmd.dns import cmd_add_record
from samba.netcmd.user import cmd_user_list, cmd_user_add, cmd_user_add_unix_attrs, cmd_user_delete, cmd_user_password
from samba.netcmd.group import cmd_group_list, cmd_group_add, cmd_group_add_members, cmd_group_add_unix_attrs, cmd_group_delete
from samba.getopt import SambaOptions, CredentialsOptions
from optparse import OptionParser
from samba.logger import get_samba_logger
import psutil, os, signal
import socket
from enum import Enum
import tdb, ldb
from ldb import MessageElement
from io import StringIO
from samba.auth import system_session
from tempfile import NamedTemporaryFile

nsswitch = '/etc/nsswitch.conf'
krb5_conf = '/etc/krb5.conf'
winbind_conf = '/etc/security/pam_winbind.conf'
debug_level = '0'

class Output(Enum):
    STDOUT = 0
    STDERR = 1
    PIPE = 2

class SambaToolFactory(object):
    def __init__(self, scmd, args=None, creds=None, stdout=Output.STDOUT, stderr=Output.STDERR):
        global debug_level
        self.args = args

        parser = OptionParser()
        self.sambaopts = SambaOptions(parser)
        self.credopts = CredentialsOptions(parser)
        if creds:
            self.credopts.creds = creds
        elif self.args:
            self.credopts.creds = get_creds(self.args)
            self.credopts.ask_for_password = False
        self.credopts.machine_pass = False
        lp = self.sambaopts.get_loadparm()
        lp.set('log level', debug_level)
        self.realm = lp.get('realm')

        self.outlog = StringIO()
        self.errlog = StringIO()
        self.sobj = scmd(errf=sys.stderr if stderr == Output.STDERR else self.errlog,
                         outf=sys.stdout if stdout == Output.STDOUT else self.outlog)
        self.sobj.logger = get_samba_logger(name='ads',
                    verbose=(int(debug_level) > 0),
                    quiet=False, fmt='%(message)s')

    def set_realm(self, realm):
        self.sambaopts._set_realm(realm)

    def run(self, *args, **kwargs):
        try:
            self.sobj.run(sambaopts=self.sambaopts,
                          credopts=self.credopts,
                          *args, **kwargs)
            return (self.outlog.getvalue(), self.errlog.getvalue())
        except Exception as e:
            sys.stderr.write('%s exception:\n%s\n' % (type(e).__name__, str(e)))
            return -1

def clean_samba_db():
    lp = LoadParm()
    samba_dirs = [lp.state_path('lock'),
                  os.path.realpath(lp.state_path('.')),
                  os.path.realpath(lp.cache_path('.')),
                  os.path.realpath(lp.private_path('.'))]
    for topdir in samba_dirs:
        for root, dirs, files in os.walk(topdir):
            for fname in files:
                if fname.endswith('.tdb') or fname.endswith('.ldb'):
                    os.remove(os.path.join(root, fname))

samba_version = None
def get_samba_version():
    global samba_version
    if not samba_version:
        samba_version = re.findall('\d+\.\d+\.\d', samba.version)[0]
    return samba_version

default_realm = None
def get_default_realm():
    global default_realm
    if not default_realm:
        lp = LoadParm()
        lp.load_default()
        default_realm = lp.get('realm')
    return default_realm

def get_netbios_name(creds, realm):
    net = Net(creds)
    cldap_ret = net.finddc(domain=realm, flags=(nbt.NBT_SERVER_LDAP | nbt.NBT_SERVER_DS))
    return cldap_ret.domain_name

def user_list(creds):
    ulist = SambaToolFactory(cmd_user_list, creds=creds, stdout=Output.PIPE)
    ret = ulist.run(H='ldap://%s' % ulist.realm)
    if ret != -1:
        return ret[0].strip().split('\n')
    else:
        return []

def group_list(creds):
    glist = SambaToolFactory(cmd_group_list, creds=creds, stdout=Output.PIPE)
    ret = glist.run(H='ldap://%s' % glist.realm)
    if ret != -1:
        return ret[0].strip().split('\n')
    else:
        return []

def getpwnam(user):
    try:
        return ':'.join([str(i) for i in pwd.getpwnam(user)])
    except KeyError:
        return ''

def getgrnam(group):
    try:
        g = grp.getgrnam(group)
        return '%s:x:%d:%s' % (group, g.gr_gid, ','.join(g.gr_mem))
    except KeyError:
        return group

def getpwent(creds):
    ulist = []
    for d in pwd.getpwall():
        ulist.append(':'.join([str(i) for i in d]))
    for u in user_list(creds):
        ulist.append(getpwnam(u))
    return ulist

def getgrent(creds):
    for d in grp.getgrall():
        print(':'.join([str(i) for i in d[:-1]])+':'+','.join(d.gr_mem))
    for g in group_list(creds):
        print(getgrnam(g))

def get_creds(args):
    creds = Credentials()
    if args.u:
        creds.parse_string(args.u)
    else:
        creds.parse_string(getuser())
    if args.w:
        creds.set_password(args.w)
    else:
        creds.set_password(getpass("%s's Password: " % creds.get_username()))
    return creds

def ldap_posix_user(creds, user, container):
    l = ldap_open(get_default_realm(), creds)
    results = l.search(container, ldb.SCOPE_SUBTREE, '(cn=%s)' % user, ['sAMAccountName', 'uidNumber', 'gidNumber', 'gecos', 'homeDirectory', 'loginShell'])
    if len(results) == 1:
        result = results[0]
        data = '%s:x' % result['sAMAccountName'][-1]
        for key in ['uidNumber', 'gidNumber', 'gecos', 'homeDirectory', 'loginShell']:
            arg = result[key][-1] if key in result.keys() else ''
            data += ':%s' % arg
        return data

def ldap_posix_pwent(creds, container):
    ulist = []
    for u in user_list(creds):
        ulist.append(ldap_posix_user(creds, u, container))
    return ulist

def nss_getpwnam(args):
    if args.direct:
        creds = get_creds(args)
        pw = ldap_posix_user(creds, args.object, user_container(creds))
    else:
        realm = get_default_realm()
        pw = getpwnam(args.object)
        if not pw and not args.object.lower().endswith(realm.lower()) and not args.object.lower().startswith(realm.lower()):
            pw = getpwnam('%s@%s' % (args.object, realm))
    if pw:
        print(pw)

def nss_getpwuid(args):
    pw = getpwnam(args.object)
    if pw:
        print(pw)

def nss_getgrnam_getgrid(args):
    gr = getgrnam(args.object)
    realm = get_default_realm()
    if not gr and not args.object.lower().endswith(realm.lower()) and not args.object.lower().startswith(realm.lower()):
        gr = getgrnam('%s@%s' % (args.object, realm))
    if gr:
        print(gr)

def nss_getpwent(args):
    creds = get_creds(args)
    if 'direct' in args and args.direct:
        ulist = ldap_posix_pwent(creds, user_container(creds))
    else:
        ulist = getpwent(creds)
    for pw in ulist:
        if pw and pw.strip():
            print(pw)

def nss_getgrent(args):
    creds = get_creds(args)
    getgrent(creds)

def create_user(args):
    unix_attrs = {}
    if args.i:
        data = args.i.split(':')
        if len(data) != 7:
            sys.stderr.write('Invalid unix attrs initializer string')
            return -1
        unix_attrs['uid'] = data[0]
        if not args.e:
            unix_attrs['uid_number'] = data[2]
        else:
            uid_number = data[2]
        unix_attrs['gid_number'] = data[3]
        unix_attrs['gecos'] = data[4]
        unix_attrs['unix_home'] = data[5]
        unix_attrs['login_shell'] = data[6]
    if not args.e:
        uadd = SambaToolFactory(cmd_user_add, args)
        uadd.run(args.object, H='ldap://%s' % uadd.realm, **unix_attrs)
    else:
        uunix_attrs = SambaToolFactory(cmd_user_add_unix_attrs, args)
        uunix_attrs.run(args.object, uid_number, H='ldap://%s' % uunix_attrs.realm, **unix_attrs)

def create_group(args):
    unix_attrs = {}
    if args.i:
        data = args.i.split(':')
        if len(data) != 4:
            sys.stderr.write('Invalid unix attrs initializer string')
            return -1
        unix_attrs['gid_number'] = data[2]
    if not args.e:
        gadd = SambaToolFactory(cmd_group_add, args)
        if 'gid_number' in unix_attrs:
            unix_attrs['nis_domain'] = gadd.realm
        gadd.run(args.object, H='ldap://%s' % gadd.realm, **unix_attrs)
    elif 'gid_number' in unix_attrs:
        gunix_attrs = SambaToolFactory(cmd_group_add_unix_attrs, args)
        if 'gid_number' in unix_attrs:
            unix_attrs['nis_domain'] = gunix_attrs.realm
        gunix_attrs.run(args.object, unix_attrs['gid_number'], H='ldap://%s' % gunix_attrs.realm)

def delete_user(args):
    delu = SambaToolFactory(cmd_user_delete, args)
    delu.run(args.object, H='ldap://%s' % delu.realm)

def delete_group(args):
    delg = SambaToolFactory(cmd_group_delete, args)
    delg.run(args.object, H='ldap://%s' % delg.realm)

def user_checklogin(args):
    p = pam.pam()
    p.authenticate(args.user, getpass())
    if p.code == 0:
        print('Authentication succeeded')
    else:
        print('Authentication failed, reason: %s' % p.reason)
    return p.code

def passwd(args):
    if args.r:
        password = samba.generate_random_password(8, 20)
    else:
        password = None
    while True:
        if password is not None and password != '':
            break
        password = getpass("New Password: ")
        passwordverify = getpass("Retype Password: ")
        if not password == passwordverify:
            password = None
            print("Sorry, passwords do not match.")
    creds = get_creds(args)
    net = Net(creds)
    if args.object:
        net.set_password(args.object, get_default_realm(), password)
    else:
        net.change_password(password)
    if args.o:
        print(password)

# http://stackoverflow.com/questions/33188413/python-code-to-convert-from-objectsid-to-sid-representation
def convert_objsid_to_sidstr(binary):
    version = struct.unpack('B', binary[0])[0]
    # I do not know how to treat version != 1 (it does not exist yet)
    assert version == 1, version
    length = struct.unpack('B', binary[1])[0]
    authority = struct.unpack('>Q', '\x00\x00' + binary[2:8])[0]
    string = 'S-%d-%d' % (version, authority)
    binary = binary[8:]
    assert len(binary) == 4 * length
    for i in xrange(length):
        value = struct.unpack('<L', binary[4*i:4*(i+1)])[0]
        string += '-%d' % (value)
    return string

def realm_to_dn(realm):
    return ','.join(['dc=%s' % part for part in realm.split('.')])

ldap_open_connections = {}
def ldap_open(realm, creds):
    global ldap_open_connections, debug_level
    key = '%s:%s' % (realm, creds.get_username())
    if key not in ldap_open_connections.keys():
        lp = LoadParm()
        lp.load_default()
        lp.set('log level', debug_level)
        creds.guess(lp)
        l = Ldb(url="ldap://%s" % realm,
                session_info=system_session(),
                credentials=creds, lp=lp)
        ldap_open_connections[key] = l
    return ldap_open_connections[key]

wkguiduc = 'A9D1CA15768811D1ADED00C04FD8D5CD'
uc = None
def user_container(creds):
    global uc, wkguiduc
    if not uc:
        l = ldap_open(get_default_realm(), creds)
        results = l.search('<WKGUID=%s,%s>' % (wkguiduc, realm_to_dn(get_default_realm())), ldb.SCOPE_SUBTREE, '(objectClass=container)', ['distinguishedName'])
        uc = results[0]['distinguishedName'][-1]
    return uc

def print_ldap_object(obj, sidstr):
    for key in obj.keys():
        if key in ['logonHours', 'objectGUID', 'objectSid']:
            if key == 'objectSid' and sidstr:
                obj[key] = [convert_objsid_to_sidstr(o) for o in obj[key]]
            else:
                obj[key] = [b64encode(o) for o in obj[key]]
            for ob in obj[key]:
                print('%s:: %s' % (key, ob))
        else:
            if type(obj[key]) == MessageElement:
                for i in range(len(obj[key])):
                    print('%s: %s' % (key, obj[key].get(i)))
            else:
                print('%s: %s' % (key, obj[key]))

def attrs(args):
    if args.object:
        creds = get_creds(args)
        l = ldap_open(get_default_realm(), creds)
        container = args.c
        if not container:
            container = user_container(creds)
        results = l.search(container, ldb.SCOPE_SUBTREE, '(cn=%s)' % args.object, args.attributes)
        for result in results:
            print_ldap_object(result, args.b)
            print()
    else:
        args.help_func()

def getdn(cn, container=None):
    creds = get_creds(args)
    l = ldap_open(get_default_realm(), creds)
    if not container:
        container = user_container(creds)
    results = l.search(container, ldb.SCOPE_SUBTREE, '(cn=%s)' % cn, ['distinguishedName'])
    return results[0]['distinguishedName'][-1]

def setattrs(args):
    creds = get_creds(args)
    l = ldap_open(get_default_realm(), creds)
    dn = getdn(args.object)
    ldif = 'dn: %s\nchangetype: modify\nreplace: %s\n%s: %s' % \
            (dn, args.attribute, args.attribute, args.value)
    l.modify_ldif(ldif)

def timesync(args):
    ntpdate = which('ntpdate')
    if args.s and ntpdate:
        service('ntpd', 'stop')
        Popen([ntpdate, args.s]).wait()
        service('ntpd', 'start')

def config_ntp(servers):
    ntp_conf = '/etc/ntp.conf'
    if not os.path.exists(ntp_conf):
        sys.stderr.write('ntp not found. Package install required.\nzypper in ntp\n')
        sys.exit(1)

    # stop the ntp service
    service('ntpd', 'stop')

    config = ''
    for line in open(ntp_conf):
        if line.strip() and line.strip().split()[0] != 'server': # throw out old server list
            config += line
    for server in servers:
        config += '\nserver %s\n' % server
    of = open(ntp_conf, 'w')
    of.write(config)

    # tell ntp to update the time
    Popen(['/usr/sbin/ntpdate', servers[0]]).wait()

    # restart ntp
    service('ntpd', 'start')

def config_smb_conf(creds, domain, autogen=False, server=False):
    conf = LoadParm()
    conf.load_default()
    smb_conf = conf.configfile if conf.configfile else default_path()
    if not server:
        netbios = get_netbios_name(creds, domain)
        conf.set('security', 'ads')
        conf.set('workgroup', domain.split('.')[0].upper())
        conf.set('realm', domain.upper())
        conf.set('log file', '/var/log/samba/%m.log')
        conf.set('log level', '1')
        conf.set('passdb backend', 'tdbsam')
        conf.set('map to guest', 'Bad User')
        conf.set('logon path', '\\\\%L\\profiles\\.msprofile')
        conf.set('logon home', '\\\\%L\\%U\\.9xprofile')
        conf.set('logon drive', 'P:')
        conf.set('usershare allow guests', 'yes')
        conf.set('winbind offline logon', 'yes')
        if autogen:
            conf.set('template shell', '/bin/bash')
            conf.set('template homedir', '/home/%D/%U')
            conf.set('idmap config * : backend', 'tdb')
            conf.set('idmap config * : range', '2000-3999')
            conf.set('idmap config %s : backend' % netbios, 'rid')
            conf.set('idmap config %s : range' % netbios, '4000-99999')
            if float(get_samba_version()[:3]) >= 4.6:
                conf.set('idmap config *:unix_nss_info', 'no')
            else:
                conf.set('winbind nss info', 'template')
        else:
            conf.set('idmap config *:backend', 'autorid')
            conf.set('idmap config *:range', '2000-3999')
            conf.set('idmap config %s:backend' % netbios, 'ad')
            conf.set('idmap config %s:schema_mode' % netbios, 'rfc2307')
            conf.set('idmap config %s:range' % netbios, '4000-99999')
            if float(get_samba_version()[:3]) >= 4.6:
                conf.set('idmap config %s:unix_nss_info' % netbios, 'yes')
            else:
                conf.set('winbind nss info', 'rfc2307')
    else:
        conf.set('winbind nss info', 'template')
        conf.set('template shell', '/bin/bash')
        conf.set('template homedir', '/home/%D/%U')
    with NamedTemporaryFile(delete=False,
                            dir=os.path.dirname(smb_conf)) as f:
        conf.dump(False, f.name)
        if os.path.exists(smb_conf):
            mode = os.stat(smb_conf).st_mode
            os.chmod(f.name, mode)
        else:
            os.chmod(f.name, 0o644)
        os.rename(f.name, smb_conf)

def config_winbind_conf():
    global winbind_conf
    conf = ConfigParser()
    if os.path.exists(winbind_conf):
        conf.read(winbind_conf)
    if 'global' not in conf.sections():
        conf.add_section('global')
    conf.set('global', 'cached_login', 'yes')
    conf.set('global', 'krb5_auth', 'yes')
    conf.set('global', 'krb5_ccache_type', 'FILE')
    conf.set('global', 'warn_pwd_expire', '14')
    conf.set('global', 'mkhomedir', 'yes')
    of = open(winbind_conf, 'w')
    conf.write(of)
    of.close()

def configure_nss(args):
    global nsswitch
    try:
        conf = ''
        for line in open(nsswitch):
            fore = line.strip().split(':')[0]
            if (fore == 'passwd' or fore == 'group') and 'winbind' not in line:
                conf += '%s winbind\n' % line.replace('[NOTFOUND=return]', '').rstrip()
            else:
                conf += line
        of = open(nsswitch, 'w')
        of.write(conf)
    except IOError:
        sys.stderr.write('Configure nsswitch failed, filename \'%s\' not found\n' % nsswitch)

def unconfigure_nss(args):
    global nsswitch
    try:
        conf = ''
        for line in open(nsswitch):
            fore = line.strip().split(':')[0]
            if (fore == 'passwd' or fore == 'group') and 'winbind' in line:
                conf += line.replace('winbind', '')
            else:
                conf += line
        of = open(nsswitch, 'w')
        of.write(conf)
    except IOError:
        sys.stderr.write('Unconfigure nsswitch failed, filename \'%s\' not found\n' % nsswitch)

def configure_pam(args):
    if getuser() != 'root':
        sys.stderr.write('ads configure must be run as root\n')
        sys.exit(1)
    pam_config = which('pam-config')
    if pam_config:
        return Popen([pam_config, '--add', '--winbind']).wait()

def unconfigure_pam(args):
    if getuser() != 'root':
        sys.stderr.write('ads configure must be run as root\n')
        sys.exit(1)
    pam_config = which('pam-config')
    if pam_config:
        return Popen([pam_config, '--delete', '--winbind']).wait()

def configure(args):
    configure_pam(args)
    configure_nss(args)

def unconfigure(args):
    unconfigure_pam(args)
    unconfigure_nss(args)

def ip_addrs():
    ifaces = [netifaces.ifaddresses(interface) for interface in netifaces.interfaces()]
    return [data[2][0]['addr'] for data in ifaces if len(data)>2 and len(data[2])>0 and 'addr' in data[2][0] and data[2][0]['addr'] != '127.0.0.1']

def update_hostname(hostname, domain):
    # Retrieve the current hostname if one was specified
    shortname = None
    if not hostname:
        hostname = socket.gethostname()
    if not domain.lower() in hostname:
        shortname = hostname
        hostname = '%s.%s\n' % (hostname, domain)
    else:
        shortname = hostname.split('.')[0]

    # Update the /etc/hostname file
    hf = open('/etc/hostname', 'w')
    hf.write('%s' % hostname)

    # Update the in-memory hostname
    socket.sethostname(hostname)

    # Add an entry to /etc/hosts so we can resolve our own name
    hosts = '/etc/hosts'
    ips = ip_addrs()
    conf = ''
    for line in open(hosts):
        if hostname not in line:
            conf += line
    conf = conf.strip()
    conf += '\n'
    # Only the last one is used, may need to be manually configured
    for ip in ips:
        conf += '%s\t%s %s\n' % (ip, hostname, shortname)
    of = open(hosts, 'w')
    of.write(conf)

    return (hostname, shortname)

def remove_hosts_config():
    hostname = Popen(['hostname'], stdout=PIPE).communicate()[0].strip().decode('utf-8')
    hosts = '/etc/hosts'
    conf = ''
    for line in open(hosts):
        if hostname not in line:
            conf += line
    of = open(hosts, 'w')
    of.write(conf)

def krb5_basic_conf(domain):
    global krb5_conf
    kof = open(krb5_conf, 'w')
    kof.write('[libdefaults]\n')
    kof.write('\tdns_lookup_realm = false\n')
    kof.write('\tdns_lookup_kdc = true\n')
    kof.write('\tdefault_realm = %s\n' % domain)

def config_krb5_conf(domain, server):
    global krb5_conf
    kof = open(krb5_conf, 'w')
    kof.write('[libdefaults]\n')
    kof.write('\tdefault_realm = %s\n' % domain)
    kof.write('\tclockskew = 300\n')
    kof.write('\tticket_lifetime = 1d\n')
    kof.write('\tforwardable = true\n')
    kof.write('\tproxiable = true\n')
    kof.write('\tdns_lookup_realm = true\n')
    kof.write('\tdns_lookup_kdc = true\n')
    kof.write('\tudp_preference_limit = 1\n') # disable UDP packets
    kof.write('\n\n[realms]\n')
    kof.write('\t%s = {\n\t\tkdc = %s\n\t\tadmin_server = %s\n\t\tdefault_domain = %s\n\t}\n' % (domain, server, server, domain))

def list_servers(domain):
    return [str(resolver.query(reversename.from_address(r.address),"PTR")[0])[:-1] for r in resolver.query(domain, 'A')]

def service(name, command):
    service = which('service')
    systemctl = which('systemctl')
    if systemctl:
        return Popen([systemctl, command, name], stdout=DEVNULL, stderr=DEVNULL).wait()
    elif service and command not in ['enable', 'disable']:
        return Popen([service, name, command], stdout=DEVNULL, stderr=DEVNULL).wait()
    return -1

def stash_config():
    global krb5_conf
    # Stash any old krb5.conf
    if os.path.exists(krb5_conf):
        now = datetime.datetime.now()
        new_krb5_conf = '%s.%s' % (krb5_conf, now.strftime('%b-%d-%Y_%I:%M%p'))
        print('Stashing krb5.conf to %s...' % new_krb5_conf)
        os.rename(krb5_conf, new_krb5_conf)

def join(args):
    if getuser() != 'root':
        sys.stderr.write('ads join must be run as root\n')
        sys.exit(1)
    creds = get_creds(args)

    # Set the default realm, else kinit will fail
    default_realm = args.domain

    for srv in ['samba', 'smbd', 'nmbd', 'winbind', 'nscd']:
        service(srv, 'stop')
    service('nscd', 'disable')

    server = None
    if not args.servers:
        args.servers = list_servers(args.domain)
    server = args.servers[-1]

    stash_config()

    # configure kerberos
    print('Configuring kerberos...')
    config_krb5_conf(args.domain.upper(), server.upper())

    # Configure ntp
    print('Adding ntp servers and time syncing with AD...')
    if args.servers:
        config_ntp(args.servers)
    else:
        config_ntp([server])

    # Update hostname
    print('Updating hostname...')
    hostname = update_hostname(args.n, args.domain)

    if not hasattr(args, 'domain_controller') or not args.domain_controller:
        # configure smb.conf
        print('Configuring smb.conf...')
        config_smb_conf(creds, args.domain, autogen=args.autogen_posix_attrs)

        # configure pam_winbind.conf
        print('Configuring pam_winbind.conf...')
        config_winbind_conf()

        # net ads join the domain
        print('Joining the domain...')
        cmd = [which('net'), '--configfile=%s' % default_path(), 'ads', 'join', '-U%s%%%s' % (creds.get_username(), creds.get_password()), '-d', debug_level, '-S', server]
        if debug_level != '0':
            print(' '.join(cmd))
        ret = Popen(cmd).wait()
        if ret != 0:
            return ret

        # start winbind
        sys.stdout.write('Starting winbind... ')
        ret = service('winbind', 'start')
        if ret != 0:
            sys.stdout.write('failed\n')
        else:
            sys.stdout.write('ok\n')

        sys.stdout.write('Enabling winbind service... ')
        ret = service('winbind', 'enable')
        if ret != 0:
            sys.stdout.write('failed\n')
        else:
            sys.stdout.write('ok\n')
    else:
        # Cleanup old samba database files
        clean_samba_db()

        # samba-tool domain join
        print('Joining the domain as a Domain Controller...')
        join = SambaToolFactory(cmd_domain_join, args, creds)
        ret = join.run(domain=args.domain, role='DC')
        if ret == -1:
            return ret

        # configure smb.conf
        print('Configuring smb.conf...')
        config_smb_conf(creds, args.domain, server=True)

        # configure pam_winbind.conf
        print('Configuring pam_winbind.conf...')
        config_winbind_conf()

        # start samba
        sys.stdout.write('Starting samba... ')
        ret = service('samba', 'start')
        if ret != 0:
            sys.stdout.write('failed\n')
        else:
            sys.stdout.write('ok\n')

        sys.stdout.write('Enabling samba service... ')
        ret = service('samba', 'enable')
        if ret != 0:
            sys.stdout.write('failed\n')
        else:
            sys.stdout.write('ok\n')

        # make sure the A record was added 
        dns_add = SambaToolFactory(cmd_add_record, args, creds)
        for ip in ip_addrs():
            print('Verifying the DC DNS Record...')
            cmd = ['host', '-t', 'A', hostname[0]]
            if debug_level != '0':
                print(' '.join(cmd))
            ret = Popen(cmd, stdout=DEVNULL, stderr=DEVNULL).wait()
            if ret != 0:
                sys.stdout.write('Creating the DC DNS Record... ')
                dns_add.run(server, args.domain, hostname[1], 'A', ip)

        print('Verifying the objectGUID Record...')
        sleep(3) # Creating the objectGUID dies if we don't sleep a bit
        objectGUID = None
        l = ldap_open(args.domain, creds)
        results = l.search("CN=Sites,CN=Configuration,%s" % realm_to_dn(args.domain), ldb.SCOPE_SUBTREE, "(invocationId=*)", ["objectguid"])
        for result in results:
            if 'CN=NTDS Settings,CN=%s,' % hostname[1] in result[0]:
                objectGUID = str(uuid.UUID(bytes=result[1]['objectGUID'][-1]))
                break
        if objectGUID:
            cmd = ['host', '-t', 'CNAME', '%s._msdcs.%s' % (objectGUID, args.domain)]
            if debug_level != '0':
                print(' '.join(cmd))
            ret = Popen(cmd).wait()
            if ret != 0:
                sys.stdout.write('Creating the objectGUID Record... ')
                dns_add.run(hostname[1], '_msdcs.%s' % args.domain,
                            objectGUID, 'CNAME', hostname[0])

    # configure nss
    print('Configuring nsswitch.conf...')
    configure_nss(args)

    # configure pam
    print('Configuring pam...')
    configure_pam(args)

def provision(args):
    if getuser() != 'root':
        sys.stderr.write('ads provision must be run as root\n')
        sys.exit(1)

    for srv in ['samba', 'smbd', 'nmbd', 'winbind', 'nscd']:
        service(srv, 'stop')
    service('nscd', 'disable')

    # Update hostname
    print('Updating hostname...')
    hostname = update_hostname(args.host_name, args.domain)

    stash_config()
    clean_samba_db()

    print('Provision the domain controller...')
    provision = SambaToolFactory(cmd_domain_provision, args)
    provision.set_realm(args.domain.upper())
    ret = provision.run(interactive=args.interactive, domain=args.domain,
                        domainguid=args.domain_guid, domainsid=args.domain_sid,
                        hostname=args.host_name,
                        hostip=args.host_ip, hostip6=args.host_ip6,
                        sitename=args.site, ntdsguid=args.ntds_guid,
                        invocationid=args.invocationid,
                        adminpass=args.adminpass,
                        krbtgtpass=args.krbtgtpass,
                        dns_backend=args.dns_backend,
                        dnspass=args.dnspass, root=args.root,
                        nobody=args.nobody, users=args.users, blank=args.blank,
                        serverrole=args.server_role,
                        function_level=args.function_level,
                        next_rid=args.next_rid,
                        partitions_only=args.partitions_only,
                        use_rfc2307=args.use_rfc2307)
    if ret == -1:
        return ret

    # configure smb.conf
    print('Configuring smb.conf...')
    config_smb_conf(creds, args.domain, server=True)

    # configure pam_winbind.conf
    print('Configuring pam_winbind.conf...')
    config_winbind_conf()

    print('Configuring kerberos...')
    lp = LoadParm()
    lp.load_default()
    cmd = ['ln', '-sf', os.path.join(lp.get('private directory'), 'krb5.conf'), krb5_conf]
    if debug_level != '0':
        print(' '.join(cmd))
    Popen(cmd).wait()

    # start samba
    sys.stdout.write('Starting samba... ')
    ret = service('samba', 'start')
    if ret != 0:
        sys.stdout.write('failed\n')
    else:
        sys.stdout.write('ok\n')

    sys.stdout.write('Enabling samba service... ')
    ret = service('samba', 'enable')
    if ret != 0:
        sys.stdout.write('failed\n')
    else:
        sys.stdout.write('ok\n')

    # configure nss
    print('Configuring nsswitch.conf...')
    configure_nss(args)

    # configure pam
    print('Configuring pam...')
    configure_pam(args)

def demote(args):
    global debug_level, krb5_conf
    if getuser() != 'root':
        sys.stderr.write('ads demote must be run as root\n')
        sys.exit(1)
    creds = get_creds(args)

    print('Demoting the ADDC...')
    demote = SambaToolFactory(cmd_domain_demote, args)
    demote.run()

    lp = LoadParm()
    lp.load_default()
    print('Unmounting the sysvol...')
    Popen(['umount', lp.get('path', 'sysvol')]).wait()

    print('Disabling the samba service...')
    for srv in ['samba', 'smbd', 'nmbd', 'winbind']:
        service(srv, 'stop')
        service(srv, 'disable')

    print('Unconfiguring pam...')
    unconfigure_pam(args)

    print('Unconfiguring nss...')
    unconfigure_nss(args)

    # Cleanup old samba database files
    print('Deleting samba database files...')
    clean_samba_db()

    if os.path.exists(default_path()):
        print('Deleting smb.conf...')
        os.remove(default_path())
    if os.path.exists(krb5_conf):
        print('Deleting krb5.conf...')
        os.remove(krb5_conf)


def unjoin(args):
    global debug_level, krb5_conf
    lp = LoadParm()
    lp.load_default()
    if lp.get('server role') == 'active directory domain controller':
        demote(args)
    else:
        if getuser() != 'root':
            sys.stderr.write('ads unjoin must be run as root\n')
            sys.exit(1)
        creds = get_creds(args)

        print('Unjoining the domain...')
        cmd = [which('net'), 'ads', 'leave', '-d', debug_level, '-U%s%%%s' % (creds.get_username(), creds.get_password())]
        if debug_level != '0':
            print(' '.join(cmd[:-1]), '-U%s' % admin)
        Popen(cmd).wait()

        remove_hosts_config()

        for srv in ['samba', 'smbd', 'nmbd', 'winbind']:
            service(srv, 'stop')
            service(srv, 'disable')

        print('Unconfiguring pam...')
        unconfigure_pam(args)

        print('Unconfiguring nss...')
        unconfigure_nss(args)

        # Cleanup old samba database files
        clean_samba_db()

        if os.path.exists(default_path()):
            print('Deleting smb.conf...')
            os.remove(default_path())
        if os.path.exists(krb5_conf):
            print('Deleting krb5.conf...')
            os.remove(krb5_conf)

def info_domain(args):
    print(get_default_realm())

def info_cldap(args):
    creds = Credentials()
    lp = LoadParm()
    lp.load_default()

    netcmd = Net(creds, lp)
    cldap_ret = netcmd.finddc(address=args.server, flags=(nbt.NBT_SERVER_LDAP | nbt.NBT_SERVER_DS))
    print('DomainGuid:\t\t%s' % cldap_ret.domain_uuid)
    print('DnsForestName:\t\t%s' % cldap_ret.forest)
    print('DnsDomainName:\t\t%s' % cldap_ret.dns_domain)
    print('DnsHostName:\t\t%s' % cldap_ret.pdc_dns_name)
    print('NetbiosDomainName:\t%s' % cldap_ret.domain_name)
    print('NetbiosComputerName:\t%s' % cldap_ret.pdc_name)
    print('UserName:\t\t%s' % cldap_ret.user_name)
    print('DcSiteName:\t\t%s' % cldap_ret.server_site)
    print('ClientSiteName:\t\t%s' % cldap_ret.client_site)

def kcc_cmd(name):
    path = which(name)
    if path is None:
        sys.stderr.write('%s is not in your path\n' % name)
        sys.exit(1)
    return path

def kinit(args):
    cmd = [kcc_cmd('kinit')]
    cmd.extend(args)
    Popen(cmd).wait()

def klist(args):
    cmd = [kcc_cmd('klist')]
    cmd.extend(args)
    Popen(cmd).wait()

def kdestroy(args):
    cmd = [kcc_cmd('kdestroy')]
    cmd.extend(args)
    Popen(cmd).wait()

def ktutil(args):
    cmd = [kcc_cmd('ktutil')]
    cmd.extend(args)
    Popen(cmd).wait()

def flush(args):
    if getuser() != 'root':
        sys.stderr.write('ads flush must be run as root\n')
        sys.exit(1)
    for srv in ['samba', 'smbd', 'nmbd', 'winbind']:
        service(srv, 'stop')

    lp = LoadParm()
    lp.load_default()
    gencache = tdb.open(lp.state_path('lock/gencache.tdb'), tdb_flags=tdb.NOLOCK)
    gencache.clear()

    if lp.get('server role') == 'active directory domain controller':
        service('samba', 'start')
    else:
        service('winbind', 'start')

def daemon(args):
    if getuser() != 'root':
        sys.stderr.write('ads daemon must be run as root\n')
        sys.exit(1)
    return service(args.service, args.action)

def inspect(args):
    lp = LoadParm()
    lp.load_default()
    print(lp.get(args.setting, args.section))

def is_user_ad(args):
    try:
        local = pwd.getpwnam(args.name) is not None
    except KeyError:
        local = False
    ad = False
    creds = get_creds(args)
    l = ldap_open(get_default_realm(), creds)
    results = l.search(user_container(creds), ldb.SCOPE_SUBTREE, '(&(objectClass=user)(cn=%s))' % args.name, ['dn'])
    if len(results) > 0:
        ad = True
    if ad and local:
        return 4
    elif not ad and not local:
        return 3
    elif local:
        return 2
    elif ad:
        return 0

def is_group_ad(args):
    try:
        local = grp.getgrnam(args.name) is not None
    except KeyError:
        local = False
    ad = False
    creds = get_creds(args)
    l = ldap_open(get_default_realm(), creds)
    results = l.search(user_container(creds), ldb.SCOPE_SUBTREE, '(&(objectClass=group)(cn=%s))' % args.name, ['dn'])
    if len(results) > 0:
        ad = True
    if ad and local:
        return 4
    elif not ad and not local:
        return 3
    elif local:
        return 2
    elif ad:
        return 0

def argparse_add_options(parser, options, ignore=[]):
    '''Add samba options to an argparse parser
    param parser    The parser to append arguments to
    param options   A list of Option objects to modify and add as arguments
    '''
    for opt in options:
        if opt.dest in ignore:
            continue
        kwargs = {}
        if opt.action is not None:
            kwargs['action'] = opt.action
        if opt.type is not None:
            if opt.type == 'string':
                kwargs['type'] = str
            elif opt.type == 'choice':
                pass
            else:
                exec("kwargs['type'] = %s" % opt.type)
        if opt.dest is not None:
            kwargs['dest'] = opt.dest
        if opt.default is not None and opt.default != ('NO', 'DEFAULT'):
            kwargs['default'] = opt.default
        if opt.nargs is not None:
            kwargs['nargs'] = opt.nargs
        if opt.const is not None:
            kwargs['const'] = opt.const
        if opt.choices is not None:
            kwargs['choices'] = opt.choices
        if opt.help is not None:
            kwargs['help'] = opt.help
        if opt.metavar is not None:
            kwargs['metavar'] = opt.metavar
        parser.add_argument(opt.get_opt_string(), **kwargs)

def argparser():
    description = "Active Directory Swiss army knife for samba.\nFor join, unjoin, provisioning, demotion, user/group and password administration,\nldap attribute modification, posix enablement, kdc timesync, pam and nss configuration,\ndaemon start/stop, cache flush, etc.\nThe ads command attempts to maintain compatibility with the proprietary vastool command,\nwhile also adding additional features relevant to samba (such as kdc provisioning)."
    parser = argparse.ArgumentParser(description=description, formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('-v', '--version', action='version', version='ads 2.0', help='ads version')
    parser.add_argument('-u', help='Authenticating user')
    parser.add_argument('-w', help='Authenticating password')
    parser.add_argument('-d', help='debug level')
    parser.set_defaults(func=lambda args: args.help_func())
    parser.set_defaults(help_func=parser.print_help)
    subparsers = parser.add_subparsers()

    nss_parser = subparsers.add_parser('nss', help='Run nss functions')
    nss_subparsers = nss_parser.add_subparsers()
    getpwnam_parser = nss_subparsers.add_parser('getpwnam')
    getpwnam_parser.add_argument('object', help='username')
    getpwnam_parser.add_argument('-d', '--direct', help='bypass the nss layer and return results directly from ldap', action='store_true')
    getpwnam_parser.set_defaults(func=nss_getpwnam)
    getpwuid_parser = nss_subparsers.add_parser('getpwuid')
    getpwuid_parser.add_argument('object', help='uid')
    getpwuid_parser.set_defaults(func=nss_getpwuid)
    getgrnam_parser = nss_subparsers.add_parser('getgrnam')
    getgrnam_parser.add_argument('object', help='groupname')
    getgrnam_parser.set_defaults(func=nss_getgrnam_getgrid)
    getgrid_parser = nss_subparsers.add_parser('getgrid')
    getgrid_parser.add_argument('object', help='gid')
    getgrid_parser.set_defaults(func=nss_getgrnam_getgrid)
    getpwent_parser = nss_subparsers.add_parser('getpwent')
    getpwent_parser.add_argument('-d', '--direct', help='bypass the nss layer and return results directly from ldap', action='store_true')
    getpwent_parser.set_defaults(func=nss_getpwent)
    getgrent_parser = nss_subparsers.add_parser('getgrent')
    getgrent_parser.set_defaults(func=nss_getgrent)
    nss_parser.set_defaults(func=lambda args: args.help_func())
    nss_parser.set_defaults(help_func=nss_parser.print_help)

    user_parser = subparsers.add_parser('user', help='Manage Active Directory users')
    user_subparsers = user_parser.add_subparsers()
    checklogin_parser = user_subparsers.add_parser('checklogin')
    checklogin_parser.add_argument('user')
    checklogin_parser.set_defaults(func=user_checklogin)
    user_parser.set_defaults(func=lambda args: args.help_func())
    user_parser.set_defaults(help_func=user_parser.print_help)

    create_parser = subparsers.add_parser('create', help='Create Active Directory users/groups')
    create_parser.add_argument('-e', help='Operate on an existing object', action='store_true')
    create_parser.add_argument('-i', help='Passwd line for unix enable')
    create_parser.add_argument('-c', help='Container to create object in')
    create_subparsers = create_parser.add_subparsers()
    create_user_parser = create_subparsers.add_parser('user')
    create_user_parser.add_argument('object', help='user')
    create_user_parser.set_defaults(func=create_user)
    create_group_parser = create_subparsers.add_parser('group')
    create_group_parser.add_argument('object', help='group')
    create_group_parser.set_defaults(func=create_group)
    create_parser.set_defaults(func=lambda args: args.help_func())
    create_parser.set_defaults(help_func=create_parser.print_help)

    delete_parser = subparsers.add_parser('delete', help='Delete Active Directory users/groups')
    delete_subparsers = delete_parser.add_subparsers()
    delete_user_parser = delete_subparsers.add_parser('user')
    delete_user_parser.add_argument('object', help='user')
    delete_user_parser.set_defaults(func=delete_user)
    delete_group_parser = delete_subparsers.add_parser('group')
    delete_group_parser.add_argument('object', help='group')
    delete_group_parser.set_defaults(func=delete_group)
    delete_parser.set_defaults(func=lambda args: args.help_func())
    delete_parser.set_defaults(help_func=delete_parser.print_help)

    list_parser = subparsers.add_parser('list', help='List Active Directory users/groups')
    list_subparsers = list_parser.add_subparsers()
    list_users_parser = list_subparsers.add_parser('users')
    list_users_parser.set_defaults(func=nss_getpwent)
    list_groups_parser = list_subparsers.add_parser('groups')
    list_groups_parser.set_defaults(func=nss_getgrent)
    list_parser.set_defaults(func=lambda args: args.help_func())
    list_parser.set_defaults(help_func=list_parser.print_help)

    passwd_parser = subparsers.add_parser('passwd', help='Change Active Directory user passwords')
    passwd_parser.add_argument('-r', help='Sets the password to a random value')
    passwd_parser.add_argument('-o', help='Will output the new password to stdout')
    passwd_parser.add_argument('object', nargs='?')
    passwd_parser.set_defaults(func=passwd)
    passwd_parser.set_defaults(help_func=passwd_parser.print_help)

    attrs_parser = subparsers.add_parser('attrs', help='List Active Directory object attributes')
    attrs_parser.add_argument('-b', action='store_true', help='Convert sid to human readable form')
    attrs_parser.add_argument('-c', help='Container to search in')
    attrs_parser.add_argument('-g', help='Treat the object as a group name', action='store_true')
    attrs_parser.add_argument('object')
    attrs_parser.add_argument('attributes', nargs='*')
    attrs_parser.set_defaults(func=attrs)
    attrs_parser.set_defaults(help_func=attrs_parser.print_help)

    setattrs_parser = subparsers.add_parser('setattrs', help='Modify Active Directory object attributes')
    setattrs_parser.add_argument('object', help='distinguishedName')
    setattrs_parser.add_argument('attribute')
    setattrs_parser.add_argument('value')
    setattrs_parser.set_defaults(func=setattrs)
    setattrs_parser.set_defaults(help_func=setattrs_parser.print_help)

    join_parser = subparsers.add_parser('join', help='Join this computer to an Active Directory domain')
    join_parser.add_argument('--autogen-posix-attrs', action='store_true')
    join_parser.add_argument('--domain-controller', action='store_true', help='Join the machine as a Active Directory Domain Controller member server')
    join_parser.add_argument('--disable-pam', action='store_true', help='Don\'t configure pam during the join')
    join_parser.add_argument('domain')
    join_parser.add_argument('-n', help='Join as hostname')
    join_parser.add_argument('servers', nargs='*')
    join_parser.set_defaults(func=join)
    join_parser.set_defaults(help_func=join_parser.print_help)

    unjoin_parser = subparsers.add_parser('unjoin', help='Unjoin this computer from the Active Directory domain')
    unjoin_parser.set_defaults(func=unjoin)
    unjoin_parser.set_defaults(help_func=unjoin_parser.print_help)

    provision_parser = subparsers.add_parser('provision', help='Provision an Active Directory Domain Controller')
    provision_parser.add_argument('domain', help='NetBIOS domain name to use')
    argparse_add_options(provision_parser, cmd_domain_provision.takes_options, ['domain'])
    provision_parser.set_defaults(func=provision)
    provision_parser.set_defaults(help_func=provision_parser.print_help)

    demote_parser = subparsers.add_parser('demote', help='Demote an Active Directory Domain Controller')
    demote_parser.set_defaults(func=demote)
    demote_parser.set_defaults(help_func=demote_parser.print_help)

    timesync_parser = subparsers.add_parser('timesync', help='Syncronize machine time with Active Directory')
    timesync_parser.add_argument('-s', help='Server to sync with')
    timesync_parser.set_defaults(func=timesync)
    timesync_parser.set_defaults(help_func=timesync_parser.print_help)

    configure_parser = subparsers.add_parser('configure', help='Configure pam and nss for winbind authentication')
    configure_subparsers = configure_parser.add_subparsers()
    configure_pam_parser = configure_subparsers.add_parser('pam')
    configure_pam_parser.set_defaults(func=configure_pam)
    configure_nss_parser = configure_subparsers.add_parser('nss')
    configure_nss_parser.set_defaults(func=configure_nss)
    configure_parser.set_defaults(func=configure)
    configure_parser.set_defaults(help_func=configure_parser.print_help)

    unconfigure_parser = subparsers.add_parser('unconfigure', help='Unconfigure pam and nss for winbind authentication')
    unconfigure_subparsers = unconfigure_parser.add_subparsers()
    unconfigure_pam_parser = unconfigure_subparsers.add_parser('pam')
    unconfigure_pam_parser.set_defaults(func=unconfigure_pam)
    unconfigure_nss_parser = unconfigure_subparsers.add_parser('nss')
    unconfigure_nss_parser.set_defaults(func=unconfigure_nss)
    unconfigure_parser.set_defaults(func=unconfigure)
    unconfigure_parser.set_defaults(help_func=unconfigure_parser.print_help)

    info_parser = subparsers.add_parser('info', help='Get information about the domain')
    info_subparsers = info_parser.add_subparsers()
    info_domain_parser = info_subparsers.add_parser('domain')
    info_domain_parser.set_defaults(func=info_domain)
    info_cldap_parser = info_subparsers.add_parser('cldap')
    info_cldap_parser.add_argument('server')
    info_cldap_parser.set_defaults(func=info_cldap)
    info_parser.set_defaults(func=lambda args: args.help_func())
    info_parser.set_defaults(help_func=info_parser.print_help)

    kinit_parser = subparsers.add_parser('kinit', help='Request an initial ticket-granting ticket', add_help=False)
    kinit_parser.set_defaults(func=kinit)

    klist_parser = subparsers.add_parser('klist', help='Lists the Kerberos principal and Kerberos tickets held in a  credentials  cache', add_help=False)
    klist_parser.set_defaults(func=klist)

    kdestroy_parser = subparsers.add_parser('kdestroy', help='Destroys the user\'s active Kerberos authorization tickets', add_help=False)
    kdestroy_parser.set_defaults(func=kdestroy)

    ktutil_parser = subparsers.add_parser('ktutil', help='Invokes a command interface from which an administrator can read, write, or edit entries in a keytab', add_help=False)
    ktutil_parser.set_defaults(func=ktutil)

    flush_parser = subparsers.add_parser('flush', help='Deletes all cache entries')
    flush_parser.set_defaults(func=flush)
    flush_parser.set_defaults(help_func=flush_parser.print_help)

    daemon_parser = subparsers.add_parser('daemon', help='Start, stop or restart the samba or winbind service')
    daemon_parser.add_argument('action', choices=['start', 'stop', 'restart', 'enable', 'disable', 'status'])
    daemon_parser.add_argument('service', choices=['samba', 'smbd', 'nmbd', 'winbind'])
    daemon_parser.set_defaults(func=daemon)
    daemon_parser.set_defaults(help_func=daemon_parser.print_help)

    inspect_parser = subparsers.add_parser('inspect', help='Returns the value of a configuration file setting')
    inspect_parser.add_argument('section')
    inspect_parser.add_argument('setting')
    inspect_parser.set_defaults(func=inspect)
    inspect_parser.set_defaults(help_func=inspect_parser.print_help)

    parent_isad_parser = argparse.ArgumentParser(add_help=False)
    isad_subparser = parent_isad_parser.add_subparsers()
    isad_parser = subparsers.add_parser('isad', parents=[parent_isad_parser],
            help='Used to check if a given user is an Active Directory user')
    isvas_parser = subparsers.add_parser('isvas', parents=[parent_isad_parser],
            help='Used to check if a given user is an Active Directory user. This is an alias to isad')
    isad_user_parser = isad_subparser.add_parser('user')
    isad_user_parser.add_argument('name')
    isad_group_parser = isad_subparser.add_parser('group')
    isad_group_parser.add_argument('name')
    isad_user_parser.set_defaults(func=is_user_ad)
    isad_group_parser.set_defaults(func=is_group_ad)
    isad_parser.set_defaults(func=lambda args: args.help_func())
    isad_parser.set_defaults(help_func=isad_parser.print_help)
    isvas_parser.set_defaults(func=lambda args: args.help_func())
    isvas_parser.set_defaults(help_func=isvas_parser.print_help)

    return parser

if __name__ == "__main__":
    parser = argparser()
    args, unknownargs = parser.parse_known_args()

    if args.d:
        debug_level = args.d

    if args.func in [kinit, klist, kdestroy, ktutil]:
        sys.exit(args.func(unknownargs))
    elif len(unknownargs) == 0:
        sys.exit(args.func(args))
    else:
        sys.exit(args.help_func())
