For-Comprehensions and Monads: Elegant Composition

Introduction

For-comprehensions (also called for-yield expressions) are one of Scala's most elegant features for composing operations across collections and monadic types. They provide a readable, imperative-style syntax for functional operations like map, flatMap, and filter.

Behind the scenes, for-comprehensions work with any type that implements the monadic operations: map, flatMap, and filter. This lesson will teach you to use for-comprehensions effectively and understand the monadic patterns that power them.

Basic For-Comprehensions

Simple Transformations

// Basic for-comprehension with collections
val numbers = List(1, 2, 3, 4, 5)

// Traditional approach
val doubled = numbers.map(_ * 2)
val evens = numbers.filter(_ % 2 == 0)
val doubledEvens = numbers.filter(_ % 2 == 0).map(_ * 2)

// For-comprehension approach
val forDoubled = for (n <- numbers) yield n * 2
val forEvens = for (n <- numbers if n % 2 == 0) yield n
val forDoubledEvens = for {
  n <- numbers
  if n % 2 == 0
} yield n * 2

println(s"Traditional doubled: $doubled")
println(s"For-comp doubled: $forDoubled")
println(s"Traditional evens: $evens")
println(s"For-comp evens: $forEvens")
println(s"Traditional doubled evens: $doubledEvens")
println(s"For-comp doubled evens: $forDoubledEvens")

// Multiple generators (Cartesian product)
val letters = List('a', 'b', 'c')
val digits = List(1, 2, 3)

val combinations = for {
  letter <- letters
  digit <- digits
} yield s"$letter$digit"

println(s"Combinations: $combinations")
// List(a1, a2, a3, b1, b2, b3, c1, c2, c3)

// Filtered combinations
val filteredCombinations = for {
  letter <- letters
  digit <- digits
  if digit % 2 == 1  // Only odd digits
} yield s"$letter$digit"

println(s"Filtered combinations: $filteredCombinations")
// List(a1, a3, b1, b3, c1, c3)

// Nested transformations
val matrix = List(
  List(1, 2, 3),
  List(4, 5, 6),
  List(7, 8, 9)
)

val flattened = for {
  row <- matrix
  element <- row
} yield element

val evenElements = for {
  row <- matrix
  element <- row
  if element % 2 == 0
} yield element

val squaredOdds = for {
  row <- matrix
  element <- row
  if element % 2 == 1
} yield element * element

println(s"Flattened: $flattened")     // List(1, 2, 3, 4, 5, 6, 7, 8, 9)
println(s"Even elements: $evenElements") // List(2, 4, 6, 8)
println(s"Squared odds: $squaredOdds")   // List(1, 9, 25, 49, 81)

Pattern Matching in For-Comprehensions

// For-comprehensions with pattern matching
case class Person(name: String, age: Int, city: String)

val people = List(
  Person("Alice", 25, "New York"),
  Person("Bob", 30, "San Francisco"),
  Person("Charlie", 35, "New York"),
  Person("Diana", 28, "Chicago"),
  Person("Eve", 32, "San Francisco")
)

// Extract names from people in specific cities
val newYorkers = for {
  Person(name, _, "New York") <- people
} yield name

val youngPeople = for {
  Person(name, age, city) <- people
  if age < 30
} yield s"$name from $city"

val adultNewYorkers = for {
  person@Person(_, age, "New York") <- people
  if age >= 25
} yield person

println(s"New Yorkers: $newYorkers")
println(s"Young people: $youngPeople")
println(s"Adult New Yorkers: $adultNewYorkers")

// Working with tuples
val coordinates = List((1, 2), (3, 4), (5, 6), (7, 8))

val sumCoordinates = for {
  (x, y) <- coordinates
} yield x + y

val positiveQuadrant = for {
  (x, y) <- coordinates
  if x > 0 && y > 0
} yield (x, y)

println(s"Sum coordinates: $sumCoordinates")      // List(3, 7, 11, 15)
println(s"Positive quadrant: $positiveQuadrant") // List((1,2), (3,4), (5,6), (7,8))

// Deconstructing complex data
case class Order(id: Int, items: List[String], total: Double)

val orders = List(
  Order(1, List("laptop", "mouse"), 1050.0),
  Order(2, List("book", "pen"), 25.0),
  Order(3, List("monitor", "keyboard", "mouse"), 450.0)
)

