#!/usr/bin/python3
# -*- Mode: Python; indent-tabs-mode: nil; tab-width: 4; coding: utf-8 -*-

from __future__ import print_function

import apt_pkg
import distro_info
import glob
import inspect
import os
import shutil
import subprocess
import unittest
from DistUpgrade.DistUpgradeController import (
    DistUpgradeController,
    component_ordering_key,
)
from DistUpgrade.DistUpgradeViewNonInteractive import (
    DistUpgradeViewNonInteractive,
)
from DistUpgrade import DistUpgradeConfigParser
import logging
import mock

DistUpgradeConfigParser.CONFIG_OVERRIDE_DIR = None

CURDIR = os.path.dirname(os.path.abspath(__file__))
dpkg = subprocess.Popen(['dpkg', '--print-architecture'],
                        stdout=subprocess.PIPE)
ARCH = dpkg.communicate()[0].decode().strip()
di = distro_info.UbuntuDistroInfo()
LTSES = [supported for supported in di.supported()
         if di.is_lts(supported)]


class TestComponentOrdering(unittest.TestCase):

    def test_component_ordering_key_from_set(self):
        self.assertEqual(
            sorted(set(["x", "restricted", "main"]),
                   key=component_ordering_key),
            ["main", "restricted", "x"])

    def test_component_ordering_key_from_list(self):
        self.assertEqual(
            sorted(["x", "main"], key=component_ordering_key),
            ["main", "x"])
        self.assertEqual(
            sorted(["restricted", "main"],
                   key=component_ordering_key),
            ["main", "restricted"])
        self.assertEqual(
            sorted(["main", "restricted"],
                   key=component_ordering_key),
            ["main", "restricted"])
        self.assertEqual(
            sorted(["main", "multiverse", "restricted", "universe"],
                   key=component_ordering_key),
            ["main", "restricted", "universe", "multiverse"])
        self.assertEqual(
            sorted(["a", "main", "multiverse", "restricted", "universe"],
                   key=component_ordering_key),
            ["main", "restricted", "universe", "multiverse", "a"])


