Traits: The Building Blocks of Behavior

Introduction

Traits are Scala's answer to multiple inheritance and composition. They allow you to define reusable pieces of behavior that can be mixed into classes, enabling flexible and modular design. Think of traits as interfaces that can contain both abstract method declarations and concrete implementations.

Traits solve the "diamond problem" of multiple inheritance through linearization, provide clean separation of concerns, and enable powerful design patterns like the strategy pattern, decorator pattern, and mixin composition. Understanding traits is essential for writing flexible, maintainable Scala code.

Basic Trait Definition

Simple Traits

// Basic trait with abstract methods
trait Drawable {
  def draw(): String
  def color: String

  // Concrete method with default implementation
  def display(): Unit = {
    println(s"Drawing ${color} shape: ${draw()}")
  }

  // Method with implementation that uses abstract methods
  def description: String = s"A ${color} drawable object"
}

// Trait with only concrete methods
trait Timestamped {
  private val _createdAt: java.time.LocalDateTime = java.time.LocalDateTime.now()

  def createdAt: java.time.LocalDateTime = _createdAt
  def age: java.time.Duration = java.time.Duration.between(_createdAt, java.time.LocalDateTime.now())

  def isOlderThan(minutes: Int): Boolean = {
    age.toMinutes > minutes
  }
}

// Class implementing a trait
class Circle(radius: Double, val color: String) extends Drawable {
  def draw(): String = s"Circle with radius $radius"

  def area: Double = math.Pi * radius * radius
}

// Class mixing in multiple traits
class Rectangle(width: Double, height: Double, val color: String) 
  extends Drawable with Timestamped {

  def draw(): String = s"Rectangle ${width}x${height}"

  def area: Double = width * height
}

// Usage
val circle = new Circle(5.0, "red")
circle.display()  // Drawing red shape: Circle with radius 5.0
println(circle.description)  // A red drawable object

val rect = new Rectangle(10.0, 5.0, "blue")
rect.display()  // Drawing blue shape: Rectangle 10.0x5.0
println(s"Created: ${rect.createdAt}")
println(s"Area: ${rect.area}")

Thread.sleep(1000)
println(s"Is older than 0 minutes: ${rect.isOlderThan(0)}")  // true

Traits with Type Parameters

trait Container[T] {
  def add(item: T): Unit
  def remove(item: T): Boolean
  def contains(item: T): Boolean
  def size: Int
  def isEmpty: Boolean = size == 0
  def nonEmpty: Boolean = !isEmpty

  // Abstract method that subclasses must implement
  def items: List[T]

  // Default implementation using abstract method
  def foreach(f: T => Unit): Unit = items.foreach(f)
  def map[U](f: T => U): List[U] = items.map(f)
  def filter(predicate: T => Boolean): List[T] = items.filter(predicate)
}

trait Ordered[T] {
  def compare(other: T): Int

  def <(other: T): Boolean = compare(other) < 0
  def <=(other: T): Boolean = compare(other) <= 0
  def >(other: T): Boolean = compare(other) > 0
  def >=(other: T): Boolean = compare(other) >= 0
  def ==(other: T): Boolean = compare(other) == 0
  def !=(other: T): Boolean = compare(other) != 0
}

// Implementation
class SimpleList[T] extends Container[T] {
  private var _items: List[T] = List.empty

  def add(item: T): Unit = _items = _items :+ item
  def remove(item: T): Boolean = {
    if (_items.contains(item)) {
      _items = _items.filterNot(_ == item)
      true
    } else false
  }
  def contains(item: T): Boolean = _items.contains(item)
  def size: Int = _items.length
  def items: List[T] = _items
}

class Priority(val value: Int, val description: String) extends Ordered[Priority] {
  def compare(other: Priority): Int = value.compare(other.value)

  override def toString: String = s"Priority($value, $description)"
}

// Usage
val list = new SimpleList[String]()
list.add("apple")
list.add("banana")
list.add("cherry")

println(s"Size: ${list.size}")
println(s"Contains banana: ${list.contains("banana")}")
list.foreach(println)

val mapped = list.map(_.toUpperCase)
println(s"Mapped: $mapped")

val priority1 = new Priority(1, "High")
val priority2 = new Priority(3, "Low")
val priority3 = new Priority(2, "Medium")

println(s"${priority1} > ${priority2}: ${priority1 > priority2}")  // true
println(s"${priority2} <= ${priority3}: ${priority2 <= priority3}")  // false

Mixin Composition

Stacking Traits

trait Logger {
  def log(message: String): Unit = {
    println(s"[LOG] $message")
  }
}

