#!/usr/bin/env python3
#
# Copyright 2018-2021 Quobyte Inc. All rights reserved.
#
import argparse
import datetime
import getpass
import json
import os
import os.path
import stat
import subprocess
import sys

VERBOSE=False


class PrettyTable(object):
    """ Generalized output generator.

        Attributes:
            data(list): data to print
            columns(list): which columns to print, first used as key to sort
            header(list): names of the columns
            style(str): json, csv, or table
    """

    def __init__(self, data, columns,
                 header=None, style="table", units="metric", precision=2):
        self._header = header
        self._columns = columns
        self._style = style
        assert self._style in ["json", "table", "csv"]
        self._data = data
        self._format_string = None
        self._units = units
        self._precision = precision

    @staticmethod
    def _get_nested(dictionary, args):
        if not dictionary:
            return "-"
        if isinstance(args, str):
            return dictionary.get(args) or "-"
        if isinstance(args[0], int):
            if len(dictionary) < args[0] + 1:
                return "-"
            else:
                value = dictionary[args[0]]
        else:
            value = dictionary.get(args[0], None)
        if len(args) == 1:
            if value:
                return value
            else:
                return "-"
        else:
            return PrettyTable._get_nested(value, args[1:])

    @staticmethod
    def _set_nested(dictionary, args, value):
        if isinstance(args, str):
            dictionary[args] = value
            return
        if len(args) == 1:
            dictionary[args[0]] = value
        else:
            PrettyTable._set_nested(dictionary[args[0]], args[1:], value)

    def _process_data(self):
        for row in self._data:
            for keys in self._columns:
                col = self._get_nested(row, keys)
                if isinstance(col, list):
                    self._set_nested(row, keys, ';'.join(col))
                elif isinstance(col, int) and "_bytes" in ''.join(keys):
                    self._set_nested(
                        row, keys, human_readable_bytes(
                            col, self._units, self._precision))
                elif isinstance(col, int) and 'timestamp' in ''.join(keys):
                    self._set_nested(row, keys,
                        datetime.datetime.fromtimestamp(
                            float(col) / float(1000)).strftime('%Y-%m-%d %H:%M:%S'))

    def _get_format_string(self):
        if self._style == "table":
            self._format_string = ""
            for i, keys in enumerate(self._columns):
                list = [str(self._get_nested(row, keys)) for row in self._data]
                if self._header:
                    list.append(self._header[i])
                else:
                    list.append(self._columns[i])
                len_max = len(max(list, key=len)) + 1
                if len_max > 45:
                    len_max = 45
                self._format_string += '{%i:<%s} ' % (i, len_max)
        elif self._style == "csv":
            self._format_string = \
                ','.join(["{%i}" % i for i, key in enumerate(self._columns)])

    def rename_columns(self, header):
        self._header = header

    def modify_column(self, column_name, func):
        for row in self._data:
            row[column_name] = func(row[column_name])

    def append_column(self, key, name=None):
        self._columns.append(key)
        if self._header:
            self._header.append(name or key)

    def out(self):
        if self._style == "json":
            return self._data
        self._process_data()
        self._get_format_string()
        result = self._format_string.format(
            *(self._header or self._columns))
        for row in sorted(self._data, key=lambda v: v.get(self._columns[0])):
            result += "\n" + self._format_string.format(
                *[self._get_nested(row, keys) for keys in self._columns])
        return result


def human_readable_bytes(num, units, decimal_places=2):
    factor = 1000.0 if units == 'metric' else 1024.0
    if units == "bytes" or num < factor:
        return f'{num} bytes'

    for x in ['KB', 'MB', 'GB', 'TB', 'PB', 'EB']:
        num /= factor
        if num < factor:
            x = x if units == 'metric' else x.replace('B', 'iB')
            return f'{num:.{decimal_places}f} %s' % x