val expensiveItems = for {
  Order(id, items, total) <- orders
  if total > 100
  item <- items
} yield (id, item, total)

println("Expensive order items:")
expensiveItems.foreach { case (id, item, total) =>
  println(s"  Order $id: $item (total: $$$total)")
}

// Handling nested Options with pattern matching
val maybeData = List(Some("apple"), None, Some("banana"), Some("cherry"), None)

val uppercasedData = for {
  Some(data) <- maybeData
} yield data.toUpperCase

val lengthyData = for {
  Some(data) <- maybeData
  if data.length > 5
} yield data

println(s"Uppercased data: $uppercasedData")  // List(APPLE, BANANA, CHERRY)
println(s"Lengthy data: $lengthyData")        // List(banana, cherry)

For-Comprehensions with Options

Safe Operations

// Working with Options to avoid null pointer exceptions
case class Address(street: String, city: String, zipCode: String)
case class Person(name: String, age: Int, address: Option[Address])

val people = List(
  Person("Alice", 25, Some(Address("123 Main St", "New York", "10001"))),
  Person("Bob", 30, None),
  Person("Charlie", 35, Some(Address("456 Oak Ave", "San Francisco", "94102"))),
  Person("Diana", 28, None)
)

// Extract cities safely
val cities = for {
  person <- people
  address <- person.address
} yield address.city

println(s"Cities: $cities")  // List(New York, San Francisco)

// Complex validations with Options
def parseInt(s: String): Option[Int] = {
  try Some(s.toInt) catch { case _: NumberFormatException => None }
}

def validateAge(age: Int): Option[Int] = {
  if (age >= 0 && age <= 150) Some(age) else None
}

def validateEmail(email: String): Option[String] = {
  if (email.contains("@")) Some(email.toLowerCase) else None
}

// Chain validations using for-comprehensions
def createUser(name: String, ageStr: String, email: String): Option[Person] = {
  for {
    age <- parseInt(ageStr)
    validAge <- validateAge(age)
    validEmail <- validateEmail(email)
  } yield Person(name, validAge, None)
}

val userInputs = List(
  ("Alice", "25", "alice@example.com"),
  ("Bob", "not-a-number", "bob@test.com"),
  ("Charlie", "200", "charlie@example.com"),
  ("Diana", "30", "invalid-email")
)

userInputs.foreach { case (name, age, email) =>
  createUser(name, age, email) match {
    case Some(user) => println(s"✓ Created user: ${user.name}")
    case None => println(s"✗ Failed to create user: $name")
  }
}

// Combining multiple Options
def findPersonByName(name: String): Option[Person] = {
  people.find(_.name == name)
}

def getPersonCity(name: String): Option[String] = {
  for {
    person <- findPersonByName(name)
    address <- person.address
  } yield address.city
}

def getPersonInfo(name: String): Option[String] = {
  for {
    person <- findPersonByName(name)
    address <- person.address
  } yield s"${person.name}, age ${person.age}, lives in ${address.city}"
}

val nameQueries = List("Alice", "Bob", "Charlie", "Unknown")
nameQueries.foreach { name =>
  getPersonCity(name) match {
    case Some(city) => println(s"$name lives in $city")
    case None => println(s"City for $name not found")
  }
}

// Working with nested Options
val nestedOptions = List(Some(Some(1)), Some(None), None, Some(Some(5)))

val flattenedValues = for {
  outer <- nestedOptions
  inner <- outer
} yield inner

println(s"Flattened values: $flattenedValues")  // List(1, 5)

// Option combinations with guards
def safeDivide(a: Double, b: Double): Option[Double] = {
  if (b != 0) Some(a / b) else None
}

def safeSquareRoot(x: Double): Option[Double] = {
  if (x >= 0) Some(math.sqrt(x)) else None
}

def complexCalculation(a: Double, b: Double): Option[Double] = {
  for {
    divided <- safeDivide(a, b)
    sqrt <- safeSquareRoot(divided)
    if sqrt > 1.0  // Guard condition
  } yield sqrt * 2
}

val calculations = List(
  (16.0, 4.0),  // sqrt(16/4) = 2.0, * 2 = 4.0
  (4.0, 4.0),   // sqrt(4/4) = 1.0, fails guard
  (16.0, 0.0),  // Division by zero
  (-16.0, 4.0)  // Negative square root
)

