#!/usr/bin/python3

# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# in Version 2 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; If not, see <http://www.gnu.org/licenses/>.

"""CLI too for disk partitioning.

"setup-disk" is a small CLI tool for disk partitioning, with a backend
based on libstorage-ng.

To partition a single device we can do:
  setup-disk -d /dev/sda

Internally "setup-disk" have a default layout for a MicroOS
installation, where three partitions are created to allocate the EFI
payload (installed by Grub in a different stage), sysroot as a BtrFS
with different subvolumes, and a "/var" partition with a "nodatacow"
mount option.

This default layout is parameterized for a "DEVICE1" name, that is
provided by the "--device" or "-d" command line parameter.


User defined layouts
--------------------

The user can provide their own partition layouts in YAML format, using
the parameter "--layout":
  setup_disk -l my_layout.yml

Internally the layout can have some Python template variables, with
the format "${DEVICEx}" for devices and "${DEVICEx_Py}" for
partitions.  The user can indicate multiple devices via CLI repeating
the "--device" parameter in order.

For example:
  setup_disk -d /dev/sda -d /dev/sdb -l my_layout.yml

This will replace "${DEVICE1}" with "/dev/sda" and "${DEVICE2}" with
"/dev/sdb", the different partitions are later calculated and replaced
in the template.  So "${DEVICE1_P1}" will be replaced by "/dev/sda1".


Layout reference
----------------

"partitions" section
====================

"setup-disk" separate partitioning the devices from providing a file
system, creating volumes or building arrays of disks (WIP). The
advantage of this is that this, usually, compose better that other
approaches, and makes more easy adding more options that needs to work
correctly with the rest of the system.

* "config": Dictionary. Optional.

  Subsection that store some configuration options related with the
  partitioner.

  * "label": String. Optional. Default: "gpt"

    Default label for the partitions of the devices. The possible
    values can be "gpt", "msdos" or "dos". For UEFI systems, we need
    to set it to "gpt". This value will be used for all the devices if
    is not overwritten.

* `devices`: Dictionary.

  List of devices that will be partitioned. We can indicate already
  present devices, like "/dev/sda" or "/dev/hda", or we can use Python
  template variables like "${DEVICE1}". We can use any valid device
  name in Linux such as all the `/dev/disk/by-id/...`,
  `/dev/disk/by-label/...`, `/dev/disk/by-uuid/...` and others.

  For each device we have:

  * "label": String. Optional. Default: "gpt"

    Partition label for the device. The meaning and the possible
    values are identical for "label" in the "config" section.

  * "partitions": Array. Optional.

    Partitions inside a device are described with an array. Each
    element of the array is a dictionary that describe a single
    partition.

    * "number": Integer. Optional. Default: "index"

      Expected partition number. Eventually this parameter will be
      really optional, when the partitioner can deduce it from other
      parameters. Today is better to be explicit in the partition
      number, as this will guarantee that the partition is found in
      the hard disk if present. If is not set, number will be the
      current index position in the array.

    * "size": Float or String.

      Size of the partition expressed in IEC units. All the units
      needs to match for partitions on the same device.

      The last partition can use the string "rest" to indicate that
      this partition will use all the free space available.

    * "type": String.

      A string that indicate for what this partition will be
      used. "setup-disk" recognize several types:

      * "swap": This partition will be used for SWAP.
      * "linux": Partition used to root, home or any data.
      * "boot": Small partition used for GRUB when in BIOS and "gpt".
      * "efi": EFI partition used by GRUB when UEFI.
      * "lvm": Partition used to build an LVM physical volume.
      * "raid": Partition that will be a component of an array.

Example:

---
partitions:
  config:
    label: gpt
  devices:
    /dev/sda:
      partitions:
        - number: 1
          size: 512MiB
          type: efi
        - number: 2
          size: 8GiB
          type: linux
        - number: 3
          size: rest
          type: linux


"filesystems" section
=====================

The partitions created in previous sections usually requires a file
system. This section will simply list the device name (or a Python
template variable) and the file system (and properties) that will be
applied to it.

* "filesystem". String.

  File system to apply in the device. Valid values are "reiserfs",
  "ext{2,3,4}", "btrfs", "vfat", "xfs", "jfs", "ntfs", "swap", "nfs",
  "tmpfs", "iso9660", "udf", "f2fs" "exfat" or "bitlocker".

* "mountpoint". String.

  Mount point where the device will be registered in "fstab".

* "mount_options". Array. Optional.

  Array of mount options for the partition.

* "subvolumes". Dictionary.

  For "btrfs" file systems we can specify more details.

  * "prefix". String. Optional. Default: "@"

    "btrfs" sub-volume name where the rest of the sub-volumes will be
    under. For example, if we set "prefix" as "@" and we create a
    sub-volume named "var", "setup-disk" will create it as "@/var".

  * "subvolume". Dictionary.

    * "path". String.

      Path name for the sub-volume.

    * "copy_on_write". Boolean. Optional. Default: "yes"

      Value for the copy-on-write option in "btrfs".

    * "mount_options". Array. Optional.

      Array of mount options for the partition.

Example:

---
filesystems:
  /dev/sda1:
    filesystem: vfat
    mountpoint: /boot/efi
  /dev/sda2:
    filesystem: btrfs
    mountpoint: /
    subvolumes:
      prefix: '@'
      subvolume:
        - path: root
          mount_options: ['x-initrd.mount']
        - path: var
          copy_on_write: no
          mount_options: ['x-initrd.mount']
        - path: srv
        - path: opt
        - path: home
        - path: usr
          mount_options: ['x-initrd.mount']
        - path: boot/grub2/i386-pc
        - path: boot/grub2/x86_64-efi

"""


