Inheritance: Building Class Hierarchies

Introduction

Inheritance is a fundamental object-oriented programming concept that allows you to create new classes based on existing ones. In Scala, inheritance provides a way to share code, establish "is-a" relationships, and enable polymorphism. Combined with traits and case classes, inheritance helps you build flexible and maintainable class hierarchies.

Understanding when and how to use inheritance effectively is crucial for designing clean APIs and avoiding common pitfalls like deep inheritance hierarchies and tight coupling. This lesson will show you how to use inheritance wisely in Scala.

Basic Inheritance

Simple Class Inheritance

// Base class
abstract class Animal(val name: String, val species: String) {
  // Abstract method - must be implemented by subclasses
  def makeSound(): String

  // Concrete method with default implementation
  def introduce(): String = s"I am $name, a $species"

  // Method that can be overridden
  def sleep(): String = s"$name is sleeping"

  // Final method - cannot be overridden
  final def breathe(): String = s"$name is breathing"
}

// Concrete subclass
class Dog(name: String, val breed: String) extends Animal(name, "Dog") {
  // Must implement abstract method
  def makeSound(): String = "Woof!"

  // Override concrete method
  override def sleep(): String = s"$name is sleeping in a dog bed"

  // Additional methods specific to Dog
  def fetch(): String = s"$name is fetching the ball"
  def wagTail(): String = s"$name is wagging tail happily"
}

class Cat(name: String, val isIndoor: Boolean) extends Animal(name, "Cat") {
  def makeSound(): String = "Meow!"

  override def sleep(): String = {
    if (isIndoor) s"$name is sleeping on the couch"
    else s"$name is sleeping under the stars"
  }

  def purr(): String = s"$name is purring contentedly"
  def hunt(): String = if (isIndoor) s"$name is hunting toy mice" else s"$name is hunting real mice"
}

class Bird(name: String, val canFly: Boolean) extends Animal(name, "Bird") {
  def makeSound(): String = "Tweet!"

  def fly(): String = {
    if (canFly) s"$name is soaring through the sky"
    else s"$name cannot fly but is running around"
  }
}

// Usage
val dog = new Dog("Buddy", "Golden Retriever")
val cat = new Cat("Whiskers", true)
val bird = new Bird("Tweety", true)

println(dog.introduce())      // I am Buddy, a Dog
println(dog.makeSound())      // Woof!
println(dog.sleep())          // Buddy is sleeping in a dog bed
println(dog.fetch())          // Buddy is fetching the ball

println(cat.introduce())      // I am Whiskers, a Cat
println(cat.makeSound())      // Meow!
println(cat.hunt())           // Whiskers is hunting toy mice

println(bird.introduce())     // I am Tweety, a Bird
println(bird.fly())           // Tweety is soaring through the sky

// Polymorphism - treating all animals the same way
val animals: List[Animal] = List(dog, cat, bird)
animals.foreach { animal =>
  println(s"${animal.name}: ${animal.makeSound()}")
}

Method Overriding with Super

abstract class Vehicle(val make: String, val model: String, val year: Int) {
  def startEngine(): String = s"Starting the $make $model"
  def stopEngine(): String = s"Stopping the $make $model"
  def getInfo(): String = s"$year $make $model"

  // Abstract method
  def fuelType(): String

  // Method that uses abstract method
  def refuel(): String = s"Refueling with ${fuelType()}"
}

class Car(make: String, model: String, year: Int, val doors: Int) 
  extends Vehicle(make, model, year) {

  def fuelType(): String = "gasoline"

  override def startEngine(): String = {
    super.startEngine() + " - Car engine started with key"
  }

  override def getInfo(): String = {
    super.getInfo() + s" ($doors doors)"
  }

  def openTrunk(): String = s"Opening trunk of $make $model"
}

class Motorcycle(make: String, model: String, year: Int, val engineSize: Int) 
  extends Vehicle(make, model, year) {

  def fuelType(): String = "gasoline"

  override def startEngine(): String = {
    super.startEngine() + " - Motorcycle engine started with button"
  }

  override def getInfo(): String = {
    super.getInfo() + s" (${engineSize}cc)"
  }

  def wheelie(): String = s"$make $model is doing a wheelie!"
}

class ElectricCar(make: String, model: String, year: Int, val batteryCapacity: Int) 
  extends Vehicle(make, model, year) {

  def fuelType(): String = "electricity"

  override def startEngine(): String = {
    super.startEngine() + " - Electric motor activated silently"
  }

  override def refuel(): String = s"Charging battery (${batteryCapacity}kWh capacity)"

  override def getInfo(): String = {
    super.getInfo() + s" (${batteryCapacity}kWh battery)"
  }

  def checkBattery(): String = s"$make $model battery level: 85%"
}