trait TimestampedLogger extends Logger {
  override def log(message: String): Unit = {
    val timestamp = java.time.LocalDateTime.now()
    super.log(s"[$timestamp] $message")
  }
}

trait EncryptedLogger extends Logger {
  override def log(message: String): Unit = {
    val encrypted = s"ENCRYPTED[${message.map(c => (c + 1).toChar).mkString}]"
    super.log(encrypted)
  }
}

trait FileLogger extends Logger {
  private val logFile = "application.log"

  override def log(message: String): Unit = {
    super.log(message)  // Still call the parent logger
    // In real implementation, would write to file
    println(s"[FILE] Writing to $logFile: $message")
  }
}

class Application

// Different mixing orders create different behavior
class App1 extends Application with Logger with TimestampedLogger with EncryptedLogger {
  def start(): Unit = {
    log("Application starting")
  }
}

class App2 extends Application with Logger with EncryptedLogger with TimestampedLogger {
  def start(): Unit = {
    log("Application starting")
  }
}

class App3 extends Application with Logger with TimestampedLogger with FileLogger {
  def start(): Unit = {
    log("Application starting")
  }
}

// Test different linearizations
println("=== App1 (Logger -> Timestamped -> Encrypted) ===")
new App1().start()

println("\n=== App2 (Logger -> Encrypted -> Timestamped) ===")
new App2().start()

println("\n=== App3 (Logger -> Timestamped -> File) ===")
new App3().start()

Diamond Problem Resolution

trait A {
  def method(): String = "A"
}

trait B extends A {
  override def method(): String = "B calls " + super.method()
}

trait C extends A {
  override def method(): String = "C calls " + super.method()
}

// Diamond inheritance: D inherits from both B and C, which both inherit from A
class D extends B with C {
  override def method(): String = "D calls " + super.method()
}

// Scala resolves this through linearization
val d = new D()
println(d.method())  // D calls C calls B calls A

// Linearization order: D -> C -> B -> A (right-to-left mixin order)

// More complex example
trait Database {
  def save(data: String): String = s"Saving '$data' to database"
}

trait Cached extends Database {
  private val cache = scala.collection.mutable.Map[String, String]()

  override def save(data: String): String = {
    if (cache.contains(data)) {
      s"Found '$data' in cache"
    } else {
      val result = super.save(data)
      cache(data) = result
      result
    }
  }
}

trait Logged extends Database {
  override def save(data: String): String = {
    println(s"[LOG] Attempting to save: $data")
    val result = super.save(data)
    println(s"[LOG] Save result: $result")
    result
  }
}

trait Validated extends Database {
  override def save(data: String): String = {
    if (data.trim.isEmpty) {
      "Error: Cannot save empty data"
    } else {
      super.save(data)
    }
  }
}

class Repository extends Database with Cached with Logged with Validated

val repo = new Repository()
println("=== First save ===")
println(repo.save("user-data"))

println("\n=== Second save (should hit cache) ===")
println(repo.save("user-data"))

println("\n=== Invalid save ===")
println(repo.save(""))

// Linearization: Repository -> Validated -> Logged -> Cached -> Database

Self Types and Dependencies

Self Type Annotations

trait Database {
  def save(key: String, value: String): Unit
  def load(key: String): Option[String]
}

trait Cache {
  def put(key: String, value: String): Unit
  def get(key: String): Option[String]
}

// UserService requires both Database and Cache to be mixed in
trait UserService { 
  self: Database with Cache =>  // Self type annotation

  def saveUser(id: String, userData: String): Unit = {
    // Can use methods from Database and Cache
    save(s"user:$id", userData)  // from Database
    put(s"cache:user:$id", userData)  // from Cache
    println(s"User $id saved")
  }

  def getUser(id: String): Option[String] = {
    // Try cache first, then database
    get(s"cache:user:$id") match {
      case Some(data) => 
        println(s"User $id found in cache")
        Some(data)
      case None => 
        load(s"user:$id") match {
          case Some(data) =>
            println(s"User $id loaded from database, caching...")
            put(s"cache:user:$id", data)
            Some(data)
          case None =>
            println(s"User $id not found")
            None
        }
    }
  }
}

// Implementations
class SimpleDatabase extends Database {
  private val storage = scala.collection.mutable.Map[String, String]()

  def save(key: String, value: String): Unit = {
    storage(key) = value
    println(s"[DB] Saved $key")
  }

  def load(key: String): Option[String] = {
    val result = storage.get(key)
    println(s"[DB] Loading $key: ${result.isDefined}")
    result
  }
}

