# Copyright (C) 2003 by Intevation GmbH
# Authors:
# Martin Mueller <mmueller@intevation.de>
# Bernhard Herzog <bh@intevation.de>
#
# This program is free software under the GPL (>=v2)
# Read the file COPYING coming with the software for details.

"""Basic interface to a PostGIS database"""

from __future__ import generators

try:
    import psycopg
except ImportError:
    psycopg = None

import table
import wellknowntext

from data import SHAPETYPE_POLYGON, SHAPETYPE_ARC, SHAPETYPE_POINT, RAW_WKT

def has_postgis_support():
    """Return whether this Thuban instance supports PostGIS connections

    Having PostGIS support means that the psycopg module can be
    imported.
    """
    return psycopg is not None

def psycopg_version():
    return psycopg.__version__

if psycopg is not None:
    type_map = [(psycopg.STRING, table.FIELDTYPE_STRING),
                (psycopg.INTEGER, table.FIELDTYPE_INT),
                (psycopg.FLOAT, table.FIELDTYPE_DOUBLE)]


class ConnectionError(Exception):

    """Class for exceptions occurring when establishing a Databse connection"""


class PostGISConnection:

    """Represent a PostGIS database

    A PostGISConnection instance has the following public attributes:

    dbname -- The name of the database
    host, port -- Host and port to connect to
    user -- The user name to connect as.

    All of these attributes are strings and may be empty strings to
    indicate default values.
    """

    def __init__(self, dbname, host="", user="", password="", dbtype="",
                 port=""):
        self.dbname = dbname
        self.host = host
        self.port = port
        self.user = user
        self.password = password
        self.dbtype = dbtype
        self.connect()

    def connect(self):
        """Internal: Establish the database connection"""
        params = []
        for name in ("host", "port", "dbname", "user", "password"):
            val = getattr(self, name)
            if val:
                params.append("%s=%s" % (name, val))
        try:
            self.connection = psycopg.connect(" ".join(params))
        except psycopg.OperationalError, val:
            raise ConnectionError(str(val))

        # determine the OID for the geometry type. This is PostGIS
        # specific.
        cursor = self.connection.cursor()
        cursor.execute("SELECT OID, typname FROM pg_type WHERE"
                       +" typname = 'geometry'")
        row = cursor.fetchone()
        self.connection.commit()
        if row is not None:
            self.geometry_type = row[0]
        else:
            raise ValueError("Can't determine postgres type of geometries")

    def BriefDescription(self):
        """Return a brief, one-line description of the connection

        The return value is suitable for a list box of all database
        connections.
        """
        return ("postgis://%(user)s@%(host)s:%(port)s/%(dbname)s"
                % self.__dict__)

    def Close(self):
        """Close the database connection"""
        self.connection.close()

    def GeometryTables(self):
        """Return a list with the names of all tables with a geometry column"""
        cursor = self.connection.cursor()
        cursor.execute("SELECT f_table_name FROM geometry_columns;")
        result = [row[0] for row in cursor.fetchall()]
        self.connection.commit()
        return result

    def cursor(self):
        """Return a DB API 2.0 cursor for the database"""
        return self.connection.cursor()



class PostGISColumn:

    """Column description for a PostGISTable"""

    def __init__(self, name, type, index):
        self.name = name
        self.type = type
        self.index = index