import argparse
import re
import string
import sys

import storage
import yaml


DEFAULT_LAYOUT = """
---
partitions:
  config:
    label: gpt
  devices:
    ${DEVICE1}:
      partitions:
        - number: 1
          size: 512MiB
          type: efi
        - number: 2
          size: 1GiB
          type: linux
        - number: 3
          size: 6GiB
          type: linux
        - number: 4
          size: rest
          type: linux

filesystems:
  ${DEVICE1_P1}:
    filesystem: vfat
    mountpoint: /boot/efi
  ${DEVICE1_P2}:
    filesystem: btrfs
    mountpoint: /
    subvolumes:
      prefix: '@'
      subvolume:
        - path: root
          mount_options: ['x-initrd.mount']
        - path: srv
        - path: opt
        - path: home
        - path: boot/grub2/i386-pc
        - path: boot/grub2/x86_64-efi
        - path: usr/local
  ${DEVICE1_P3}:
    filesystem: btrfs
    mountpoint: /usr
    mount_options: ['ro','x-initrd.mount']
  ${DEVICE1_P4}:
    filesystem: btrfs
    mountpoint: /var
    mount_options: ['nodatacow', 'x-initrd.mount']
"""

def _log_storage_milestone(message):
    storage.get_logger().write(storage.LogLevel_MILESTONE,
                               "tiu", __file__, sys._getframe(1).f_lineno,
                               "", message)

def _log_storage_error(error_msg):
    storage.get_logger().write(storage.LogLevel_ERROR,
                               "tiu", __file__, sys._getframe(1).f_lineno,
                               "", error_msg)

def _raise_exception(error_msg):
    _log_storage_error(error_msg)
    raise Exception(error_msg)

class MyCommitCallbacks(storage.CommitCallbacks):
    def message(self, message):
        _log_storage_milestone(message)
        print(f"message '{message}'")

    def error(self, message, what):
        _log_storage_error(f"'{message}' '{what}'")
        print(f"error '{message}' '{what}'")
        return False

def _resolve_devices(devices, layout):
    """Resolve device and partition names in a template."""

    # Create an environment where to find the devices and deduce
    # partition names
    environment = storage.Environment(True)
    strge = storage.Storage(environment)
    strge.probe()
    probed = strge.get_probed()

    mapping = {}
    for index, device in enumerate(devices, start=1):
        partitionable = storage.Partitionable.find_by_name(probed, device)
        mapping[f"DEVICE{index}"] = device
        # TODO - Better range criteria!
        mapping.update(
            {
                f"DEVICE{index}_P{i}": partitionable.partition_name(i)
                for i in range(1, 16)
            }
        )

    try:
        layout = string.Template(layout).substitute(mapping)
    except KeyError as e:
        _raise_exception(f"Device {e} cannot be mapped in the layout")

    return layout


