# Copyright (C) 2003 by Intevation GmbH
# Authors:
# Bernhard Herzog <bh@intevation.de>
#
# This program is free software under the GPL (>=v2)
# Read the file COPYING coming with the software for details.

"""Test for Thuban.Model.postgisdb"""

import os
import unittest


try:
    import psycopg
except ImportError:
    # No psycopg available. Nothing to be done here because the
    # postgis.py support module determines this too and the tests will
    # be skipped completely.
    pass

import postgissupport

import support
support.initthuban()

from Thuban.Model.postgisdb import ConnectionError, PostGISConnection, \
     PostGISTable, PostGISShapeStore
from Thuban.Model.table import FIELDTYPE_INT, FIELDTYPE_STRING, \
     FIELDTYPE_DOUBLE
from Thuban.Model.data import SHAPETYPE_POINT, SHAPETYPE_POLYGON, \
     SHAPETYPE_ARC, RAW_WKT


class NonConnection(PostGISConnection):

    """Special connection class that doesn't actually connect"""

    def connect(self):
        pass



class TestBriefDescription(unittest.TestCase):

    def test(self):
        """Test PostGISConnection.BriefDescription()"""
        self.assertEquals(NonConnection("somedb").BriefDescription(),
                          "postgis://@:/somedb")
        self.assertEquals(NonConnection("db", host="here",
                                        port="123").BriefDescription(),
                          "postgis://@here:123/db")
        self.assertEquals(NonConnection("db", user="me",
                                        port="123").BriefDescription(),
                          "postgis://me@:123/db")


class TestPostGISConnection(unittest.TestCase):

    def setUp(self):
        """Start the server and create a database.

        The database name will be stored in self.dbname, the server
        object in self.server and the db object in self.db.
        """
        postgissupport.skip_if_no_postgis()
        self.server = postgissupport.get_test_server()
        self.dbname = ".".join(self.id().split(".")[-2:])[-31:]
        self.db = self.server.new_postgis_db(self.dbname)

    def test_gis_tables_empty(self):
        """Test PostGISConnection.GISTables() on empty DB"""
        db = PostGISConnection(dbname = self.dbname,
                               **self.server.connection_params("user"))

        # An empty database doesn't have any GIS tables
        self.assertEquals(db.GeometryTables(), [])

    def test_gis_tables_non_empty(self):
        """Test PostGISConnection.GISTables() on non-empty DB"""
        db = PostGISConnection(dbname = self.dbname,
                               **self.server.connection_params("user"))

        conn = psycopg.connect("dbname=%s " % self.dbname
                               + self.server.connection_string("admin"))
        cursor = conn.cursor()
        cursor.execute("CREATE TABLE test (A INT);")
        cursor.execute("SELECT AddGeometryColumn(%(dbname)s, 'test',"
                       " 'geometry', -1, 'POINT', 2);",
                       {"dbname": self.dbname})
        conn.commit()

        # An empty database doesn't have any GIS tables
        self.assertEquals(db.GeometryTables(), ["test"])


class TestPostgisDBExceptions(unittest.TestCase):

    def setUp(self):
        """Start the postgis server and switch on authentication"""
        postgissupport.skip_if_no_postgis()
        self.server = postgissupport.get_test_server()
        self.postgisdb = self.server.get_default_static_data_db()
        self.server.require_authentication(True)

    def tearDown(self):
        """Extend the inherited method to switch off postgresql authentication
        """
        self.server.require_authentication(False)

    def test_no_password(self):
        """Test PostGISConnection with omitted but required password"""
        connection_params = self.server.connection_params("user")
        # remove the password deliberately
        del connection_params["password"]

        self.assertRaises(ConnectionError,
                          PostGISConnection, dbname = self.postgisdb.dbname,
                          **connection_params)