class PostGISTable:

    """A Table in a PostGIS database

    A PostgreSQL table may contain columns with types not (yet)
    supported by Thuban. Instances of this class ignore those columns
    and pretend they don't exist, i.e. they won't show up in the column
    descriptions returned by Columns() and other methods.
    """

    def __init__(self, db, tablename):
        """Initialize the PostGISTable.

        The db parameter should be an instance of PostGISConnection and
        tablename the name of a table in the database represented by db.
        """
        self.db = db
        self.tablename = tablename
        self.column_map = {}
        self._fetch_table_information()

    def _fetch_table_information(self):
        """Internal: Update information about the table"""
        self.columns = []
        cursor = self.db.cursor()
        cursor.execute("SELECT * FROM %s WHERE 0=1" % self.tablename)
        description = cursor.description

        for i in range(len(description)):
            for pgtyp, tabletyp in type_map:
                if pgtyp == description[i][1]:
                    col = PostGISColumn(description[i][0], tabletyp,
                                        len(self.columns))
                    break
            else:
                if description[i][1] == self.db.geometry_type:
                    self.geometry_column = description[i][0]
                # No matching table type. Ignore the column.
                # FIXME: We should at least print a warning about
                # ignored columns
                continue
            self.columns.append(col)

        for col in self.columns:
            self.column_map[col.name] = col
            self.column_map[col.index] = col

        # Build query string for ReadRowAsDict
        self.query_stmt = ("SELECT %s from %s"
                           % (", ".join([col.name for col in self.columns]),
                              self.tablename))

    def DBConnection(self):
        """Return the dbconnection used by the table"""
        return self.db

    def TableName(self):
        """Return the name of the table in the database"""
        return self.tablename

    def Title(self):
        """Return the title of the table.

        The title is currently fixed and equal to the tablename
        """
        return self.tablename

    def Dependencies(self):
        """Return an empty tuple because a PostGISTable depends on nothing else
        """
        return ()

    def NumColumns(self):
        return len(self.columns)

    def Columns(self):
        return self.columns

    def Column(self, col):
        return self.column_map[col]

    def HasColumn(self, col):
        return self.column_map.has_key(col)

    def NumRows(self):
        cursor = self.db.cursor()
        cursor.execute("SELECT count(*) FROM %s" % self.tablename)
        return cursor.fetchone()[0]

    def RowIdToOrdinal(self, gid):
        """Return the row ordinal given its id"""
        cursor = self.db.cursor()
        cursor.execute("SELECT count(*) FROM %s WHERE gid < %d;"
                       % (self.tablename, gid))
        return cursor.fetchone()[0]

    def RowOrdinalToId(self, num):
        """Return the rowid for given its ordinal"""
        cursor = self.db.cursor()
        cursor.execute("SELECT gid FROM %s LIMIT 1 OFFSET %d;"
                       % (self.tablename, num))
        return cursor.fetchone()[0]

    def ReadRowAsDict(self, row, row_is_ordinal = 0):
        cursor = self.db.cursor()
        if row_is_ordinal:
            stmt = self.query_stmt + " LIMIT 1 OFFSET %d" % row
        else:
            stmt = self.query_stmt + " WHERE gid = %d" % row
        cursor.execute(stmt)
        result = {}
        for col, value in zip(self.columns, cursor.fetchone()):
            result[col.name] = value
        return result

    def ReadValue(self, row, col, row_is_ordinal = 0):
        cursor = self.db.cursor()
        if row_is_ordinal:
            stmt = ("SELECT %s FROM %s LIMIT 1 OFFSET %d" %
                    (self.column_map[col].name, self.tablename, row))
        else:
            stmt = ("SELECT %s FROM %s WHERE gid = %d" %
                    (self.column_map[col].name, self.tablename, row))
        cursor.execute(stmt)
        return cursor.fetchone()[0]

    def ValueRange(self, col):
        cursor = self.db.cursor()
        name = self.column_map[col].name
        cursor.execute("SELECT min(%s), max(%s) FROM %s" %
                       (name, name, self.tablename))
        return tuple(cursor.fetchone())

    def UniqueValues(self, col):
        cursor = self.db.cursor()
        name = self.column_map[col].name
        cursor.execute("SELECT %s FROM %s GROUP BY %s" %
                       (name, self.tablename, name))
        return [row[0] for row in cursor.fetchall()]

    def SimpleQuery(self, left, comparison, right):
        if comparison not in ("==", "!=", "<", "<=", ">=", ">"):
            raise ValueError("Comparison operator %r not allowed" % comparison)

        if comparison == "==":
            comparison = "="

        if isinstance(right, PostGISColumn):
            right_template = right.name
            params = ()
        else:
            right_template = "%s"
            params = (right,)

        query = "SELECT gid FROM %s WHERE %s %s %s ORDER BY gid;" \
                % (self.tablename, left.name, comparison, right_template)

        cursor = self.db.cursor()
        cursor.execute(query, params)
        result = []
        while 1:
            row = cursor.fetchone()
            if row is None:
                break
            result.append(row[0])
        return result


