Companion Objects: The Perfect Partnership
Introduction
Companion objects are one of Scala's most elegant features. When a class and an object share the same name and are defined in the same file, they become companions. This special relationship allows them to access each other's private members and creates powerful patterns for clean API design, factory methods, and data manipulation.
Understanding companion objects is essential for writing idiomatic Scala code. They bridge the gap between object-oriented and functional programming, providing a natural place for utility methods, constructors, and pattern matching extractors.
The Companion Relationship
Basic Companion Setup
class User private(val id: Long, val username: String, val email: String,
private var _isActive: Boolean) {
// Private methods accessible to companion
private def setActive(active: Boolean): Unit = {
_isActive = active
}
def isActive: Boolean = _isActive
def deactivate(): User = {
setActive(false)
this
}
// Method that uses companion object
def toJson: String = User.toJsonString(this)
override def toString: String = s"User($id, $username, $email, active=$_isActive)"
}
// Companion object - same name, same file
object User {
// Companion can access private constructor and methods
private var nextId: Long = 1
// Factory method using private constructor
def create(username: String, email: String): User = {
val user = new User(nextId, username, email, true)
nextId += 1
user
}
// Factory method with validation
def createIfValid(username: String, email: String): Either[String, User] = {
if (username.length < 3) {
Left("Username must be at least 3 characters")
} else if (!email.contains("@")) {
Left("Invalid email format")
} else {
Right(create(username, email))
}
}
// Access private methods of class instances
def reactivate(user: User): User = {
user.setActive(true) // Can access private method
user
}
// Utility method that accesses private state
def toJsonString(user: User): String = {
s"""{"id": ${user.id}, "username": "${user.username}", "email": "${user.email}", "active": ${user._isActive}}"""
}
// Constants and utilities
val MIN_USERNAME_LENGTH: Int = 3
val MAX_USERNAME_LENGTH: Int = 50
def isValidUsername(username: String): Boolean = {
username.length >= MIN_USERNAME_LENGTH &&
username.length <= MAX_USERNAME_LENGTH &&
username.matches("[a-zA-Z0-9_]+")
}
}
// Usage
val user1 = User.create("alice", "alice@example.com")
println(user1) // User(1, alice, alice@example.com, active=true)
user1.deactivate()
println(user1.isActive) // false
User.reactivate(user1) // Companion can access private methods
println(user1.isActive) // true
User.createIfValid("ab", "invalid-email") match {
case Right(user) => println(s"Created: $user")
case Left(error) => println(s"Error: $error")
}
Apply and Unapply Methods
class Money private(val amount: BigDecimal, val currency: String) {
require(amount >= 0, "Amount cannot be negative")
def +(other: Money): Money = {
require(currency == other.currency, "Currency mismatch")
Money(amount + other.amount, currency)
}
def -(other: Money): Money = {
require(currency == other.currency, "Currency mismatch")
require(amount >= other.amount, "Insufficient funds")
Money(amount - other.amount, currency)
}
def *(multiplier: Double): Money = {
Money(amount * multiplier, currency)
}
def formatted: String = f"$amount%.2f $currency"
override def toString: String = formatted
override def equals(obj: Any): Boolean = obj match {
case other: Money => amount == other.amount && currency == other.currency
case _ => false
}
override def hashCode(): Int = (amount, currency).hashCode()
}
object Money {
// Apply method - makes the object callable like a function
def apply(amount: BigDecimal, currency: String): Money = {
new Money(amount, currency.toUpperCase)
}
// Convenience apply methods
def apply(amount: Double, currency: String): Money = {
apply(BigDecimal(amount), currency)
}
def apply(amount: Int, currency: String): Money = {
apply(BigDecimal(amount), currency)
}
// Unapply method - enables pattern matching (extractor)
def unapply(money: Money): Option[(BigDecimal, String)] = {
Some((money.amount, money.currency))
}
// Common currency constructors
def usd(amount: Double): Money = apply(amount, "USD")
def eur(amount: Double): Money = apply(amount, "EUR")
def gbp(amount: Double): Money = apply(amount, "GBP")
def jpy(amount: Double): Money = apply(amount, "JPY")
// Zero values for different currencies
val ZERO_USD: Money = usd(0)
val ZERO_EUR: Money = eur(0)
val ZERO_GBP: Money = gbp(0)
// Parsing from string
def parse(str: String): Option[Money] = {
val pattern = """(\d+(?:\.\d{2})?)\s+([A-Z]{3})""".r
str.trim match {
case pattern(amount, currency) =>
try {
Some(Money(BigDecimal(amount), currency))
} catch {
case _: NumberFormatException => None
}
case _ => None
}
}
// Conversion utilities (simplified rates)
private val exchangeRates: Map[String, Map[String, Double]] = Map(
"USD" -> Map("EUR" -> 0.85, "GBP" -> 0.73, "JPY" -> 110.0),
"EUR" -> Map("USD" -> 1.18, "GBP" -> 0.86, "JPY" -> 129.0),
"GBP" -> Map("USD" -> 1.37, "EUR" -> 1.16, "JPY" -> 150.0)
)
def convert(money: Money, toCurrency: String): Option[Money] = {
if (money.currency == toCurrency.toUpperCase) {
Some(money)
} else {
exchangeRates.get(money.currency)
.flatMap(_.get(toCurrency.toUpperCase))
.map(rate => Money(money.amount * rate, toCurrency))
}
}
}
// Usage with apply method
val price1 = Money(29.99, "USD") // Calls apply method
val price2 = Money.usd(49.99) // Convenience method
val price3 = Money.eur(39.99)
// Pattern matching with unapply
def analyzePrice(money: Money): String = money match {
case Money(amount, "USD") if amount > 100 => "Expensive in USD"
case Money(amount, "EUR") if amount > 100 => "Expensive in EUR"
case Money(amount, currency) if amount == 0 => s"Free in $currency"
case Money(amount, currency) => s"$amount $currency"
}
println(analyzePrice(Money.usd(150))) // "Expensive in USD"
println(analyzePrice(Money.ZERO_EUR)) // "Free in EUR"
// Parsing and conversion
Money.parse("25.50 USD") match {
case Some(money) =>
Money.convert(money, "EUR") match {
case Some(converted) => println(s"$money = $converted")
case None => println("Conversion failed")
}
case None => println("Parsing failed")
}
Advanced Factory Patterns
sealed trait DatabaseConnection {
def query(sql: String): List[Map[String, Any]]
def close(): Unit
}
class MySQLConnection private(host: String, port: Int, database: String) extends DatabaseConnection {
private var isOpen: Boolean = true
def query(sql: String): List[Map[String, Any]] = {
require(isOpen, "Connection is closed")
println(s"MySQL: Executing '$sql' on $host:$port/$database")
// Simulate query result
List(Map("result" -> s"MySQL result for: $sql"))
}
def close(): Unit = {
isOpen = false
println(s"MySQL connection to $host:$port/$database closed")
}
}
class PostgreSQLConnection private(host: String, port: Int, database: String) extends DatabaseConnection {
private var isOpen: Boolean = true
def query(sql: String): List[Map[String, Any]] = {
require(isOpen, "Connection is closed")
println(s"PostgreSQL: Executing '$sql' on $host:$port/$database")
List(Map("result" -> s"PostgreSQL result for: $sql"))
}
def close(): Unit = {
isOpen = false
println(s"PostgreSQL connection to $host:$port/$database closed")
}
}
object DatabaseConnection {
// Configuration case classes
case class ConnectionConfig(
host: String,
port: Int,
database: String,
username: String,
password: String,
poolSize: Int = 10,
timeout: Int = 30
)
// Factory method with different database types
def create(dbType: String, config: ConnectionConfig): Either[String, DatabaseConnection] = {
dbType.toLowerCase match {
case "mysql" =>
validateConfig(config) match {
case Some(error) => Left(error)
case None => Right(new MySQLConnection(config.host, config.port, config.database))
}
case "postgresql" | "postgres" =>
validateConfig(config) match {
case Some(error) => Left(error)
case None => Right(new PostgreSQLConnection(config.host, config.port, config.database))
}
case unsupported => Left(s"Unsupported database type: $unsupported")
}
}
// Validation
private def validateConfig(config: ConnectionConfig): Option[String] = {
if (config.host.isEmpty) Some("Host cannot be empty")
else if (config.port <= 0 || config.port > 65535) Some("Invalid port number")
else if (config.database.isEmpty) Some("Database name cannot be empty")
else if (config.username.isEmpty) Some("Username cannot be empty")
else None
}
// Predefined configurations
def localMySQL(database: String, username: String, password: String): ConnectionConfig = {
ConnectionConfig("localhost", 3306, database, username, password)
}
def localPostgreSQL(database: String, username: String, password: String): ConnectionConfig = {
ConnectionConfig("localhost", 5432, database, username, password)
}
// Connection pool management
private val connectionPools: scala.collection.mutable.Map[String, List[DatabaseConnection]] =
scala.collection.mutable.Map()
def createPool(poolName: String, dbType: String, config: ConnectionConfig): Either[String, Unit] = {
val connections = (1 to config.poolSize).map { _ =>
create(dbType, config)
}.toList
val (errors, validConnections) = connections.foldLeft((List.empty[String], List.empty[DatabaseConnection])) {
case ((errs, conns), Left(error)) => (errs :+ error, conns)
case ((errs, conns), Right(conn)) => (errs, conns :+ conn)
}
if (errors.nonEmpty) {
Left(s"Failed to create pool: ${errors.mkString(", ")}")
} else {
connectionPools(poolName) = validConnections
Right(())
}
}
def getFromPool(poolName: String): Option[DatabaseConnection] = {
connectionPools.get(poolName).flatMap(_.headOption)
}
def closePool(poolName: String): Unit = {
connectionPools.get(poolName).foreach { connections =>
connections.foreach(_.close())
connectionPools.remove(poolName)
}
}
}
// Usage
val config = DatabaseConnection.localMySQL("myapp", "user", "password")
DatabaseConnection.create("mysql", config) match {
case Right(connection) =>
val results = connection.query("SELECT * FROM users")
println(results)
connection.close()
case Left(error) =>
println(s"Failed to create connection: $error")
}
// Create connection pool
DatabaseConnection.createPool("main", "postgresql",
DatabaseConnection.localPostgreSQL("myapp", "user", "password")) match {
case Right(()) =>
println("Connection pool created successfully")
DatabaseConnection.getFromPool("main") match {
case Some(conn) =>
conn.query("SELECT COUNT(*) FROM products")
// Don't close - return to pool in real implementation
case None => println("No connections available")
}
case Left(error) => println(s"Pool creation failed: $error")
}
Extractors and Pattern Matching
Custom Extractors
class EmailAddress private(val localPart: String, val domain: String) {
def address: String = s"$localPart@$domain"
def isGmail: Boolean = domain.toLowerCase == "gmail.com"
def isCorporate: Boolean = !List("gmail.com", "yahoo.com", "hotmail.com", "outlook.com")
.contains(domain.toLowerCase)
override def toString: String = address
}
object EmailAddress {
def apply(address: String): Option[EmailAddress] = {
parse(address)
}
private def parse(address: String): Option[EmailAddress] = {
val trimmed = address.trim
if (trimmed.contains("@")) {
val parts = trimmed.split("@")
if (parts.length == 2 && parts(0).nonEmpty && parts(1).nonEmpty) {
Some(new EmailAddress(parts(0), parts(1)))
} else None
} else None
}
// Standard unapply for pattern matching
def unapply(email: EmailAddress): Option[(String, String)] = {
Some((email.localPart, email.domain))
}
// Custom extractors for specific patterns
object Gmail {
def unapply(email: EmailAddress): Option[String] = {
if (email.domain.toLowerCase == "gmail.com") Some(email.localPart)
else None
}
}
object Corporate {
def unapply(email: EmailAddress): Option[(String, String)] = {
if (email.isCorporate) Some((email.localPart, email.domain))
else None
}
}
object Domain {
def unapply(email: EmailAddress): Option[String] = Some(email.domain)
}
// Boolean extractor
object ValidForNewsletter {
def unapply(email: EmailAddress): Boolean = {
!email.domain.toLowerCase.contains("temp") &&
!email.localPart.toLowerCase.startsWith("noreply")
}
}
}
// Usage with pattern matching
def categorizeEmail(emailStr: String): String = {
EmailAddress(emailStr) match {
case Some(EmailAddress.Gmail(localPart)) =>
s"Gmail user: $localPart"
case Some(EmailAddress.Corporate(localPart, domain)) =>
s"Corporate email: $localPart at $domain"
case Some(EmailAddress.ValidForNewsletter()) =>
"Valid for newsletter"
case Some(email) =>
s"Regular email: $email"
case None =>
"Invalid email format"
}
}
// Test the pattern matching
val emails = List(
"john.doe@gmail.com",
"alice.smith@company.com",
"newsletter@temp.com",
"noreply@service.com",
"invalid-email",
"bob@yahoo.com"
)
emails.foreach { email =>
println(s"$email -> ${categorizeEmail(email)}")
}
Multi-Pattern Extractors
class PhoneNumber private(val countryCode: String, val areaCode: String, val number: String) {
def formatted: String = s"+$countryCode ($areaCode) $number"
def national: String = s"($areaCode) $number"
def international: String = s"+$countryCode$areaCode$number"
override def toString: String = formatted
}
object PhoneNumber {
// Multiple apply methods for different formats
def apply(countryCode: String, areaCode: String, number: String): Option[PhoneNumber] = {
if (isValidComponents(countryCode, areaCode, number)) {
Some(new PhoneNumber(countryCode, areaCode, number))
} else None
}
def parse(phoneStr: String): Option[PhoneNumber] = {
val cleaned = phoneStr.replaceAll("[^0-9+]", "")
// Try different patterns
parseInternational(cleaned)
.orElse(parseNational(cleaned))
.orElse(parseLocal(cleaned))
}
private def parseInternational(phone: String): Option[PhoneNumber] = {
// Pattern: +1234567890 or +12345678901
val pattern = """\+(\d{1,3})(\d{3})(\d{7})""".r
phone match {
case pattern(country, area, number) => PhoneNumber(country, area, number)
case _ => None
}
}
private def parseNational(phone: String): Option[PhoneNumber] = {
// Pattern: 1234567890 (assume US)
val pattern = """(\d{3})(\d{7})""".r
phone match {
case pattern(area, number) => PhoneNumber("1", area, number)
case _ => None
}
}
private def parseLocal(phone: String): Option[PhoneNumber] = {
// Pattern: 4567890 (assume US, need area code)
val pattern = """(\d{7})""".r
phone match {
case pattern(number) => PhoneNumber("1", "555", number) // Default area code
case _ => None
}
}
private def isValidComponents(countryCode: String, areaCode: String, number: String): Boolean = {
countryCode.matches("\\d{1,3}") &&
areaCode.matches("\\d{3}") &&
number.matches("\\d{7}")
}
// Multiple extractors
def unapply(phone: PhoneNumber): Option[(String, String, String)] = {
Some((phone.countryCode, phone.areaCode, phone.number))
}
// Extractor for US numbers
object US {
def unapply(phone: PhoneNumber): Option[(String, String)] = {
if (phone.countryCode == "1") Some((phone.areaCode, phone.number))
else None
}
}
// Extractor for formatted patterns
object Formatted {
def unapply(phoneStr: String): Option[PhoneNumber] = parse(phoneStr)
}
// Boolean extractors
object IsUS {
def unapply(phone: PhoneNumber): Boolean = phone.countryCode == "1"
}
object IsTollFree {
def unapply(phone: PhoneNumber): Boolean = {
phone.countryCode == "1" && Set("800", "833", "844", "855", "866", "877", "888").contains(phone.areaCode)
}
}
}
def analyzePhoneNumber(phoneStr: String): String = {
phoneStr match {
case PhoneNumber.Formatted(PhoneNumber.IsTollFree()) =>
"Toll-free number"
case PhoneNumber.Formatted(PhoneNumber.US(area, number)) =>
s"US number in area code $area: $number"
case PhoneNumber.Formatted(PhoneNumber(country, area, number)) =>
s"International: +$country ($area) $number"
case _ =>
"Invalid phone number format"
}
}
// Test phone number parsing and matching
val phoneNumbers = List(
"+15551234567",
"(555) 123-4567",
"555-123-4567",
"1234567",
"+447911123456",
"800-555-1234",
"invalid-phone"
)
phoneNumbers.foreach { phone =>
println(s"$phone -> ${analyzePhoneNumber(phone)}")
}
Serialization and Deserialization Patterns
JSON Serialization with Companion
import scala.util.{Try, Success, Failure}
case class Person(id: Long, name: String, age: Int, email: String, isActive: Boolean)
object Person {
// JSON serialization
def toJson(person: Person): String = {
s"""{
| "id": ${person.id},
| "name": "${person.name}",
| "age": ${person.age},
| "email": "${person.email}",
| "isActive": ${person.isActive}
|}""".stripMargin
}
// JSON deserialization
def fromJson(json: String): Try[Person] = {
Try {
val cleanJson = json.replaceAll("\\s+", "")
val idPattern = """"id":(\d+)""".r
val namePattern = """"name":"([^"]+)"""".r
val agePattern = """"age":(\d+)""".r
val emailPattern = """"email":"([^"]+)"""".r
val activePattern = """"isActive":(true|false)""".r
val id = idPattern.findFirstMatchIn(cleanJson).map(_.group(1).toLong)
.getOrElse(throw new IllegalArgumentException("Missing or invalid id"))
val name = namePattern.findFirstMatchIn(cleanJson).map(_.group(1))
.getOrElse(throw new IllegalArgumentException("Missing name"))
val age = agePattern.findFirstMatchIn(cleanJson).map(_.group(1).toInt)
.getOrElse(throw new IllegalArgumentException("Missing or invalid age"))
val email = emailPattern.findFirstMatchIn(cleanJson).map(_.group(1))
.getOrElse(throw new IllegalArgumentException("Missing email"))
val isActive = activePattern.findFirstMatchIn(cleanJson).map(_.group(1).toBoolean)
.getOrElse(throw new IllegalArgumentException("Missing or invalid isActive"))
Person(id, name, age, email, isActive)
}
}
// CSV serialization
def toCsv(person: Person): String = {
s"${person.id},${person.name},${person.age},${person.email},${person.isActive}"
}
def fromCsv(csv: String): Try[Person] = {
Try {
val parts = csv.split(",")
if (parts.length != 5) {
throw new IllegalArgumentException("CSV must have exactly 5 fields")
}
Person(
id = parts(0).trim.toLong,
name = parts(1).trim,
age = parts(2).trim.toInt,
email = parts(3).trim,
isActive = parts(4).trim.toBoolean
)
}
}
// Batch operations
def toJsonArray(people: List[Person]): String = {
val jsonObjects = people.map(toJson).mkString(",\n")
s"[\n$jsonObjects\n]"
}
def fromCsvLines(csvLines: List[String]): (List[Person], List[String]) = {
csvLines.foldLeft((List.empty[Person], List.empty[String])) {
case ((people, errors), line) =>
fromCsv(line) match {
case Success(person) => (people :+ person, errors)
case Failure(exception) => (people, errors :+ s"Line '$line': ${exception.getMessage}")
}
}
}
// Factory methods with validation
def create(name: String, age: Int, email: String): Either[String, Person] = {
if (name.trim.isEmpty) {
Left("Name cannot be empty")
} else if (age < 0 || age > 150) {
Left("Age must be between 0 and 150")
} else if (!email.contains("@")) {
Left("Invalid email format")
} else {
Right(Person(
id = System.currentTimeMillis(), // Simple ID generation
name = name.trim,
age = age,
email = email.trim.toLowerCase,
isActive = true
))
}
}
}
// Usage examples
val person = Person(1, "Alice Johnson", 30, "alice@example.com", true)
// JSON serialization
val json = Person.toJson(person)
println("JSON:")
println(json)
// JSON deserialization
Person.fromJson(json) match {
case Success(parsed) => println(s"Parsed from JSON: $parsed")
case Failure(error) => println(s"JSON parsing failed: $error")
}
// CSV operations
val csvData = List(
"2,Bob Smith,25,bob@example.com,true",
"3,Carol Davis,35,carol@example.com,false",
"invalid,line,here",
"4,David Wilson,40,david@example.com,true"
)
val (parsedPeople, errors) = Person.fromCsvLines(csvData)
println(s"\nParsed ${parsedPeople.length} people from CSV")
println(s"Errors: ${errors.length}")
errors.foreach(println)
// Batch JSON export
val jsonArray = Person.toJsonArray(parsedPeople)
println("\nJSON Array:")
println(jsonArray)
Summary
In this lesson, you've mastered companion objects and their powerful capabilities:
â
Companion Relationship: Classes and objects sharing names and private access
â
Factory Patterns: Clean object creation with validation and configuration
â
Apply/Unapply Methods: Function-like syntax and pattern matching extractors
â
Custom Extractors: Sophisticated pattern matching for domain-specific logic
â
Serialization Patterns: Converting objects to/from JSON, CSV, and other formats
â
Advanced Patterns: Connection pools, registries, and configuration management
Companion objects provide a clean, idiomatic way to organize related functionality and create elegant APIs that feel natural to use.
What's Next
In the next lesson, we'll explore "Case Classes: Data Made Simple." You'll learn about Scala's case classes, which automatically provide many of the patterns we've been implementing manually, including factory methods, pattern matching, and immutable data structures.
This will show you how Scala's language features can dramatically reduce boilerplate while providing powerful functionality.
Ready to discover case classes? Let's continue!
Comments
Be the first to comment on this lesson!