#!/usr/bin/env python3
#
# Copyright 2018-2025 Quobyte Inc. All rights reserved.
#
import argparse
import datetime
import getpass
import json
import os
import os.path
import stat
import sys

VERBOSE=False


class PrettyTable:
    """ Generalized output generator.

        Attributes:
            data(list): data to print
            columns(list): which columns to print, first used as key to sort
            header(list): names of the columns
            style(str): json, csv, or table
    """

    def __init__(self, data, columns,
                 header=None, style="table", units="metric", precision=2):
        self._header = header
        self._columns = columns
        self._style = style
        assert self._style in ["json", "table", "csv"]
        self._data = data
        self._format_string = None
        self._units = units
        self._precision = precision

    @staticmethod
    def _get_nested(dictionary, args):
        if not dictionary:
            return "-"
        if isinstance(args, str):
            return dictionary.get(args) or "-"
        if isinstance(args[0], int):
            if len(dictionary) < args[0] + 1:
                return "-"
            else:
                value = dictionary[args[0]]
        else:
            value = dictionary.get(args[0], None)
        if len(args) == 1:
            if value:
                return value
            else:
                return "-"
        else:
            return PrettyTable._get_nested(value, args[1:])

    @staticmethod
    def _set_nested(dictionary, args, value):
        if isinstance(args, str):
            dictionary[args] = value
            return
        if len(args) == 1:
            dictionary[args[0]] = value
        else:
            PrettyTable._set_nested(dictionary[args[0]], args[1:], value)

    def _process_data(self):
        for row in self._data:
            for keys in self._columns:
                col = self._get_nested(row, keys)
                if isinstance(col, list):
                    self._set_nested(row, keys, ';'.join(col))
                elif isinstance(col, int) and "_bytes" in ''.join(keys):
                    self._set_nested(
                        row, keys, human_readable_bytes(
                            col, self._units, self._precision))
                elif isinstance(col, int) and 'timestamp' in ''.join(keys):
                    self._set_nested(row, keys,
                        datetime.datetime.fromtimestamp(
                            float(col) / float(1000)).strftime('%Y-%m-%d %H:%M:%S'))

    def _get_format_string(self):
        if self._style == "table":
            self._format_string = ""
            for i, keys in enumerate(self._columns):
                target = [str(self._get_nested(row, keys)) for row in self._data]
                if self._header:
                    target.append(self._header[i])
                else:
                    target.append(self._columns[i])
                len_max = len(max(target, key=len)) + 1
                len_max = min(len_max, 45)
                self._format_string += '{%i:<%s} ' % (i, len_max)
        elif self._style == "csv":
            self._format_string = \
                ','.join(["{%i}" % i for i, key in enumerate(self._columns)])

    def rename_columns(self, header):
        self._header = header

    def modify_column(self, column_name, func):
        for row in self._data:
            row[column_name] = func(row[column_name])

    def append_column(self, key, name=None):
        self._columns.append(key)
        if self._header:
            self._header.append(name or key)

    def out(self):
        if self._style == "json":
            return self._data
        self._process_data()
        self._get_format_string()
        result = self._format_string.format(
            *(self._header or self._columns))
        for row in sorted(self._data, key=lambda v: v.get(self._columns[0])):
            result += "\n" + self._format_string.format(
                *[self._get_nested(row, keys) for keys in self._columns])
        return result


def human_readable_bytes(num, units, decimal_places=2):
    factor = 1000.0 if units == 'metric' else 1024.0
    if units == "bytes" or num < factor:
        return f'{num} bytes'

    for x in ['KB', 'MB', 'GB', 'TB', 'PB', 'EB']:
        num /= factor
        if num < factor:
            x = x if units == 'metric' else x.replace('B', 'iB')
            return f'{num:.{decimal_places}f} %s' % x

def set_access_key(path, scope, key_id, secret):
    path = normalize_mount_point(path)
    if not secret:
        try:
            secret = getpass.getpass(f"Secret for access key {key_id}: ")
        except KeyboardInterrupt:
            print("interrupted, exiting.")
            sys.exit(1)
    value = {
        "access_key_id": key_id,
        "access_key_secret": secret,
        "access_key_scope": scope
    }
    try:
        value = bytes(json.dumps(value), encoding='utf8')
        os.setxattr(path, 'quobyte.access_key', value, follow_symlinks=False)
    except OSError as e:
        print("Could not set access key for client via path '" + path +
              "': " + str(e) + ". "
              "Make sure this is a Quobyte mount point.")
        return 1
    if scope == 'user':
        print("Registered access key", key_id, "for user", getpass.getuser())
    elif scope == 'client':
        print("Registered access key", key_id, "for all users")
    return 0

