# Copyright (c) 2001, 2002 by Intevation GmbH
# Authors:
# Bernhard Herzog <bh@intevation.de>
#
# This program is free software under the GPL (>=v2)
# Read the file COPYING coming with Thuban for details.

"""
Classes for display of a map and interaction with it
"""

__version__ = "$Revision: 1.7 $"

from math import hypot

from wxPython.wx import wxWindow,\
     wxPaintDC, wxColour, wxClientDC, wxINVERT, wxTRANSPARENT_BRUSH, wxFont,\
     EVT_PAINT, EVT_LEFT_DOWN, EVT_LEFT_UP, EVT_MOTION


from wxPython import wx

from wxproj import point_in_polygon_shape, shape_centroid


from Thuban.Model.messages import MAP_PROJECTION_CHANGED, \
     LAYERS_CHANGED, LAYER_LEGEND_CHANGED, LAYER_VISIBILITY_CHANGED
from Thuban.Model.layer import SHAPETYPE_POLYGON, SHAPETYPE_ARC, \
     SHAPETYPE_POINT
from Thuban.Model.label import ALIGN_CENTER, ALIGN_TOP, ALIGN_BOTTOM, \
     ALIGN_LEFT, ALIGN_RIGHT


from renderer import ScreenRenderer, PrinterRender

import labeldialog

from messages import SELECTED_SHAPE


#
#   The tools
#

class Tool:

    """
    Base class for the interactive tools
    """

    def __init__(self, view):
        """Intitialize the tool. The view is the canvas displaying the map"""
        self.view = view
        self.start = self.current = None
        self.dragging = 0
        self.drawn = 0

    def Name(self):
        """Return the tool's name"""
        return ''

    def drag_start(self, x, y):
        self.start = self.current = x, y
        self.dragging = 1

    def drag_move(self, x, y):
        self.current = x, y

    def drag_stop(self, x, y):
        self.current = x, y
        self.dragging = 0

    def Show(self, dc):
        if not self.drawn:
            self.draw(dc)
        self.drawn = 1

    def Hide(self, dc):
        if self.drawn:
            self.draw(dc)
        self.drawn = 0

    def draw(self, dc):
        pass

    def MouseDown(self, event):
        self.drag_start(event.m_x, event.m_y)

    def MouseMove(self, event):
        if self.dragging:
            self.drag_move(event.m_x, event.m_y)

    def MouseUp(self, event):
        if self.dragging:
            self.drag_move(event.m_x, event.m_y)

    def Cancel(self):
        self.dragging = 0


class RectTool(Tool):

    """Base class for tools that draw rectangles while dragging"""

    def draw(self, dc):
        sx, sy = self.start
        cx, cy = self.current
        dc.DrawRectangle(sx, sy, cx - sx, cy - sy)

class ZoomInTool(RectTool):

    """The Zoom-In Tool"""

    def Name(self):
        return "ZoomInTool"

    def proj_rect(self):
        """return the rectangle given by start and current in projected
        coordinates"""
        sx, sy = self.start
        cx, cy = self.current
        left, top = self.view.win_to_proj(sx, sy)
        right, bottom = self.view.win_to_proj(cx, cy)
        return (min(left, right), min(top, bottom),
                max(left, right), max(top, bottom))

    def MouseUp(self, event):
        if self.dragging:
            Tool.MouseUp(self, event)
            sx, sy = self.start
            cx, cy = self.current
            if sx == cx and sy == cy:
                # Just a mouse click. Simply zoom in by a factor of two
                self.view.ZoomFactor(2, center = (cx, cy))
            else:
                # A drag. Zoom in to the rectangle
                self.view.FitRectToWindow(self.proj_rect())


class ZoomOutTool(RectTool):

    """The Zoom-Out Tool"""
    
    def Name(self):
        return "ZoomOutTool"

    def MouseUp(self, event):
        if self.dragging:
            Tool.MouseUp(self, event)
            sx, sy = self.start
            cx, cy = self.current
            if sx == cx and sy == cy:
                # Just a mouse click. Simply zoom out by a factor of two
                self.view.ZoomFactor(0.5, center = (cy, cy))
            else:
                # A drag. Zoom out to the rectangle
                self.view.ZoomOutToRect((min(sx, cx), min(sy, cy),
                                         max(sx, cx), max(sy, cy)))