def get_xattr(path, name):
    getfattr_process = subprocess.Popen(
        ('getfattr', '--only-value', '--absolute-names', '--no-dereference', '-n', name, path),
        stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    output = getfattr_process.communicate()[0].decode("utf-8")
    if getfattr_process.returncode == 0:
        return output
    elif not os.path.exists(path):
        raise Exception(path + " does not exist")
    elif not os.access(path, os.R_OK):
        raise Exception("No permission to read " + path)
    else:
        raise Exception(
            "Cannot read xattr " + name + " for path " + path +
            ". Make sure this is a file or directory of a Quobyte volume.")

def set_xattr(path, name, value):
    setfattr_process = subprocess.Popen(
        ('setfattr', '--no-dereference', '-n', name, '-v', value, path),
        stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    output, stderr = setfattr_process.communicate()
    if setfattr_process.returncode == 0:
        return output.decode("utf-8")
    elif not os.path.exists(path):
        raise Exception(path + " does not exist")
    elif not os.access(path, os.W_OK):
        raise Exception("No permission to write " + path)
    raise Exception(stderr.decode("utf-8"))

def set_access_key(path, scope, key_id, secret):
    path = normalize_mount_point(path)
    if not secret:
        try:
            secret = getpass.getpass("Secret for access key %s: " % (key_id))
        except KeyboardInterrupt:
            print("interrupted, exiting.")
            sys.exit(1)
    value = ('{"access_key_id" : "%s", ' +
             '"access_key_secret" : "%s", ' +
             '"access_key_scope" : "%s"}') % (key_id, secret, scope)
    try:
        set_xattr(path, 'quobyte.access_key', value)
    except Exception as e:
        print("Could not set access key for client via path '" + path +
              "': " + str(e) + ". "
              "Make sure this is a Quobyte mount point.")
        return 1
    if scope == 'user':
        print("Registered access key", key_id, "for user", getpass.getuser())
    elif scope == 'client':
        print("Registered access key", key_id, "for all users")
    return 0


def show_quotas(path, units, precision):
    try:
        quotas_string = get_xattr(path, "quobyte.quotas")
        quotas = json.loads(quotas_string)
        if not quotas:
            print("No effective quotas")
        else:
            output = []
            for quota in quotas:
                row = {}
                row["consuming_entity"] = quota["entity_type"] + ": " + quota["consuming_entity"]
                row["type"] = quota["type"]
                if not (quota["type"] == "FILE_COUNT"):
                    row["limit"] = human_readable_bytes(quota["limit"], units, precision)
                    row["usage"] =\
                        str(human_readable_bytes(quota["usage"], units, precision)
                            + " (" + str(int(quota["usage"]) * 100 // int(quota["limit"]))) + "%)"
                else:
                    row["limit"] = quota["limit"]
                    row["usage"] = str(quota["usage"]) \
                        + " (" + str(int(quota["usage"]) * 100 // int(quota["limit"])) + "%)"
                output.append(row)
            print_quotas(output, units, precision)
    except Exception as e:
        print("Cannot read quotas: " + str(e))
        return 2
    return 0


def print_quotas(quotas, units, precision):
    table = PrettyTable(
        quotas,
        ["consuming_entity", "type", "limit", "usage"],
        ["Consuming Entity", "Type", "Limit", "Current Usage"],
        units=units, precision=precision)
    print(table.out())

def show_file_metadata(path):
    try:
        info_string = get_xattr(path, "quobyte.info")
        print(info_string)
    except Exception as e:
        print(str(e))
        return 2
    try:
        st_mode = os.lstat(path).st_mode
        if not stat.S_ISREG(st_mode):
            return 0
    except FileNotFoundError:
        return 0
    try:
        info_string = get_xattr(path, "quobyte.ec_state")
        print(info_string)
    except Exception as e:  # optional
        return 0
    return 0

def show_file_versions(path):
    try:
        info_string = get_xattr(path, "quobyte.versions")
        print(info_string)
    except Exception as e:
        print(str(e))
        return 2
    return 0

def show_file_locks(path):
    try:
        lock_string = get_xattr(path, "quobyte.locks")
        print(lock_string)
    except Exception as e:
        print(str(e))
        return 2

def show_whoami(path):
    try:
        whoami_string = get_xattr(path, "quobyte.whoami")
        whoami = json.loads(whoami_string)
        print("Effective credentials for Quobyte file system access control")
        print("Username:", whoami['username'])
        print("Groups:", ','.join(whoami['groups']))
        print("Member of tenants:", ','.join(whoami['tenants']))
        print("Privileged access:",
              "yes (like root)"
                  if whoami['privileged_access']
                  else "no")
        print("Access to foreign tenant:",
              "yes (user is not member of volume's tenant; special ACLs needed for access)"
                  if whoami['foreign_tenant_access']
                  else "no (normal POSIX access control rules apply)")
        print("Source:", whoami['authentication_source'])
    except Exception as e:
        print(str(e))
        return 2

def set_replication_factor(factor, path, recursive):
    if not recursive:
        try_set_file_replication_factor(factor, path, True)
        return
    try:
        entries = os.listdir(path)
        for entry in entries:
            full_path = os.path.join(path, entry)
            if os.path.islink(full_path):
                if VERBOSE:
                    print(full_path, "is a symlink")
            elif os.path.isdir(full_path):
                set_replication_factor(factor, full_path, recursive)
            else:
                try_set_file_replication_factor(factor, full_path, VERBOSE)
    except Exception as e:
        print("Cannot list entries of", path, str(e))

def try_set_file_replication_factor(factor, path, verbose):
    if os.path.islink(path):
        print(path, "is a symlink")
        return
    if os.path.isdir(path):
        print(path, "is a directory")
        return
    info_string = get_xattr(path, "quobyte.info")
    if ("coding_method: REED_SOLOMON" in info_string or
        "mixed" in info_string):
        if verbose:
            print("Nothing to do,", path, "is an erasure-coded file")
        return
    if "replication_factor: " + str(factor) in info_string:
        if verbose:
            print("Nothing to do,", path, "replication factor already correct")
        return
    assert not "coding_method:" in info_string or "coding_method: NONE" in info_string
    set_xattr(path, "quobyte.target_replication_factor", str(factor))
    if verbose:
        print(path, "set", path, "to replication factor", factor)

def normalize_mount_point(given_path):
    given_path = os.path.abspath(given_path)
    with open("/proc/mounts") as f:
        mount_point = ''
        mount_point_type = ''
        for line in f.readlines():
            if not line:
                continue
            path, mount_type = line.split()[1:3]
            if given_path.startswith(path) and len(path) > len(mount_point):
                mount_point = path
                mount_point_type = mount_type

    if mount_point and mount_point != given_path:
        print("Given path", given_path, "seems actually to be part of mount point", mount_point,
              ". Using", mount_point, "instead.")
        if mount_point_type.strip() != "fuse.quobyte":
            print(mount_point, "may not be a Quobyte mount point, found '" + mount_point_type + "'.")
        return mount_point
    else:
        return given_path

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--unit",
        help="Output bytes unconverted or in metric (1000) units (GB, TB) or IEC (1024)"
             " units (GiB, TiB). Default is metric.",
        choices=['metric', 'iec', 'bytes'],
        default="metric")
    parser.add_argument(
        "--precision",
        metavar="<number>",
        help="Number of decimal places to output converted bytes. Default is 2.",
        default=2)
    parser.add_argument(
        "-v", action="store_true",
        help="verbose")
    subparsers = parser.add_subparsers(
        dest='command',
        title='commands',
        help='additional help')

    quota_parser = subparsers.add_parser(
        "quota", help="show effective quotas")
    quota_parser.add_argument(
        "quota_mount_point",
        metavar="<mount point>",
        help="optional path to Quobyte mount point",
        nargs='?')

    info_parser = subparsers.add_parser(
        "info", help="show file info")
    info_parser.add_argument(
        "info_path",
        metavar="<path>",
        help="path to file or directory")

    version_parser = subparsers.add_parser(
        "version", help="show list of file versions")
    version_parser.add_argument(
        "version_path",
        metavar="<path>",
        help="path to file or directory")

    locks_parser = subparsers.add_parser(
        "locks", help="show current file locks")
    locks_parser.add_argument(
        "locks_path",
        metavar="<path>",
        help="path to file or directory")

    login_parser = subparsers.add_parser(
        "set-access-key",
        help="set access key")
    login_parser.add_argument(
        "key_id",
        metavar="<access key id>",
        help="Access key id")
    login_parser.add_argument(
        "mount_point",
        metavar="<mount point>",
        help="Quobyte mount point",
        default=".",
        nargs='?')
    login_parser.add_argument(
        "--secret",
        metavar="<access key secret>",
        help="Access key secret")
    login_parser.add_argument(
        "--scope",
        choices=["user", "client"],
        default="user",
        help="Access key secret")

    whoami_parser = subparsers.add_parser(
        "whoami",
        help=("show own user credentials as effective " +
            "for Quobyte file system access control"))
    whoami_parser.add_argument(
        "path",
        metavar="<path>", default="", nargs='?',
        help="path to directory")

    setreplfactor_parser = subparsers.add_parser(
        "set-replication-factor",
        help="set replication factor of replicated files")
    setreplfactor_parser.add_argument(
        "-r", action="store_true",
        help="recursive")
    setreplfactor_parser.add_argument(
        "-t", help="Target replication factor",
        choices=['1','3'], default='3')
    setreplfactor_parser.add_argument(
        "path",
        metavar="<path>", default="", nargs=1,
        help="path to file or directory")

    args = parser.parse_args()

    if args.v:
        global VERBOSE
        VERBOSE=True

    if args.command == 'quota':
        if args.quota_mount_point is None:
            return show_quotas(os.getcwd(), args.unit, args.precision)
        else:
            return show_quotas(args.quota_mount_point, args.unit, args.precision)
    elif args.command == 'info':
        return show_file_metadata(args.info_path)
    elif args.command == 'version':
        return show_file_versions(args.version_path)
    elif args.command == 'locks':
        return show_file_locks(args.locks_path)
    elif args.command == 'whoami':
        return show_whoami(args.path if args.path else os.getcwd())
    elif args.command == 'set-replication-factor':
        return set_replication_factor(args.t, args.path[0], args.r)
    elif args.command == 'set-access-key':
        return set_access_key(args.mount_point, args.scope, args.key_id, args.secret)
    return 2


if __name__ == "__main__":
    try:
        sys.exit(main())
    except Exception as e:
        if VERBOSE:
            raise
        else:
            print(str(e) + ", terminating")