def _size(size):
    """Parse a string that represent a size"""
    units = {
        "B": storage.B,
        "KiB": storage.KiB,
        "MiB": storage.MiB,
        "GiB": storage.GiB,
        "TiB": storage.TiB,
        "PiB": storage.PiB,
        "EiB": storage.EiB,
    }

    if size == "rest":
        return 0

    match = re.match(r"(\d+)\s*(\w+)", size)
    if not match:
        _raise_exception(f"Size {size} not recognized")

    size, unit = match.groups()
    unit = "B" if not unit else unit
    if unit not in units:
        _raise_exception(f"Unit {unit} not recognized")

    return int(size) * units[unit]


def _prepare_disk(devicegraph, name):
    """Search for a disk in a device map and clean the partitions."""
    # TODO check for thing that cannot be deleted, e.g. multipath
    disk = storage.Disk.find_by_name(devicegraph, name)
    if not disk:
        _raise_exception(f"Device {name} not found!")

    disk.remove_descendants()
    return disk


def _partition_table(disk, label):
    """Generate a partition table"""
    labels = {
        "msdos": (storage.PtType_MSDOS, storage.to_msdos),
        "dos": (storage.PtType_MSDOS, storage.to_msdos),
        "gpt": (storage.PtType_GPT, storage.to_gpt),
    }

    if label not in labels:
        _raise_exception(f"Partition label {label} not recognized")

    part_type, convert = labels[label]
    part_table = convert(disk.create_partition_table(part_type))
    return part_table


def _create_partition(table, size):
    """Create a partition in a partition table."""
    for slot in table.get_unused_partition_slots():
        if not slot.primary_possible:
            continue

        region = slot.region

        if region.to_bytes(region.get_length()) < size:
            continue

        # If size if zero, we allocate the full region
        if size:
            region.set_length(int(size / region.get_block_size()))
        region = table.align(region)
        partition = table.create_partition(
            slot.name, region, storage.PartitionType_PRIMARY
        )
        return partition

    _raise_exception("No suitable partition slot found!")


def _partition(disk, label, partitions):
    """Partition a disk following a layout."""
    ids = {
        "swap": [storage.ID_SWAP],
        "linux": [storage.ID_LINUX],
        "lvm": [storage.ID_LVM],
        "raid": [storage.ID_RAID],
        "efi": [storage.ID_BIOS_BOOT, storage.ID_ESP],
        "boot": [storage.ID_BIOS_BOOT],
    }

    table = _partition_table(disk, label)

    for index, partition in enumerate(partitions, start=1):
        number = partition.get("number", index)
        if index != number:
            _raise_exception(f"Expected partition number {number} instead of {index}")

        if "size" not in partition:
            _raise_exception(f"Partition {number} has missing 'size'")
        part = _create_partition(table, _size(partition["size"]))

        type_ = partition.get("type", "linux")
        if type_ not in ids:
            _raise_exception(f"Type {type_} not valid for partition {number}")

        for id_ in ids[type_]:
            part.set_id(id_)


def _extend_mount_options(mount_point, mount_options):
    """Extend the mount options of a mount point."""
    current_mount_options = mount_point.get_mount_options()
    # current_mount_options is a VectorString, so cannot be directly
    # extended
    for mount_option in mount_options:
        current_mount_options.append(mount_option)
    mount_point.set_mount_options(current_mount_options)