class PanTool(Tool):

    """The Pan Tool"""

    def Name(self):
        return "PanTool"

    def MouseMove(self, event):
        if self.dragging:
            x0, y0 = self.current
            Tool.MouseMove(self, event)
            x, y = self.current
            width, height = self.view.GetSizeTuple()
            dc = self.view.drag_dc
            dc.Blit(0, 0, width, height, dc, x0 - x, y0 - y)

    def MouseUp(self, event):
        if self.dragging:
            Tool.MouseUp(self, event)
            sx, sy = self.start
            cx, cy = self.current
            self.view.Translate(cx - sx, cy - sy)
        
class IdentifyTool(Tool):

    """The "Identify" Tool"""
    
    def Name(self):
        return "IdentifyTool"

    def MouseUp(self, event):
        self.view.SelectShapeAt(event.m_x, event.m_y)


class LabelTool(Tool):

    """The "Label" Tool"""

    def Name(self):
        return "LabelTool"

    def MouseUp(self, event):
        self.view.LabelShapeAt(event.m_x, event.m_y)




class MapPrintout(wx.wxPrintout):

    """
    wxPrintout class for printing Thuban maps
    """

    def __init__(self, map):
        wx.wxPrintout.__init__(self)
        self.map = map

    def GetPageInfo(self):
        return (1, 1, 1, 1)

    def HasPage(self, pagenum):
        return pagenum == 1

    def OnPrintPage(self, pagenum):
        if pagenum == 1:
            self.draw_on_dc(self.GetDC())

    def draw_on_dc(self, dc):
        width, height = self.GetPageSizePixels()
        llx, lly, urx, ury = self.map.ProjectedBoundingBox()
        scalex = width / (urx - llx)
        scaley = height / (ury - lly)
        scale = min(scalex, scaley)
        offx = 0.5 * (width - (urx + llx) * scale)
        offy = 0.5 * (height + (ury + lly) * scale)

        resx, resy = self.GetPPIPrinter()
        renderer = PrinterRender(dc, scale, (offx, offy), resolution = resx)
        renderer.RenderMap(self.map)
        return wx.true
        

