As I continue to play with Scala, I wanted to do something non-trivial, so I decided to to implement Dijstra’s shortest path algorithm
As background, my experience is mostly object oriented OOAD and programming in Java (I use it at work).
The first challenge was getting the algorithm right (the first attempt ended in a beautiful recursive algorithm that just traversed the graph greedily selecting the shortest arc), but wikipedia’s page helped
Just in case, be aware that I’m no Scala expert and no functional programming expert… (I’m no expert on anything at all) I’m just trying to share what I found, I’m sure there’s a better way, feel free to suggest it.
So far, I’ve enjoyed the experience: programming in Scala is easy and fun (at least compared to Java or C++, … yes, yes, it doesn’t require much, I can think many jokes, leave your in the comments ;-)).
Sure, I spend a while on seemingly trivial stuff, but that happens when you learn any language (or framework, or anything new!):
I tried to define a property for a class, in Scala you can do it in the declaration, like this:
class Node (label:String) { … }
But I couldn’t do node.label!! Then I realized I was declaring a private property, if you want accessors you need to use val (and you get a getter) or a var (and you get a getter and a setter), so I’ve changed it to
class Node (val label:String) { … }
And worked :-D
I did a quick search but couldn’t find a quick and easy functional Dsp implementation (seems that there isn’t one, and makes sense as it relies on shared state)
One of the main steps of the algorithm is find the node with the lower cost. I came up with:
nodes.reduceLeft((a:Node, b:Node)=> if (a.weight<b.weight) a else b)
Maybe there’s a better way, but I learned how to use “reduce” and looks quite concise and clear J
Then, you need to check if you can improve the cost of the adjacent nodes using the one you just found, so I defined:
def improveDistance(a:Arc) ={
if (a.start.weight+a.weight< a.end.weight) {
a.end.weight=a.start.weight+a.weight
a.end.previous=a.start
}
}
Nothing magic, but then I can use “map” to apply it to all the arcs from the node:
vertx.transitions.map(improveDistance(_))
Don’t be scared by the “_”, just I’m too lazy and glad that Scala allows me not to type “map((a:Arc)=>improveDistance(a))”
I didn’t use Scala’s available unit testing frameworks, but I did my sort of unit test, so I declared the arcs;
var a12= new Arc(n1,n2,1.0)
var a13= new Arc(n1,n3,2.2)
var a24= new Arc(n2,n4,1.5)
…
But it felt too “Java”, I wanted to try something more “DSL-ish”, and Scala can help because almost any character is valid as a method name, so I added the following to the Node class (transitions is the arc list of the node):
def --> (end: Node):Arc={
transitions = new Arc(this,end) :: transitions
transitions.head
}
Why? Because now, to add an arc of weight 2 from node N1 to node N2
, I can write:
N1-->N2 weight=2
Isn’t that neat? J
(I used --> instead of -> because I didn’t want to freak out the people scared about changing the meaning of the operators, and Scala uses -> for maps)
Nodes are declared var n1=new Node("Start"), It would be nicer to declare the nodes with “var n2= Node "Node2" “ but it didn’t bother me that much and couldn’t find a quicker way other than use Case classes. If somebody knows a better way, let me know!
I leave you here with the full code. Please suggest improvements!
package myScala;
class Node (val label:String) {
var transitions: List[Arc] = Nil
var previous: Node = _
var weight= Float.PositiveInfinity
var visited = false
override def toString()= {
label+" w:"+weight+" p:"+previous
}
def --> (end: Node):Arc={
transitions = new Arc(this,end) :: transitions
transitions.head
}
}
class Arc(var start: Node, var end: Node) {
var weight: Float = _
override def toString()= {
start.label+"-->"+end.label+" w:"+weight
}
}
package myScala;
object Dijkstra {
def shortestPath(graph:Set[Node], start: Node, end: Node) = {
var unvisited=graph
start.weight=0
while (!unvisited.isEmpty) {
val vertx=min(unvisited)
vertx.transitions.map(improveDistance(_))
unvisited=unvisited-vertx
}
}
def improveDistance(a:Arc) ={
if (a.start.weight+a.weight< a.end.weight) {
a.end.weight=a.start.weight+a.weight
a.end.previous=a.start
}
}
def min(nodes: Set[Node]): Node = {
nodes.reduceLeft((a:Node, b:Node)
=> if (a.weight<b.weight) a else b)
}
def pathTo(end:Node):List[Node] = {
if (end == null)
Nil
else
end :: pathTo(end.previous)
}
}
package myScala;
import scala.collection.mutable.HashSet
object Test {
/*
n1 --2--> n2--1--> n5
| | |
1 1 3
| | |
V V V
n3--3---> n4--1--> n6
*/
var n1=new Node("Start")
var n2=new Node("Node2")
var n3=new Node("Node3")
var n4=new Node("Node4")
var n5=new Node("Node5")
var n6=new Node("End")
n1-->n2 weight=2
n1-->n3 weight=1
n2-->n4 weight=1
n3-->n4 weight=3
n2-->n5 weight=1
n4-->n6 weight=1
n5-->n6 weight=3
var graph= Set(n1, n2, n3, n4, n5, n6)
def main(args: Array[String]) {
Dijkstra.shortestPath(graph,n1,n6)
println("Path")
Dijkstra.pathTo(n6).reverse.map(
(v:Node)=>println(v.label+" dist:"+v.weight)
)
}
}
No comments:
Post a Comment