#!/usr/bin/python3
#-*- encoding: utf-8 -*-
"""
This script can be used instead of the traditional `snap` command to download
snaps and accompanying assertions. It uses the new store API (v2) which allows
creating temporary snapshots of the channel map.

To create such a snapshot run

    snap-tool cohort-create

This will print a "cohort-key" to stdout, which can then be passed to future
invocations of `snap-tool download`. Whenever a cohort key is provided, the
store will provide a view of the channel map as it existed when the key was
created.
"""

from textwrap import dedent

import argparse
import base64
import binascii
import getopt
import hashlib
import json
import os
import shutil
import subprocess
import sys
import urllib.error
import urllib.request

EXIT_OK  = 0
EXIT_ERR = 1


class SnapError(Exception):
    """Generic error thrown by the Snap class."""
    pass


class SnapCraftError(SnapError):
    """Error thrown on problems with the snapcraft APIs."""
    pass


class SnapAssertionError(SnapError):
    """Error thrown on problems with the assertions API."""
    pass


class Snap:
    """This class provides methods to retrieve information about a snap and
    download it together with its assertions."""

    def __init__(self, name, channel="stable", arch="amd64", series=16,
            cohort_key=None, assertion_url="https://assertions.ubuntu.com",
            snapcraft_url="https://api.snapcraft.io", **kwargs):
        """
        :param name:
            The name of the snap.
        :param channel:
            The channel to operate on.
        :param arch:
            The Debian architecture of the snap (e.g. amd64, armhf, arm64, ...).
        :param series:
            The device series. This should always be 16.
        :param cohort_key:
            A cohort key to access a snapshot of the channel map.
        """
        self._name          = name
        self._channel       = channel
        self._arch          = arch
        self._series        = series
        self._cohort_key    = cohort_key
        self._assertion_url = assertion_url
        self._snapcraft_url = snapcraft_url
        self._details       = None
        self._assertions    = {}

    @classmethod
    def cohort_create(cls):
        """Get a cohort key for the current moment. A cohort key is valid
        across all snaps, channels and architectures."""
        return Snap("core")\
            .get_details(cohort_create=True)\
            .get("cohort-key")

    def download(self, download_assertions=True):
        """Download the snap container. If download_assertions is True, the
        corresponding assertions will be downloaded, as well."""
        snap = self.get_details()

        snap_name         = snap["name"]
        snap_revision     = snap["revision"]
        publisher_id      = snap["publisher"]["id"]
        snap_download_url = snap["download"]["url"]
        snap_byte_size    = snap["download"]["size"]
        filename          = snap_name + "_" + str(snap_revision)
        snap_filename     = filename + ".snap"
        assert_filename   = filename + ".assert"

        skip_snap_download = False

        if os.path.exists(snap_filename) and os.path.getsize(snap_filename) \
                == snap_byte_size:
            skip_snap_download = True

        headers = {}

        if os.environ.get("SNAPPY_STORE_NO_CDN", "0") != "0":
            headers.update({
                "X-Ubuntu-No-Cdn": "true",
                "Snap-CDN": "none",
            })

        request = urllib.request.Request(snap_download_url, headers=headers)

        if not skip_snap_download:
            with urllib.request.urlopen(request) as response, \
                    open(snap_filename, "wb+") as fp:
                shutil.copyfileobj(response, fp)

        if not download_assertions:
            return

        required_assertions = [
            "account-key",
            "account",
            "snap-declaration",
            "snap-revision",
        ]

        if publisher_id == "canonical":
            required_assertions.remove("account")

        for assertion_name in required_assertions:
            attr_name = "get_assertion_" + assertion_name.replace("-", "_")
            # This will populate self._assertions[<assertion_name>].
            getattr(self, attr_name)()

        with open(assert_filename, "w+", encoding="utf-8") as fp:
            fp.write("\n".join(self._assertions[a] for a in
                required_assertions))

    def get_details(self, cohort_create=False):
        """Get details about the snap. On subsequent calls, the cached results
        are returned. If cohort_create is set to True, a cohort key will be
        created and included in the result."""
        if self._details and not cohort_create:
            return self._details

        path = "/v2/snaps/refresh"

        data = {
            "context": [],
            "actions": [
                {
                    "action":       "download",
                    "instance-key": "0",
                    "name":         self._name,
                    "channel":      self._channel,
                }
            ]
        }

        # These are mutually exclusive.
        if cohort_create:
            data["actions"][0]["cohort-create"] = True
        elif self._cohort_key:
            data["actions"][0]["cohort-key"] = self._cohort_key

        request_json = json.dumps(data, ensure_ascii=False).encode("utf-8")

        try:
            response_dict = self._do_snapcraft_request(path, data=request_json)
        except SnapCraftError as e:
            raise SnapError("failed to get details for '{}': {}"
                    .format(self._name, str(e)))

        snap_data = response_dict["results"][0]

        # Copy the key into the snap details.
        if "cohort-key" in snap_data:
            snap_data["snap"]["cohort-key"] = snap_data["cohort-key"]

        if "error" in snap_data:
            raise SnapError(
                "failed to get details for '{}' in '{}' on '{}': {}"
                    .format(self._name, self._channel, self._arch,
                        snap_data["error"]["message"])
            )

        self._details = snap_data["snap"]
        return self._details

    def get_assertion_snap_revision(self):
        """Download the snap-revision assertion associated with this snap. The
        assertion is returned as a string."""
        if "snap-revision" in self._assertions:
            return self._assertions["snap-revision"]
        snap = self.get_details()

        snap_sha3_384 = base64.urlsafe_b64encode(
            binascii.a2b_hex(snap["download"]["sha3-384"])
        ).decode("us-ascii")

        data = self._do_assertion_request("/v1/assertions/snap-revision/{}"
                .format(snap_sha3_384))
        self._assertions["snap-revision"] = data
        return data

    def get_assertion_snap_declaration(self):
        """Download the snap-declaration assertion associated with this snap.
        The assertion is returned as a string."""
        if "snap-declaration" in self._assertions:
            return self._assertions["snap-declaration"]
        snap = self.get_details()
        series = self._series
        snap_id = snap["snap-id"]

        data = self._do_assertion_request(
                "/v1/assertions/snap-declaration/{}/{}"
                    .format(series, snap_id))

        self._assertions["snap-declaration"] = data
        return data

    def get_assertion_account(self):
        """Download the account assertion associated with this snap. The
        assertion is returned as a string."""
        if "account" in self._assertions:
            return self._assertions["account"]
        snap = self.get_details()
        publisher_id = snap["publisher"]["id"]
        data = self._do_assertion_request("/v1/assertions/account/{}"
                .format(publisher_id))
        self._assertions["account"] = data
        return data

    def get_assertion_account_key(self):
        """Download the account-key assertion associated with this snap. The
        assertion will be returned as a string."""
        if "account-key" in self._assertions:
            return self._assertions["account-key"]

        declaration_data = self.get_assertion_snap_declaration()
        sign_key_sha3 = None

        for line in declaration_data.splitlines():
            if line.startswith("sign-key-sha3-384:"):
                sign_key_sha3 = line.split(":")[1].strip()

        data = self._do_assertion_request("/v1/assertions/account-key/{}"
                .format(sign_key_sha3))

        self._assertions["account-key"] = data
        return data

    def _do_assertion_request(self, path):
        url = self._assertion_url + path

        headers = {
            "Accept": "application/x.ubuntu.assertion",
        }

        request = urllib.request.Request(url, headers=headers)

        try:
            with urllib.request.urlopen(request) as response:
                body = response.read()
        except urllib.error.HTTPError as e:
            raise SnapAssertionError(str(e))

        return body.decode("utf-8")

    def _do_snapcraft_request(self, path, data=None):
        url = self._snapcraft_url + "/" + path

        headers = {
            "Snap-Device-Series": str(self._series),
            "Snap-Device-Architecture": self._arch,
            "Content-Type": "application/json",
        }

        request = urllib.request.Request(url, data=data, headers=headers)

        try:
            with urllib.request.urlopen(request) as response:
                body = response.read()
        except urllib.error.HTTPError as e:
            raise SnapCraftError(str(e))

        try:
            response_data = json.loads(body, encoding="utf-8")
        except json.JSONDecodeError as e:
            raise SnapCraftError("failed to decode response body: " + str(e))

        return response_data