class TestDeb822SourcesUpdate(unittest.TestCase):

    testdir = os.path.abspath(os.path.join(CURDIR, "data-deb822-sources-test"))
    sourceparts_dir = os.path.join(testdir, "sources.list.d")

    def setUp(self):
        apt_pkg.config.set("Dir::Etc", self.testdir)
        apt_pkg.config.set("Dir::Etc::sourcelist", "sources.list")
        apt_pkg.config.set("Dir::Etc::sourceparts", self.sourceparts_dir)
        apt_pkg.config.set("APT::Default-Release", "")

    def tearDown(self):
        shutil.rmtree(self.sourceparts_dir, ignore_errors=True)

    def test_ubuntu_sources_with_nothing(self):
        """
        test ubuntu.sources rewrite with nothing in it
        """
        self.prepareTestSources()

        v = DistUpgradeViewNonInteractive()
        d = DistUpgradeController(v, datadir=self.testdir)
        d.openCache(lock=False)
        res = d.updateDeb822Sources()
        self.assertTrue(res)
        self.assertSourcesMatchExpected(
            dist=d.toDist,
            default_source_uri=d.default_source_uri,
            security_source_uri=d.security_source_uri,
        )

    @mock.patch("DistUpgrade.DistUpgradeController.DistUpgradeController._deb822SourceEntryDownloadable")
    def test_sources_rewrite(self, mock_deb822SourceEntryDownloadable):
        """
        test regular ubuntu.sources rewrite
        """
        v = DistUpgradeViewNonInteractive()
        d = DistUpgradeController(v, datadir=self.testdir)
        d.useNetwork = False
        self.prepareTestSources(
            dist=d.fromDist,
            default_source_uri=d.default_source_uri,
            security_source_uri=d.security_source_uri,
        )
        d.config.set("Distro", "BaseMetaPkgs", "ubuntu-minimal")
        d.openCache(lock=False)
        mock_deb822SourceEntryDownloadable.return_value = (True, [])
        res = d.updateDeb822Sources()
        self.assertTrue(mock_deb822SourceEntryDownloadable.called)
        self.assertTrue(res)
        self.assertSourcesMatchExpected(
            dist=d.toDist,
            default_source_uri=d.default_source_uri,
            security_source_uri=d.security_source_uri,
        )

    @unittest.skipUnless(ARCH in ('amd64', 'i386'), "ports are not mirrored")
    @mock.patch("DistUpgrade.DistUpgradeController.DistUpgradeController._deb822SourceEntryDownloadable")
    def test_sources_inactive_mirror(self,
                                     mock_deb822SourceEntryDownloadable):
        """
        test ubuntu.sources rewrite of an obsolete mirror
        """
        v = DistUpgradeViewNonInteractive()
        d = DistUpgradeController(v, datadir=self.testdir)
        d.config.set("Distro", "BaseMetaPkgs", "ubuntu-minimal")
        d.openCache(lock=False)
        self.prepareTestSources(dist=d.fromDist)
        mock_deb822SourceEntryDownloadable.return_value = (True, [])
        res = d.updateDeb822Sources()
        self.assertTrue(mock_deb822SourceEntryDownloadable.called)
        self.assertTrue(res)
        self.assertSourcesMatchExpected(
            from_dist=d.fromDist,
            to_dist=d.toDist,
            default_source_uri=d.default_source_uri
        )

    def testEOL2SupportedUpgrade(self):
        " test upgrade from a EOL release to a supported release "
        to_dist = LTSES[-2]
        from_dist = di.all[di.all.index(to_dist) - 1]
        os.environ["LANG"] = "C"
        v = DistUpgradeViewNonInteractive()
        d = DistUpgradeController(v, datadir=self.testdir)
        d.fromDist = from_dist
        d.toDist = to_dist
        d.openCache(lock=False)
        self.prepareTestSources(dist=from_dist)
        res = d.updateDeb822Sources()
        self.assertTrue(res)
        self.assertSourcesMatchExpected(
            dist=to_dist,
            default_source_uri=d.default_source_uri,
            security_source_uri=d.security_source_uri,
        )

    @mock.patch("DistUpgrade.DistUpgradeController.DistUpgradeController._deb822SourceEntryDownloadable")
    def test_private_ppa_transition(self, mock_deb822SourceEntryDownloadable):
        if "RELEASE_UPGRADER_ALLOW_THIRD_PARTY" in os.environ:
            del os.environ["RELEASE_UPGRADER_ALLOW_THIRD_PARTY"]

        v = DistUpgradeViewNonInteractive()
        d = DistUpgradeController(v, datadir=self.testdir)
        d.openCache(lock=False)
        self.prepareTestSources(
            dist=d.fromDist,
            default_source_uri=d.default_source_uri,
            security_source_uri=d.security_source_uri,
        )
        mock_deb822SourceEntryDownloadable.return_value = (True, [])
        res = d.updateDeb822Sources()
        self.assertTrue(mock_deb822SourceEntryDownloadable.called)
        self.assertTrue(res)
        self.assertSourcesMatchExpected(
            to_dist=d.toDist,
            from_dist=d.fromDist,
            default_source_uri=d.default_source_uri,
            security_source_uri=d.security_source_uri,
        )

    @mock.patch("DistUpgrade.DistUpgradeController.DistUpgradeController._deb822SourceEntryDownloadable")
    def test_apt_cacher_and_apt_bittorent(self,
                                          mock_deb822SourceEntryDownloadable):
        """
        test transition of apt-cacher/apt-torrent uris
        """
        v = DistUpgradeViewNonInteractive()
        d = DistUpgradeController(v, datadir=self.testdir)
        d.openCache(lock=False)
        self.prepareTestSources(dist=d.fromDist)
        mock_deb822SourceEntryDownloadable.return_value = (True, [])
        res = d.updateDeb822Sources()
        self.assertTrue(mock_deb822SourceEntryDownloadable.called)
        self.assertTrue(res)
        self.assertSourcesMatchExpected(dist=d.toDist)

    @mock.patch("DistUpgrade.DistUpgradeController.DistUpgradeController._deb822SourceEntryDownloadable")
    def test_local_mirror(self, mock_deb822SourceEntryDownloadable):
        """
        test that a local mirror with official -backports works (LP: #1067393)
        """
        v = DistUpgradeViewNonInteractive()
        d = DistUpgradeController(v, datadir=self.testdir)
        d.openCache(lock=False)
        self.prepareTestSources(
            dist=d.fromDist,
            default_source_uri=d.default_source_uri,
            security_source_uri=d.security_source_uri,
        )
        mock_deb822SourceEntryDownloadable.return_value = (True, [])
        res = d.updateDeb822Sources()
        self.assertTrue(mock_deb822SourceEntryDownloadable.called)
        self.assertTrue(res)
        self.assertSourcesMatchExpected(
            dist=d.toDist,
            default_source_uri=d.default_source_uri,
            security_source_uri=d.security_source_uri,
        )

    @mock.patch("DistUpgrade.DistUpgradeController.DistUpgradeController._deb822SourceEntryDownloadable")
    def test_disable_proposed(self, mock_deb822SourceEntryDownloadable):
        """
        Test that proposed is disabled when upgrading to a development
        release.
        """
        v = DistUpgradeViewNonInteractive()
        options = mock.Mock()
        options.devel_release = True
        d = DistUpgradeController(v, options, datadir=self.testdir)
        d.openCache(lock=False)
        self.prepareTestSources(
            dist=d.fromDist,
            default_source_uri=d.default_source_uri,
            security_source_uri=d.security_source_uri,
        )
        mock_deb822SourceEntryDownloadable.return_value = (True, [])
        res = d.updateDeb822Sources()
        self.assertTrue(mock_deb822SourceEntryDownloadable.called)
        self.assertTrue(res)
        self.assertSourcesMatchExpected(
            from_dist=d.fromDist,
            to_dist=d.toDist,
            default_source_uri=d.default_source_uri,
            security_source_uri=d.security_source_uri,
        )

    def assertSourcesMatchExpected(self, **kwargs):
        # Compare the output to the expected data corresponding to the calling
        # test function.
        testdata = os.path.join(self.testdir, inspect.stack()[1][3], "expect")
        if os.path.exists(testdata + "." + ARCH):
            testdata = testdata + "." + ARCH

        for expect_path in os.listdir(testdata):
            actual_path = os.path.join(self.sourceparts_dir, expect_path)
            expect_path = os.path.join(testdata, expect_path)
            with open(expect_path) as e, open(actual_path) as a:
                expect = e.read().format(**kwargs)
                actual = a.read()
                self.assertEqual(expect, actual,
                                 '\n# Actual {}:\n{}'
                                 .format(os.path.basename(actual_path), actual))

    def prepareTestSources(self, **kwargs):
        # Copy the test data from the directory corresponding to the calling
        # test function.
        testdata = os.path.join(self.testdir, inspect.stack()[1][3], "in")

        shutil.rmtree(self.sourceparts_dir, ignore_errors=True)
        os.mkdir(self.sourceparts_dir)

        for file in os.listdir(testdata):
            path_in = os.path.join(testdata, file)
            path_out = os.path.join(self.sourceparts_dir, file)

            with open(path_in) as fin, open(path_out, 'w') as fout:
                fout.write(fin.read().format(**kwargs))


