# Copyright (C) 2003 by Intevation GmbH
# Authors:
# Bernhard Herzog <bh@intevation.de>
#
# This program is free software under the GPL (>=v2)
# Read the file COPYING coming with the software for details.

"""Support module for tests that use a live PostGIS database"""

__version__ = "$Revision: 1.5 $"
# $Source: /thubanrepository/thuban/test/postgissupport.py,v $
# $Id: postgissupport.py,v 1.5 2003/08/28 14:23:08 bh Exp $

import sys
import os
import time
import popen2
import shutil
import traceback

import support

try:
    import psycopg
except ImportError:
    psycopg = None

#
#       Helper code
#

def run_config_script(cmdline):
    """Run command cmdline and return its stdout or none in case of errors"""
    pipe = os.popen(cmdline)
    result = pipe.read()
    if pipe.close() is not None:
        raise RuntimeError('Command %r failed' % cmdline)
    return result

def run_command(command, outfilename = None):
    """Run command as a subprocess and send its stdout and stderr to outfile

    The subprocess is run synchroneously so the function returns once
    the subprocess has termninated. If the process' exit code is not
    zero raise a RuntimeError.

    If outfilename is None stdout and stderr are still captured but they
    are ignored and not written to any file.
    """
    proc = popen2.Popen4(command)
    proc.tochild.close()
    output = proc.fromchild.read()
    status = proc.wait()
    if outfilename is not None:
        outfile = open(outfilename, "w")
        outfile.write(output)
        outfile.close()
    if not os.WIFEXITED(status) or os.WEXITSTATUS(status) != 0:
        if outfilename:
            message = "see %s" % outfilename
        else:
            message = output
        raise RuntimeError("command %r exited with code %d.\n%s"
                           % (command, status, message))


def run_boolean_command(command):
    """
    Run command as a subprocess silently and return whether it ran successfully

    Silently means that all output is captured and ignored. The exit
    status is true if the command ran successfull, i.e. it terminated by
    exiting and returned as zero exit code and false other wise
    """
    try:
        run_command(command, None)
        return 1
    except RuntimeError:
        pass
    return 0


#
#       PostgreSQL and database
#