class SimpleCache extends Cache {
  private val cache = scala.collection.mutable.Map[String, String]()

  def put(key: String, value: String): Unit = {
    cache(key) = value
    println(s"[CACHE] Cached $key")
  }

  def get(key: String): Option[String] = {
    val result = cache.get(key)
    println(s"[CACHE] Looking up $key: ${result.isDefined}")
    result
  }
}

// This works - provides both required traits
class UserManager extends SimpleDatabase with SimpleCache with UserService

// This would fail to compile - missing Cache requirement
// class IncompleteUserManager extends SimpleDatabase with UserService

val userManager = new UserManager()
userManager.saveUser("123", "John Doe")
userManager.getUser("123")  // Cache hit
userManager.getUser("456")  // Not found

Abstract Types in Traits

trait Serializer {
  type Data
  type Serialized

  def serialize(data: Data): Serialized
  def deserialize(serialized: Serialized): Data

  // Default implementation using abstract types
  def roundTrip(data: Data): Data = {
    deserialize(serialize(data))
  }
}

trait Compressor {
  type Input
  type Output

  def compress(input: Input): Output
  def decompress(output: Output): Input
}

// JSON serializer implementation
class JsonSerializer extends Serializer {
  type Data = Map[String, Any]
  type Serialized = String

  def serialize(data: Data): String = {
    // Simplified JSON serialization
    data.map { case (k, v) => s""""$k": "$v"""" }.mkString("{", ", ", "}")
  }

  def deserialize(json: String): Map[String, Any] = {
    // Simplified JSON deserialization
    val pattern = """"([^"]+)":\s*"([^"]+)"""".r
    pattern.findAllMatchIn(json).map { m =>
      m.group(1) -> m.group(2)
    }.toMap
  }
}

// String compressor implementation
class GzipCompressor extends Compressor {
  type Input = String
  type Output = Array[Byte]

  def compress(input: String): Array[Byte] = {
    // Simplified compression (just convert to bytes)
    s"COMPRESSED[${input}]".getBytes
  }

  def decompress(output: Array[Byte]): String = {
    // Simplified decompression
    new String(output).stripPrefix("COMPRESSED[").stripSuffix("]")
  }
}

// Combining serialization and compression
class SerializingCompressor extends JsonSerializer with GzipCompressor {
  // Need to resolve type conflicts
  type DataToSerialize = Map[String, Any]
  type SerializedData = String
  type CompressedOutput = Array[Byte]

  def process(data: DataToSerialize): CompressedOutput = {
    val json = serialize(data)
    compress(json)
  }

  def restore(compressed: CompressedOutput): DataToSerialize = {
    val json = decompress(compressed)
    deserialize(json)
  }
}

val processor = new SerializingCompressor()
val data = Map("name" -> "Alice", "age" -> "30", "city" -> "San Francisco")

val compressed = processor.process(data)
println(s"Compressed: ${new String(compressed)}")

val restored = processor.restore(compressed)
println(s"Restored: $restored")
println(s"Round trip successful: ${data == restored}")

Advanced Trait Patterns

Strategy Pattern with Traits

trait SortingStrategy[T] {
  def sort(items: List[T])(implicit ordering: Ordering[T]): List[T]
  def name: String
}

trait BubbleSort[T] extends SortingStrategy[T] {
  def sort(items: List[T])(implicit ordering: Ordering[T]): List[T] = {
    println(s"Using $name")
    // Simplified bubble sort implementation
    def bubble(list: List[T]): List[T] = list match {
      case Nil => Nil
      case x :: Nil => List(x)
      case x :: y :: rest if ordering.gt(x, y) => y :: bubble(x :: rest)
      case x :: rest => x :: bubble(rest)
    }

    def bubbleSort(list: List[T], passes: Int): List[T] = {
      if (passes <= 0) list
      else bubbleSort(bubble(list), passes - 1)
    }

    bubbleSort(items, items.length)
  }

  def name: String = "Bubble Sort"
}

trait QuickSort[T] extends SortingStrategy[T] {
  def sort(items: List[T])(implicit ordering: Ordering[T]): List[T] = {
    println(s"Using $name")

    def quickSort(list: List[T]): List[T] = list match {
      case Nil => Nil
      case head :: tail =>
        val (smaller, larger) = tail.partition(ordering.lt(_, head))
        quickSort(smaller) ++ List(head) ++ quickSort(larger)
    }

    quickSort(items)
  }