class TestDeb822SourcesMigration(unittest.TestCase):

    testdir = os.path.abspath(os.path.join(CURDIR, "data-deb822-migration-test"))
    sourceparts_dir = os.path.join(testdir, "sources.list.d")
    trustedparts_dir = os.path.join(testdir, "trusted.gpg.d")

    def setUp(self):
        apt_pkg.config.set("Dir::Etc", self.testdir)
        apt_pkg.config.set("Dir::Etc::sourcelist", "sources.list")
        apt_pkg.config.set("Dir::Etc::sourceparts", self.sourceparts_dir)
        apt_pkg.config.set("Dir::Etc::trustedparts", self.trustedparts_dir)
        apt_pkg.config.set("APT::Default-Release", "")
        self.maxDiff = None

    def tearDown(self):
        shutil.rmtree(self.sourceparts_dir, ignore_errors=True)
        shutil.rmtree(self.trustedparts_dir, ignore_errors=True)
        for file in glob.glob("{}/sources.list*".format(self.testdir)):
            os.remove(file)

    def assertSourcesMatchExpected(self, **kwargs):
        # Compare the output to the expected data corresponding to the calling
        # test function.
        testdata = os.path.join(self.testdir, inspect.stack()[1][3], "expect")

        for expect_path in os.listdir(testdata):
            if expect_path.endswith(".gpg"):
                actual_path = os.path.join(self.trustedparts_dir, expect_path)
            else:
                actual_path = os.path.join(self.sourceparts_dir, expect_path)

            expect_path = os.path.join(testdata, expect_path)

            self.assertTrue(
                os.path.exists(actual_path),
                "File {} was not created during deb822 migration"
                .format(actual_path)
            )

            with open(expect_path) as e, open(actual_path) as a:
                expect = e.read().format(**kwargs)
                actual = a.read()
                self.assertEqual(expect, actual,
                                 '\n# Actual {}:\n{}'
                                 .format(os.path.basename(actual_path), actual))

    def prepareTestSources(self, **kwargs):
        # Copy the test data from the directory corresponding to the calling
        # test function.
        testdata = os.path.join(self.testdir, inspect.stack()[1][3], "in")

        shutil.rmtree(self.sourceparts_dir, ignore_errors=True)
        os.mkdir(self.sourceparts_dir)
        shutil.rmtree(self.trustedparts_dir, ignore_errors=True)
        os.mkdir(self.trustedparts_dir)

        for file in os.listdir(testdata):
            path_in = os.path.join(testdata, file)
            in_mode = 'r'
            out_mode = 'w'

            if file == "sources.list":
                path_out = os.path.join(self.testdir, file)
            elif file.endswith(".gpg"):
                path_out = os.path.join(self.trustedparts_dir, file)
                in_mode += 'b'
                out_mode += 'b'
            else:
                path_out = os.path.join(self.sourceparts_dir, file)

            with open(path_in, in_mode) as fin, open(path_out, out_mode) as fout:
                data = fin.read()
                if isinstance(data, str):
                    data = data.format(**kwargs)

                fout.write(data)

    def test_sources_list_migration(self):
        """
        test that the usual sources.list is migrated to an appropriate
        ubuntu.sources
        """
        v = DistUpgradeViewNonInteractive()
        d = DistUpgradeController(v, datadir=self.testdir)
        d.openCache(lock=False)

        if d.default_source_uri == d.security_source_uri:
            self.skipTest("Test requires different default and security source URIs")

        self.prepareTestSources(
            dist=di.stable(),
            default_source_uri=d.default_source_uri,
            security_source_uri=d.security_source_uri,
        )
        self.assertEqual(d.migratedToDeb822(), -1)

        d.migrateToDeb822Sources()

        self.assertEqual(d.migratedToDeb822(), 1)
        self.assertSourcesMatchExpected(
            dist=di.stable(),
            default_source_uri=d.default_source_uri,
            security_source_uri=d.security_source_uri,
        )

        # Make sure we leave behind a comment in sources.list
        with open(os.path.join(self.testdir, "sources.list")) as f:
            self.assertEqual(
                f.read(),
                "# Ubuntu sources have moved to {}/ubuntu.sources\n"
                .format(self.sourceparts_dir)
            )

    @unittest.skipIf(ARCH in ('amd64', 'i386'), "test is for port arches")
    def test_ports_sources_list_migration(self):
        """
        test that the usual ports sources.list is migrated to an appropriate
        ubuntu.sources
        """
        v = DistUpgradeViewNonInteractive()
        d = DistUpgradeController(v, datadir=self.testdir)
        d.openCache(lock=False)

        if d.default_source_uri != d.security_source_uri:
            self.skipTest("Test assumes equal default and security source URIs")

        self.prepareTestSources(
            dist=di.stable(),
            default_source_uri=d.default_source_uri,
            security_source_uri=d.security_source_uri,
        )
        self.assertEqual(d.migratedToDeb822(), -1)

        d.migrateToDeb822Sources()

        self.assertEqual(d.migratedToDeb822(), 1)
        self.assertSourcesMatchExpected(
            dist=di.stable(),
            default_source_uri=d.default_source_uri,
            security_source_uri=d.security_source_uri,
        )

        # Make sure we leave behind a comment in sources.list
        with open(os.path.join(self.testdir, "sources.list")) as f:
            self.assertEqual(
                f.read(),
                "# Ubuntu sources have moved to {}/ubuntu.sources\n"
                .format(self.sourceparts_dir)
            )

    def test_third_party_sources_migration(self):
        """
        test that third party sources from sources.list are moved to
        third-party.sources.
        """
        v = DistUpgradeViewNonInteractive()
        d = DistUpgradeController(v, datadir=self.testdir)
        d.openCache(lock=False)

        self.prepareTestSources(
            dist=di.stable(),
            default_source_uri=d.default_source_uri,
        )
        self.assertEqual(d.migratedToDeb822(), -1)

        d.migrateToDeb822Sources()

        self.assertEqual(d.migratedToDeb822(), 1)
        self.assertSourcesMatchExpected(
            dist=di.stable(),
            default_source_uri=d.default_source_uri,
        )

    def test_partial_migration(self):
        """
        test that only .list sources are modified during migration
        """
        v = DistUpgradeViewNonInteractive()
        d = DistUpgradeController(v, datadir=self.testdir)
        d.openCache(lock=False)

        self.prepareTestSources(
            dist=di.stable(),
            default_source_uri=d.default_source_uri,
        )
        self.assertEqual(d.migratedToDeb822(), 0)

        d.migrateToDeb822Sources()

        self.assertEqual(d.migratedToDeb822(), 1)
        self.assertSourcesMatchExpected(
            dist=di.stable(),
            default_source_uri=d.default_source_uri,
        )

    def test_ppa_migration(self):
        """
        test that PPA sources.list.d .list files are migrated
        """
        v = DistUpgradeViewNonInteractive()
        d = DistUpgradeController(v, datadir=self.testdir)
        d.openCache(lock=False)

        self.prepareTestSources()
        self.assertEqual(d.migratedToDeb822(), -1)

        d.migrateToDeb822Sources()

        self.assertEqual(d.migratedToDeb822(), 1)
        self.assertSourcesMatchExpected()

    def test_consolidate_types(self):
        """
        test that deb and deb-src entries are consolidated when appropriate
        """
        v = DistUpgradeViewNonInteractive()
        d = DistUpgradeController(v, datadir=self.testdir)
        d.openCache(lock=False)

        if d.default_source_uri == d.security_source_uri:
            self.skipTest("Test requires different default and security source URIs")

        self.prepareTestSources(
            dist=di.stable(),
            default_source_uri=d.default_source_uri,
            security_source_uri=d.security_source_uri,
        )
        self.assertEqual(d.migratedToDeb822(), -1)

        d.migrateToDeb822Sources()

        self.assertEqual(d.migratedToDeb822(), 1)
        self.assertSourcesMatchExpected(
            dist=di.stable(),
            default_source_uri=d.default_source_uri,
            security_source_uri=d.security_source_uri,
        )

    def test_consolidate_suites(self):
        """
        test that suites are consolidated when appropriate
        """
        v = DistUpgradeViewNonInteractive()
        d = DistUpgradeController(v, datadir=self.testdir)
        d.openCache(lock=False)

        self.prepareTestSources(
            dist=di.stable(),
            default_source_uri=d.default_source_uri,
        )
        self.assertEqual(d.migratedToDeb822(), -1)

        d.migrateToDeb822Sources()

        self.assertEqual(d.migratedToDeb822(), 1)
        self.assertSourcesMatchExpected(
            dist=di.stable(),
            default_source_uri=d.default_source_uri,
        )

    def test_consolidate_comps(self):
        """
        test that components are consolidated when appropriate
        """
        v = DistUpgradeViewNonInteractive()
        d = DistUpgradeController(v, datadir=self.testdir)
        d.openCache(lock=False)

        self.prepareTestSources(
            dist=di.stable(),
            default_source_uri=d.default_source_uri,
        )
        self.assertEqual(d.migratedToDeb822(), -1)

        d.migrateToDeb822Sources()

        self.assertEqual(d.migratedToDeb822(), 1)
        self.assertSourcesMatchExpected(
            dist=di.stable(),
            default_source_uri=d.default_source_uri,
        )


if __name__ == "__main__":
    import sys
    for e in sys.argv:
        if e == "-v":
            logging.basicConfig(level=logging.DEBUG)
    unittest.main()