class PostgreSQLServer:

    """A PostgreSQL server

    Instances of this class represent a PostgreSQL server with postgis
    extensions run explicitly for the test cases. Such a server has its
    own database directory and its own directory for the unix sockets so
    that it doesn't interfere with any other PostgreSQL server already
    running on the system.
    """

    def __init__(self, dbdir, port, postgis_sql, socket_dir):
        """Initialize the PostgreSQLServer object

        Parameters:

          dbdir -- The directory for the databases
          port -- The port to use
          postgis_sql -- The name of the file with the SQL statements to
                         initialize a database for postgis.
          socket_dir -- The directory for the socket files.

        When connecting to the database server use the port and host
        instance variables.
        """
        self.dbdir = dbdir
        self.port = port
        self.postgis_sql = postgis_sql
        self.socket_dir = socket_dir

        # For the client side the socket directory can be used as the
        # host if the name starts with a slash.
        self.host = os.path.abspath(socket_dir)

        # name and password for the admin and an unprivileged user
        self.admin_name = "postgres"
        self.admin_password = "postgres"
        self.user_name = "observer"
        self.user_password = "telescope"

        # Map db names to db objects
        self.known_dbs = {}

    def createdb(self):
        """Create the database in dbdir and start the server.

        First check whether the dbdir already exists and if necessary
        stop an already running postmaster and remove the dbdir
        directory completely. Then create a new database cluster in the
        dbdir and start a postmaster.
        """
        if os.path.isdir(self.dbdir):
            if self.is_running():
                self.shutdown()
            shutil.rmtree(self.dbdir)
        os.mkdir(self.dbdir)

        run_command(["initdb", "-D", self.dbdir, "-U", self.admin_name],
                    os.path.join(self.dbdir, "initdb.log"))

        extra_opts = "-p %d" % self.port
        if self.socket_dir is not None:
            extra_opts += " -k %s" % self.socket_dir
        run_command(["pg_ctl", "-D", self.dbdir,
                     "-l", os.path.join(self.dbdir, "logfile"),
                     "-o", extra_opts, "start"],
                    os.path.join(self.dbdir, "pg_ctl-start.log"))
        # the -w option of pg_ctl doesn't work properly when the port is
        # not the default port, so we have to implement waiting for the
        # server ourselves
        self.wait_for_postmaster()

        self.alter_user(self.admin_name, self.admin_password)
        self.create_user(self.user_name, self.user_password)

    def wait_for_postmaster(self):
        """Return when the database server is running

        Internal method to wait until the postmaster process has been
        started and is ready for client connections.
        """
        max_count = 60
        count = 0
        while count < max_count:
            try:
                run_command(["psql", "-l", "-p", str(self.port),
                             "-h", self.host, "-U", self.admin_name],
                            os.path.join(self.dbdir, "psql-%d.log" % count))
            except RuntimeError:
                pass
            except:
                traceback.print_exc()
            else:
                break
            time.sleep(0.5)
            count += 1
        else:
            raise RuntimeError("postmaster didn't start")

    def is_running(self):
        """Return true a postmaster process is running on self.dbdir

        This method runs pg_ctl status on the dbdir so even if the
        object has just been created it is possible that this method
        returns true if there's still a postmaster process running for
        self.dbdir.
        """
        return run_boolean_command(["pg_ctl", "-D", self.dbdir, "status"])

    def shutdown(self):
        """Stop the postmaster running for self.dbdir"""
        run_command(["pg_ctl", "-m", "fast", "-D", self.dbdir, "stop"],
                    os.path.join(self.dbdir, "pg_ctl-stop.log"))

    def new_postgis_db(self, dbname, tables = None):
        """Create and return a new PostGISDatabase object using self as server
        """
        db = PostGISDatabase(self, self.postgis_sql, dbname, tables = tables)
        db.initdb()
        self.known_dbs[dbname] = db
        return db

    def get_static_data_db(self, dbname, tables):
        """Return a PostGISDatabase for a database with the given static data

        If no databasse of the name dbname exists, create a new one via
        new_postgis_db and upload the data.

        If a database of the name dbname already exists and uses the
        indicated data, return that. If the already existing db uses
        different data raise a value error.

        The tables argument should be a sequence of table specifications
        where each specifications is a (tablename, shapefilename) pair.
        """
        db = self.known_dbs.get(dbname)
        if db is not None:
            if db.has_data(tables):
                return db
            raise ValueError("PostGISDatabase named %r doesn't have tables %r"
                             % (dbname, tables))
        return self.new_postgis_db(dbname, tables)

    def get_default_static_data_db(self):
        dbname = "PostGISStaticTests"
        tables = [
            # Direct copies of the shapefiles. The shapeids are exactly
            # the same.
            ("landmarks", os.path.join("..", "Data", "iceland",
                                       "cultural_landmark-point.shp"),
             [("gid_offset", 1000)]),
            ("political", os.path.join("..", "Data", "iceland",
                                             "political.shp")),
            ("roads", os.path.join("..", "Data", "iceland",
                                         "roads-line.shp")),

            # The polygon data as a MULTIPOLYGON geometry type
            ("political_multi", os.path.join("..", "Data", "iceland",
                                             "political.shp"),
             [("force_wkt_type", "MULTIPOLYGON")]),
            ]
        return self.get_static_data_db(dbname, tables)

    def connection_params(self, user):
        """Return the connection parameters for the given user

        The return value is a dictionary suitable as keyword argument
        list to PostGISConnection. The user parameter may be either
        'admin' to connect as admin or 'user' to connect as an
        unprivileged user.
        """
        return {"host": self.host, "port": self.port,
                "user": getattr(self, user + "_name"),
                "password": getattr(self, user + "_password")}

    def connection_string(self, user):
        """Return (part of) the connection string to pass to psycopg.connect

        The string contains host, port, user and password. The user
        parameter must be either 'admin' or 'user', as for
        connection_params.
        """
        params = []
        for key, value in self.connection_params(user).items():
            # FIXME: this doesn't do quiting correctly but that
            # shouldn't be much of a problem (people shouldn't be using
            # single quotes in filenames anyway :) )
            params.append("%s='%s'" % (key, value))
        return " ".join(params)

    def execute_sql(self, dbname, user, sql):
        """Execute the sql statament

        The user parameter us used as in connection_params. The dbname
        parameter must be the name of a database in the cluster.
        """
        conn = psycopg.connect("dbname=%s " % dbname
                               + self.connection_string(user))
        cursor = conn.cursor()
        cursor.execute(sql)
        conn.commit()
        conn.close()

    def require_authentication(self, required):
        """Switch authentication requirements on or off

        When started for the first time no passwords are required. Some
        tests want to explicitly test whether Thuban's password
        infrastructure works and switch password authentication on
        explicitly. When switching it on, there should be a
        corresponding call to switch it off again in the test case'
        tearDown method or in a finally: block.
        """
        if required:
            contents = "local all password\n"
        else:
            contents = "local all trust\n"
        f = open(os.path.join(self.dbdir, "pg_hba.conf"), "w")
        f.write(contents)
        f.close()
        run_command(["pg_ctl", "-D", self.dbdir, "reload"],
                    os.path.join(self.dbdir, "pg_ctl-reload.log"))


    def create_user(self, username, password):
        """Create user username with password in the database"""
        self.execute_sql("template1", "admin",
                         "CREATE USER %s PASSWORD '%s';" % (username,password))

    def alter_user(self, username, password):
        """Change the user username's password in the database"""
        self.execute_sql("template1", "admin",
                         "ALTER USER %s PASSWORD '%s';" % (username,password))