class SnapCli:

    def __call__(self, args):
        """Parse the command line arguments and execute the selected command."""
        options = self._parse_opts(args)

        try:
            options.func(getattr(options, "snap", None), **vars(options))
        except SnapError as e:
            sys.stderr.write("snap-tool {}: {}\n".format(
                options.command, str(e)))
            return EXIT_ERR
        return EXIT_OK

    @staticmethod
    def _get_host_deb_arch():
        result = subprocess.run(["dpkg", "--print-architecture"],
                stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                universal_newlines=True, check=True)

        return result.stdout.strip()

    def _parse_opts(self, args):
        main_parser = argparse.ArgumentParser()
        subparsers = main_parser.add_subparsers(dest="command")

        parser_cohort_create = subparsers.add_parser("cohort-create",
                help="Create a cohort key for the snap store channel map.")
        parser_cohort_create.set_defaults(func=self._cohort_create)

        parser_download = subparsers.add_parser("download",
                help="Download a snap from the store.")
        parser_download.set_defaults(func=self._download)

        parser_info = subparsers.add_parser("info",
                help="Retrieve information about a snap.")
        parser_info.set_defaults(func=self._info)

        # Add common parameters.
        for parser in [parser_download, parser_info]:
            parser.add_argument("--cohort-key", dest="cohort_key",
                help="A cohort key to pin the channel map to.", type=str)
            parser.add_argument("--channel", dest="channel",
                help="The publication channel to query (default: stable).",
                type=str, default="stable")
            parser.add_argument("--series", dest="series",
                help="The device series (default: 16)",
                type=int, default=16)
            parser.add_argument("--arch", dest="arch",
                help="The Debian architecture (default: amd64).",
                type=str, default=self._get_host_deb_arch())
            parser.add_argument("snap", help="The name of the snap.")

        if not args:
            main_parser.print_help()
            sys.exit(EXIT_ERR)

        return main_parser.parse_args(args)

    def _cohort_create(self, _, **kwargs):
        print(Snap.cohort_create())

    def _download(self, snap_name, **kwargs):
        Snap(snap_name, **kwargs).download()

    def _info(self, snap_name, **kwargs):
        snap = Snap(snap_name, **kwargs)
        info = snap.get_details()

        print(dedent("""\
            name:      {}
            summary:   {}
            arch:      {}
            channel:   {}
            publisher: {}
            license:   {}
            snap-id:   {}
            revision:  {}"""
            .format(
                snap_name,
                info.get("summary", ""),
                snap._arch,
                snap._channel,
                info.get("publisher", {}).get("display-name", ""),
                info.get("license", ""),
                info.get("snap-id", ""),
                info.get("revision", "")
            ))
        )


if __name__ == "__main__":
    try:
        SnapCli()(sys.argv[1:])
    except KeyboardInterrupt:
        sys.stderr.write("snap-tool: caught keyboard interrupt, exiting.\n")
        sys.exit(EXIT_ERR)