For-Comprehensions and Monads: Elegant Composition
Introduction
For-comprehensions (also called for-yield expressions) are one of Scala's most elegant features for composing operations across collections and monadic types. They provide a readable, imperative-style syntax for functional operations like map, flatMap, and filter.
Behind the scenes, for-comprehensions work with any type that implements the monadic operations: map
, flatMap
, and filter
. This lesson will teach you to use for-comprehensions effectively and understand the monadic patterns that power them.
Basic For-Comprehensions
Simple Transformations
// Basic for-comprehension with collections
val numbers = List(1, 2, 3, 4, 5)
// Traditional approach
val doubled = numbers.map(_ * 2)
val evens = numbers.filter(_ % 2 == 0)
val doubledEvens = numbers.filter(_ % 2 == 0).map(_ * 2)
// For-comprehension approach
val forDoubled = for (n <- numbers) yield n * 2
val forEvens = for (n <- numbers if n % 2 == 0) yield n
val forDoubledEvens = for {
n <- numbers
if n % 2 == 0
} yield n * 2
println(s"Traditional doubled: $doubled")
println(s"For-comp doubled: $forDoubled")
println(s"Traditional evens: $evens")
println(s"For-comp evens: $forEvens")
println(s"Traditional doubled evens: $doubledEvens")
println(s"For-comp doubled evens: $forDoubledEvens")
// Multiple generators (Cartesian product)
val letters = List('a', 'b', 'c')
val digits = List(1, 2, 3)
val combinations = for {
letter <- letters
digit <- digits
} yield s"$letter$digit"
println(s"Combinations: $combinations")
// List(a1, a2, a3, b1, b2, b3, c1, c2, c3)
// Filtered combinations
val filteredCombinations = for {
letter <- letters
digit <- digits
if digit % 2 == 1 // Only odd digits
} yield s"$letter$digit"
println(s"Filtered combinations: $filteredCombinations")
// List(a1, a3, b1, b3, c1, c3)
// Nested transformations
val matrix = List(
List(1, 2, 3),
List(4, 5, 6),
List(7, 8, 9)
)
val flattened = for {
row <- matrix
element <- row
} yield element
val evenElements = for {
row <- matrix
element <- row
if element % 2 == 0
} yield element
val squaredOdds = for {
row <- matrix
element <- row
if element % 2 == 1
} yield element * element
println(s"Flattened: $flattened") // List(1, 2, 3, 4, 5, 6, 7, 8, 9)
println(s"Even elements: $evenElements") // List(2, 4, 6, 8)
println(s"Squared odds: $squaredOdds") // List(1, 9, 25, 49, 81)
Pattern Matching in For-Comprehensions
// For-comprehensions with pattern matching
case class Person(name: String, age: Int, city: String)
val people = List(
Person("Alice", 25, "New York"),
Person("Bob", 30, "San Francisco"),
Person("Charlie", 35, "New York"),
Person("Diana", 28, "Chicago"),
Person("Eve", 32, "San Francisco")
)
// Extract names from people in specific cities
val newYorkers = for {
Person(name, _, "New York") <- people
} yield name
val youngPeople = for {
Person(name, age, city) <- people
if age < 30
} yield s"$name from $city"
val adultNewYorkers = for {
person@Person(_, age, "New York") <- people
if age >= 25
} yield person
println(s"New Yorkers: $newYorkers")
println(s"Young people: $youngPeople")
println(s"Adult New Yorkers: $adultNewYorkers")
// Working with tuples
val coordinates = List((1, 2), (3, 4), (5, 6), (7, 8))
val sumCoordinates = for {
(x, y) <- coordinates
} yield x + y
val positiveQuadrant = for {
(x, y) <- coordinates
if x > 0 && y > 0
} yield (x, y)
println(s"Sum coordinates: $sumCoordinates") // List(3, 7, 11, 15)
println(s"Positive quadrant: $positiveQuadrant") // List((1,2), (3,4), (5,6), (7,8))
// Deconstructing complex data
case class Order(id: Int, items: List[String], total: Double)
val orders = List(
Order(1, List("laptop", "mouse"), 1050.0),
Order(2, List("book", "pen"), 25.0),
Order(3, List("monitor", "keyboard", "mouse"), 450.0)
)
val expensiveItems = for {
Order(id, items, total) <- orders
if total > 100
item <- items
} yield (id, item, total)
println("Expensive order items:")
expensiveItems.foreach { case (id, item, total) =>
println(s" Order $id: $item (total: $$$total)")
}
// Handling nested Options with pattern matching
val maybeData = List(Some("apple"), None, Some("banana"), Some("cherry"), None)
val uppercasedData = for {
Some(data) <- maybeData
} yield data.toUpperCase
val lengthyData = for {
Some(data) <- maybeData
if data.length > 5
} yield data
println(s"Uppercased data: $uppercasedData") // List(APPLE, BANANA, CHERRY)
println(s"Lengthy data: $lengthyData") // List(banana, cherry)
For-Comprehensions with Options
Safe Operations
// Working with Options to avoid null pointer exceptions
case class Address(street: String, city: String, zipCode: String)
case class Person(name: String, age: Int, address: Option[Address])
val people = List(
Person("Alice", 25, Some(Address("123 Main St", "New York", "10001"))),
Person("Bob", 30, None),
Person("Charlie", 35, Some(Address("456 Oak Ave", "San Francisco", "94102"))),
Person("Diana", 28, None)
)
// Extract cities safely
val cities = for {
person <- people
address <- person.address
} yield address.city
println(s"Cities: $cities") // List(New York, San Francisco)
// Complex validations with Options
def parseInt(s: String): Option[Int] = {
try Some(s.toInt) catch { case _: NumberFormatException => None }
}
def validateAge(age: Int): Option[Int] = {
if (age >= 0 && age <= 150) Some(age) else None
}
def validateEmail(email: String): Option[String] = {
if (email.contains("@")) Some(email.toLowerCase) else None
}
// Chain validations using for-comprehensions
def createUser(name: String, ageStr: String, email: String): Option[Person] = {
for {
age <- parseInt(ageStr)
validAge <- validateAge(age)
validEmail <- validateEmail(email)
} yield Person(name, validAge, None)
}
val userInputs = List(
("Alice", "25", "alice@example.com"),
("Bob", "not-a-number", "bob@test.com"),
("Charlie", "200", "charlie@example.com"),
("Diana", "30", "invalid-email")
)
userInputs.foreach { case (name, age, email) =>
createUser(name, age, email) match {
case Some(user) => println(s"✓ Created user: ${user.name}")
case None => println(s"✗ Failed to create user: $name")
}
}
// Combining multiple Options
def findPersonByName(name: String): Option[Person] = {
people.find(_.name == name)
}
def getPersonCity(name: String): Option[String] = {
for {
person <- findPersonByName(name)
address <- person.address
} yield address.city
}
def getPersonInfo(name: String): Option[String] = {
for {
person <- findPersonByName(name)
address <- person.address
} yield s"${person.name}, age ${person.age}, lives in ${address.city}"
}
val nameQueries = List("Alice", "Bob", "Charlie", "Unknown")
nameQueries.foreach { name =>
getPersonCity(name) match {
case Some(city) => println(s"$name lives in $city")
case None => println(s"City for $name not found")
}
}
// Working with nested Options
val nestedOptions = List(Some(Some(1)), Some(None), None, Some(Some(5)))
val flattenedValues = for {
outer <- nestedOptions
inner <- outer
} yield inner
println(s"Flattened values: $flattenedValues") // List(1, 5)
// Option combinations with guards
def safeDivide(a: Double, b: Double): Option[Double] = {
if (b != 0) Some(a / b) else None
}
def safeSquareRoot(x: Double): Option[Double] = {
if (x >= 0) Some(math.sqrt(x)) else None
}
def complexCalculation(a: Double, b: Double): Option[Double] = {
for {
divided <- safeDivide(a, b)
sqrt <- safeSquareRoot(divided)
if sqrt > 1.0 // Guard condition
} yield sqrt * 2
}
val calculations = List(
(16.0, 4.0), // sqrt(16/4) = 2.0, * 2 = 4.0
(4.0, 4.0), // sqrt(4/4) = 1.0, fails guard
(16.0, 0.0), // Division by zero
(-16.0, 4.0) // Negative square root
)
calculations.foreach { case (a, b) =>
complexCalculation(a, b) match {
case Some(result) => println(f"calc($a, $b) = $result%.2f")
case None => println(s"calc($a, $b) = failed")
}
}
For-Comprehensions with Either
Error Handling Composition
// Working with Either for error handling
type Result[T] = Either[String, T]
def validateName(name: String): Result[String] = {
if (name.trim.nonEmpty) Right(name.trim)
else Left("Name cannot be empty")
}
def validateAge(age: Int): Result[Int] = {
if (age >= 0 && age <= 150) Right(age)
else Left(s"Invalid age: $age")
}
def validateEmail(email: String): Result[String] = {
if (email.contains("@")) Right(email.toLowerCase)
else Left("Invalid email format")
}
// Chain validations using for-comprehensions
def validateUser(name: String, age: Int, email: String): Result[Person] = {
for {
validName <- validateName(name)
validAge <- validateAge(age)
validEmail <- validateEmail(email)
} yield Person(validName, validAge, None)
}
val userValidations = List(
("Alice", 25, "alice@example.com"),
("", 30, "bob@test.com"),
("Charlie", 200, "charlie@example.com"),
("Diana", 28, "invalid-email")
)
userValidations.foreach { case (name, age, email) =>
validateUser(name, age, email) match {
case Right(user) => println(s"✓ Valid user: ${user.name}")
case Left(error) => println(s"✗ Validation error: $error")
}
}
// Multiple error scenarios
def parseConfig(data: Map[String, String]): Result[Map[String, Any]] = {
for {
host <- data.get("host").toRight("Missing host")
portStr <- data.get("port").toRight("Missing port")
port <- Either.cond(portStr.forall(_.isDigit), portStr.toInt, "Invalid port format")
_ <- Either.cond(port > 0 && port <= 65535, (), "Port out of range")
} yield Map("host" -> host, "port" -> port)
}
val configTests = List(
Map("host" -> "localhost", "port" -> "8080"),
Map("host" -> "localhost"), // Missing port
Map("host" -> "localhost", "port" -> "abc"), // Invalid port
Map("host" -> "localhost", "port" -> "99999") // Port out of range
)
configTests.foreach { config =>
parseConfig(config) match {
case Right(parsed) => println(s"✓ Config: $parsed")
case Left(error) => println(s"✗ Config error: $error")
}
}
// Combining Either results
def fetchUserData(userId: Int): Result[String] = {
if (userId > 0) Right(s"User data for $userId")
else Left(s"Invalid user ID: $userId")
}
def fetchUserPreferences(userId: Int): Result[Map[String, Any]] = {
if (userId > 0) Right(Map("theme" -> "dark", "lang" -> "en"))
else Left(s"No preferences for user $userId")
}
def getUserProfile(userId: Int): Result[(String, Map[String, Any])] = {
for {
userData <- fetchUserData(userId)
preferences <- fetchUserPreferences(userId)
} yield (userData, preferences)
}
List(1, 0, 5, -1).foreach { userId =>
getUserProfile(userId) match {
case Right((data, prefs)) =>
println(s"✓ User $userId: $data, preferences: $prefs")
case Left(error) =>
println(s"✗ Failed to get profile for user $userId: $error")
}
}
Advanced For-Comprehension Patterns
Custom Monadic Types
// Creating a custom monadic type
case class Logged[A](value: A, log: List[String]) {
def map[B](f: A => B): Logged[B] =
Logged(f(value), log)
def flatMap[B](f: A => Logged[B]): Logged[B] = {
val result = f(value)
Logged(result.value, log ++ result.log)
}
def withFilter(p: A => Boolean): Logged[A] =
if (p(value)) this
else Logged(value, log :+ s"Filter failed for $value")
}
object Logged {
def apply[A](value: A, message: String): Logged[A] =
Logged(value, List(message))
}
// Functions that return Logged values
def loggedAdd(a: Int, b: Int): Logged[Int] =
Logged(a + b, s"Added $a + $b = ${a + b}")
def loggedMultiply(a: Int, b: Int): Logged[Int] =
Logged(a * b, s"Multiplied $a * $b = ${a * b}")
def loggedDivide(a: Int, b: Int): Logged[Option[Int]] = {
if (b != 0) Logged(Some(a / b), s"Divided $a / $b = ${a / b}")
else Logged(None, s"Cannot divide $a by $b")
}
// Using for-comprehensions with custom monadic type
val loggedComputation = for {
x <- Logged(5, "Starting with 5")
y <- loggedAdd(x, 3)
z <- loggedMultiply(y, 2)
if z > 10
} yield z * z
println(s"Result: ${loggedComputation.value}")
println("Log:")
loggedComputation.log.foreach(entry => println(s" $entry"))
// Complex logged computation
def complexCalculation(a: Int, b: Int): Logged[Option[Int]] = {
for {
sum <- loggedAdd(a, b)
product <- loggedMultiply(sum, 2)
division <- loggedDivide(product, 4)
} yield division.value
}
val result = complexCalculation(3, 7)
println(s"\nComplex calculation result: ${result.value}")
println("Computation log:")
result.log.foreach(entry => println(s" $entry"))
// State monad example
case class State[S, A](run: S => (S, A)) {
def map[B](f: A => B): State[S, B] =
State(s => {
val (newState, a) = run(s)
(newState, f(a))
})
def flatMap[B](f: A => State[S, B]): State[S, B] =
State(s => {
val (newState, a) = run(s)
f(a).run(newState)
})
}
object State {
def get[S]: State[S, S] = State(s => (s, s))
def put[S](newState: S): State[S, Unit] = State(_ => (newState, ()))
def modify[S](f: S => S): State[S, Unit] = State(s => (f(s), ()))
}
// Counter example using State monad
type Counter[A] = State[Int, A]
def increment: Counter[Int] = State(count => (count + 1, count + 1))
def decrement: Counter[Int] = State(count => (count - 1, count - 1))
def reset: Counter[Unit] = State(_ => (0, ()))
val counterProgram = for {
_ <- increment
_ <- increment
a <- increment
_ <- decrement
b <- State.get[Int]
} yield (a, b)
val (finalState, (a, b)) = counterProgram.run(0)
println(s"\nCounter program: a=$a, b=$b, final state=$finalState")
For-Comprehensions with Multiple Types
// Combining different monadic types
import scala.util.{Try, Success, Failure}
// Helper to convert Try to Either
def tryToEither[T](t: Try[T]): Either[String, T] = t match {
case Success(value) => Right(value)
case Failure(exception) => Left(exception.getMessage)
}
// Mixed computations
def parseAndValidate(input: String): Either[String, Int] = {
for {
parsed <- tryToEither(Try(input.toInt))
validated <- if (parsed > 0) Right(parsed) else Left("Number must be positive")
} yield validated
}
def processNumbers(inputs: List[String]): List[Either[String, Int]] = {
for {
input <- inputs
} yield parseAndValidate(input)
}
val numberInputs = List("42", "-5", "not-a-number", "100", "0")
val processed = processNumbers(numberInputs)
processed.zip(numberInputs).foreach { case (result, input) =>
result match {
case Right(value) => println(s"'$input' -> $value")
case Left(error) => println(s"'$input' -> Error: $error")
}
}
// Combining Option and List
val maybeNumbers = List(Some(1), None, Some(3), Some(4), None)
val multipliers = List(2, 3, 5)
val products = for {
maybeNum <- maybeNumbers
num <- maybeNum.toList // Convert Option to List
multiplier <- multipliers
} yield num * multiplier
println(s"Products: $products") // List(2, 3, 5, 6, 9, 15, 8, 12, 20)
// Nested for-comprehensions
case class Department(name: String, employees: List[Person])
case class Company(name: String, departments: List[Department])
val company = Company("TechCorp", List(
Department("Engineering", List(
Person("Alice", 30, None),
Person("Bob", 35, None)
)),
Department("Marketing", List(
Person("Charlie", 28, None),
Person("Diana", 32, None)
))
))
val allEmployees = for {
dept <- company.departments
employee <- dept.employees
} yield (dept.name, employee.name)
val youngEmployees = for {
dept <- company.departments
employee <- dept.employees
if employee.age < 30
} yield s"${employee.name} from ${dept.name}"
println("All employees:")
allEmployees.foreach { case (dept, name) =>
println(s" $name ($dept)")
}
println(s"Young employees: $youngEmployees")
// Parallel comprehensions
val lists = List(
List(1, 2, 3),
List(4, 5, 6),
List(7, 8, 9)
)
val parallelSum = for {
(list, index) <- lists.zipWithIndex
element <- list
} yield (index, element)
println(s"Parallel enumeration: $parallelSum")
// List((0,1), (0,2), (0,3), (1,4), (1,5), (1,6), (2,7), (2,8), (2,9))
Performance and Best Practices
Efficient For-Comprehensions
import scala.util.Random
// Performance considerations
def timeForComprehension[T](name: String)(operation: => T): T = {
val start = System.nanoTime()
val result = operation
val end = System.nanoTime()
println(f"$name: ${(end - start) / 1e6}%.2f ms")
result
}
val largeList = (1 to 100000).toList
// Efficient filtering and mapping
timeForComprehension("For-comprehension with early filtering") {
val result = for {
n <- largeList
if n % 1000 == 0 // Filter early
doubled = n * 2 // Intermediate value
} yield doubled * doubled
result.length
}
timeForComprehension("Traditional approach") {
val result = largeList
.filter(_ % 1000 == 0)
.map(n => {
val doubled = n * 2
doubled * doubled
})
result.length
}
// View for lazy evaluation
timeForComprehension("For-comprehension with view") {
val result = for {
n <- largeList.view
if n % 1000 == 0
doubled = n * 2
} yield doubled * doubled
result.take(10).toList.length
}
// Best practices demonstration
case class Product(id: Int, name: String, price: Double, category: String)
case class Order(id: Int, productId: Int, quantity: Int, customerId: Int)
case class Customer(id: Int, name: String, city: String)
val products = (1 to 1000).map(i =>
Product(i, s"Product $i", Random.nextDouble() * 100,
List("Electronics", "Books", "Clothing")(i % 3))
).toList
val orders = (1 to 5000).map(i =>
Order(i, Random.nextInt(1000) + 1, Random.nextInt(10) + 1, Random.nextInt(100) + 1)
).toList
val customers = (1 to 100).map(i =>
Customer(i, s"Customer $i", List("New York", "San Francisco", "Chicago")(i % 3))
).toList
// Efficient join operations
timeForComprehension("For-comprehension join") {
val result = for {
order <- orders.view
product <- products
if product.id == order.productId
if product.category == "Electronics"
customer <- customers
if customer.id == order.customerId
} yield (order.id, product.name, customer.name, order.quantity * product.price)
result.take(100).toList.length
}
// More efficient with Maps for lookups
val productMap = products.map(p => p.id -> p).toMap
val customerMap = customers.map(c => c.id -> c).toMap
timeForComprehension("Optimized with Maps") {
val result = for {
order <- orders
product <- productMap.get(order.productId)
if product.category == "Electronics"
customer <- customerMap.get(order.customerId)
} yield (order.id, product.name, customer.name, order.quantity * product.price)
result.length
}
// Guidelines for readable for-comprehensions
def processComplexData(): List[String] = {
for {
// Use meaningful variable names
customer <- customers
// Group related conditions
if customer.city == "New York"
// Extract intermediate values when complex
customerOrders = orders.filter(_.customerId == customer.id)
// Continue with clear, single-purpose steps
order <- customerOrders
product <- productMap.get(order.productId).toList
// Final transformations
orderValue = order.quantity * product.price
if orderValue > 50
} yield f"${customer.name}: ${product.name} ($$${orderValue}%.2f)"
}
val complexResults = processComplexData()
println(s"Complex query results: ${complexResults.take(5)}")
Summary
In this lesson, you've mastered for-comprehensions and monadic composition:
✅ Basic For-Comprehensions: Syntax and simple transformations
✅ Pattern Matching: Deconstructing data within for-expressions
✅ Option Composition: Safe chaining of nullable operations
✅ Either Composition: Elegant error handling pipelines
✅ Custom Monads: Creating your own composable types
✅ Performance: Writing efficient for-comprehensions
✅ Best Practices: Readable and maintainable code patterns
For-comprehensions provide a powerful abstraction for working with monadic types, making complex data transformations readable and composable.
What's Next
In the next lesson, we'll explore lazy evaluation and streams, learning how to work with potentially infinite data structures and optimize performance through deferred computation.
Comments
Be the first to comment on this lesson!