class PostGISShape:

    def __init__(self, shapeid, data):
        self.shapeid = shapeid
        self.data = data

    def compute_bbox(self):
        """
        Return the bounding box of the shape as a tuple (minx,miny,maxx,maxy)
        """
        xs = []
        ys = []
        for part in self.Points():
            for x, y in part:
                xs.append(x)
                ys.append(y)
        return (min(xs), min(ys), max(xs), max(ys))

    def ShapeID(self):
        return self.shapeid

    def Points(self):
        return wellknowntext.parse_wkt_thuban(self.data)

    def RawData(self):
        return self.data


shapetype_map = {"POLYGON": SHAPETYPE_POLYGON,
                 "MULTIPOLYGON": SHAPETYPE_POLYGON,
                 "MULTILINESTRING": SHAPETYPE_ARC,
                 "POINT": SHAPETYPE_POINT}


class PostGISShapeStore(PostGISTable):

    """Shapestore interface to a table in a PostGIS database"""

    def Table(self):
        """Return self since a PostGISShapeStore is its own table."""
        return self

    def OrigShapeStore(self):
        """Return None since the PostGISShapeStore is not derived from another
        """
        return None

    def ShapeType(self):
        """Return the type of the shapes in the shapestore."""
        cursor = self.db.cursor()
        cursor.execute("SELECT type FROM geometry_columns WHERE"
                       " f_table_name=%s", (self.tablename,))
        result = cursor.fetchone()[0]
        cursor.close()
        return shapetype_map[result]

    def RawShapeFormat(self):
        """Return the raw data format of the shape data.

        For the PostGISShapeStore this is RAW_WKT.
        """
        return RAW_WKT

    def NumShapes(self):
        # The number of shapes is the same as the number of rows,
        # assuming that the geometry can't be NULL.
        return self.NumRows()

    def BoundingBox(self):
        """Return the bounding box of all shapes in the postgis table"""
        minx = miny = maxx = maxy = None
        x=[]
        y=[]
        cursor = self.db.cursor()
        try:
            # Using the extent function is postgis specific. An OGC
            # Simple Features compliant solution would be to use a query
            # like "SELECT AsText(Envelope(the_geom)) FROM mytable;" and
            # calculate the bounding box by hand from that
            cursor.execute("SELECT extent(%s) FROM %s;"
                           % (self.geometry_column, self.tablename))
            result = cursor.fetchone()
            if result:
                (minx, miny), (maxx, maxy) \
                      = wellknowntext.parse_wkt_thuban(result[0])[0]
                return (minx, miny, maxx, maxy)
        finally:
            cursor.close()

    def Shape(self, shapeid):
        cursor = self.db.cursor()
        cursor.execute("SELECT AsText(%s) FROM %s WHERE gid=%d"
                       % (self.geometry_column, self.tablename, shapeid))
        wkt = cursor.fetchone()[0]
        cursor.close()
        return PostGISShape(shapeid, wkt)

    def AllShapes(self):
        cursor = self.db.cursor()
        cursor.execute("SELECT gid, AsText(%s) FROM %s ORDER BY gid"
                       % (self.geometry_column, self.tablename))
        while 1:
            result = cursor.fetchone()
            if result is None:
                return
            yield PostGISShape(result[0], result[1])


    def ShapesInRegion(self, bbox):
        """Generate all shapes overlapping the region given by bbox."""
        # IMPORTANT:This will work for PostGIS < 0.8
        left, bottom, right, top = bbox
        geom = (("POLYGON((" + ", ".join(["%f %f"] * 5) + "))")
                % (left, bottom, left, top, right, top, right, bottom,
                   left, bottom))
        cursor = self.db.cursor()
        cursor.execute("SELECT gid, AsText(%s) FROM %s"
                     " WHERE %s && GeometryFromText('%s', -1) ORDER BY gid"
                       % (self.geometry_column, self.tablename,
                          self.geometry_column, geom))
        while 1:
            result = cursor.fetchone()
            if result is None:
                return
            yield PostGISShape(result[0], result[1])