// Usage
val car = new Car("Toyota", "Camry", 2023, 4)
val motorcycle = new Motorcycle("Harley-Davidson", "Street 750", 2022, 750)
val electricCar = new ElectricCar("Tesla", "Model 3", 2023, 75)

val vehicles = List(car, motorcycle, electricCar)

vehicles.foreach { vehicle =>
  println(s"=== ${vehicle.getInfo()} ===")
  println(vehicle.startEngine())
  println(vehicle.refuel())
  println(vehicle.stopEngine())
  println()
}

Abstract Classes vs Traits

When to Use Abstract Classes

// Abstract class - good for sharing state and constructor logic
abstract class DatabaseConnection(val host: String, val port: Int, val database: String) {
  // Shared state
  protected var isConnected: Boolean = false
  protected val connectionId: String = java.util.UUID.randomUUID().toString

  // Abstract methods
  def connect(): Boolean
  def disconnect(): Boolean
  def executeQuery(sql: String): List[Map[String, Any]]

  // Concrete methods using abstract methods
  def ensureConnected(): Boolean = {
    if (!isConnected) connect() else true
  }

  def safeExecute(sql: String): Either[String, List[Map[String, Any]]] = {
    if (ensureConnected()) {
      try {
        Right(executeQuery(sql))
      } catch {
        case e: Exception => Left(s"Query failed: ${e.getMessage}")
      }
    } else {
      Left("Could not establish connection")
    }
  }

  // Final method
  final def connectionInfo(): String = s"Connection $connectionId to $host:$port/$database"
}

class MySQLConnection(host: String, port: Int, database: String, username: String, password: String)
  extends DatabaseConnection(host, port, database) {

  def connect(): Boolean = {
    println(s"Connecting to MySQL at $host:$port/$database as $username")
    isConnected = true
    true
  }

  def disconnect(): Boolean = {
    println(s"Disconnecting from MySQL")
    isConnected = false
    true
  }

  def executeQuery(sql: String): List[Map[String, Any]] = {
    println(s"MySQL executing: $sql")
    List(Map("result" -> s"MySQL result for: $sql"))
  }
}

class PostgreSQLConnection(host: String, port: Int, database: String, username: String, password: String)
  extends DatabaseConnection(host, port, database) {

  def connect(): Boolean = {
    println(s"Connecting to PostgreSQL at $host:$port/$database as $username")
    isConnected = true
    true
  }

  def disconnect(): Boolean = {
    println(s"Disconnecting from PostgreSQL")
    isConnected = false
    true
  }

  def executeQuery(sql: String): List[Map[String, Any]] = {
    println(s"PostgreSQL executing: $sql")
    List(Map("result" -> s"PostgreSQL result for: $sql"))
  }
}

// Usage
val mysqlConn = new MySQLConnection("localhost", 3306, "myapp", "user", "pass")
val postgresConn = new PostgreSQLConnection("localhost", 5432, "myapp", "user", "pass")

List(mysqlConn, postgresConn).foreach { conn =>
  println(conn.connectionInfo())
  conn.safeExecute("SELECT * FROM users") match {
    case Right(results) => println(s"Success: $results")
    case Left(error) => println(s"Error: $error")
  }
  conn.disconnect()
  println()
}

Combining Abstract Classes with Traits

// Abstract base class for shared state
abstract class Shape(val name: String) {
  // Shared state
  protected var _isVisible: Boolean = true
  protected val _id: String = java.util.UUID.randomUUID().toString.take(8)

  // Abstract methods
  def area(): Double
  def perimeter(): Double

  // Concrete methods
  def id: String = _id
  def isVisible: Boolean = _isVisible
  def show(): Unit = _isVisible = true
  def hide(): Unit = _isVisible = false

  override def toString: String = s"$name[$_id]"
}

// Traits for additional behavior
trait Movable {
  protected var _x: Double = 0.0
  protected var _y: Double = 0.0

  def position: (Double, Double) = (_x, _y)
  def moveTo(x: Double, y: Double): Unit = {
    _x = x
    _y = y
  }
  def moveBy(dx: Double, dy: Double): Unit = {
    _x += dx
    _y += dy
  }
}

trait Colorable {
  private var _color: String = "black"

  def color: String = _color
  def setColor(color: String): Unit = _color = color
}

trait Scalable {
  protected var _scale: Double = 1.0

