I'm having trouble with my homework implementing A* search algorithm in java, I'm given the graph, origin and destination, I had to follow some specific tasks in the homework so the code might be a bit weird(for example each time I'm adding to openlist I must call the protected boolean addItem(Object item)
method).
Anyway The graph is basically the road network of the UK, and I'm to find the least time consuming way from one crossroad to another via a car(some roads in the graph don't allow cars). So I managed to write the code to find the optimal solution, but the algorithm is too slow, I'm not sure how to make it faster since I somewhat follow the A* pseudo-code on wikipedia.
Solution below solves the task of finding the way in like 30s but I need it in roughly 2 seconds. the whole graph has 103892 nodes(crossroads) and 193736 edges(roads).
When I run a profiler in netbeans IDE, it says i spent a lot of time calling the HashMap.put
function, and self time(which is basically time spent in plan
function excluding all the time spent in other functions called by this function).
My heuristic is the time spent travelling the euclidian distance between the node and destination with the top speed possible.
The code:
package student; import cz.cvut.atg.zui.astar.AbstractOpenList; import cz.cvut.atg.zui.astar.PlannerInterface; import cz.cvut.atg.zui.astar.RoadGraph; import eu.superhub.wp5.planner.planningstructure.GraphEdge; import eu.superhub.wp5.planner.planningstructure.GraphNode; import eu.superhub.wp5.planner.planningstructure.PermittedMode; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.LinkedList; import java.util.List; public class Planner implements PlannerInterface { private OpenList openList; @Override public List<GraphEdge> plan(RoadGraph graph, GraphNode origin, GraphNode destination) { if (origin == destination) { return null; } Collection<GraphNode> allNodes = graph.getAllNodes(); int nodeSize = allNodes.size(); openList = new OpenList(nodeSize); Collection<GraphEdge> allEdges = graph.getAllEdges(); int edgeSize = allEdges.size(); double topSpeed = 0; for (GraphEdge edge : allEdges) { if (edge.getPermittedModes().contains(PermittedMode.CAR)) { if (edge.getAllowedMaxSpeedInKmph() > topSpeed) { topSpeed = edge.getAllowedMaxSpeedInKmph(); } } } System.out.println("n"+nodeSize+"e"+edgeSize); openList.initilizeHeuristics(allNodes, destination, topSpeed); List<Long> closedSet = new ArrayList(nodeSize); HashMap<GraphNode, GraphEdge> cameFrom = new HashMap(edgeSize); HashMap<Long, Double> gScore = new HashMap(nodeSize);//default value infinity for (GraphNode node : allNodes) { gScore.put(node.getId(), Double.MAX_VALUE); } gScore.put(origin.getId(), 0.0); HashMap<Long, Double> fScore = new HashMap(nodeSize);//default value infinity openList.fScore = fScore; fScore.put(origin.getId(), openList.getInitilizedHeuristics(origin)); openList.addToQueue(origin.getId()); while (!openList.isEmpty()) { GraphNode current = openList.getLowestPriceNode(graph); if (current == destination) { return reconstructPath(graph, cameFrom, current); } openList.removeHead(); closedSet.add(current.getId()); List<GraphEdge> outcomingEdges = graph.getNodeOutcomingEdges(current.getId()); if (outcomingEdges != null) { for (GraphEdge edge : outcomingEdges) { if (edge.getPermittedModes().contains(PermittedMode.CAR)) { Long neighbourId = edge.getToNodeId(); GraphNode node = graph.getNodeByNodeId(neighbourId); if (closedSet.contains(neighbourId)) { continue; } double tentative_gScore = gScore.get(current.getId()) + edge.getLengthInMetres() / edge.getAllowedMaxSpeedInKmph(); if (!openList.contains(neighbourId)) { fScore.put(neighbourId, Double.MAX_VALUE); openList.addToQueue(neighbourId); } if (tentative_gScore >= gScore.get(neighbourId)) { continue; } cameFrom.put(node, edge); gScore.put(neighbourId, tentative_gScore); fScore.put(neighbourId, gScore.get(neighbourId) + openList.getInitilizedHeuristics(node)); openList.ReSort(neighbourId); } } } } return null; } private List<GraphEdge> reconstructPath(RoadGraph graph, HashMap<GraphNode, GraphEdge> cameFrom, GraphNode current) { List<GraphEdge> total_path = new LinkedList(); total_path.add(cameFrom.get(current)); while (cameFrom.containsKey(current)) { current = graph.getNodeByNodeId(cameFrom.get(current).getFromNodeId()); total_path.add(cameFrom.get(current)); } Collections.reverse(total_path); total_path.remove(0); return total_path; } }
-
package student; import cz.cvut.atg.zui.astar.*; import eu.superhub.wp5.planner.planningstructure.GraphNode; import java.util.Collection; import java.util.Comp arator; import java.util.HashMap; import java.util.PriorityQueue; public class OpenList extends AbstractOpenList { private final HashMap<GraphNode, Double> heuristics; private final Comparator<Long> comparator; private final PriorityQueue<Long> queue; public HashMap<Long, Double> fScore; public OpenList(int size) { heuristics = new HashMap(size); comparator = new MyComparator(); queue = new PriorityQueue<>(10000, comparator); } @Override protected boolean addItem(Object item) { return true; } protected boolean addToQueue(long nodeidToInsert) { queue.add(nodeidToInsert); return add(nodeidToInsert); } protected boolean isEmpty() { return queue.isEmpty(); } private double getHeuristic(GraphNode node, GraphNode destination, double topSpeed) {//convert to meters double x0 = node.getLatitude(); double y0 = node.getLongitude(); double x1 = destination.getLatitude(); double y1 = destination.getLongitude(); double R = 6378137; double dLat = x1 * Math.PI / 180 - x0 * Math.PI / 180; double dLon = y1 * Math.PI / 180 - y0 * Math.PI / 180; double a = Math.sin(dLat / 2) * Math.sin(dLat / 2) + Math.cos(x0 * Math.PI / 180) * Math.cos(x1 * Math.PI / 180) * Math.sin(dLon / 2) * Math.sin(dLon / 2); double c = 2 * Math.atan2(Math.sqrt(a), Math.sqrt(1 - a)); return R * c / topSpeed; } protected GraphNode getLowestPriceNode(RoadGraph graph) { return graph.getNodeByNodeId(queue.peek()); } void remove(long nodeid) { queue.remove(nodeid); } boolean contains(long nodeid) { return queue.contains(nodeid); } void initilizeHeuristics(Collection<GraphNode> nodes, GraphNode destination, double topSpeed) { for (GraphNode node : nodes) { heuristics.put(node, getHeuristic(node, destination, topSpeed)); } } double getInitilizedHeuristics(GraphNode node) { return heuristics.get(node); } void ReSort(long nodeid) { queue.remove(nodeid); queue.add(nodeid); } void removeHead() { queue.remove(); } private class MyComparator implements Comparator<Long> { @Override public int compare(Long id1, Long id2) { if (fScore.get(id1) <= fScore.get(id2)) { return -1; } else { return 1; } } } }
So any ideas on how to make it faster? ps: I can supply the whole project for anybody who would want to debug it etc.