def _filesystem(partition, filesystem):
    """Apply a filesystem in a partition."""
    filesystems = {
        "reiserfs": (storage.FsType_REISERFS, storage.to_reiserfs),
        "ext2": (storage.FsType_EXT2, storage.to_ext2),
        "ext3": (storage.FsType_EXT3, storage.to_ext3),
        "ext4": (storage.FsType_EXT4, storage.to_ext4),
        "btrfs": (storage.FsType_BTRFS, storage.to_btrfs),
        "vfat": (storage.FsType_VFAT, storage.to_vfat),
        "xfs": (storage.FsType_XFS, storage.to_xfs),
        "jfs": (storage.FsType_JFS, storage.to_jfs),
        # "hfs": (storage.FsType_HFS, storage.to_hfs),
        "ntfs": (storage.FsType_NTFS, storage.to_ntfs),
        "swap": (storage.FsType_SWAP, storage.to_swap),
        # "hfsplus": (storage.FsType_HFSPLUS, storage.to_hfsplus),
        "nfs": (storage.FsType_NFS, storage.to_nfs),
        # "nfs4": (storage.FsType_NFS4, storage.to_nfs4),
        "tmpfs": (storage.FsType_TMPFS, storage.to_tmpfs),
        "iso9660": (storage.FsType_ISO9660, storage.to_iso9660),
        "udf": (storage.FsType_UDF, storage.to_udf),
        # "nilfs2": (storage.FsType_NILFS2, storage.to_nilfs2),
        # "minix": (storage.FsType_MINIX, storage.to_minix),
        # "ntfs3g": (storage.FsType_NTFS3G, storage.to_ntfs3g),
        "f2fs": (storage.FsType_F2FS, storage.to_f2fs),
        "exfat": (storage.FsType_EXFAT, storage.to_exfat),
        "bitlocker": (storage.FsType_BITLOCKER, storage.to_bitlocker),
        # "vboxsf": (storage.FsType_VBOXSF, storage.to_vboxsf),
    }

    fs = filesystem.get("filesystem", "btrfs")
    if fs not in filesystems:
        _raise_exception(f"Filesystem {fs} not recognized")

    fs_type, convert = filesystems[fs]
    fs_ = convert(partition.create_filesystem(fs_type))

    if "mountpoint" in filesystem:
        mount_point = fs_.create_mount_point(filesystem["mountpoint"])
        _extend_mount_options(mount_point, filesystem.get("mount_options", []))

    if fs == "btrfs" and "subvolumes" in filesystem:
        subvolumes = filesystem["subvolumes"]
        top_level = fs_.get_top_level_btrfs_subvolume()
        at_subvolume = top_level.create_btrfs_subvolume(subvolumes.get("prefix", "@"))

        # The following code relies on the snapper installation helper
        # to create further subvolumes, set the default subvolume and
        # other stuff. Setup is confusing. YaST does it this way.
        fs_.set_default_btrfs_subvolume(at_subvolume)
        fs_.set_configure_snapper(True)

        for subvolume in subvolumes.get("subvolume", []):
            subv = at_subvolume.create_btrfs_subvolume("@/" + subvolume["path"])
            mount_point = subv.create_mount_point("/" + subvolume["path"])
            _extend_mount_options(mount_point, subvolume.get("mount_options", []))
            subv.set_nocow(not subvolume.get("copy_on_write", True))


def setup_disks(layout):
    strge = storage.Storage(storage.Environment(False))
    strge.set_rootprefix("/mnt")
    strge.probe()

    staging = strge.get_staging()

    # Default label, in case that there is not a per-device
    # redefinition
    label_default = layout.get("config", {}).get("label", "gpt")

    try:
        devices = layout["partitions"]["devices"]
    except KeyError:
        _raise_exception("Section 'devices' not found in the layout")

    if not devices:
        _raise_exception("Section 'devices' is empty")

    for device, partitions in devices.items():
        label = partitions.get("label", label_default)
        partitions = partitions["partitions"]

        disk = _prepare_disk(staging, device)
        _partition(disk, label, partitions)

    try:
        filesystems = layout["filesystems"]
    except KeyError:
        _raise_exception("Section 'filesystems' not found in the layout")

    if not filesystems:
        _raise_exception("Section 'filesystems' is empty")

    for partition, filesystem in filesystems.items():
        part = storage.Partition.find_by_name(staging, partition)
        _filesystem(part, filesystem)

    commit_options = storage.CommitOptions(True)
    my_commit_callbacks = MyCommitCallbacks()

    try:
        strge.calculate_actiongraph()
        strge.commit(commit_options, my_commit_callbacks)
    except Exception as exception:
        print(exception.what())
        sys.exit(1)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="TIU disk partitioner.")
    parser.add_argument(
        "-d",
        "--device",
        action="append",
        default=[],
        help="device to partition, repeat for each device",
        metavar="DEVICE",
    )
    parser.add_argument(
        "-o",
        "--output",
        default="/var/log/tiu-setup-disk.log",
        help="name of log file",
        metavar="FILENAME",
    )
    parser.add_argument(
        "-l",
        "--layout",
        help="YAML document with partitioning layout",
        metavar="YAML",
    )
    args = parser.parse_args()

    storage.set_logger(storage.get_logfile_logger(args.output))

    if args.layout:
        with open(args.layout, "r") as f:
            layout = f.read()
    else:
        layout = DEFAULT_LAYOUT

    try:
        layout = _resolve_devices(args.device, layout)
        setup_disks(yaml.safe_load(layout))
    except Exception as e:
        if hasattr(e, "what"):
            print(f"ERRROR: {e.what()}")
        else:
            print(f"ERROR: {e}")
        sys.exit(1)