class PostGISDatabase:

    """A PostGIS database in a PostgreSQLServer"""

    def __init__(self, server, postgis_sql, dbname, tables = None):
        """Initialize the PostGISDatabase

        Parameters:

            server -- The PostgreSQLServer instance containing the
                database

            postgis_sql -- Filename of the postgis.sql file with the
                postgis initialization code

            dbname -- The name of the database

            tables -- Optional description of tables to create in the
                new database. If given it should be a list of
                (tablename, shapefilename) pairs meaning that a table
                tablename will be created with the contents of the given
                shapefile or (tablename, shapefilename, extraargs)
                triples. The extraargs should be a list of key, value
                pairs to use as keyword arguments to upload_shapefile.
        """
        self.server = server
        self.postgis_sql = postgis_sql
        self.dbname = dbname
        self.tables = tables

    def initdb(self):
        """Remove the old db directory and create and initialize a new database
        """
        run_command(["createdb", "-p", str(self.server.port),
                     "-h", self.server.host, "-U", self.server.admin_name,
                     self.dbname],
                    os.path.join(self.server.dbdir, "createdb.log"))
        run_command(["createlang", "-p", str(self.server.port),
                     "-h", self.server.host,  "-U", self.server.admin_name,
                     "plpgsql", self.dbname],
                    os.path.join(self.server.dbdir, "createlang.log"))
        # for some reason psql doesn't exit with an error code if the
        # file given as -f doesn't exist, so we check manually by trying
        # to open it before we run psql
        f = open(self.postgis_sql)
        f.close()
        del f
        run_command(["psql", "-f", self.postgis_sql, "-d", self.dbname,
                     "-p", str(self.server.port), "-h", self.server.host,
                     "-U", self.server.admin_name],
                     os.path.join(self.server.dbdir, "psql.log"))

        self.server.execute_sql(self.dbname, "admin",
                                "GRANT SELECT ON geometry_columns TO PUBLIC;")

        if self.tables is not None:
            def unpack(item):
                extra = {"force_wkt_type": None, "gid_offset": 0}
                if len(info) == 2:
                    tablename, shapefile = info
                else:
                    tablename, shapefile, kw = info
                    for key, val in kw:
                        extra[key] = val
                return tablename, shapefile, extra

            for info in self.tables:
                tablename, shapefile, kw = unpack(info)
                upload_shapefile(shapefile, self, tablename, **kw)

    def has_data(self, tables):
        return self.tables == tables


def find_postgis_sql():
    """Return the name of the postgis_sql file

    A postgis installation usually has the postgis_sql file in
    PostgreSQL's datadir (i.e. the directory where PostgreSQL keeps
    static files, not the directory containing the databases).
    Unfortunately there's no way to determine the name of this directory
    with pg_config so we assume here that it's
    $bindir/../share/postgresql/.
    """
    bindir = run_config_script("pg_config --bindir").strip()
    return os.path.join(bindir, "..", "share", "postgresql",
                        "contrib", "postgis.sql")

_postgres_server = None
def get_test_server():
    """Return the test database server object.

    If it doesn't exist yet, create it first.

    The server will use the directory postgis under the temp dir (as
    defined by support.create_temp_dir()) for the database cluster.
    Sockets will be created in tempdir.
    """
    global _postgres_server
    if _postgres_server is None:
        tempdir = support.create_temp_dir()
        dbdir = os.path.join(tempdir, "postgis")
        socket_dir = tempdir

        _postgres_server = PostgreSQLServer(dbdir, 6543, find_postgis_sql(),
                                            socket_dir = socket_dir)
        _postgres_server.createdb()

    return _postgres_server

def shutdown_test_server():
    """Shutdown the test server if it is running"""
    global _postgres_server
    if _postgres_server is not None:
        _postgres_server.shutdown()
        _postgres_server = None


def reason_for_not_running_tests():
    """
    Determine whether postgis tests can be run and return a reason they can't

    There's no fool-proof way to reliably determine this short of
    actually running the tests but we try the following here:

     - test whether pg_ctl --help can be run successfully
     - test whether the postgis_sql can be opened
       The name of the postgis_sql file is determined by find_postgis_sql()
     - psycopg can be imported successfully.
    """
    # run_command currently uses Popen4 which is not available under
    # Windows, for example.
    if not hasattr(popen2, "Popen4"):
        return "Can't run PostGIS test because popen2.Popen4 does not exist"

    try:
        run_command(["pg_ctl", "--help"], None)
    except RuntimeError:
        return "Can't run PostGIS tests because pg_ctl fails"

    try:
        postgis_sql = find_postgis_sql()
    except:
        return "Can't run PostGIS tests because postgis.sql can't be found"

    try:
        f = open(postgis_sql)
        f.close()
    except:
        return "Can't run PostGIS tests because postgis.sql can't be opened"

    # The test for psycopg was already done when this module was
    # imported so we only have to check whether it was successful
    if psycopg is None:
        return "Can't run PostGIS tests because psycopg can't be imported"

    return ""


