# coding=utf-8

from heapq import heappop, heappush
from graph import NoPathGraphException


class Path(object):
    """ Path class.

    This class stores the data on a path needed by the SearchAlgorithm
    (path weight and last edge). Comparing two paths means comparing
    their weights.
    """
    def __init__(self, weight, edge):
        """Initialise the path."""
        self.weight = weight
        self.edge = edge

    def split(self):
        """Return data on path as pair."""
        return self.weight, self.edge

    def __lt__(self, path):
        """Compare '<' """
        return self.weight < path.weight

class SearchAlgorithm:
    """ Search Algorithm for Graphs """

    def _weight_func(self, edge):
        """Edge weights are assumed to be identical. Subclass to find
        the shortest way in weighted graphs"""
        return 1

    def find(self, start_node, end_node):
        """Return the shortest path and its overall weight.

        weight_func -- Functions that maps an edge to a weight value, the
            default function maps all edges to 1.
        """
        ## this is basically Dijkstra's algorithm
        # store the shortest path to all nodes,
        shortest_paths = {start_node: Path(0, None)}
        # we use this list as a priority cue with heapq
        edge_heap = []
        for edge in start_node.out_edges:
            heappush(edge_heap, Path(self._weight_func(edge), edge))
        while edge_heap:
            path_weight, edge = heappop(edge_heap).split()
            if ((edge.head not in shortest_paths) or
                (shortest_paths[edge.head].weight > path_weight)):
                shortest_paths[edge.head] = Path(path_weight, edge)
                # if we already visited this node then there may be edge
                # duplicates in the heap, but this is no problem because
                # the newer edges are guaranteed to come first
                for out_edge in edge.head.out_edges:
                    heappush(edge_heap,
                             Path(path_weight + self._weight_func(out_edge),
                                  out_edge))
        if end_node not in shortest_paths:
            err = ("The is no connection from node %s" % str(start_node) +
                   " to node %s." % str(end_node))
            raise NoPathGraphException(err)
        # assemble the shortest path from end to start
        path_weight = shortest_paths[end_node].weight
        path_edges = [shortest_paths[end_node].edge]
        current_node = path_edges[-1].tail
        while current_node is not start_node:
            path_edges.append(shortest_paths[current_node].edge)
            current_node = path_edges[-1].tail
        return path_edges[::-1], path_weight