calculations.foreach { case (a, b) =>
  complexCalculation(a, b) match {
    case Some(result) => println(f"calc($a, $b) = $result%.2f")
    case None => println(s"calc($a, $b) = failed")
  }
}

For-Comprehensions with Either

Error Handling Composition

// Working with Either for error handling
type Result[T] = Either[String, T]

def validateName(name: String): Result[String] = {
  if (name.trim.nonEmpty) Right(name.trim)
  else Left("Name cannot be empty")
}

def validateAge(age: Int): Result[Int] = {
  if (age >= 0 && age <= 150) Right(age)
  else Left(s"Invalid age: $age")
}

def validateEmail(email: String): Result[String] = {
  if (email.contains("@")) Right(email.toLowerCase)
  else Left("Invalid email format")
}

// Chain validations using for-comprehensions
def validateUser(name: String, age: Int, email: String): Result[Person] = {
  for {
    validName <- validateName(name)
    validAge <- validateAge(age)
    validEmail <- validateEmail(email)
  } yield Person(validName, validAge, None)
}

val userValidations = List(
  ("Alice", 25, "alice@example.com"),
  ("", 30, "bob@test.com"),
  ("Charlie", 200, "charlie@example.com"),
  ("Diana", 28, "invalid-email")
)

userValidations.foreach { case (name, age, email) =>
  validateUser(name, age, email) match {
    case Right(user) => println(s"✓ Valid user: ${user.name}")
    case Left(error) => println(s"✗ Validation error: $error")
  }
}

// Multiple error scenarios
def parseConfig(data: Map[String, String]): Result[Map[String, Any]] = {
  for {
    host <- data.get("host").toRight("Missing host")
    portStr <- data.get("port").toRight("Missing port")
    port <- Either.cond(portStr.forall(_.isDigit), portStr.toInt, "Invalid port format")
    _ <- Either.cond(port > 0 && port <= 65535, (), "Port out of range")
  } yield Map("host" -> host, "port" -> port)
}

val configTests = List(
  Map("host" -> "localhost", "port" -> "8080"),
  Map("host" -> "localhost"),  // Missing port
  Map("host" -> "localhost", "port" -> "abc"),  // Invalid port
  Map("host" -> "localhost", "port" -> "99999")  // Port out of range
)

configTests.foreach { config =>
  parseConfig(config) match {
    case Right(parsed) => println(s"✓ Config: $parsed")
    case Left(error) => println(s"✗ Config error: $error")
  }
}

// Combining Either results
def fetchUserData(userId: Int): Result[String] = {
  if (userId > 0) Right(s"User data for $userId")
  else Left(s"Invalid user ID: $userId")
}

def fetchUserPreferences(userId: Int): Result[Map[String, Any]] = {
  if (userId > 0) Right(Map("theme" -> "dark", "lang" -> "en"))
  else Left(s"No preferences for user $userId")
}

def getUserProfile(userId: Int): Result[(String, Map[String, Any])] = {
  for {
    userData <- fetchUserData(userId)
    preferences <- fetchUserPreferences(userId)
  } yield (userData, preferences)
}

List(1, 0, 5, -1).foreach { userId =>
  getUserProfile(userId) match {
    case Right((data, prefs)) => 
      println(s"✓ User $userId: $data, preferences: $prefs")
    case Left(error) => 
      println(s"✗ Failed to get profile for user $userId: $error")
  }
}

Advanced For-Comprehension Patterns

Custom Monadic Types

// Creating a custom monadic type
case class Logged[A](value: A, log: List[String]) {
  def map[B](f: A => B): Logged[B] = 
    Logged(f(value), log)

  def flatMap[B](f: A => Logged[B]): Logged[B] = {
    val result = f(value)
    Logged(result.value, log ++ result.log)
  }

  def withFilter(p: A => Boolean): Logged[A] = 
    if (p(value)) this 
    else Logged(value, log :+ s"Filter failed for $value")
}

object Logged {
  def apply[A](value: A, message: String): Logged[A] = 
    Logged(value, List(message))
}

// Functions that return Logged values
def loggedAdd(a: Int, b: Int): Logged[Int] = 
  Logged(a + b, s"Added $a + $b = ${a + b}")

def loggedMultiply(a: Int, b: Int): Logged[Int] = 
  Logged(a * b, s"Multiplied $a * $b = ${a * b}")

def loggedDivide(a: Int, b: Int): Logged[Option[Int]] = {
  if (b != 0) Logged(Some(a / b), s"Divided $a / $b = ${a / b}")
  else Logged(None, s"Cannot divide $a by $b")
}

