#!/usr/bin/python
# Copyright (c) 2015 SUSE Linux GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from pprint import pformat
import os
import sys
import re
import logging
import cmdln
from FileListExtractor import FileListExtractor, Channel, RepoData
import pickle
import signal
import subprocess
import time
import yaml
import psutil
import traceback
from collections import namedtuple

try:
    from xml.etree import cElementTree as ET
except ImportError:
    import cElementTree as ET

USER_CONFIG_DIR = os.environ.get('XDG_CONFIG_HOME', os.path.expanduser('~/.config'))
SYS_CONFIG_DIRS = os.environ.get('XDG_CONFIG_DIRS', '/etc/xdg').split(':')
CONFIG_DIRS = [USER_CONFIG_DIR] + SYS_CONFIG_DIRS
CONFIG_DIRS = [os.path.join(d, 'rpmlint-backports-tool') for d in CONFIG_DIRS]
BLACKLIST = ['java-.*-ibm.*']
DBFILE = 'SLE-file-package-map.db'
module_re = re.compile(".*(sle-module|sle-manager)", flags=re.IGNORECASE)

PickleResult = namedtuple('PickleResult', 'newpkgs newfiles missing')


def memory_usage_psutil():
    '''Return memory usage in megabytes'''
    process = psutil.Process(os.getpid())

    # First try current api and fall back to old api on error
    try:
        mem = process.memory_info()[0] / float(2 ** 20)
    except AttributeError:
        mem = process.get_memory_info()[0] / float(2 ** 20)
    return mem


def service_pack(name, SP=None, url=False):
    if name is None:
        return None
    if (not SP or module_re.match(name)):
        return name
    if SP == 'SP0':
        return name
    if not url:
        return '%s-%s' % (name, SP)
    return re.sub(r'([_\-]12)([_\-:])', r'\1-%s\2' % SP, name)