class TestPostGISIgnoredColumns(unittest.TestCase):

    def setUp(self):
        """Start the server and create a database.

        The database name will be stored in self.dbname, the server
        object in self.server and the db object in self.db.
        """
        postgissupport.skip_if_no_postgis()
        self.server = postgissupport.get_test_server()
        self.dbname = ".".join(self.id().split(".")[-2:])[-31:]
        self.db = self.server.new_postgis_db(self.dbname)

    def test(self):
        """test PostGISTable on a table with unsupported data types"""
        stmt = """CREATE TABLE foo (
                        gid integer,
                        ignored bigint,
                        length float);
                  GRANT SELECT ON foo TO PUBLIC;
                  """
        self.server.execute_sql(self.dbname, "admin", stmt)

        db = PostGISConnection(dbname = self.dbname,
                               **self.server.connection_params("user"))
        table = PostGISTable(db, "foo")

        # The bigint column will be ignored because it's not mapped to a
        # known integer type, so there are only two colunns
        self.assertEquals(table.NumColumns(), 2)
        self.assertEquals(table.Column(0).name, "gid")
        self.assertEquals(table.Column(1).name, "length")


class PostGISStaticTests(unittest.TestCase, support.FloatComparisonMixin):

    """Base class for PostGIS tests with static data."""

    def setUp(self):
        """Start the server and create a database with static data

        This method sets the following instance attributes:

           dbname -- the name of the database

           server -- The server object

           db -- the PostGISConnection object
        """
        postgissupport.skip_if_no_postgis()
        self.server = postgissupport.get_test_server()
        self.postgisdb = self.server.get_default_static_data_db()
        self.db = PostGISConnection(dbname = self.postgisdb.dbname,
                                    **self.server.connection_params("user"))

    def tearDown(self):
        """Close the database connection"""
        self.db.Close()


