"""
Example solution for the afternoon Exercise of OOP.

written by Niko Wilbert for the G-Node Python Winter School 2010

This module contains all elements for an abstract graph representation.
"""

class Node(object):
    """Base class for nodes.
    
    To add additional data attributes to nodes you can derive subclasses.
    """
    
    def __init__(self):
        """Initialize the node."""
        self._out_edges = []  # outgoing edges
        self._in_edges = []  # incoming edges
        
    @property
    def out_edges(self):
        return self._out_edges
    

class Edge(object):
    """Edge class.
    
    This class also registers itself with the affected nodes, so it sits
    between Node and Graph in the hierarchy. The head and tail can also be
    None.
    """
    
    def __init__(self, tail, head):
        """Initialize the edge.
        
        head, tail -- The end and starting node of this edge, or can be None.
        """
        self._head = head
        if head is not None:
            head._in_edges.append(self)
        self._tail = tail
        if tail is not None:
            tail._out_edges.append(self)
        
    def _set_head(self, head):
        """Set the head and register this in the nodes as well.
        
        The head can also be None, then the edge is unregistered.
        """
        if self._head is not None:
            self._head._in_edges.remove(self)
        self._head = head
        if head is not None:
            self._head._in_edges.append(self)
            
    def _set_tail(self, tail):
        """Set the tail and register this in the nodes as well.
        
        The tail can also be None, then the edge is unregistered.
        """
        if self._tail is not None:
            self._tail._out_edges.remove(self)
        self._tail = tail
        if tail is not None:
            self._tail._out_edges.append(self)
    
    # use these public properties
    head = property(lambda obj: obj._head, _set_head)
    tail = property(lambda obj: obj._tail, _set_tail)

    def clear_nodes(self):
        """Clear the node references."""
        self._set_tail(None)
        self._set_head(None)


class GraphException(Exception):
    """Base Exception for Graph."""
    pass


class NoPathGraphException(GraphException):
    """Exception signaling that there is no path between nodes."""
    pass


class Graph(object):
    """Class to represent a complete graph with nodes and edges.
    
    Note that this class does not support nodes which do not belong to any
    edge (internally only edges are stored).
    """
    
    def __init__(self, edges=None):
        self._edges = set()
        if edges:
            for edge in edges:
                self.add_edge(edge)
                
    @property
    def edges(self):
        """Return list of the edges in this graph."""
        return self._edges
    
    @property
    def nodes(self):
        """Return set of nodes in this graph."""
        nodes = set()
        for edge in self._edges:
            nodes.add(edge.head)
            nodes.add(edge.tail)
        return nodes
        
    def add_edge(self, edge):
        """Add an edge to the graph."""
        self._edges.add(edge)
        
    def remove_edge(self, edge):
        """Remove an edge from the graph."""
        edge.clear_nodes()
        self._edges.remove(edge)

    def get_map(self):
        """Create dict with neighbours reachable from a given node"""
        lookup_table = {}
        for i, n in enumerate(self.nodes):
            lookup_table[n] = i
        out_map = {}
        for i, n in enumerate(self.nodes):
            out_map[i] = [lookup_table[oe.head] for oe in n.out_edges]
        return out_map

    def draw(self):
        """Create graphical representation of the graph"""
        out_map = self.get_map()
        import networkx
        from matplotlib.pyplot import show
        G = networkx.MultiDiGraph(out_map)
        networkx.draw(G, with_labels=True)
        show()

    def print_dot(self, filename=None):
        """Create dot source for this graph

        Parameter
        ---------
        filename : string
            filename for the dot-file WITHOUT the extension
            if None, the source is printed to stdout (default: None)
        """
        out_map = self.get_map()
        output = ""
        output += "// compile with\n"
        fn = filename
        if fn is None:
            fn = "file"
        output += "//    dot -Tpdf {0}.dot > {0}.pdf\n".format(fn)
        output += "\ndigraph structure {\n"
        for tail in out_map.keys():
            for head in out_map[tail]:
                output += "  {0} -> {1};\n".format(tail, head)
        output += "}\n"
        if filename is not None:
            with open(filename+".dot", "w") as f:
                f.write(output)
        else:
            print output