// Using for-comprehensions with custom monadic type
val loggedComputation = for {
  x <- Logged(5, "Starting with 5")
  y <- loggedAdd(x, 3)
  z <- loggedMultiply(y, 2)
  if z > 10
} yield z * z

println(s"Result: ${loggedComputation.value}")
println("Log:")
loggedComputation.log.foreach(entry => println(s"  $entry"))

// Complex logged computation
def complexCalculation(a: Int, b: Int): Logged[Option[Int]] = {
  for {
    sum <- loggedAdd(a, b)
    product <- loggedMultiply(sum, 2)
    division <- loggedDivide(product, 4)
  } yield division.value
}

val result = complexCalculation(3, 7)
println(s"\nComplex calculation result: ${result.value}")
println("Computation log:")
result.log.foreach(entry => println(s"  $entry"))

// State monad example
case class State[S, A](run: S => (S, A)) {
  def map[B](f: A => B): State[S, B] = 
    State(s => {
      val (newState, a) = run(s)
      (newState, f(a))
    })

  def flatMap[B](f: A => State[S, B]): State[S, B] = 
    State(s => {
      val (newState, a) = run(s)
      f(a).run(newState)
    })
}

object State {
  def get[S]: State[S, S] = State(s => (s, s))
  def put[S](newState: S): State[S, Unit] = State(_ => (newState, ()))
  def modify[S](f: S => S): State[S, Unit] = State(s => (f(s), ()))
}

// Counter example using State monad
type Counter[A] = State[Int, A]

def increment: Counter[Int] = State(count => (count + 1, count + 1))
def decrement: Counter[Int] = State(count => (count - 1, count - 1))
def reset: Counter[Unit] = State(_ => (0, ()))

val counterProgram = for {
  _ <- increment
  _ <- increment
  a <- increment
  _ <- decrement
  b <- State.get[Int]
} yield (a, b)

val (finalState, (a, b)) = counterProgram.run(0)
println(s"\nCounter program: a=$a, b=$b, final state=$finalState")

For-Comprehensions with Multiple Types

// Combining different monadic types
import scala.util.{Try, Success, Failure}

// Helper to convert Try to Either
def tryToEither[T](t: Try[T]): Either[String, T] = t match {
  case Success(value) => Right(value)
  case Failure(exception) => Left(exception.getMessage)
}

// Mixed computations
def parseAndValidate(input: String): Either[String, Int] = {
  for {
    parsed <- tryToEither(Try(input.toInt))
    validated <- if (parsed > 0) Right(parsed) else Left("Number must be positive")
  } yield validated
}

def processNumbers(inputs: List[String]): List[Either[String, Int]] = {
  for {
    input <- inputs
  } yield parseAndValidate(input)
}

val numberInputs = List("42", "-5", "not-a-number", "100", "0")
val processed = processNumbers(numberInputs)

processed.zip(numberInputs).foreach { case (result, input) =>
  result match {
    case Right(value) => println(s"'$input' -> $value")
    case Left(error) => println(s"'$input' -> Error: $error")
  }
}

// Combining Option and List
val maybeNumbers = List(Some(1), None, Some(3), Some(4), None)
val multipliers = List(2, 3, 5)

val products = for {
  maybeNum <- maybeNumbers
  num <- maybeNum.toList  // Convert Option to List
  multiplier <- multipliers
} yield num * multiplier

println(s"Products: $products")  // List(2, 3, 5, 6, 9, 15, 8, 12, 20)

// Nested for-comprehensions
case class Department(name: String, employees: List[Person])
case class Company(name: String, departments: List[Department])

val company = Company("TechCorp", List(
  Department("Engineering", List(
    Person("Alice", 30, None),
    Person("Bob", 35, None)
  )),
  Department("Marketing", List(
    Person("Charlie", 28, None),
    Person("Diana", 32, None)
  ))
))

val allEmployees = for {
  dept <- company.departments
  employee <- dept.employees
} yield (dept.name, employee.name)

val youngEmployees = for {
  dept <- company.departments
  employee <- dept.employees
  if employee.age < 30
} yield s"${employee.name} from ${dept.name}"

println("All employees:")
allEmployees.foreach { case (dept, name) => 
  println(s"  $name ($dept)")
}

println(s"Young employees: $youngEmployees")