class TestPostGISTable(PostGISStaticTests):

    def setUp(self):
        """Extend inherited method to set self.table to a PostGISTable"""
        PostGISStaticTests.setUp(self)
        self.table = PostGISTable(self.db, "landmarks")

    def test_dbconn(self):
        """Test PostGISTable.DBConnection()"""
        self.failUnless(self.table.DBConnection() is self.db)

    def test_dbname(self):
        """Test PostGISTable.TableName()"""
        self.assertEquals(self.table.TableName(), "landmarks")

    def test_title(self):
        """test PostGISTable.Title()"""
        # The title is currently equal to the tablename
        self.assertEquals(self.table.Title(), "landmarks")

    def test_dependencies(self):
        """Test PostGISTable.Dependencies()"""
        # A PostGISTable depends on no other data container
        self.assertEquals(self.table.Dependencies(), ())

    def test_num_rows(self):
        """Test PostGISTable.NumRows()"""
        self.assertEquals(self.table.NumRows(), 34)

    def test_num_columns(self):
        """Test PostGISTable.NumColumns()"""
        # The table in the postgis db has one additional column, "gid",
        # so there's one more column in the PostGISTable than in the DBF
        self.assertEquals(self.table.NumColumns(), 7)

    def test_columns(self):
        """Test PostGISTable.Columns()"""
        self.assertEquals(len(self.table.Columns()), 7)
        self.assertEquals(self.table.Columns()[0].name, "gid")
        self.assertEquals(self.table.Columns()[0].type, FIELDTYPE_INT)
        self.assertEquals(self.table.Columns()[0].index, 0)
        self.assertEquals(self.table.Columns()[1].name, "area")
        self.assertEquals(self.table.Columns()[1].type, FIELDTYPE_DOUBLE)
        self.assertEquals(self.table.Columns()[1].index, 1)
        self.assertEquals(self.table.Columns()[5].name, "clptlabel")
        self.assertEquals(self.table.Columns()[5].type, FIELDTYPE_STRING)
        self.assertEquals(self.table.Columns()[5].index, 5)

    def test_column(self):
        """Test PostGISTable.Column()"""
        self.assertEquals(self.table.Column("area").name, "area")
        self.assertEquals(self.table.Column("area").type, FIELDTYPE_DOUBLE)
        self.assertEquals(self.table.Column("area").index, 1)

    def test_has_column(self):
        """Test PostGISTable.HasColumn()"""
        self.assert_(self.table.HasColumn("area"))
        self.failIf(self.table.HasColumn("foo"))

    def test_read_row_as_dict(self):
        """Test PostGISTable.ReadRowAsDict()"""
        self.assertEquals(self.table.ReadRowAsDict(1003),
                          {"gid": 1003,
                           "area": 0.0,
                           "perimeter": 0.0,
                           "clpoint_": 4,
                           "clpoint_id": 24,
                           "clptlabel": "RUINS",
                           "clptflag": 0})

    def test_read_row_as_dict_row_count_mode(self):
        """Test PostGISTable.ReadRowAsDict() row count address mode"""
        self.assertEquals(self.table.ReadRowAsDict(3, row_is_ordinal = 1),
                          {"gid": 1003,
                           "area": 0.0,
                           "perimeter": 0.0,
                           "clpoint_": 4,
                           "clpoint_id": 24,
                           "clptlabel": "RUINS",
                           "clptflag": 0})

    def test_read_value(self):
        """Test PostGISTable.ReadValue()"""
        self.assertEquals(self.table.ReadValue(1003, 4), 24)
        self.assertEquals(self.table.ReadValue(1003, "clpoint_id"), 24)

    def test_read_value_row_count_mode(self):
        """Test PostGISTable.ReadValue() row count address mode"""
        self.assertEquals(self.table.ReadValue(3, 4, row_is_ordinal = 1), 24)
        self.assertEquals(self.table.ReadValue(3, "clpoint_id",
                                               row_is_ordinal = 1),
                          24)

    def test_row_id_to_ordinal(self):
        """Test PostGISTable.RowIdToOrdinal()"""
        self.assertEquals(self.table.RowIdToOrdinal(1005), 5)

    def test_row_oridnal_to_id(self):
        """Test PostGISTable.RowOrdinalToId()"""
        self.assertEquals(self.table.RowOrdinalToId(5), 1005)

    def test_value_range(self):
        """Test PostGISTable.ValueRange()"""
        self.assertEquals(self.table.ValueRange("clpoint_id"), (21, 74))

    def test_unique_values(self):
        """Test PostGISTable.UniqueValues()"""
        values = self.table.UniqueValues("clptlabel")
        values.sort()
        self.assertEquals(values, ["BUILDING", "FARM", "HUT","LIGHTHOUSE",
                                   "OTHER/UNKNOWN", "RUINS"])

    def test_simple_query(self):
        """Test PostGISTable.SimpleQuery()"""
        table = self.table
        self.assertEquals(table.SimpleQuery(table.Column("clptlabel"),
                                            "==", "FARM"),
                          [1006])
        self.assertEquals(table.SimpleQuery(table.Column("clpoint_id"),
                                            ">", 70),
                          [1024, 1025, 1026])
        self.assertEquals(table.SimpleQuery(table.Column("clpoint_id"),
                                            "<", table.Column("clpoint_")),
                          [1028, 1029, 1030, 1031, 1032, 1033])