  def scale: Double = _scale
  def scaleBy(factor: Double): Unit = {
    require(factor > 0, "Scale factor must be positive")
    _scale *= factor
  }
  def scaleTo(newScale: Double): Unit = {
    require(newScale > 0, "Scale must be positive")
    _scale = newScale
  }
}

// Concrete implementations
class Circle(radius: Double) extends Shape("Circle") with Movable with Colorable with Scalable {
  require(radius > 0, "Radius must be positive")

  private val _baseRadius = radius

  def radius: Double = _baseRadius * _scale

  def area(): Double = math.Pi * radius * radius
  def perimeter(): Double = 2 * math.Pi * radius

  override def toString: String = 
    s"${super.toString} at $position, radius=${radius}, color=$color, visible=$isVisible"
}

class Rectangle(width: Double, height: Double) extends Shape("Rectangle") with Movable with Colorable {
  require(width > 0 && height > 0, "Dimensions must be positive")

  def area(): Double = width * height
  def perimeter(): Double = 2 * (width + height)

  override def toString: String = 
    s"${super.toString} at $position, ${width}x${height}, color=$color, visible=$isVisible"
}

// Usage
val circle = new Circle(5.0)
circle.moveTo(10, 20)
circle.setColor("red")
circle.scaleBy(1.5)

val rectangle = new Rectangle(10, 5)
rectangle.moveTo(0, 0)
rectangle.setColor("blue")

val shapes: List[Shape] = List(circle, rectangle)

shapes.foreach { shape =>
  println(shape)
  println(s"  Area: ${shape.area()}")
  println(s"  Perimeter: ${shape.perimeter()}")

  shape match {
    case movable: Movable => println(s"  Position: ${movable.position}")
    case _ => println("  Not movable")
  }

  shape match {
    case colorable: Colorable => println(s"  Color: ${colorable.color}")
    case _ => println("  Not colorable")
  }

  shape match {
    case scalable: Scalable => println(s"  Scale: ${scalable.scale}")
    case _ => println("  Not scalable")
  }

  println()
}

Sealed Classes and Type Safety

Sealed Hierarchies

// Sealed class - all subclasses must be in the same file
sealed abstract class Result[+T]

case class Success[T](value: T) extends Result[T]
case class Failure(error: String) extends Result[Nothing]
case object Loading extends Result[Nothing]

// Compiler ensures exhaustive pattern matching
def handleResult[T](result: Result[T]): String = result match {
  case Success(value) => s"Got value: $value"
  case Failure(error) => s"Error: $error"
  case Loading => "Still loading..."
  // No need for default case - compiler knows all possibilities
}

// Generic operations on Result
object Result {
  def map[A, B](result: Result[A])(f: A => B): Result[B] = result match {
    case Success(value) => Success(f(value))
    case Failure(error) => Failure(error)
    case Loading => Loading
  }

  def flatMap[A, B](result: Result[A])(f: A => Result[B]): Result[B] = result match {
    case Success(value) => f(value)
    case Failure(error) => Failure(error)
    case Loading => Loading
  }

  def getOrElse[A](result: Result[A], default: A): A = result match {
    case Success(value) => value
    case _ => default
  }
}

// Usage
val results = List(
  Success(42),
  Failure("Network error"),
  Loading,
  Success("Hello World")
)

results.foreach { result =>
  println(handleResult(result))
}

// Chaining operations
val stringResult: Result[String] = Success(42)
val upperCaseResult = Result.map(stringResult)(_.toString.toUpperCase)
val lengthResult = Result.map(upperCaseResult)(_.length)

println(s"Final result: ${handleResult(lengthResult)}")

Algebraic Data Types

// Sealed hierarchy for expression trees
sealed abstract class Expr

case class Num(value: Double) extends Expr
case class Var(name: String) extends Expr
case class Add(left: Expr, right: Expr) extends Expr
case class Mul(left: Expr, right: Expr) extends Expr
case class Div(left: Expr, right: Expr) extends Expr

object Expr {
  // Evaluation with environment
  def eval(expr: Expr, env: Map[String, Double] = Map.empty): Either[String, Double] = expr match {
    case Num(value) => Right(value)
    case Var(name) => 
      env.get(name).toRight(s"Variable '$name' not found")
    case Add(left, right) =>
      for {
        l <- eval(left, env)
        r <- eval(right, env)
      } yield l + r
    case Mul(left, right) =>
      for {
        l <- eval(left, env)
        r <- eval(right, env)
      } yield l * r
    case Div(left, right) =>
      for {
        l <- eval(left, env)
        r <- eval(right, env)
        result <- if (r != 0) Right(l / r) else Left("Division by zero")
      } yield result
  }