// Parallel comprehensions
val lists = List(
  List(1, 2, 3),
  List(4, 5, 6),
  List(7, 8, 9)
)

val parallelSum = for {
  (list, index) <- lists.zipWithIndex
  element <- list
} yield (index, element)

println(s"Parallel enumeration: $parallelSum")
// List((0,1), (0,2), (0,3), (1,4), (1,5), (1,6), (2,7), (2,8), (2,9))

Performance and Best Practices

Efficient For-Comprehensions

import scala.util.Random

// Performance considerations
def timeForComprehension[T](name: String)(operation: => T): T = {
  val start = System.nanoTime()
  val result = operation
  val end = System.nanoTime()
  println(f"$name: ${(end - start) / 1e6}%.2f ms")
  result
}

val largeList = (1 to 100000).toList

// Efficient filtering and mapping
timeForComprehension("For-comprehension with early filtering") {
  val result = for {
    n <- largeList
    if n % 1000 == 0  // Filter early
    doubled = n * 2   // Intermediate value
  } yield doubled * doubled
  result.length
}

timeForComprehension("Traditional approach") {
  val result = largeList
    .filter(_ % 1000 == 0)
    .map(n => {
      val doubled = n * 2
      doubled * doubled
    })
  result.length
}

// View for lazy evaluation
timeForComprehension("For-comprehension with view") {
  val result = for {
    n <- largeList.view
    if n % 1000 == 0
    doubled = n * 2
  } yield doubled * doubled
  result.take(10).toList.length
}

// Best practices demonstration
case class Product(id: Int, name: String, price: Double, category: String)
case class Order(id: Int, productId: Int, quantity: Int, customerId: Int)
case class Customer(id: Int, name: String, city: String)

val products = (1 to 1000).map(i => 
  Product(i, s"Product $i", Random.nextDouble() * 100, 
    List("Electronics", "Books", "Clothing")(i % 3))
).toList

val orders = (1 to 5000).map(i => 
  Order(i, Random.nextInt(1000) + 1, Random.nextInt(10) + 1, Random.nextInt(100) + 1)
).toList

val customers = (1 to 100).map(i => 
  Customer(i, s"Customer $i", List("New York", "San Francisco", "Chicago")(i % 3))
).toList

// Efficient join operations
timeForComprehension("For-comprehension join") {
  val result = for {
    order <- orders.view
    product <- products
    if product.id == order.productId
    if product.category == "Electronics"
    customer <- customers
    if customer.id == order.customerId
  } yield (order.id, product.name, customer.name, order.quantity * product.price)

  result.take(100).toList.length
}

// More efficient with Maps for lookups
val productMap = products.map(p => p.id -> p).toMap
val customerMap = customers.map(c => c.id -> c).toMap

timeForComprehension("Optimized with Maps") {
  val result = for {
    order <- orders
    product <- productMap.get(order.productId)
    if product.category == "Electronics"
    customer <- customerMap.get(order.customerId)
  } yield (order.id, product.name, customer.name, order.quantity * product.price)

  result.length
}

// Guidelines for readable for-comprehensions
def processComplexData(): List[String] = {
  for {
    // Use meaningful variable names
    customer <- customers

    // Group related conditions
    if customer.city == "New York"

    // Extract intermediate values when complex
    customerOrders = orders.filter(_.customerId == customer.id)

    // Continue with clear, single-purpose steps
    order <- customerOrders
    product <- productMap.get(order.productId).toList

    // Final transformations
    orderValue = order.quantity * product.price
    if orderValue > 50
  } yield f"${customer.name}: ${product.name} ($$${orderValue}%.2f)"
}

val complexResults = processComplexData()
println(s"Complex query results: ${complexResults.take(5)}")

Summary

In this lesson, you've mastered for-comprehensions and monadic composition:

Basic For-Comprehensions: Syntax and simple transformations
Pattern Matching: Deconstructing data within for-expressions
Option Composition: Safe chaining of nullable operations
Either Composition: Elegant error handling pipelines
Custom Monads: Creating your own composable types
Performance: Writing efficient for-comprehensions
Best Practices: Readable and maintainable code patterns

For-comprehensions provide a powerful abstraction for working with monadic types, making complex data transformations readable and composable.

What's Next

In the next lesson, we'll explore lazy evaluation and streams, learning how to work with potentially infinite data structures and optimize performance through deferred computation.