def manage_retention(paths, recursive, timestamp, set_delete_after, set_retain_until, remove):
    for path in paths:
        update_retention(path, timestamp, set_delete_after, set_retain_until, remove)
        if recursive and os.path.isdir(path):
            entries = [os.path.join(path, name) for name in os.listdir(path)]
            manage_retention(
                entries, recursive,
                timestamp, set_delete_after, set_retain_until, remove)
    return 0

def update_retention(path, timestamp, set_delete_after, set_retain_until, remove):
    if remove:
        os.removexattr(path, 'quobyte.retention_lock')
        if VERBOSE:
            print(path)
    elif set_delete_after or set_retain_until:
        if not timestamp:
            raise ValueError(
                'Please specify time with --timestamp=<ISO8601 timestamp>, '
                'for example: --timestamp=$(date -Iseconds -d "2 days").')
        lock = {
            'retention_time' : timestamp,
            'may_modify_retention_lock' : 'true'
        }
        if set_retain_until:
            lock['retain_until_retention_time'] = 'true'
            lock['immutable'] = 'true'
        if set_delete_after:
            lock['delete_after_retention_time'] = 'true'

        value = bytes(json.dumps(lock), encoding='utf8')
        try:
            os.setxattr(path, 'quobyte.retention_lock', value, follow_symlinks=False)
        except OSError as e:
            if e.errno == 22:
                raise ValueError(
                    'Please specify timestamp in IS8601, for example:'
                    ' --timestamp=$(date -Iseconds -d "2 days")') from e
            raise
        if VERBOSE:
            print(path)
    elif timestamp:
        try:
            os.setxattr(path, 'quobyte.retention_timestamp', timestamp.encode(), follow_symlinks=False)
        except OSError as e:
            if e.errno == 22:
                raise ValueError(
                    'Please specify timestamp in IS8601, for example:'
                    ' --timestamp=$(date -Iseconds -d "2 days")') from e
            raise
        if VERBOSE:
            print(path)
    else:
        try:
            value = os.getxattr(path, 'quobyte.retention_lock', follow_symlinks=False)
        except OSError as e:
            if e.errno == 61:
                print(path, "no retention lock set")
                return
            if e.errno == 13:
                print(path, "permission denied")
                return
            raise e
        lock = json.loads(value)

        if not lock.get('retention_time', ''):
            print("No retention time set yet", lock)
            return 1

        time = datetime.datetime.strptime(lock['retention_time'], "%Y-%m-%dT%H:%M:%SZ")
        time = time.replace(tzinfo=datetime.timezone.utc, microsecond=0)
        delta = time - datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)
        has_expired = delta.total_seconds() <= 0
        if lock.get('retain_until_retention_time', False):
            if has_expired:
                print("Target has an expired retention lock and is no longer immutable")
            else:
                print("Target is immutable until",
                        time.isoformat(timespec='minutes'), "UTC ie. in", delta)
        if lock.get('delete_after_retention_time', False):
            if has_expired:
                print("Target will be deleted by the next retention task")
            else:
                print("Target will be deleted some time after",
                        time.isoformat(timespec='minutes'), "UTC ie. in", delta)
        if lock.get('may_modify_retention_lock', False):
            print("Retention lock can be modified or removed")
        else:
            if lock.get('may_shorten_retention_time', False):
                print("Retention time can be shortened")
            if lock.get('may_extend_retention_time', False):
                print("Retention time can be extended")
            print("Retention lock from policy can not be modfified otherwise")
    return 0

