0%

3.5 Functional Programming

3.5 Functional Programming

Functions as Objects

  • Functions in Scala are first-class objects. That means we can assign a function to a val and pass it to classes, objects, or other functions as an argument.
1
2
3
4
5
6
// These are normal functions.
def plus1funct(x: Int): Int = x + 1
// These are functions as vals.
// The first one explicitly specifies the return type.
val plus1val: Int => Int = x => x + 1 // val plus1val: (Int => Int) = { x => x + 1 }
val times2val = (x: Int) => x * 2

val vs. def

  • (?) Why would you want to create a val instead of a def? With a val, you can now pass the function around to other functions.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// create our function
def plus2(x: Int): Int = x + 1
val plus1 = (x: Int) => x + 1
val times2 = (x: Int) => x * 2

// pass it to map, a list function
val myList = List(1, 2, 5, 9)
val myListPlus = myList.map(plus1)
val myListTimes = myList.map(times2)

// create a custom function, which performs an operation on X N times using recursion
def opN(x: Int, n: Int, op: Int => Int): Int = {
if (n <= 0) { x }
else { opN(op(x), n-1, op) }
}

opN(7, 5, plus1)
opN(7, 5, plus2) // O ?
opN(7, 3, times2)
  • Functions are evaluated every time they are called, while vals are evaluated at instantiation.
1
2
3
4
5
6
7
8
9
10
11
12
13
import scala.util.Random

// both x and y call the nextInt function, but x is evaluated immediately and y is a function
val x = Random.nextInt
def y = Random.nextInt

// x was previously evaluated, so it is a constant
println(s"x = $x")
println(s"x = $x")

// y is a function and gets reevaluated at each call, thus these produce different results
println(s"y = $y")
println(s"y = $y")

FIR Filter with Window Function Passed

  • window function
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// get some math functions
import scala.math.{abs, round, cos, Pi, pow}

// simple triangular window
val TriangularWindow: (Int, Int) => Seq[Int] = (length, bitwidth) => {
val raw_coeffs = (0 until length).map( (x:Int) => 1-abs((x.toDouble-(length-1)/2.0)/((length-1)/2.0)) )
val scaled_coeffs = raw_coeffs.map( (x: Double) => round(x * pow(2, bitwidth)).toInt)
scaled_coeffs
}

// Hamming window
val HammingWindow: (Int, Int) => Seq[Int] = (length, bitwidth) => {
val raw_coeffs = (0 until length).map( (x: Int) => 0.54 - 0.46*cos(2*Pi*x/(length-1)))
val scaled_coeffs = raw_coeffs.map( (x: Double) => round(x * pow(2, bitwidth)).toInt)
scaled_coeffs
}
  • FIR
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
// our FIR has parameterized window length, IO bitwidth, and windowing function
class MyFir(length: Int, bitwidth: Int, window: (Int, Int) => Seq[Int]) extends Module {
val io = IO(new Bundle {
val in = Input(UInt(bitwidth.W))
val out = Output(UInt((bitwidth*2+length-1).W)) // expect bit growth, conservative but lazy
})

// calculate the coefficients using the provided window function, mapping to UInts
val coeffs = window(length, bitwidth).map(_.U)

// create an array holding the output of the delays
// note: we avoid using a Vec here since we don't need dynamic indexing
val delays = Seq.fill(length)(Wire(UInt(bitwidth.W))).scan(io.in)( (prev: UInt, next: UInt) => {
next := RegNext(prev)
next
})

// multiply, putting result in "mults"
val mults = delays.zip(coeffs).map{ case(delay: UInt, coeff: UInt) => delay * coeff }

// add up multiplier outputs with bit growth
val result = mults.reduce(_+&_)

// connect output
io.out := result
}

visualize(() => new MyFir(7, 12, TriangularWindow))