_cannot_run_postgis_tests = None
def skip_if_no_postgis():
    global _cannot_run_postgis_tests
    if _cannot_run_postgis_tests is None:
        _cannot_run_postgis_tests = reason_for_not_running_tests()
    if _cannot_run_postgis_tests:
        raise support.SkipTest(_cannot_run_postgis_tests)

def coords_to_point(coords):
    """Return string with a WKT representation of the point in coords"""
    x, y = coords[0]
    return "POINT(%r %r)" % (x, y)

def coords_to_polygon(coords):
    """Return string with a WKT representation of the polygon in coords"""
    poly = []
    for ring in coords:
        poly.append(", ".join(["%r %r" % p for p in ring]))
    return "POLYGON((%s))" % "), (".join(poly)

def coords_to_multilinestring(coords):
    """Return string with a WKT representation of the arc in coords"""
    poly = []
    for ring in coords:
        poly.append(", ".join(["%r %r" % p for p in ring]))
    return "MULTILINESTRING((%s))" % "), (".join(poly)

def coords_to_multipolygon(coords):
    """Return string with a WKT representation of the polygon in coords"""
    poly = []
    for ring in coords:
        poly.append(", ".join(["%r %r" % p for p in ring]))
    return "MULTIPOLYGON(((%s)))" % ")), ((".join(poly)

wkt_converter = {
    "POINT": coords_to_point,
    "MULTILINESTRING": coords_to_multilinestring,
    "POLYGON": coords_to_polygon,
    "MULTIPOLYGON": coords_to_multipolygon,
    }

def upload_shapefile(filename, db, tablename, force_wkt_type = None,
                     gid_offset = 0):
    """Upload a shapefile into a new database table

    Parameters:

    filename -- The name of the shapefile

    db -- The PostGISDatabase instance representing the database

    tablename -- The name of the table to create and into which the data
                is to be inserted

    force_wkt_type -- If given the real WKT geometry type to use instead
                of the default that would be chosen based on the type of
                the shapefile

    gid_offset -- A number to add to the shapeid to get the value for
                the gid column (default 0)
    """
    import dbflib, shapelib

    # We build this map here because we need shapelib which can only be
    # imported after support.initthuban has been called which we can't
    # easily do in this module because it's imported by support.
    shp_to_wkt = {
        shapelib.SHPT_POINT: "POINT",
        shapelib.SHPT_ARC: "MULTILINESTRING",
        shapelib.SHPT_POLYGON: "POLYGON",
        }

    server = db.server
    dbname = db.dbname
    conn = psycopg.connect("dbname=%s " % dbname
                           + db.server.connection_string("admin"))
    cursor = conn.cursor()

    shp = shapelib.ShapeFile(filename)
    dbf = dbflib.DBFFile(filename)
    typemap = {dbflib.FTString: "VARCHAR",
               dbflib.FTInteger: "INTEGER",
               dbflib.FTDouble: "DOUBLE PRECISION"}

    insert_formats = ["%(gid)s"]
    fields = ["gid INT"]
    for i in range(dbf.field_count()):
        ftype, name, width, prec = dbf.field_info(i)
        fields.append("%s %s" % (name, typemap[ftype]))
        insert_formats.append("%%(%s)s" % name)
    stmt = "CREATE TABLE %s (\n    %s\n);" % (tablename,
                                              ",\n    ".join(fields))
    cursor.execute(stmt)
    #print stmt

    numshapes, shapetype, mins, maxs = shp.info()
    wkttype =  shp_to_wkt[shapetype]
    if force_wkt_type:
        wkttype = force_wkt_type
    convert = wkt_converter[wkttype]

    cursor.execute("select AddGeometryColumn('%(dbname)s',"
                   "'%(tablename)s', 'the_geom', '-1', '%(wkttype)s', 2);"
                   % locals())

    insert_formats.append("GeometryFromText(%(the_geom)s, -1)")

    insert = ("INSERT INTO %s VALUES (%s)"
              % (tablename, ", ".join(insert_formats)))

    for i in range(numshapes):
        data = dbf.read_record(i)
        data["tablename"] = tablename
        data["gid"] = i + gid_offset
        data["the_geom"] = convert(shp.read_object(i).vertices())
        #print insert % data
        cursor.execute(insert, data)

    cursor.execute("GRANT SELECT ON %s TO PUBLIC;" % tablename)

    conn.commit()