class TestPostGISShapestorePoint(PostGISStaticTests):

    """Tests for PostGISShapeStore objects with POINT data"""

    def setUp(self):
        """Extend inherited method to set self.table to a PostGISShapeStore"""
        PostGISStaticTests.setUp(self)
        self.store = PostGISShapeStore(self.db, "landmarks")

    #
    # First, some tests that should be independend of the shapetype, so
    # it shouldn't be necessary to repeat them for other shapetypes
    #

    def test_dependencies(self):
        """Test PostGISShapeStore.Dependencies()"""
        # A PostGISShapeStore depends on no other data container
        self.assertEquals(self.store.Dependencies(), ())

    def test_table(self):
        """Test PostGISShapeStore.Table() with POINT shapes"""
        # A PostGISShapeStore is its own table
        self.assert_(self.store.Table() is self.store)

    def test_orig_shapestore(self):
        """Test PostGISShapeStore.OrigShapeStore() with POINT shapes"""
        # A PostGISShapeStore is not derived from another shape store
        self.assert_(self.store.OrigShapeStore() is None)

    def test_raw_format(self):
        """Test PostGISShapeStore.RawShapeFormat() with POINT shapes"""
        self.assertEquals(self.store.RawShapeFormat(), RAW_WKT)

    def test_all_shapes(self):
        """Test PostGISShapeStore.AllShapes()"""
        self.assertEquals([s.ShapeID() for s in self.store.AllShapes()],
                          range(1000, 1000 + self.store.NumShapes()))

    #
    # Shapetype specific tests
    #

    def test_shape_type(self):
        """Test PostGISShapeStore.ShapeType() with POINT shapes"""
        self.assertEquals(self.store.ShapeType(), SHAPETYPE_POINT)

    def test_num_shapes(self):
        """Test PostGISShapeStore.NumShapes() with POINT shapes"""
        self.assertEquals(self.store.NumShapes(), 34)

    def test_bounding_box(self):
        """Test PostGISShapeStore.BoundingBox() with POINT shapes"""
        self.assertFloatSeqEqual(self.store.BoundingBox(),
                                 [-23.806047439575195, 63.405960083007812,
                                  -15.12291431427002, 66.36572265625])

    def test_shape_shapeid(self):
        """Test PostGISShapeStore.Shape(i).ShapeID() with POINT shapes"""
        self.assertEquals(self.store.Shape(1005).ShapeID(), 1005)

    def test_shape_points(self):
        """Test PostGISShapeStore.Shape(i).Points() with POINT shapes"""
        self.assertPointListEquals(self.store.Shape(1000).Points(),
                                   [[(-22.711074829101562, 66.36572265625)]])

    def test_shape_raw_data(self):
        """Test PostGISShapeStore.Shape(i).RawData() with POINT shapes"""
        self.assertEquals(self.store.Shape(1000).RawData(),
                          'POINT(-22.7110748291016 66.36572265625)')

    def test_shapes_in_region(self):
        """Test PostGISShapeStore:ShapesInRegion() with POINT shapes"""
        shapes = self.store.ShapesInRegion((-20.0, 64.0, -24.0, 67))
        self.assertEquals([s.ShapeID() for s in shapes],
                          [1000, 1001, 1002, 1003, 1004, 1005, 1027])


class TestPostGISShapestoreArc(PostGISStaticTests):

    """Tests for PostGISShapeStore objects with MULTILINESTRING data"""

    def setUp(self):
        """Extend inherited method to set self.table to a PostGISShapeStore"""
        PostGISStaticTests.setUp(self)
        self.store = PostGISShapeStore(self.db, "roads")

    def test_shape_type(self):
        """Test PostGISShapeStore.ShapeType() with ARC shapes"""
        self.assertEquals(self.store.ShapeType(), SHAPETYPE_ARC)

    def test_num_shapes(self):
        """Test PostGISShapeStore.NumShapes() with ARC shapes"""
        self.assertEquals(self.store.NumShapes(), 839)

    def test_bounding_box(self):
        """Test PostGISShapeStore.BoundingBox() with ARC shapes"""
        self.assertFloatSeqEqual(self.store.BoundingBox(),
                                 [-24.450359344482422, 63.426830291748047,
                                  -13.55668830871582, 66.520111083984375])

    def test_shape_shapeid(self):
        """Test PostGISShapeStore.Shape(i).ShapeID() with ARC shapes"""
        self.assertEquals(self.store.Shape(5).ShapeID(), 5)

    def test_shape_points(self):
        """Test PostGISShapeStore.Shape(i).Points() with ARC shapes"""
        self.assertPointListEquals(self.store.Shape(32).Points(),
                                   [[(-15.0821743011474, 66.2773818969726),
                                     (-15.0263500213623, 66.2733917236328)]])

    def test_shape_raw_data(self):
        """Test PostGISShapeStore.Shape(i).RawData() with ARC shapes"""
        self.assertEquals(self.store.Shape(32).RawData(),
                     "MULTILINESTRING((-15.0821743011475 66.2773818969727,"
                                      "-15.0263500213623 66.2733917236328))")

    def test_shapes_in_region(self):
        """Test PostGISShapeStore.ShapesInRegion() with ARC shapes"""
        shapes = self.store.ShapesInRegion((-24.0, 64.5, -23.5, 65.0))
        self.assertEquals([s.ShapeID() for s in shapes], [573, 581, 613])