  def name: String = "Quick Sort"
}

trait MergeSort[T] extends SortingStrategy[T] {
  def sort(items: List[T])(implicit ordering: Ordering[T]): List[T] = {
    println(s"Using $name")

    def merge(left: List[T], right: List[T]): List[T] = (left, right) match {
      case (Nil, r) => r
      case (l, Nil) => l
      case (l :: ls, r :: rs) =>
        if (ordering.lteq(l, r)) l :: merge(ls, right)
        else r :: merge(left, rs)
    }

    def mergeSort(list: List[T]): List[T] = {
      if (list.length <= 1) list
      else {
        val (left, right) = list.splitAt(list.length / 2)
        merge(mergeSort(left), mergeSort(right))
      }
    }

    mergeSort(items)
  }

  def name: String = "Merge Sort"
}

// Sorter that can switch strategies
class AdaptiveSorter[T] {
  def sortWith(items: List[T], strategy: SortingStrategy[T])(implicit ordering: Ordering[T]): List[T] = {
    strategy.sort(items)
  }
}

// Different sorting implementations
object BubbleSorter extends BubbleSort[Int]
object QuickSorter extends QuickSort[Int]  
object MergeSorter extends MergeSort[Int]

val numbers = List(64, 34, 25, 12, 22, 11, 90)
val sorter = new AdaptiveSorter[Int]()

println("Original list: " + numbers)
println()

val bubbleResult = sorter.sortWith(numbers, BubbleSorter)
println(s"Bubble sort result: $bubbleResult")

val quickResult = sorter.sortWith(numbers, QuickSorter)
println(s"Quick sort result: $quickResult")

val mergeResult = sorter.sortWith(numbers, MergeSorter)
println(s"Merge sort result: $mergeResult")

Decorator Pattern with Traits

trait Component {
  def operation(): String
}

class ConcreteComponent extends Component {
  def operation(): String = "Basic operation"
}

trait Decorator extends Component {
  protected val component: Component
  def operation(): String = component.operation()
}

trait BorderDecorator extends Decorator {
  abstract override def operation(): String = {
    val result = super.operation()
    s"[$result]"
  }
}

trait TimestampDecorator extends Decorator {
  abstract override def operation(): String = {
    val timestamp = java.time.LocalTime.now().toString
    val result = super.operation()
    s"$timestamp: $result"
  }
}

trait UpperCaseDecorator extends Decorator {
  abstract override def operation(): String = {
    super.operation().toUpperCase
  }
}

trait ColorDecorator extends Decorator {
  def color: String

  abstract override def operation(): String = {
    val result = super.operation()
    s"$color($result)"
  }
}

// Creating decorated components
class RedColorDecorator(val component: Component) extends ColorDecorator {
  def color: String = "RED"
}

class BlueColorDecorator(val component: Component) extends ColorDecorator {
  def color: String = "BLUE"
}

// Usage - different decoration combinations
val basic = new ConcreteComponent()
println(s"Basic: ${basic.operation()}")

// Single decorations
val bordered = new ConcreteComponent() with BorderDecorator {
  val component = new ConcreteComponent()
}
println(s"Bordered: ${bordered.operation()}")

// Multiple decorations - order matters
val component1 = new ConcreteComponent() 
  with BorderDecorator 
  with TimestampDecorator 
  with UpperCaseDecorator {
  val component = new ConcreteComponent()
}
println(s"Multi-decorated 1: ${component1.operation()}")

val component2 = new ConcreteComponent() 
  with UpperCaseDecorator 
  with BorderDecorator 
  with TimestampDecorator {
  val component = new ConcreteComponent()
}
println(s"Multi-decorated 2: ${component2.operation()}")

// Using concrete decorator classes
val redDecorated = new RedColorDecorator(basic)
println(s"Red decorated: ${redDecorated.operation()}")

val complexDecorated = new BlueColorDecorator(
  new ConcreteComponent() with BorderDecorator with UpperCaseDecorator {
    val component = new ConcreteComponent()
  }
)
println(s"Complex decorated: ${complexDecorated.operation()}")

Practical Examples

Plugin System with Traits

trait Plugin {
  def name: String
  def version: String
  def initialize(): Unit
  def shutdown(): Unit
  def isEnabled: Boolean = true
}

trait ConfigurablePlugin extends Plugin {
  type Config
  def defaultConfig: Config
  def configure(config: Config): Unit

  private var _config: Config = defaultConfig

  def currentConfig: Config = _config

  override def initialize(): Unit = {
    super.initialize()
    configure(_config)
  }
}

trait LoggingPlugin extends Plugin {
  abstract override def initialize(): Unit = {
    println(s"[PLUGIN] Initializing ${name} v${version}")
    super.initialize()
    println(s"[PLUGIN] ${name} initialized successfully")
  }

