CMU 15-112: Fundamentals of Programming and Computer Science
Class Notes: Functional Programming (map / filter / reduce)


 Learning Goal: use functional programming to generate results without relying on state. In particular:
  1. Lambda functions
  2. Check 10.7
  3. Map, Filter, and Reduce
    1. Map
    2. Filter
    3. Reduce
  4. List comprehensions
  5. A real world example (optional)
  6. Check 10.8
  7. Check 10.9

  1. Lambda functions
    The lambda syntax is a way to create a function object without attaching it to a variable name. We call these functions "anonymous".

    We often use these when we want a function as an input for another function, but we don't need it by itself.

    Lambdas are defined in this way: lambda input1, input2, inputX : returnValue. Lambdas do not use return statements.
    f = lambda x : print(x) f("Wow!") lambda x : x * 2 # this does nothing "hello" # it's just like creating any random object -- like this string g = lambda x, y : x * y print(g(4, 28)) # Now that the lambda has a name we can use it

    Since lambdas are function objects, we can pass them directly into higher order functions (functions that take in functions)
    def derivative(f, x): h = 10**-8 return (f(x+h) - f(x))/h print(derivative(lambda x: 10*x + 42, 5)) # about 10

  2. Check 10.7

  3. Map, Filter, and Reduce
    The functions map, filter, and reduce are very handy for manipulating data stored in lists! We can use them to generate a new list based on an old one.

    1. Map
      The map function generates a list by applying the given function to each element of the list.
      #Let's start with a function that takes a single integer def double(n): return n * 2 #Here's a list of positive integers! someInts = [2, 4, 33, 8, 27, 73, 4, 98, 97, 100] #We want to apply our function to each element and store the results in a list #The familiar approach would be to use a loop and build up a result: result1 = [] for i in someInts: result1.append(double(i)) #As a function def myMap(fn, lst): result = [] for element in lst: result.append(fn(element)) return result print("result1 =",myMap(double, someInts)) #We can also use map(function, inputs) to achieve the same thing: result2 = list(map(double, someInts)) print("result2 =",result2) assert(result1==result2) #The map() function by itself actually returns a map object. #Note that we have to convert the output of map to a list. #Below, if we just print result3, we won't get much info. result3 = map(double, someInts) print("result3 =",result3) print("list(result3) =",list(result3)) #map() can also be used with lambda functions result4 = list(map(lambda x: x*2, someInts)) print("result =",result4)

    2. filter
      The filter function takes a function and a list, applies the function to each element of the list, and only retains elements that make the function return True.
      #Now let's say we want to find all the primes in a list someInts = [2, 4, 33, 8, 27, 73, 4, 98, 97, 100] def isPrime(n): if n < 2: return False for i in range(2, n): if n % i == 0: return False return True #filter(function, inputs) also applies a specified function to each element #For each item in inputs, if function(item)==True, we add it to the result. result1 = list(filter(isPrime, someInts)) print("result1 =", result1) #[2, 73, 97] are all the prime items #We can use filter() with lambda functions too evenInts = list(filter(lambda x: x%2==0, someInts)) print("evenInts =",evenInts)

    3. reduce
      Finally, the reduce function combines all the elements in the given list using the provided function, to produce a single-value result.
      #If we want to calculate the product of all items in a list, we could use a loop lst = [1, 3, 2, 11, 10, 9] result1 = 1 for nextItem in lst: result1 = result1 * nextItem print("result1 =",result1) #This is the same task performed with reduce(function, inputs) #Note that we have to import it from functools! from functools import reduce result2 = reduce(lambda product, nextItem: product * nextItem, lst) print("result2 =",result2) assert(result1 == result2) #At first, product = lst[0] and nextItem = lst[1] #The lambda function computes lst[0] * lst[1] and stores the result in product #Then nextItem becomes lst[2] and the process continues #Let's show that by using a normally-defined function that prints its args def debugProduct(product, nextItem): print("product =",product,"\t nextItem = ",nextItem, end="\t ") newProduct = product * nextItem print("newProduct =",newProduct) return newProduct result3 = reduce(debugProduct, lst) print("result3 =",result3)

  4. List Comprehensions
    We already taught this concept in the 1D Lists unit, but it will be especially useful in functional programming!

    List comprehensions can be used to iteratively create lists in a very compact form. This can be handy if you would otherwise create the list using a loop and relatively basic logic.
    #A list comprehension has several parts in its general form result = [fn(item) for item in someIterable if someExpression(item)] # transformation iteration filter #In human language, this means... # ...for every item x in someIterable... # ...if someExpression(item) is True... # ...append the output of fn(item) to our result #Or, in code, this is equivalent to the following lines result = [] for item in someIterable: if someExpression(item) == True: result.append( fn(item) )

    Here's an example:
    rows, cols = 3, 2 #This line creates a 2d list of zeros with 3 rows and 2 cols as specified zeroLst = [ ([0] * cols) for row in range(rows) ] print("zeroLst =",zeroLst) #The outer brackets and everything in-between is the 'comprehension' #This line gives a list of every even integer in range(0,100) listOfEvensA = [i for i in range(0,100) if i%2==0] #We could otherwise produce the exact same result using this loop listOfEvensB = [] for i in range(0,100): if i%2 == 0: listOfEvensB.append(i) assert(listOfEvensA==listOfEvensB)

    Using list comprehensions, we can do some pretty cool stuff!
    #This function returns all factors of a given integer (except 1 and itself) def getFactors(x): return [i for i in range(2,x) if x%i==0] #If we want a list of all primes less than x, we could do that simply as well def allPrimesLessThan(x): return [i for i in range(2,x) if len(getFactors(i))==0] #Or, if we want to use every part of the comprehension... #This function returns every value whose *square root* is less than x and prime def allSqrtsPrimeLessThan(x): return [i**2 for i in range(2,x) if len(getFactors(i))==0] x=54 print("getFactors(",x,") =>",getFactors(54)) print("allPrimesLessThan(",x,") =>",allPrimesLessThan(54)) print("allSqrtsPrimeLessThan(",x,") =>",allSqrtsPrimeLessThan(54))

  5. A real world example (optional)
    # we have some data from ohqueue data = """\ course | count --------+------- 10601 | 706 15110 | 2081 15112 | 8883 15210 | 3034 15213 | 3647 15381 | 7 15418 | 113 15440 | 1177 15441 | 77 15451 | 5 15455 | 13 15619 | 5 17437 | 77 18348 | 373 18349 | 891 67272 | 459 """ # We want to know what percentage of questions come from core cs courses # Clean up the data a little bit cleanSplitLine = lambda sl : (int(sl[0].strip()), int(sl[1].strip())) cleanedData = [cleanSplitLine(line.split("|")) for line in data.splitlines()[2:-1]] print(cleanedData) # Then find our answer result = sum(map(lambda e: e[1], filter(lambda e: e[0] in [15112, 15210, 15213, 15451], cleanedData))) / sum(map(lambda e: e[1], cleanedData)) print("%0.2f%%" % result) # This is the kind of code you'd never write in industry, but might write to get a quick and dirty answer fast

  6. Check 10.8

  7. Check 10.9