#!/usr/bin/python3
# -*- coding: utf-8 -*-

import argparse
import contextlib
import glob
import gzip
import os
import re
import sys
import tarfile
import zipfile

_INCLUDE_OPTION = '--include'
_EXCLUDE_OPTION = '--exclude'
_ENABLED = 'enabled'
_DISABLED = 'disabled'
_TAR = 'tar'
_TAR_BZ2 = 'bz2'
_TAR_GZ = 'gz'
_ZIP = 'zip'


def get_binary_buffer(fileobj):
    if hasattr(fileobj, 'buffer'):
        return getattr(fileobj, 'buffer')
    return fileobj


def lookup(path):
    if path is None:
        return None
    paths = glob.glob(path)
    if len(paths) == 1:
        return paths[0]
    if not paths:
        error_msg = 'No file found for {}, current directory: {}'.format(path, os.listdir())
    else:
        error_msg = 'More than one files found for {}: {}'.format(path, paths)
    raise ValueError(error_msg)


@contextlib.contextmanager
def open_or(path, mode, fallback, dirname=None):
    if path is not None:
        if dirname is not None:
            path = os.path.join(dirname, path)
        with open(path, mode) as file:
            yield file
    else:
        yield fallback()


@contextlib.contextmanager
def maybe_compress(fileobj, file_type):
    if file_type in (_TAR, _TAR_BZ2):
        yield fileobj
    elif file_type == _TAR_GZ:
        with gzip.GzipFile(fileobj=fileobj, mode='w', mtime=0) as gzip_file:
            yield gzip_file
    else:
        raise ValueError('invalid file type {}'.format(file_type))


def detect_file_type(path, file_type, for_read=True):
    if file_type is not None:
        return file_type
    if path is not None:
        if path.endswith('.zip'):
            if not for_read:
                raise ValueError('it is not yet supported to create zip archive')
            return _ZIP
        if path.endswith('.tar'):
            return _TAR
        if path.endswith('.tar.gz'):
            return _TAR if for_read else _TAR_GZ
        if path.endswith('.tar.bz2'):
            return _TAR if for_read else _TAR_BZ2
        raise ValueError('cannot detect file type, please specify --type')
    raise ValueError('--type must be specified when {}'.format(
        'reading from stdin' if for_read else 'writing to stdout'))


def open_archive(fileobj, file_type):
    if file_type == _TAR:
        return tarfile.open(mode='r', fileobj=fileobj)
    elif file_type == _ZIP:
        return zipfile.ZipFile(fileobj, mode='r')
    else:
        raise ValueError('invalid file type {}'.format(file_type))


def list_archive(infile, file_type):
    if file_type == _TAR:
        for tar_info in infile:
            yield tar_info, lambda info=tar_info: infile.extractfile(info)
    elif file_type == _ZIP:
        for zip_info in infile.infolist():
            tar_info = tarfile.TarInfo(name=zip_info.filename)
            tar_info.mode = zip_info.external_attr >> 16
            tar_info.size = zip_info.file_size
            yield tar_info, lambda name=tar_info.name: infile.open(name)
    else:
        raise ValueError('invalid file type {}'.format(file_type))


def output_tarfile_mode(file_type):
    if file_type in (_TAR, _TAR_GZ):
        return 'w|'
    if file_type == _TAR_BZ2:
        return 'w|bz2'
    raise ValueError('invalid file type {}'.format(file_type))


def parse_args():
    parser = argparse.ArgumentParser(
        description='Open Build Service source service "tar_utils" that manipulates tar files')
    parser.add_argument('--in', dest='input', type=str, default=None)
    parser.add_argument('--in_type', choices=[_TAR, _ZIP], default=None)
    parser.add_argument('--out', dest='output', type=str, default=None)
    parser.add_argument('--out_type', choices=[_TAR, _TAR_BZ2, _TAR_GZ], default=None)
    parser.add_argument('--outdir', type=str, default='.')
    parser.add_argument('--strip_prefix', type=str, default=None)
    parser.add_argument('--delete_input', choices=[_ENABLED, _DISABLED], default=_DISABLED)

    class _PatternAction(argparse.Action):

        def __call__(self, parser, namespace, values, option_string=None):
            patterns = getattr(namespace, self.dest)
            patterns.append((option_string, re.compile(values)))

    parser.add_argument(_INCLUDE_OPTION, dest='patterns', default=[], action=_PatternAction)
    parser.add_argument(_EXCLUDE_OPTION, dest='patterns', default=[], action=_PatternAction)
    return parser.parse_args()


def main():
    args = parse_args()

    input_path = lookup(args.input)
    input_file_type = detect_file_type(input_path, args.in_type)
    output_file_type = detect_file_type(args.output, args.out_type, for_read=False)
    with open_or(input_path, 'rb', lambda: get_binary_buffer(sys.stdin)) as in_file, \
        contextlib.closing(open_archive(in_file, input_file_type)) as input_archive, \
        open_or(args.output, 'wb', lambda: get_binary_buffer(sys.stdout), args.outdir) as out_file, \
        maybe_compress(out_file, output_file_type) as maybe_compressed_file, \
        contextlib.closing(tarfile.open(mode=output_tarfile_mode(output_file_type), fileobj=maybe_compressed_file)) as output_archive:
        strip_prefix = None
        if args.strip_prefix is not None:
            strip_prefix = re.compile('^' + args.strip_prefix)

        for tar_info, file in list_archive(input_archive, input_file_type):
            include = True
            for action, regex in args.patterns:
                match = regex.fullmatch(tar_info.name)
                include = match == (action == _INCLUDE_OPTION)
                if match:
                    break
            if include:
                fileobj = file()
                if strip_prefix is not None:
                    tar_info.name = strip_prefix.sub('', tar_info.name, count=1)
                    tar_info.pax_headers.pop('path', None)
                if tar_info.name:
                    output_archive.addfile(tar_info, fileobj=fileobj)
    if args.input is not None and args.delete_input == _ENABLED:
        os.remove(args.input)


if __name__ == '__main__':
    main()