class PolygonTests:

    """Test shared by the PLYGON and MULTIPOLYGON tests

    The tests are the same because they are based on the same data.
    """

    def test_shape_type(self):
        """Test PostGISShapeStore.ShapeType() with POLYGON shapes"""
        self.assertEquals(self.store.ShapeType(), SHAPETYPE_POLYGON)

    def test_num_shapes(self):
        """Test PostGISShapeStore.NumShapes() with POLYGON shapes"""
        self.assertEquals(self.store.NumShapes(), 156)

    def test_bounding_box(self):
        """Test PostGISShapeStore.BoundingBox() with POLYGON shapes"""
        self.assertFloatSeqEqual(self.store.BoundingBox(),
                                 [-24.546524047851562, 63.286754608154297,
                                  -13.495815277099609, 66.563774108886719])

    def test_shape_shapeid(self):
        """Test PostGISShapeStore.Shape(i).ShapeID() with POLYGON shapes"""
        self.assertEquals(self.store.Shape(5).ShapeID(), 5)

    def test_shape_points(self):
        """Test PostGISShapeStore.Shape(i).Points() with POLYGON shapes"""
        self.assertPointListEquals(self.store.Shape(4).Points(),
                                   [[(-22.40639114379882, 64.714111328125),
                                     (-22.41621208190918, 64.716003417968),
                                     (-22.40605163574218, 64.719200134277),
                                     (-22.40639114379882, 64.714111328125)]])

    def test_shapes_in_region(self):
        """Test PostGISShapeStore.ShapesInRegion() with POLYGON shapes"""
        shapes = self.store.ShapesInRegion((-23.0, 65.5, -22.8, 65.25))
        self.assertEquals([s.ShapeID() for s in shapes],
                          [47, 56, 59, 61, 62, 71, 144])


class TestPostGISShapestorePolygon(PolygonTests, PostGISStaticTests):

    """Tests for PostGISShapeStore objects with POLYGON data"""

    def setUp(self):
        """Extend inherited method to set self.table to a PostGISShapeStore"""
        PostGISStaticTests.setUp(self)
        self.store = PostGISShapeStore(self.db, "political")

    def test_shape_type(self):
        """Test PostGISShapeStore.ShapeType() with POLYGON shapes"""
        self.assertEquals(self.store.ShapeType(), SHAPETYPE_POLYGON)


    def test_shape_raw_data(self):
        """Test PostGISShapeStore.Shape(i).RawData() with POLYGON shapes"""
        self.assertEquals(self.store.Shape(4).RawData(),
                          "POLYGON((-22.4063911437988 64.714111328125,"
                                   "-22.4162120819092 64.7160034179688,"
                                   "-22.4060516357422 64.7192001342773,"
                                   "-22.4063911437988 64.714111328125))")


class TestPostGISShapestoreMultiPolygon(PolygonTests, PostGISStaticTests):

    """Tests for PostGISShapeStore objects with MUTLIPOLYGON data"""

    def setUp(self):
        """Extend inherited method to set self.table to a PostGISShapeStore"""
        PostGISStaticTests.setUp(self)
        self.store = PostGISShapeStore(self.db, "political_multi")

    def test_shape_raw_data(self):
        """Test PostGISShapeStore.Shape(i).RawData() with POLYGON shapes"""
        self.assertEquals(self.store.Shape(4).RawData(),
                          "MULTIPOLYGON(((-22.4063911437988 64.714111328125,"
                                         "-22.4162120819092 64.7160034179688,"
                                         "-22.4060516357422 64.7192001342773,"
                                        "-22.4063911437988 64.714111328125)))")



if __name__ == "__main__":
    support.run_tests()