  abstract override def shutdown(): Unit = {
    println(s"[PLUGIN] Shutting down ${name}")
    super.shutdown()
    println(s"[PLUGIN] ${name} shutdown complete")
  }
}

trait MetricsPlugin extends Plugin {
  private var _startTime: Long = 0
  private var _operationCount: Long = 0

  abstract override def initialize(): Unit = {
    _startTime = System.currentTimeMillis()
    super.initialize()
  }

  def recordOperation(): Unit = {
    _operationCount += 1
  }

  def uptime: Long = System.currentTimeMillis() - _startTime
  def operationCount: Long = _operationCount

  def metrics: String = s"${name}: uptime=${uptime}ms, operations=${operationCount}"
}

// Concrete plugin implementations
case class EmailConfig(smtpHost: String, port: Int, username: String)

class EmailPlugin extends ConfigurablePlugin with LoggingPlugin with MetricsPlugin {
  type Config = EmailConfig

  def name: String = "Email Service"
  def version: String = "1.2.0"

  def defaultConfig: EmailConfig = EmailConfig("localhost", 587, "user")

  def configure(config: EmailConfig): Unit = {
    println(s"Configuring email: ${config.smtpHost}:${config.port}")
  }

  def initialize(): Unit = {
    println("Starting email service...")
  }

  def shutdown(): Unit = {
    println("Stopping email service...")
  }

  def sendEmail(to: String, subject: String, body: String): Boolean = {
    recordOperation()
    println(s"Sending email to $to: $subject")
    true  // Simulate success
  }
}

class DatabasePlugin extends Plugin with LoggingPlugin with MetricsPlugin {
  def name: String = "Database Connection"
  def version: String = "2.0.1"

  def initialize(): Unit = {
    println("Connecting to database...")
  }

  def shutdown(): Unit = {
    println("Closing database connection...")
  }

  def query(sql: String): List[Map[String, Any]] = {
    recordOperation()
    println(s"Executing query: $sql")
    List(Map("result" -> "data"))
  }
}

// Plugin manager
class PluginManager {
  private var plugins: List[Plugin] = List.empty

  def register(plugin: Plugin): Unit = {
    plugins = plugins :+ plugin
    println(s"Registered plugin: ${plugin.name}")
  }

  def initializeAll(): Unit = {
    plugins.filter(_.isEnabled).foreach { plugin =>
      try {
        plugin.initialize()
      } catch {
        case e: Exception => 
          println(s"Failed to initialize ${plugin.name}: ${e.getMessage}")
      }
    }
  }

  def shutdownAll(): Unit = {
    plugins.reverse.foreach { plugin =>
      try {
        plugin.shutdown()
      } catch {
        case e: Exception => 
          println(s"Failed to shutdown ${plugin.name}: ${e.getMessage}")
      }
    }
  }

  def getMetrics(): List[String] = {
    plugins.collect { case plugin: MetricsPlugin => plugin.metrics }
  }
}

// Usage
val pluginManager = new PluginManager()
val emailPlugin = new EmailPlugin()
val dbPlugin = new DatabasePlugin()

pluginManager.register(emailPlugin)
pluginManager.register(dbPlugin)

println("=== Initializing Plugins ===")
pluginManager.initializeAll()

Thread.sleep(100)  // Simulate some uptime

println("\n=== Using Plugins ===")
emailPlugin.sendEmail("user@example.com", "Test", "Hello World")
dbPlugin.query("SELECT * FROM users")

println("\n=== Metrics ===")
pluginManager.getMetrics().foreach(println)

println("\n=== Shutting Down ===")
pluginManager.shutdownAll()

Summary

In this lesson, you've explored traits and their powerful composition capabilities:

Basic Traits: Abstract and concrete methods, type parameters
Mixin Composition: Stacking traits and linearization
Self Types: Expressing dependencies between traits
Design Patterns: Strategy, decorator, and plugin patterns
Abstract Types: Flexible type definitions within traits
Diamond Problem: Understanding Scala's linearization resolution

Traits are fundamental to flexible Scala design, enabling clean separation of concerns and powerful composition patterns.

What's Next

In the next lesson, we'll explore "Inheritance: Building Class Hierarchies." You'll learn how class inheritance works with traits, method overriding, super calls, and creating well-designed inheritance hierarchies.

This will complete your understanding of Scala's object-oriented features and prepare you for more advanced topics.

Ready to master inheritance? Let's continue!