"""
Example solution for the afternoon Exercise of OOP.

written by Niko Wilbert for the G-Node Python Winter School 2010.

Modified and upgraded by Jonas Eschle.

This module contains all elements for an abstract graph representation.
"""
from io import BytesIO

import pygraphviz as pgv


class Node:
    """Base class for nodes.

    To add additional data attributes to nodes you can derive subclasses.
    """

    def __init__(self):
        """Initialize the node."""
        self._out_edges = []  # outgoing edges
        self._in_edges = []  # incoming edges

    @property
    def out_edges(self):
        return self._out_edges


class Edge:
    """Edge class that sits between Node and Graph.

    This class also registers itself with the affected nodes, so it sits
    between Node and Graph in the hierarchy. The head and tail can also be
    None.
    """

    def __init__(self, tail, head):
        """Initialize the edge.

        Params:
            tail(optional): Node of the edge, one side.
            head(optional): Node of the edge, other side.
        """
        self._head = head
        if head is not None:
            head._in_edges.append(self)
        self._tail = tail
        if tail is not None:
            tail._out_edges.append(self)

    @property
    def head(self):
        return self._head

    @head.setter
    def head(self, head):
        """Set the head and register this in the nodes as well.

        The head can also be None, then the edge is unregistered.
        """
        if self._head is not None:
            self._head._in_edges.remove(self)
        self._head = head
        if head is not None:
            self._head._in_edges.append(self)

    @property
    def tail(self):
        return self._tail

    @tail.setter
    def tail(self, tail):
        """Set the tail and register this in the nodes as well.

        The tail can also be None, then the edge is unregistered.
        """
        if self._tail is not None:
            self._tail._out_edges.remove(self)
        self._tail = tail
        if tail is not None:
            self._tail._out_edges.append(self)

    def clear_nodes(self):
        """Clear the node references."""
        self.tail = None
        self.head = None


class GraphException(Exception):
    """Base Exception for Graph."""
    pass


class NoPathGraphException(GraphException):
    """Exception signaling that there is no path between nodes."""
    pass


class Graph:
    """Class to represent a complete graph with nodes and edges.

    Note that this class does not support nodes which do not belong to any
    edge (internally only edges are stored).
    """

    def __init__(self, edges=None):
        """Graph with given edges.

        Args:
            edges (optional, iterable): Edges of the graph.
        """
        self._edges = set()
        if edges:
            for edge in edges:
                self.add_edge(edge)

    @property
    def edges(self):
        return self._edges

    @property
    def nodes(self):
        """Get the nodes in this graph."""
        nodes = set()
        for edge in self._edges:
            nodes.add(edge.head)
            nodes.add(edge.tail)
        return nodes

    def add_edge(self, edge):
        """Add an edge to the graph."""
        self._edges.add(edge)

    def remove_edge(self, edge):
        """Remove an edge from the graph."""
        edge.clear_nodes()
        self._edges.remove(edge)

    def _get_agraph(self):
        """return a graphviz graph"""
        # Create dict with neighbours reachable from a given node
        lookup_table = {}
        for i, n in enumerate(self.nodes):
            lookup_table[n] = i
        out_map = {}
        for i, n in enumerate(self.nodes):
            out_map[i] = [lookup_table[oe.head] for oe in n.out_edges]
        # turn this into a AGraph from graphviz, return the AGraph
        return pgv.AGraph(out_map, strict=False, directed=True)

    def draw(self, imgfile=None):
        """Draw the graph.

        draw the graph and either show it (if imgfile is None)
        or save it to the given file path

        Parameter
        ---------
        imgfile : string
            filename for the image-file WITH the extension
            if None, image is displayed instead (default: None)

        """
        show = imgfile is None
        imgfmt = None
        if imgfile is None:
            imgfile = BytesIO()
            imgfmt = "png"
        G = self._get_agraph()
        G.layout('dot')
        G.draw(imgfile, format=imgfmt)
        if show:
            try:
                from PIL import Image
            except ImportError as error:
                raise ImportError("PIL maybe fails to install."
                                  " If so, do not use the 'show'.") from error

            img = Image.open(imgfile)
            img.show()

    def print_dot(self, filename=None):
        """Create dot source for this graph.

        Parameter
        ---------
        filename : string
            filename for the dot-file WITHOUT the extension
            if None, the source is printed to stdout (default: None)
        """
        output = ""
        output += "// compile with\n"
        fn = filename
        if fn is None:
            fn = "file"
        output += "//    dot -Tpdf {0}.dot > {0}.pdf\n".format(fn)
        output += self._get_agraph().string()
        if filename is not None:
            with open(filename + ".dot", "w") as f:
                f.write(output)
        else:
            print(output)