  // Pretty printing
  def show(expr: Expr): String = expr match {
    case Num(value) => value.toString
    case Var(name) => name
    case Add(left, right) => s"(${show(left)} + ${show(right)})"
    case Mul(left, right) => s"(${show(left)} * ${show(right)})"
    case Div(left, right) => s"(${show(left)} / ${show(right)})"
  }

  // Simplification
  def simplify(expr: Expr): Expr = expr match {
    case Add(Num(0), right) => simplify(right)
    case Add(left, Num(0)) => simplify(left)
    case Add(left, right) => Add(simplify(left), simplify(right))

    case Mul(Num(0), _) => Num(0)
    case Mul(_, Num(0)) => Num(0)
    case Mul(Num(1), right) => simplify(right)
    case Mul(left, Num(1)) => simplify(left)
    case Mul(left, right) => Mul(simplify(left), simplify(right))

    case Div(left, Num(1)) => simplify(left)
    case Div(Num(0), _) => Num(0)
    case Div(left, right) => Div(simplify(left), simplify(right))

    case other => other
  }
}

// Usage
val expr = Add(
  Mul(Var("x"), Num(2)),
  Div(Num(10), Var("y"))
)

println(s"Expression: ${Expr.show(expr)}")

val env = Map("x" -> 5.0, "y" -> 2.0)
Expr.eval(expr, env) match {
  case Right(result) => println(s"Result: $result")
  case Left(error) => println(s"Error: $error")
}

val complexExpr = Add(Mul(Num(0), Var("x")), Add(Num(0), Mul(Var("y"), Num(1))))
println(s"Before simplification: ${Expr.show(complexExpr)}")
println(s"After simplification: ${Expr.show(Expr.simplify(complexExpr))}")

Advanced Inheritance Patterns

Template Method Pattern

abstract class DataProcessor[T] {
  // Template method - defines the algorithm structure
  final def process(data: List[T]): List[T] = {
    val validated = validate(data)
    val filtered = filter(validated)
    val transformed = transform(filtered)
    val sorted = sort(transformed)
    finalize(sorted)
  }

  // Abstract methods - subclasses must implement
  protected def validate(data: List[T]): List[T]
  protected def transform(data: List[T]): List[T]

  // Concrete methods with default implementations
  protected def filter(data: List[T]): List[T] = data
  protected def sort(data: List[T]): List[T] = data
  protected def finalize(data: List[T]): List[T] = data

  // Hook methods
  protected def logStep(stepName: String, count: Int): Unit = {
    println(s"$stepName: processed $count items")
  }
}

class NumberProcessor extends DataProcessor[Int] {
  protected def validate(data: List[Int]): List[Int] = {
    val result = data.filter(_ >= 0)  // Only positive numbers
    logStep("Validation", result.length)
    result
  }

  protected def transform(data: List[Int]): List[Int] = {
    val result = data.map(_ * 2)  // Double each number
    logStep("Transformation", result.length)
    result
  }

  override protected def filter(data: List[Int]): List[Int] = {
    val result = data.filter(_ < 100)  // Only numbers less than 100
    logStep("Filtering", result.length)
    result
  }

  override protected def sort(data: List[Int]): List[Int] = {
    val result = data.sorted
    logStep("Sorting", result.length)
    result
  }
}

class StringProcessor extends DataProcessor[String] {
  protected def validate(data: List[String]): List[String] = {
    val result = data.filter(_.nonEmpty)  // Only non-empty strings
    logStep("Validation", result.length)
    result
  }

  protected def transform(data: List[String]): List[String] = {
    val result = data.map(_.toLowerCase.capitalize)  // Capitalize first letter
    logStep("Transformation", result.length)
    result
  }

  override protected def filter(data: List[String]): List[String] = {
    val result = data.filter(_.length >= 3)  // Only strings with 3+ characters
    logStep("Filtering", result.length)
    result
  }

  override protected def sort(data: List[String]): List[String] = {
    val result = data.sorted
    logStep("Sorting", result.length)
    result
  }

  override protected def finalize(data: List[String]): List[String] = {
    val result = data.take(10)  // Limit to 10 items
    logStep("Finalization", result.length)
    result
  }
}

// Usage
val numberProcessor = new NumberProcessor()
val stringProcessor = new StringProcessor()

val numbers = List(-5, 10, 25, 150, 30, -2, 75, 200, 15)
val strings = List("", "apple", "banana", "a", "cherry", "date", "elderberry", "fig", "grape")

