"""
Example solution for the afternoon Exercise of OOP.

written by Niko Wilbert for the G-Node Python Winter School 2010

This module contains all elements for an abstract graph representation.
"""
import pygraphviz as pgv
from io import BytesIO
from PIL import Image

class Node:
    """Base class for nodes.
    
    To add additional data attributes to nodes you can derive subclasses.
    """
    
    def __init__(self):
        """Initialize the node."""
        self._out_edges = []  # outgoing edges
        self._in_edges = []  # incoming edges
        
    out_edges = property(lambda obj: obj._out_edges)
    

class Edge:
    """Edge class.
    
    This class also registers itself with the affected nodes, so it sits
    between Node and Graph in the hierarchy. The head and tail can also be
    None.
    """
    
    def __init__(self, tail, head):
        """Initialize the edge.
        
        tail, head -- The nodes of this edge (tail -> head), can be None.
        """
        self._head = head
        if head is not None:
            head._in_edges.append(self)
        self._tail = tail
        if tail is not None:
            tail._out_edges.append(self)
        
    def _set_head(self, head):
        """Set the head and register this in the nodes as well.
        
        The head can also be None, then the edge is unregistered.
        """
        if self._head is not None:
            self._head._in_edges.remove(self)
        self._head = head
        if head is not None:
            self._head._in_edges.append(self)
            
    def _set_tail(self, tail):
        """Set the tail and register this in the nodes as well.
        
        The tail can also be None, then the edge is unregistered.
        """
        if self._tail is not None:
            self._tail._out_edges.remove(self)
        self._tail = tail
        if tail is not None:
            self._tail._out_edges.append(self)
    
    # use these public properties
    head = property(lambda obj: obj._head, _set_head)
    tail = property(lambda obj: obj._tail, _set_tail)

    def clear_nodes(self):
        """Clear the node references."""
        self._set_tail(None)
        self._set_head(None)


class GraphException(Exception):
    """Base Exception for Graph."""
    pass


class NoPathGraphException(GraphException):
    """Exception signaling that there is no path between nodes."""
    pass


class Graph:
    """Class to represent a complete graph with nodes and edges.
    
    Note that this class does not support nodes which do not belong to any
    edge (internally only edges are stored).
    """
    
    def __init__(self, edges=None):
        self._edges = set()
        if edges:
            for edge in edges:
                self.add_edge(edge)
    
    """Return list of the edges in this graph as property"""            
    edges = property(lambda obj: obj._edges)
    
    def get_nodes(self):
        """Return set of nodes in this graph."""
        nodes = set()
        for edge in self._edges:
            nodes.add(edge.head)
            nodes.add(edge.tail)
        return nodes

    nodes = property(lambda obj: obj.get_nodes)

        
    def add_edge(self, edge):
        """Add an edge to the graph."""
        self._edges.add(edge)
        
    def remove_edge(self, edge):
        """Remove an edge from the graph."""
        edge.clear_nodes()
        self._edges.remove(edge)

    def _get_agraph(self):
        """return a graphviz graph"""
        # Create dict with neighbours reachable from a given node
        lookup_table = {}
        for i, n in enumerate(self.nodes):
            lookup_table[n] = i
        out_map = {}
        for i, n in enumerate(self.nodes):
            out_map[i] = [lookup_table[oe.head] for oe in n.out_edges]
        # turn this into a AGraph from graphviz, return the AGraph
        return pgv.AGraph(out_map, strict=False, directed=True)

    def draw(self, imgfile=None):
        """ draw the graph

        draw the graph and either show it (if imgfile is None)
        or save it to the given file

        Parameter
        ---------
        imgfile : string
            filename for the image-file WITH the extension
            if None, image is displayed instead (default: None)

        """
        show = False
        imgfmt = None
        if imgfile is None:
            imgfile = BytesIO()
            imgfmt = "png"
            show = True
        G = self._get_agraph()
        G.layout('dot')
        G.draw(imgfile, format=imgfmt)
        if show:
            img = Image.open(imgfile)
            img.show()

    def print_dot(self, filename=None):
        """Create dot source for this graph

        Parameter
        ---------
        filename : string
            filename for the dot-file WITHOUT the extension
            if None, the source is printed to stdout (default: None)
        """
        output = ""
        output += "// compile with\n"
        fn = filename
        if fn is None:
            fn = "file"
        output += "//    dot -Tpdf {0}.dot > {0}.pdf\n".format(fn)
        output += self._get_agraph().string()
        if filename is not None:
            with open(filename+".dot", "w") as f:
                f.write(output)
        else:
            print(output)