class Tool(cmdln.Cmdln):
    def __init__(self, *args, **kwargs):
        cmdln.Cmdln.__init__(self, args, kwargs)

        self.ex = FileListExtractor()
        self.ex.save_channels_files = True
        self._channels = None
        self._channels_by_arch = {}
        self._config = None

    def reset(self):
        """ Clear any cached data """
        self._channels = None

    @property
    def config(self):
        if self._config is None:
            raise Exception("No config file loaded")
        return self._config

    @config.setter
    def config(self, data):
        self._config = data

    @property
    def channels(self):
        if self._channels is None:
            self._channels = self._get_channels()
        return self._channels

    @property
    def channels_by_arch(self):
        if self.channels is not None:
            return self._channels_by_arch

    def get_optparser(self):
        parser = cmdln.CmdlnOptionParser(self)
        parser.add_option("--dry", action="store_true", help="dry run")
        parser.add_option("--debug", action="store_true", help="debug output")
        parser.add_option("--verbose", action="store_true", help="verbose")
        parser.add_option("--servicepack", type="string", dest="servicepack", help="service pack", default=None)
        parser.add_option("--config", type="string", help="config file", default=None)
        parser.add_option("--workpath", type="string", help="work path", default='.')
        return parser

    @property
    def servicepacks(self):
        if self.options.servicepack:
            return [self.options.servicepack]
        return self.config.get('servicepacks', None)

    def _get_channels(self):
        channels = []
        chan_url = {}
        chan_url['product'] = self.config.get('product_channels_url', None)
        chan_url['update'] = self.config.get('update_channels_url', None)
        chan_url['local'] = self.config.get('local_channels_url', '{name}')
        channel_configs = self.config.get('channels', None)
        channel_schema = self.config.get('channel_schema', '_channel')

        if channel_configs is None:
            raise Exception("No channels in config")
        if chan_url['product'] is None:
            raise Exception("No product_channels_url in config")
        if chan_url['update'] is None:
            raise Exception("No update_channels_url in config")

        for ID, config in channel_configs.items():
            archs = config.get('archs', ['x86_64'])
            for channeltype in ['product', 'update', 'local']:
                name = config.get(channeltype, None)
                if name is None:
                    continue
                servicepacks = self.servicepacks or ['SP0']
                if module_re.match(name) or channeltype == 'local':
                    servicepacks = ['SP0']
                for servicepack in servicepacks:
                    if '@' in servicepack:
                        servicepack, snapshot = servicepack.split('@', 1)
                    else:
                        snapshot = ''
                    for arch in archs:
                        url = chan_url[channeltype].format(name=name,
                                                           arch=arch.replace('+', '_'),
                                                           snapshot=snapshot
                                                           )
                        url = service_pack(url, SP=servicepack, url=True)
                        sp_name = service_pack(os.path.split(name)[1], SP=servicepack)
                        sp_name = '%s-%s' % (sp_name, arch.replace('+', '_'))
                        if channel_schema == 'repomd':
                            newchannel = RepoData(sp_name, url, channeltype, logger=self.logger)
                        else:
                            newchannel = Channel(sp_name, url, channeltype, logger=self.logger)
                        newchannel.id = ID
                        newchannel.exclude = config.get('exclude', False)
                        newchannel.whitelist = config.get('whitelist', None)

                        # Multi arch channels have the multi archs in one string
                        # separated by a '+'. The packages in that channel need
                        # to be added to all arch channels.
                        if '+' in arch:
                            splitarchs = arch.split('+')
                        else:
                            splitarchs = [arch]
                        for arch in splitarchs:
                            self._channels_by_arch.setdefault(arch, []).append(newchannel)
                        channels.append(newchannel)

        return channels

    def filter_channels(self, channel_names):
        if channel_names:
            names = set(channel_names)
            return [c for c in self.channels if c.name in names]

        return self.channels

    def _load_config(self, configfile):
        def _load(path):
            self.logger.debug("Loading %s" % path)
            try:
                with open(path) as f:
                    self.config = yaml.load(f)
                self.logger.info("Loaded config file: %s" % path)
                return True
            except Exception as e:
                self.logger.debug("Failed to load %s. %s" % (path, e))
                return False

        # First try path exactly as specified
        if os.path.isfile(os.path.abspath(configfile)) and _load(configfile):
            return True

        # Search for config file in standard locations
        # Load first match and return
        filenames = [configfile, configfile + '.conf', configfile + '.config']
        for path in CONFIG_DIRS:
            for fname in filenames:
                file_path = os.path.abspath(os.path.join(path, fname))
                self.logger.debug("Checking for %s" % file_path)
                if not os.path.isfile(file_path):
                    continue
                if _load(file_path):
                    return True

        raise Exception('Failed to find and load config: %s' % configfile)

    def postoptparse(self):
        logging.basicConfig()
        self.logger = logging.getLogger(self.optparser.prog)
        if self.options.debug:
            self.logger.setLevel(logging.DEBUG)
            self.ex.debug = True
        elif self.options.verbose:
            self.logger.setLevel(logging.INFO)

        if self.options.config:
            self._load_config(self.options.config)

        if self.options.workpath:
            try:
                os.chdir(os.path.abspath(self.options.workpath))
            except Exception as e:
                raise Exception("Failed to switch to workpath %s: %s" % (self.options.workpath, e))
        self.ex.logger = self.logger

    def do_list(self, subcmd, opts, *channel_names):
        """${cmd_name}: list channels

        ${cmd_usage}
        ${cmd_option_list}
        """

        channels = self.filter_channels(channel_names)
        for c in sorted(channels):
            print("%s\n  %s\n  %s" % (c.name, c.url, c.channel_filename()))

    @cmdln.option("-f", "--force", action="store_true", help='force refresh of "product repos". By default only update repos are refreshed')
    def do_pickle(self, subcmd, opts, *channel_names):
        """${cmd_name}: generate pickle for channels

        ${cmd_usage}
        ${cmd_option_list}
        """

        channels = self.filter_channels(channel_names)
        self._pickle(sorted(channels), force=opts.force)

    def _pickle(self, channels_to_process, force=None):
        if not force:
            force = self.config.get('force', False)

        self.ex.set_blacklist(BLACKLIST)  # What is this for?
        self.ex.set_file_blacklist(self.config.get('file_blacklist', None))
        stats = {}

        def process(channel):
            self.logger.debug("processing %s", channel.name)
            fn = channel.pickle_filename()
            missing = []
            newpkgs = []
            newfiles = []
            olddata = None
            if os.path.exists(fn):
                if channel.type == 'product' and not force:
                    self.logger.info('Channel type is product - skipping refresh.')
                    return None
                with open(fn, 'rb') as f:
                    olddata = pickle.load(f)
            data = self.ex.readFileLists([channel])

            # FileListExtractor will return empty package list if reading
            # repodata fails. Handle zero lists as error.
            if len(data['pkgnames']) == 0:
                self.logger.warn('Channel has no packages! Skipping...')
                return None

            stats['pkgs_processed'] = len(data['pkgnames'])
            stats['total_files'] = len(data['filenames'])
            if olddata is not None:
                missing = olddata['pkgnames'] - data['pkgnames']
                newpkgs = data['pkgnames'] - olddata['pkgnames']
                newfiles = set(data['filenames'].keys()) - set(olddata['filenames'].keys())
                data = self.ex.merge(olddata, data)
            else:
                newpkgs = ['%s packages from %s' % (len(data['pkgnames']), channel.name)]

            if newfiles or newpkgs or missing:
                with open(fn, 'wb') as f:
                    pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
                with open(channel.dump_filename(), 'wb') as f:
                    f.write(pformat(data))

            return PickleResult(newpkgs=newpkgs, newfiles=newfiles, missing=missing)

        changed = set()
        for arch, channels in self.channels_by_arch.items():
            for channel in channels:
                if channel not in channels_to_process:
                    continue
                start = time.time()

                title = '{} {}'.format(channel.name, channel.type)
                dashes = '-' * ((70 - (len(title) + 2)) / 2)
                self.logger.info('%s %s %s' % (dashes, title, dashes))
                self.logger.info('repo: %s' % (channel.url))

                result = process(channel)
                if result is None:
                    self.logger.info('-' * 70)
                    continue

                self.logger.info('Processed %.2f seconds' % (time.time() - start))
                self.logger.info('Current memory usage: %.1fM' % memory_usage_psutil())
                self.logger.info('Scanned %s packages' % stats.get('pkgs_processed', 0))
                self.logger.info('Collected %s filenames' % stats.get('total_files', 0))

                if result.missing:
                    self.logger.warning("  vanished packages: %s" % ', '.join(result.missing))
                    changed.add(arch)
                if result.newpkgs:
                    self.logger.info("  new packages: %s" % ', '.join(result.newpkgs))
                    changed.add(arch)
                if result.newfiles:
                    self.logger.info("  %s new files added" % len(result.newfiles))
                    changed.add(arch)
                if not (result.newpkgs or result.newfiles or result.missing):
                    self.logger.info("  no change")
                self.logger.info('-' * 70)

        return changed

    @cmdln.option("-o", "--output", dest='filename', metavar='FILE', help="output to FILE")
    @cmdln.option("--shelve", action="store_true", help="save as shelve file")
    @cmdln.option("--by-arch", action="store_true", help="save one database per arch")
    def do_merge(self, subcmd, opts, *files):
        """${cmd_name}: merge .p files after call to pickle

        ${cmd_usage}
        ${cmd_option_list}
        """

        if not opts.filename:
            raise Exception('filename missing')

        return '\n'.join(self._merge(opts.filename, files, shelve=opts.shelve, by_arch=opts.by_arch))

    def _merge(self, filename, files, shelve=False, by_arch=False, archs=None):

        dbfiles = []

        if archs is None:
            archs = self.channels_by_arch.keys()

        # If by_arch is set, then we call ourselves for each arch, merging
        # only the channels for that arch in a separate db file set
        if by_arch:
            for arch, channels in self.channels_by_arch.items():
                if arch not in archs:
                    continue
                pfiles = [c.pickle_filename() for c in sorted(channels) if not c.exclude]
                dbfile = '-{}.'.format(arch).join(DBFILE.split('.'))
                self._merge(dbfile, pfiles, shelve)
                dbfiles.append(dbfile)

            return dbfiles

        dbfiles = [filename]
        if not files:
            files = [i.pickle_filename() for i in sorted(self.channels) if not i.exclude]

        res = {
            'filenames': dict(),
            'pkgnames': set(),
        }

        for fn in files:
            with open(fn, 'rb') as f:
                self.logger.debug("merging %s" % fn)
                data = pickle.load(f)
                res = self.ex.merge(res, data)

        filename = filename
        tmpfn = filename + '.new'
        if shelve:
            import shelve
            # For the shelve backend db, Python2 uses bsddb by default,
            # Python3 does not support bsddb. The gdbm format is endian
            # sensitive which causes issues when trying to use a db generated
            # on a machine with a different endianness as the system running
            # the rpmlint checks at build time.
            # So we explictly use dumbdbm for compatiblity with python 2 and
            # 3 (i.e. SLES 12 and SLES 15) as well as little and big
            # endian systems. Also ensure same protocol (2) is used.
            #
            # The dumbdbm module changed location with python3. We try
            # to load the python2 based module first and fall back to
            # python3 if that fails.
            try:
                mod = __import__("dumbdbm")
            except Exception as e:
                self.logger.debug(e)
                self.logger.debug('trying dbm.dumb')
                mod = __import__("dbm.dumb")
            d = shelve.Shelf(mod.open(tmpfn, 'n'), protocol=2)
            d.update(res)
            d.close()
        else:
            with open(tmpfn, 'wb') as f:
                pickle.dump(res, f, pickle.HIGHEST_PROTOCOL)

        # dbm dumb creates 3 files filename.dat and filename.dir filename.bak
        # We need the .dir and .dat files
        for ext in ['.dat', '.dir']:
            os.rename(tmpfn + ext, filename + ext)

        os.remove(tmpfn + '.bak')

        self.logger.debug('Current memory usage: %s' % memory_usage_psutil())

        return dbfiles

    @cmdln.option("-f", "--force", action="store_true", help="force something")
    @cmdln.option('-c', "--check", metavar="check", action="append", help="which file to check against the rest")
    def do_check_bsk(self, subcmd, opts, *channel_names):
        """${cmd_name}: check bsk consistency

        checks if files in BSK are actually just subpackages of stuff that
        already is in other products. Also prints duplicates and what would be
        left if the duplicates and supackages were removed.

        Only operates locally. Needs previous run of pickle

        ${cmd_usage}
        ${cmd_option_list}
        """
        pkgs = dict()       # srcpkg -> set(binpkgs)
        checkpkgs = dict()  # srcpkg -> set(binpkgs)

        b2chan = dict()     # binpkg -> set(channel)

        if not opts.check:
            opts.check = [c for c in self.channels if c.id == 'BSK']

        channels = self.filter_channels(channel_names)
        channels = [c for c in channels if c.name != 'sle_exceptions']

        def parse(dst, channels):
            for channel in channels:
                fn = channel.channel_filename()
                with open(fn, 'rb') as f:
                    root = ET.parse(f).getroot()
                    for binaries in root.findall('binaries'):
                        for node in binaries.findall('binary'):
                            name = node.attrib['name']
                            package = node.attrib['package']
                            if package.startswith('_product:'):
                                continue
                            if name.endswith('-debuginfo') or name.endswith('-debugsource'):
                                continue
                            dst.setdefault(package, set()).add(name)
                            b2chan.setdefault(name, set()).add(channel.name)

        parse(pkgs, channels)
        self._pickle(opts.check)
        parse(checkpkgs, opts.check)

        bsk_all = set(checkpkgs.keys())
        for p in sorted(checkpkgs.keys()):
            if p in pkgs:
                missing = checkpkgs[p] - pkgs[p]
                overlap = pkgs[p] & checkpkgs[p]
                if missing:
                    bsk_all.remove(p)
                    print("separate subpackage %s: %s" % (p, ', '.join(sorted(missing))))
                if overlap:
                    if p in bsk_all:
                        bsk_all.remove(p)
                    for b in overlap:
                        print("duplicate %s: %s" % (b, ', '.join(sorted(b2chan[b]))))

        for p in sorted(bsk_all):
            print("left %s: %s" % (p, ','.join(sorted(checkpkgs[p]))))

    def do_check_exported(self, subcmd, opts, *channel_names):
        """${cmd_name}: check OBS exported packages

        check if binary rpms are exported in obs

        ${cmd_usage}
        ${cmd_option_list}
        """

        exported = set()

        import osc.conf

        self.ex._init_osc()

        apiurl = osc.conf.config['apiurl']
        apipath = service_pack('build/openSUSE.org:SUSE:SLE-12:GA/standard/x86_64/_repository',
                               self.options.servicepack)
        u = osc.core.makeurl(apiurl, apipath.split('/'), ['view=binaryversions'])
        r = osc.core.http_GET(u)
        root = ET.parse(r).getroot()
        for node in root.findall('binary'):
            name = node.attrib['name']
            name = name[:-len('.rpm')]
            if name.endswith('-debuginfo') or name.endswith('-debugsource'):
                continue
            if name.endswith('-debuginfo-32bit') or name.endswith('-debugsource-32bit'):
                continue
            exported.add(name)

        channels = self.filter_channels(channel_names)

        b2chan = dict()  # binpkg -> set(channel)

        def parse(channels):
            blacklist = re.compile('(' + '|'.join(BLACKLIST + ['update-test-.*']) + ')')
            for channel in channels:
                fn = channel.channel_filename()
                with open(fn, 'rb') as f:
                    root = ET.parse(f).getroot()
                    for binaries in root.findall('binaries'):
                        for node in binaries.findall('binary'):
                            name = node.attrib['name']
                            package = node.attrib['package']
                            if package.startswith('_product:'):
                                continue
                            if name.endswith('-debuginfo') or name.endswith('-debugsource'):
                                continue
                            if name.endswith('-debuginfo-32bit') or name.endswith('-debugsource-32bit'):
                                continue
                            if blacklist.match(name):
                                continue
                            b2chan.setdefault(name, set()).add(channel.name)

        parse(channels)
        needed = set(b2chan.keys())

        for p in sorted(needed - exported):
            print("missing %s: %s" % (p, ', '.join(b2chan[p])))
        for p in sorted(exported - needed):
            print("extra %s" % p)

    def do_update_obs(self, subcmd, opts):
        """${cmd_name}: upload dbfile to rpmlint-backports-data package
        """
        self._upload(DBFILE)

    def _upload(self, dbfile):
        import osc.core
        if not os.path.exists(dbfile):
            raise Exception("Can't find %s. Run pickle and merge" % dbfile)
        self.logger.info("uploading")
        self.ex._init_osc()
        apiurl = self.config.get('rpmling_apiurl', 'https://api.opensuse.org')
        prj = self.config.get('rpmlint_prj', None)
        if prj is None:
            raise Exception("No project defined for obs update!")
        pkg = 'rpmlint-backports-data'
        u = osc.core.makeurl(apiurl, ['source', prj, pkg, dbfile], {})
        self.logger.info('Preparing to upload %s to %s' % (dbfile, u))
        if self.options.dry:
            self.logger.info('Dry run - upload aborted')
            return
        r = osc.core.http_PUT(u, file=dbfile)
        self.logger.info(r.read())

    @cmdln.option('-n', '--interval', metavar="minutes", type="int", help="periodic interval in minutes")
    @cmdln.option('--git-commit', action="store_true", help="commit changes to git")
    @cmdln.option("-f", "--force", action="store_true", help="force refresh of product repos (use for beta releases)")
    def do_run_bot(self, subcmd, opts):
        """${cmd_name}: update files, upload

        ${cmd_usage}
        ${cmd_option_list}
        """

        class ExTimeout(Exception):
            """raised on timeout"""

        if opts.interval:
            def alarm_called(nr, frame):
                raise ExTimeout()
            signal.signal(signal.SIGALRM, alarm_called)

        while True:
            self.reset()
            start = time.time()
            try:
                channels = sorted(self.channels)
                changed = self._pickle(channels, force=opts.force)
                if changed:
                    if opts.git_commit:
                        try:
                            subprocess.check_output(['git', 'init'])
                            for c in channels:
                                subprocess.check_output(['git', 'add', c.dump_filename()])
                            subprocess.check_output(['git', 'commit', '-m', 'update'])
                        except subprocess.CalledProcessError as e:
                            self.logger.warning("### git ERROR: %s" % e)
                    self.logger.info("merging")
                    dbfiles = self._merge(DBFILE, None, shelve=True, by_arch=True, archs=changed)
                    for dbfile in dbfiles:
                        for ext in ['dat', 'dir']:
                            fn = dbfile + '.' + ext
                            self._upload(fn)
            except Exception as e:
                self.logger.error("### ERROR: %s" % e)
                self.logger.info(traceback.format_exc())

            self.logger.info('====================================================================')
            self.logger.info('Completed in %s seconds' % (time.time() - start))
            self.logger.info('Current memory usage: %.1fM' % memory_usage_psutil())
            self.logger.info('====================================================================')

            if opts.interval:
                self.logger.info("sleeping %d minutes. Press enter to check now ..." % opts.interval)
                signal.alarm(opts.interval * 60)
                try:
                    input()
                except ExTimeout:
                    pass
                except EOFError:
                    # no tty available, disable alarm and sleep
                    signal.alarm(0)
                    time.sleep(opts.interval * 60)
                continue
            break


if __name__ == "__main__":
    app = Tool()
    try:
        sys.exit(app.main())
    except Exception as e:
        app.logger.error('%s. exiting...' % e, exc_info=True)
        sys.exit(1)

# vim: sw=4 et