def show_quotas(path, units, precision):
    quotas_string = os.getxattr(path, "quobyte.quotas", follow_symlinks=False)
    quotas = json.loads(quotas_string)
    if not quotas:
        print("No effective quotas")
    else:
        output = []
        for quota in quotas:
            row = {}
            row["consuming_entity"] = quota["entity_type"] + ": " + quota["consuming_entity"]
            row["type"] = quota["type"]
            if quota["type"] != "FILE_COUNT":
                row["limit"] = human_readable_bytes(quota["limit"], units, precision)
                row["usage"] =\
                    str(human_readable_bytes(quota["usage"], units, precision)
                        + " (" + str(int(quota["usage"]) * 100 // int(quota["limit"]))) + "%)"
            else:
                row["limit"] = quota["limit"]
                row["usage"] = str(quota["usage"]) \
                    + " (" + str(int(quota["usage"]) * 100 // int(quota["limit"])) + "%)"
            output.append(row)
        print_quotas(output, units, precision)
    return 0


def print_quotas(quotas, units, precision):
    table = PrettyTable(
        quotas,
        ["consuming_entity", "type", "limit", "usage"],
        ["Consuming Entity", "Type", "Limit", "Current Usage"],
        units=units, precision=precision)
    print(table.out())

def show_file_metadata(path):
    try:
        info_string = os.getxattr(path, "quobyte.info", follow_symlinks=False)
        print(info_string.decode())
    except Exception as e:
        print(str(e))
        return 2
    try:
        st_mode = os.lstat(path).st_mode
        if not stat.S_ISREG(st_mode):
            return 0
    except FileNotFoundError:
        return 0
    try:
        info_string = os.getxattr(path, "quobyte.ec_state", follow_symlinks=False)
        print(info_string.decode())
    except Exception:  # optional
        return 0
    return 0

def show_file_versions(path):
    info_string = os.getxattr(path, "quobyte.versions", follow_symlinks=False)
    print(info_string.decode())
    return 0

def show_file_locks(path):
    value = os.getxattr(path, "quobyte.locks", follow_symlinks=False).decode()
    if value.startswith('{'):
        value = value.strip()[1:-1]  # remove legacy {}
    locks = json.loads(value)
    if not locks:
        print("No locks on", path)
    else:
        print(locks)
    return 0

def show_whoami(path):
    whoami_string = os.getxattr(path, "quobyte.whoami", follow_symlinks=False)
    whoami = json.loads(whoami_string)
    print("Effective credentials for Quobyte file system access control")
    print("  Username:", whoami['username'])
    print("  Groups:", ','.join(whoami['groups']))
    print("  Member of tenants:", ','.join(whoami['tenants']))
    print("  Privileged access:",
            "yes (like root)"
                if whoami['privileged_access']
                else "no")
    print("  Access to foreign tenant:",
            "yes (user is not member of volume's tenant; special ACLs needed for access)"
                if whoami.get('foreign_tenant_access', False)
                else "no (normal POSIX access control rules apply)")
    print("  Authenticated by:", whoami['authentication_source'])
    return 0

def set_replication_factor(paths, recursive, factor):
    for path in paths:
        try_set_file_replication_factor(path, VERBOSE, factor)
        if recursive and os.path.isdir(path):
            entries = [os.path.join(path, name) for name in os.listdir(path)]
            set_replication_factor(entries, recursive, factor)
    return 0

def try_set_file_replication_factor(path, verbose, factor):
    if os.path.islink(path):
        print(path, "is a symlink")
        return
    if os.path.isdir(path):
        print(path, "is a directory")
        return
    try:
        info_string = os.getxattr(path, "quobyte.info", follow_symlinks=False).decode()
    except OSError as e:
        if e.errno == 13:
            print(path, "permission denied")
            return
        raise
    if ("coding_method: REED_SOLOMON" in info_string or
        "mixed" in info_string):
        if verbose:
            print(path, "is an erasure-coded file")
        return
    if "replication_factor: " + str(factor) in info_string:
        if verbose:
            print(path, "replication factor is already", factor)
        return
    assert not "coding_method:" in info_string or "coding_method: NONE" in info_string
    os.setxattr(path, "quobyte.target_replication_factor", str(factor), follow_symlinks=False)
    if verbose:
        print(path, "set", path, "to replication factor", factor)

def normalize_mount_point(given_path):
    given_path = os.path.abspath(given_path)
    with open("/proc/mounts", "r", encoding="utf8") as f:
        mount_point = ''
        mount_point_type = ''
        for line in f.readlines():
            if not line:
                continue
            path, mount_type = line.split()[1:3]
            if given_path.startswith(path) and len(path) > len(mount_point):
                mount_point = path
                mount_point_type = mount_type

    if mount_point and mount_point != given_path:
        print("Given path", given_path, "seems actually to be part of mount point", mount_point,
              ". Using", mount_point, "instead.")
        if mount_point_type.strip() != "fuse.quobyte":
            print(mount_point, f"may not be a Quobyte mount point, found '{mount_point_type}'.")
        return mount_point
    else:
        return given_path

def main():
    if sys.version_info < (3,3):
        print("Python version too old, should be 3.3 or newer. Some commands may cause errors.")

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--unit",
        help="Output bytes unconverted or in metric (1000) units (GB, TB) or IEC (1024)"
             " units (GiB, TiB). Default is metric.",
        choices=['metric', 'iec', 'bytes'],
        default="metric")
    parser.add_argument(
        "--precision",
        metavar="<number>",
        help="Number of decimal places to output converted bytes. Default is 2.",
        default=2)
    parser.add_argument(
        "-v", action="store_true",
        help="verbose")
    subparsers = parser.add_subparsers(
        dest='command',
        title='commands',
        help='additional help')

    quota_parser = subparsers.add_parser(
        "quota", help="show effective quotas")
    quota_parser.add_argument(
        "quota_mount_point",
        metavar="<mount point>",
        default=os.getcwd(),
        help="optional path to Quobyte mount point",
        nargs='?')

    info_parser = subparsers.add_parser(
        "info", help="show file info")
    info_parser.add_argument(
        "info_path",
        metavar="<path>",
        help="path to file or directory")

    version_parser = subparsers.add_parser(
        "versions", help="show list of file versions")
    version_parser.add_argument(
        "version_path",
        metavar="<path>",
        help="path to file or directory")

    locks_parser = subparsers.add_parser(
        "locks", help="show current file locks")
    locks_parser.add_argument(
        "locks_path",
        metavar="<path>",
        help="path to file or directory")

    access_key_parser = subparsers.add_parser(
        "set-access-key",
        help="set access key")
    access_key_parser.add_argument(
        "key_id",
        metavar="<access key id>",
        help="Access key id")
    access_key_parser.add_argument(
        "mount_point",
        metavar="<mount point>",
        help="Quobyte mount point",
        default=".",
        nargs='?')
    access_key_parser.add_argument(
        "--secret",
        metavar="<access key secret>",
        help="Access key secret")
    access_key_parser.add_argument(
        "--scope",
        choices=["user", "client"],
        default="user",
        help="Access key secret")

    retention_parser = subparsers.add_parser(
        "retention",
        help="Configure retention lock")
    retention_parser.add_argument(
        "--recursive", "-R", action="store_true",
        help="Recursive")
    retention_parser.add_argument(
        "--set-delete-after", action="store_true",
        help="Delete file/directory/... after retention timestamp")
    retention_parser.add_argument(
        "--set-retain-until", action="store_true",
        help="File/directory/... is immutable until retention timestamp")
    retention_parser.add_argument(
        "--remove", action="store_true",
        help="Remove retention lock")
    retention_parser.add_argument(
        "--timestamp",
        help="Retention lock timestamp, in ISO 8601 like 2024-12-03T10:15:30.00Z")
    retention_parser.add_argument(
        "path",
        metavar="<path>", default="", nargs='+',
        help="paths to file or directory")

    whoami_parser = subparsers.add_parser(
        "whoami",
        help=("show own user credentials as effective " +
            "for Quobyte file system access control"))
    whoami_parser.add_argument(
        "mount_point",
        default=os.getcwd(),
        metavar="<path>", nargs='?',
        help="path to an entry in a Quobyte mount point")

    setreplfactor_parser = subparsers.add_parser(
        "set-replication-factor",
        help="set replication factor of replicated files")
    setreplfactor_parser.add_argument(
        "--recursive", "-R", action="store_true",
        help="recursive")
    setreplfactor_parser.add_argument(
        "-t", help="Target replication factor",
        choices=['1','3'], default='3')
    setreplfactor_parser.add_argument(
        "path",
        metavar="<path>", default="", nargs='+',
        help="path to file or directory")

    args = parser.parse_args()

    if args.v:
        global VERBOSE
        VERBOSE=True

    if args.command == 'quota':
        return show_quotas(args.quota_mount_point, args.unit, args.precision)
    elif args.command == 'info':
        return show_file_metadata(args.info_path)
    elif args.command == 'versions':
        return show_file_versions(args.version_path)
    elif args.command == 'locks':
        return show_file_locks(args.locks_path)
    elif args.command == 'whoami':
        return show_whoami(args.mount_point)
    elif args.command == 'set-replication-factor':
        return set_replication_factor(args.path, args.recursive, args.t)
    elif args.command == 'set-access-key':
        return set_access_key(args.mount_point, args.scope, args.key_id, args.secret)
    elif args.command == 'retention':
        return manage_retention(
            args.path, args.recursive, args.timestamp,
            args.set_delete_after, args.set_retain_until, args.remove)
    return 2


if __name__ == "__main__":
    try:
        sys.exit(main())
    except OSError as e:
        if e.errno == 95:
            print("Command needs to be run against a Quobyte mount point, terminating.")
        else:
            print(str(e) + ", terminating")
    except Exception as e:
        if VERBOSE:
            raise
        else:
            print(str(e) + ", terminating")
    sys.exit(1)