class MapCanvas(wxWindow):

    """A widget that displays a map and offers some interaction"""

    def __init__(self, parent, winid, interactor):
        wxWindow.__init__(self, parent, winid)
        self.SetBackgroundColour(wxColour(255, 255, 255))
        self.map = None
        self.scale = 1.0
        self.offset = (0, 0)
        self.dragging = 0
        self.tool = None
        self.redraw_on_idle = 0
        EVT_PAINT(self, self.OnPaint)
        EVT_LEFT_DOWN(self, self.OnLeftDown)
        EVT_LEFT_UP(self, self.OnLeftUp)
        EVT_MOTION(self, self.OnMotion)
        wx.EVT_IDLE(self, self.OnIdle)
        self.interactor = interactor
        self.interactor.Subscribe(SELECTED_SHAPE, self.shape_selected)

    def OnPaint(self, event):
        dc = wxPaintDC(self)
        if self.map is not None and self.map.HasLayers():
            # We have a non-empty map. Redraw it in idle time
            self.redraw_on_idle = 1
        else:
            # If we've got no map or if the map is empty, simply clear
            # the screen.
            
            # XXX it's probably possible to get rid of this. The
            # background color of the window is already white and the
            # only thing we may have to do is to call self.Refresh()
            # with a true argument in the right places.
            dc.BeginDrawing()
            dc.Clear()            
            dc.EndDrawing()

    def do_redraw(self):
        # This should only be called if we have a non-empty map. We draw
        # it into a memory DC and then blit it to the screen.
        width, height = self.GetSizeTuple()
        bitmap = wx.wxEmptyBitmap(width, height)
        dc = wx.wxMemoryDC()
        dc.SelectObject(bitmap)
        dc.BeginDrawing()

        # clear the background
        dc.SetBrush(wx.wxWHITE_BRUSH)
        dc.SetPen(wx.wxTRANSPARENT_PEN)
        dc.DrawRectangle(0, 0, width, height)

        if 1: #self.interactor.selected_map is self.map:
            selected_layer = self.interactor.selected_layer
            selected_shape = self.interactor.selected_shape
        else:
            selected_layer = None
            selected_shape = None

        # draw the map into the bitmap
        renderer = ScreenRenderer(dc, self.scale, self.offset)
        renderer.RenderMap(self.map, selected_layer, selected_shape)

        dc.EndDrawing()

        # blit the bitmap to the screen
        clientdc = wxClientDC(self)
        clientdc.BeginDrawing()
        clientdc.Blit(0, 0, width, height, dc, 0, 0)
        clientdc.EndDrawing()

    def Print(self):
        printer = wx.wxPrinter()
        printout = MapPrintout(self.map)
        printer.Print(self, printout, wx.true)
        printout.Destroy()
        
    def SetMap(self, map):
        redraw_channels = (LAYERS_CHANGED, LAYER_LEGEND_CHANGED,
                           LAYER_VISIBILITY_CHANGED)
        if self.map is not None:
            for channel in redraw_channels:
                self.map.Unsubscribe(channel, self.redraw)
            self.map.Unsubscribe(MAP_PROJECTION_CHANGED,
                                 self.projection_changed)
        self.map = map
        if self.map is not None:
            for channel in redraw_channels:
                self.map.Subscribe(channel, self.redraw)
            self.map.Subscribe(MAP_PROJECTION_CHANGED, self.projection_changed)
        self.FitMapToWindow()
        # force a redraw. If map is not empty, it's already been called
        # by FitMapToWindow but if map is empty it hasn't been called
        # yet so we have to explicitly call it.
        self.redraw()

    def Map(self):
        return self.map

    def redraw(self, *args):
        self.Refresh(0)

    def projection_changed(self, *args):
        self.FitMapToWindow()
        self.redraw()

    def set_view_transform(self, scale, offset):
        self.scale = scale
        self.offset = offset
        self.redraw()

    def proj_to_win(self, x, y):
        """\
        Return the point in  window coords given by projected coordinates x y
        """
        offx, offy = self.offset
        return (self.scale * x + offx, -self.scale * y + offy)

    def win_to_proj(self, x, y):
        """\
        Return the point in projected coordinates given by window coords x y
        """
        offx, offy = self.offset
        return ((x - offx) / self.scale, (offy - y) / self.scale)

    def FitRectToWindow(self, rect):
        width, height = self.GetSizeTuple()
        llx, lly, urx, ury = rect
        if llx == urx or lly == ury:
            # zero with or zero height. Do Nothing
            return
        scalex = width / (urx - llx)
        scaley = height / (ury - lly)
        scale = min(scalex, scaley)
        offx = 0.5 * (width - (urx + llx) * scale)
        offy = 0.5 * (height + (ury + lly) * scale)
        self.set_view_transform(scale, (offx, offy))

    def FitMapToWindow(self):
        """\
        Set the scale and offset so that the map is centered in the
        window
        """
        bbox = self.map.ProjectedBoundingBox()
        if bbox is not None:
            self.FitRectToWindow(bbox)

    def ZoomFactor(self, factor, center = None):
        """Multiply the zoom by factor and center on center.

        The optional parameter center is a point in window coordinates
        that should be centered. If it is omitted, it defaults to the
        center of the window
        """
        width, height = self.GetSizeTuple()
        scale = self.scale * factor
        offx, offy = self.offset
        if center is not None:
            cx, cy = center
        else:
            cx = width / 2
            cy = height / 2
        offset = (factor * (offx - cx) + width / 2,
                  factor * (offy - cy) + height / 2)
        self.set_view_transform(scale, offset)

    def ZoomOutToRect(self, rect):
        # rect is given in window coordinates

        # determine the bbox of the displayed region in projected
        # coordinates
        width, height = self.GetSizeTuple()
        llx, lly = self.win_to_proj(0, height - 1)
        urx, ury = self.win_to_proj(width - 1, 0)

        sx, sy, ex, ey = rect
        scalex = (ex - sx) / (urx - llx)
        scaley = (ey - sy) / (ury - lly)
        scale = min(scalex, scaley)

        offx = 0.5 * ((ex + sx) - (urx + llx) * scale)
        offy = 0.5 * ((ey + sy) + (ury + lly) * scale)
        self.set_view_transform(scale, (offx, offy))

    def Translate(self, dx, dy):
        offx, offy = self.offset
        self.set_view_transform(self.scale, (offx + dx, offy + dy))

    def ZoomInTool(self):
        self.tool = ZoomInTool(self)

    def ZoomOutTool(self):
        self.tool = ZoomOutTool(self)

    def PanTool(self):
        self.tool = PanTool(self)

    def IdentifyTool(self):
        self.tool = IdentifyTool(self)

    def LabelTool(self):
        self.tool = LabelTool(self)

    def CurrentTool(self):
        return self.tool and self.tool.Name() or None

    def OnLeftDown(self, event):
        if self.tool is not None:
            self.drag_dc = wxClientDC(self)
            self.drag_dc.SetLogicalFunction(wxINVERT)
            self.drag_dc.SetBrush(wxTRANSPARENT_BRUSH)
            self.CaptureMouse()
            self.tool.MouseDown(event)
            self.tool.Show(self.drag_dc)
            self.dragging = 1
        
    def OnLeftUp(self, event):
        self.ReleaseMouse()
        if self.dragging:
            self.tool.Hide(self.drag_dc)
            self.tool.MouseUp(event)
            self.drag_dc = None
        self.dragging = 0

    def OnMotion(self, event):
        if self.dragging:
            self.tool.Hide(self.drag_dc)
            self.tool.MouseMove(event)
            self.tool.Show(self.drag_dc)

    def OnIdle(self, event):
        if self.redraw_on_idle:
            self.do_redraw()
        self.redraw_on_idle = 0

    def shape_selected(self, layer, shape):
        self.redraw()

    def find_shape_at(self, px, py, select_labels = 0, selected_layer = 1):
        """Determine the shape at point px, py in window coords

        Return the shape and the corresponding layer as a tuple (layer,
        shape).

        If the optional parameter select_labels is true (default false)
        search through the labels. If a label is found return it's index
        as the shape and None as the layer.

        If the optional parameter selected_layer is true (default), only
        search in the currently selected layer.
        """
        map_proj = self.map.projection
        if map_proj is not None:
            forward = map_proj.Forward
        else:
            forward = None

        scale = self.scale
        offx, offy = self.offset

        if select_labels:
            labels = self.map.LabelLayer().Labels()
            
            if labels:
                dc = wxClientDC(self)
                font = wxFont(10, wx.wxSWISS, wx.wxNORMAL, wx.wxNORMAL)
                dc.SetFont(font)
                for i in range(len(labels) - 1, -1, -1):
                    label = labels[i]
                    x = label.x
                    y = label.y
                    text = label.text
                    if forward:
                        x, y = forward(x, y)
                    x = x * scale + offx
                    y = -y * scale + offy
                    width, height = dc.GetTextExtent(text)
                    if label.halign == ALIGN_LEFT:
                        # nothing to be done
                        pass
                    elif label.halign == ALIGN_RIGHT:
                        x = x - width
                    elif label.halign == ALIGN_CENTER:
                        x = x - width/2
                    if label.valign == ALIGN_TOP:
                        # nothing to be done
                        pass
                    elif label.valign == ALIGN_BOTTOM:
                        y = y - height
                    elif label.valign == ALIGN_CENTER:
                        y = y - height/2
                    if x <= px < x + width and y <= py <= y + height:
                        return None, i

        if selected_layer:
            layer = self.interactor.SelectedLayer()
            if layer is not None:
                layers = [layer]
            else:
                # no layer selected. Use an empty list to effectively
                # ignore all layers.
                layers = []
        else:
            layers = self.map.Layers()

        for layer_index in range(len(layers) - 1, -1, -1):
            layer = layers[layer_index]

            # search only in visible layers
            if not layer.Visible():
                continue

            filled = layer.fill is not None
            stroked = layer.stroke is not None
                
            layer_proj = layer.projection
            if layer_proj is not None:
                inverse = layer_proj.Inverse
            else:
                inverse = None
                
            shapetype = layer.ShapeType()

            select_shape = -1
            if shapetype == SHAPETYPE_POLYGON:
                for i in range(layer.NumShapes() - 1, -1, -1):
                    result = point_in_polygon_shape(layer.shapefile.cobject(),
                                                    i,
                                                    filled, stroked,
                                                    map_proj, layer_proj,
                                                    scale, -scale, offx, offy,
                                                    px, py)
                    if result:
                        select_shape = i
                        break
            elif shapetype == SHAPETYPE_ARC:
                for i in range(layer.NumShapes() - 1, -1, -1):
                    result = point_in_polygon_shape(layer.shapefile.cobject(),
                                                    i, 0, 1,
                                                    map_proj, layer_proj,
                                                    scale, -scale, offx, offy,
                                                    px, py)
                    if result < 0:
                        select_shape = i
                        break
            elif shapetype == SHAPETYPE_POINT:
                for i in range(layer.NumShapes() - 1, -1, -1):
                    shape = layer.Shape(i)
                    x, y = shape.Points()[0]
                    if inverse:
                        x, y = inverse(x, y)
                    if forward:
                        x, y = forward(x, y)
                    x = x * scale + offx
                    y = -y * scale + offy
                    if hypot(px - x, py - y) < 5:
                        select_shape = i
                        break

            if select_shape >= 0:
                return layer, select_shape
        return None, None

    def SelectShapeAt(self, x, y):
        layer, shape = self.find_shape_at(x, y, selected_layer = 0)
        # If layer is None, then shape will also be None. We don't want
        # to deselect the currently selected layer, so we simply select
        # the already selected layer again.
        if layer is None:
            layer = self.interactor.SelectedLayer()
        self.interactor.SelectLayerAndShape(layer, shape)

    def LabelShapeAt(self, x, y):
        ox = x; oy = y
        label_layer = self.map.LabelLayer()
        layer, shape_index = self.find_shape_at(x, y, select_labels = 1)
        if layer is None and shape_index is not None:
            # a label was selected
            label_layer.RemoveLabel(shape_index)
        elif layer is not None:
            text = labeldialog.run_label_dialog(self, layer.table, shape_index)
            if text:
                proj = self.map.projection
                if proj is not None:
                    map_proj = proj
                else:
                    map_proj = None
                proj = layer.projection
                if proj is not None:
                    layer_proj = proj
                else:
                    layer_proj = None

                shapetype = layer.ShapeType()
                if shapetype == SHAPETYPE_POLYGON:
                    x, y = shape_centroid(layer.shapefile.cobject(),
                                          shape_index,
                                          map_proj, layer_proj, 1, 1, 0, 0)
                    if map_proj is not None:
                        x, y = map_proj.Inverse(x, y)
                else:
                    shape = layer.Shape(shape_index)
                    if shapetype == SHAPETYPE_POINT:
                        x, y = shape.Points()[0]
                    else:
                        # assume SHAPETYPE_ARC
                        points = shape.Points()
                        x, y = points[len(points) / 2]
                    if layer_proj is not None:
                        x, y = layer_proj.Inverse(x, y)
                if shapetype == SHAPETYPE_POINT:
                    halign = ALIGN_LEFT
                    valign = ALIGN_CENTER
                elif shapetype == SHAPETYPE_POLYGON:
                    halign = ALIGN_CENTER
                    valign = ALIGN_CENTER
                elif shapetype == SHAPETYPE_ARC:
                    halign = ALIGN_LEFT
                    valign = ALIGN_CENTER
                label_layer.AddLabel(x, y, text,
                                     halign = halign, valign = valign)