println("=== Number Processing ===")
val processedNumbers = numberProcessor.process(numbers)
println(s"Original: $numbers")
println(s"Processed: $processedNumbers")

println("\n=== String Processing ===")
val processedStrings = stringProcessor.process(strings)
println(s"Original: $strings")
println(s"Processed: $processedStrings")

Factory Pattern with Inheritance

abstract class Logger {
  def log(level: String, message: String): Unit
  def close(): Unit = {}  // Default implementation
}

class ConsoleLogger extends Logger {
  def log(level: String, message: String): Unit = {
    val timestamp = java.time.LocalDateTime.now()
    println(s"[$timestamp] [$level] $message")
  }
}

class FileLogger(filename: String) extends Logger {
  private val writer = new java.io.PrintWriter(new java.io.FileWriter(filename, true))

  def log(level: String, message: String): Unit = {
    val timestamp = java.time.LocalDateTime.now()
    writer.println(s"[$timestamp] [$level] $message")
    writer.flush()
  }

  override def close(): Unit = {
    writer.close()
  }
}

class NetworkLogger(host: String, port: Int) extends Logger {
  def log(level: String, message: String): Unit = {
    val timestamp = java.time.LocalDateTime.now()
    // Simulate network logging
    println(s"Sending to $host:$port - [$timestamp] [$level] $message")
  }
}

// Abstract factory
abstract class LoggerFactory {
  def createLogger(): Logger

  // Template method using factory method
  def createConfiguredLogger(): Logger = {
    val logger = createLogger()
    configureLogger(logger)
    logger
  }

  protected def configureLogger(logger: Logger): Unit = {
    // Default configuration
    logger.log("INFO", "Logger initialized")
  }
}

// Concrete factories
class ConsoleLoggerFactory extends LoggerFactory {
  def createLogger(): Logger = new ConsoleLogger()
}

class FileLoggerFactory(filename: String) extends LoggerFactory {
  def createLogger(): Logger = new FileLogger(filename)

  override protected def configureLogger(logger: Logger): Unit = {
    super.configureLogger(logger)
    logger.log("INFO", s"Logging to file: $filename")
  }
}

class NetworkLoggerFactory(host: String, port: Int) extends LoggerFactory {
  def createLogger(): Logger = new NetworkLogger(host, port)

  override protected def configureLogger(logger: Logger): Unit = {
    super.configureLogger(logger)
    logger.log("INFO", s"Logging to network: $host:$port")
  }
}

// Factory registry
object LoggerFactory {
  def create(loggerType: String, config: Map[String, String] = Map.empty): Option[LoggerFactory] = {
    loggerType.toLowerCase match {
      case "console" => Some(new ConsoleLoggerFactory())
      case "file" => 
        config.get("filename").map(new FileLoggerFactory(_))
      case "network" =>
        for {
          host <- config.get("host")
          port <- config.get("port").flatMap(p => scala.util.Try(p.toInt).toOption)
        } yield new NetworkLoggerFactory(host, port)
      case _ => None
    }
  }
}

// Usage
val factories = List(
  ("console", Map.empty[String, String]),
  ("file", Map("filename" -> "app.log")),
  ("network", Map("host" -> "logger.company.com", "port" -> "9999"))
)

factories.foreach { case (loggerType, config) =>
  LoggerFactory.create(loggerType, config) match {
    case Some(factory) =>
      val logger = factory.createConfiguredLogger()
      logger.log("DEBUG", s"Testing $loggerType logger")
      logger.log("INFO", "Application started")
      logger.log("WARN", "Low memory warning")
      logger.log("ERROR", "Database connection failed")
      logger.close()
      println()
    case None =>
      println(s"Failed to create logger for type: $loggerType")
  }
}

Summary

In this lesson, you've mastered inheritance and class hierarchies in Scala:

Basic Inheritance: Extending classes, method overriding, and super calls
Abstract Classes: Sharing state and constructor logic
Inheritance vs Traits: When to use each approach
Sealed Classes: Type-safe hierarchies and exhaustive pattern matching
Design Patterns: Template method and factory patterns
Best Practices: Building maintainable inheritance hierarchies

Understanding inheritance helps you create well-structured object-oriented designs while avoiding common pitfalls.

What's Next

In the next lesson, we'll explore "Packages and Imports: Organizing Your Code." You'll learn how to structure large Scala applications using packages, manage imports, and create clean modular architectures.

This will teach you essential skills for building real-world Scala applications with proper organization and maintainability.

Ready to master code organization? Let's continue!