Advanced Concurrency Patterns: Actors, STM, and Parallel Collections
Modern applications require sophisticated concurrency patterns to achieve high performance and scalability. This comprehensive lesson explores advanced concurrency techniques in Scala, including actor systems, Software Transactional Memory, parallel collections, and lock-free programming patterns for building robust concurrent applications.
Advanced Actor Patterns and Coordination
Hierarchical Actor Systems and Supervision
import akka.actor.typed.scaladsl.{ActorContext, Behaviors}
import akka.actor.typed.{ActorRef, ActorSystem, Behavior, SupervisorStrategy}
import akka.actor.typed.scaladsl.adapter._
import scala.concurrent.duration._
import scala.util.{Failure, Success}
import java.time.Instant
import java.util.UUID
// Advanced actor system with hierarchical supervision
object WorkflowManagement {
// Message definitions for workflow system
sealed trait WorkflowMessage
case class StartWorkflow(
workflowId: String,
steps: List[WorkflowStep],
replyTo: ActorRef[WorkflowResponse]
) extends WorkflowMessage
case class StepCompleted(
workflowId: String,
stepId: String,
result: StepResult,
replyTo: ActorRef[WorkflowResponse]
) extends WorkflowMessage
case class WorkflowFailed(
workflowId: String,
error: String,
replyTo: ActorRef[WorkflowResponse]
) extends WorkflowMessage
case class GetWorkflowStatus(
workflowId: String,
replyTo: ActorRef[WorkflowStatusResponse]
) extends WorkflowMessage
// Response types
sealed trait WorkflowResponse
case class WorkflowStarted(workflowId: String) extends WorkflowResponse
case class WorkflowCompleted(workflowId: String, results: Map[String, StepResult]) extends WorkflowResponse
case class WorkflowError(workflowId: String, error: String) extends WorkflowResponse
sealed trait WorkflowStatusResponse
case class WorkflowStatus(
workflowId: String,
state: WorkflowState,
completedSteps: List[String],
totalSteps: Int,
startTime: Instant,
endTime: Option[Instant]
) extends WorkflowStatusResponse
// Domain types
case class WorkflowStep(
id: String,
name: String,
stepType: StepType,
dependencies: List[String] = List.empty,
timeout: FiniteDuration = 30.seconds,
retryConfig: RetryConfig = RetryConfig.default
)
sealed trait StepType
case class DataProcessing(inputSource: String, outputTarget: String) extends StepType
case class ExternalAPICall(endpoint: String, payload: String) extends StepType
case class DatabaseOperation(query: String, parameters: Map[String, Any]) extends StepType
case class FileOperation(operation: String, path: String) extends StepType
case class StepResult(
stepId: String,
success: Boolean,
data: Map[String, Any] = Map.empty,
error: Option[String] = None,
executionTime: FiniteDuration,
attempts: Int = 1
)
case class RetryConfig(
maxRetries: Int = 3,
backoffStrategy: BackoffStrategy = BackoffStrategy.Exponential(1.second, 2.0),
retryCondition: Throwable => Boolean = _ => true
)
object RetryConfig {
val default: RetryConfig = RetryConfig()
}
sealed trait BackoffStrategy
object BackoffStrategy {
case class Fixed(delay: FiniteDuration) extends BackoffStrategy
case class Exponential(initialDelay: FiniteDuration, multiplier: Double) extends BackoffStrategy
case class Linear(baseDelay: FiniteDuration, increment: FiniteDuration) extends BackoffStrategy
}
sealed trait WorkflowState
case object Pending extends WorkflowState
case object Running extends WorkflowState
case object Completed extends WorkflowState
case object Failed extends WorkflowState
case object Cancelled extends WorkflowState
// Workflow Coordinator - top-level supervisor
object WorkflowCoordinator {
def apply(): Behavior[WorkflowMessage] = {
Behaviors.setup { context =>
context.log.info("Workflow Coordinator started")
val workflowExecutors = scala.collection.mutable.Map[String, ActorRef[WorkflowExecutor.Protocol]]()
Behaviors.receive { (context, message) =>
message match {
case StartWorkflow(workflowId, steps, replyTo) =>
context.log.info(s"Starting workflow: $workflowId")
val executorBehavior = Behaviors.supervise(WorkflowExecutor(workflowId, steps, replyTo))
.onFailure[Exception](SupervisorStrategy.restart.withLimit(3, 1.minute))
val executor = context.spawn(executorBehavior, s"workflow-executor-$workflowId")
workflowExecutors(workflowId) = executor
executor ! WorkflowExecutor.Execute()
replyTo ! WorkflowStarted(workflowId)
Behaviors.same
case GetWorkflowStatus(workflowId, replyTo) =>
workflowExecutors.get(workflowId) match {
case Some(executor) =>
executor ! WorkflowExecutor.GetStatus(replyTo)
case None =>
replyTo ! WorkflowStatus(
workflowId = workflowId,
state = Failed,
completedSteps = List.empty,
totalSteps = 0,
startTime = Instant.now(),
endTime = Some(Instant.now())
)
}
Behaviors.same
case _ => Behaviors.unhandled
}
}
}
}
}
// Workflow Executor - manages individual workflow execution
object WorkflowExecutor {
sealed trait Protocol
case class Execute() extends Protocol
case class StepExecuted(stepId: String, result: StepResult) extends Protocol
case class GetStatus(replyTo: ActorRef[WorkflowStatusResponse]) extends Protocol
def apply(
workflowId: String,
steps: List[WorkflowStep],
originalReplyTo: ActorRef[WorkflowResponse]
): Behavior[Protocol] = {
Behaviors.setup { context =>
val startTime = Instant.now()
val stepActors = scala.collection.mutable.Map[String, ActorRef[StepExecutor.Protocol]]()
def createStepActors(): Unit = {
steps.foreach { step =>
val stepBehavior = Behaviors.supervise(StepExecutor(step))
.onFailure[Exception](SupervisorStrategy.restart.withLimit(step.retryConfig.maxRetries, 1.minute))
val stepActor = context.spawn(stepBehavior, s"step-${step.id}")
stepActors(step.id) = stepActor
}
}
createStepActors()
def active(
state: WorkflowState,
completedSteps: Set[String],
stepResults: Map[String, StepResult],
pendingSteps: Set[String]
): Behavior[Protocol] = {
Behaviors.receive { (context, message) =>
message match {
case Execute() =>
context.log.info(s"Executing workflow: $workflowId")
// Find steps with no dependencies to start execution
val readySteps = steps.filter(step =>
step.dependencies.forall(completedSteps.contains) &&
!completedSteps.contains(step.id) &&
!pendingSteps.contains(step.id)
)
readySteps.foreach { step =>
context.log.info(s"Starting step: ${step.id}")
stepActors.get(step.id).foreach { actor =>
actor ! StepExecutor.ExecuteStep(stepResults, context.self)
}
}
active(
state = Running,
completedSteps = completedSteps,
stepResults = stepResults,
pendingSteps = pendingSteps ++ readySteps.map(_.id).toSet
)
case StepExecuted(stepId, result) =>
context.log.info(s"Step completed: $stepId, success: ${result.success}")
val newCompletedSteps = completedSteps + stepId
val newStepResults = stepResults + (stepId -> result)
val newPendingSteps = pendingSteps - stepId
if (!result.success) {
// Workflow failed
originalReplyTo ! WorkflowError(workflowId, result.error.getOrElse("Step failed"))
active(Failed, newCompletedSteps, newStepResults, newPendingSteps)
} else if (newCompletedSteps.size == steps.size) {
// Workflow completed
originalReplyTo ! WorkflowCompleted(workflowId, newStepResults)
active(Completed, newCompletedSteps, newStepResults, newPendingSteps)
} else {
// Continue execution - find newly ready steps
val readySteps = steps.filter(step =>
step.dependencies.forall(newCompletedSteps.contains) &&
!newCompletedSteps.contains(step.id) &&
!newPendingSteps.contains(step.id)
)
readySteps.foreach { step =>
context.log.info(s"Starting dependent step: ${step.id}")
stepActors.get(step.id).foreach { actor =>
actor ! StepExecutor.ExecuteStep(newStepResults, context.self)
}
}
active(
state = Running,
completedSteps = newCompletedSteps,
stepResults = newStepResults,
pendingSteps = newPendingSteps ++ readySteps.map(_.id).toSet
)
}
case GetStatus(replyTo) =>
replyTo ! WorkflowStatus(
workflowId = workflowId,
state = state,
completedSteps = completedSteps.toList,
totalSteps = steps.size,
startTime = startTime,
endTime = if (state == Completed || state == Failed) Some(Instant.now()) else None
)
Behaviors.same
}
}
}
active(Pending, Set.empty, Map.empty, Set.empty)
}
}
}
// Step Executor - executes individual workflow steps
object StepExecutor {
sealed trait Protocol
case class ExecuteStep(
previousResults: Map[String, StepResult],
replyTo: ActorRef[WorkflowExecutor.Protocol]
) extends Protocol
def apply(step: WorkflowStep): Behavior[Protocol] = {
Behaviors.setup { context =>
def executeWithRetry(
attempt: Int,
previousResults: Map[String, StepResult],
replyTo: ActorRef[WorkflowExecutor.Protocol]
): Unit = {
val startTime = System.nanoTime()
try {
// Execute the actual step based on its type
val success = executeStepLogic(step, previousResults, context)
val endTime = System.nanoTime()
val executionTime = FiniteDuration(endTime - startTime, java.util.concurrent.TimeUnit.NANOSECONDS)
val result = StepResult(
stepId = step.id,
success = success,
data = Map("stepType" -> step.stepType.toString),
error = None,
executionTime = executionTime,
attempts = attempt
)
replyTo ! WorkflowExecutor.StepExecuted(step.id, result)
} catch {
case exception: Exception =>
val endTime = System.nanoTime()
val executionTime = FiniteDuration(endTime - startTime, java.util.concurrent.TimeUnit.NANOSECONDS)
if (attempt < step.retryConfig.maxRetries && step.retryConfig.retryCondition(exception)) {
context.log.warn(s"Step ${step.id} failed, retrying (attempt $attempt): ${exception.getMessage}")
val delay = calculateRetryDelay(step.retryConfig.backoffStrategy, attempt)
context.scheduleOnce(delay, context.self, ExecuteStep(previousResults, replyTo))
} else {
context.log.error(s"Step ${step.id} failed after $attempt attempts: ${exception.getMessage}")
val result = StepResult(
stepId = step.id,
success = false,
error = Some(exception.getMessage),
executionTime = executionTime,
attempts = attempt
)
replyTo ! WorkflowExecutor.StepExecuted(step.id, result)
}
}
}
Behaviors.receive { (context, message) =>
message match {
case ExecuteStep(previousResults, replyTo) =>
executeWithRetry(1, previousResults, replyTo)
Behaviors.same
}
}
}
}
private def executeStepLogic(
step: WorkflowStep,
previousResults: Map[String, StepResult],
context: ActorContext[Protocol]
): Boolean = {
context.log.info(s"Executing step: ${step.id} of type: ${step.stepType}")
step.stepType match {
case DataProcessing(inputSource, outputTarget) =>
// Simulate data processing
Thread.sleep(scala.util.Random.nextInt(1000) + 500)
true
case ExternalAPICall(endpoint, payload) =>
// Simulate API call
Thread.sleep(scala.util.Random.nextInt(2000) + 1000)
scala.util.Random.nextDouble() > 0.1 // 90% success rate
case DatabaseOperation(query, parameters) =>
// Simulate database operation
Thread.sleep(scala.util.Random.nextInt(800) + 200)
true
case FileOperation(operation, path) =>
// Simulate file operation
Thread.sleep(scala.util.Random.nextInt(500) + 100)
true
}
}
private def calculateRetryDelay(strategy: BackoffStrategy, attempt: Int): FiniteDuration = {
strategy match {
case BackoffStrategy.Fixed(delay) => delay
case BackoffStrategy.Exponential(initialDelay, multiplier) =>
FiniteDuration((initialDelay.toMillis * math.pow(multiplier, attempt - 1)).toLong, java.util.concurrent.TimeUnit.MILLISECONDS)
case BackoffStrategy.Linear(baseDelay, increment) =>
baseDelay + (increment * (attempt - 1))
}
}
}
}
// Actor-based concurrent cache with TTL and eviction policies
object ConcurrentCache {
sealed trait CacheMessage[K, V]
case class Put[K, V](key: K, value: V, ttl: Option[FiniteDuration] = None) extends CacheMessage[K, V]
case class Get[K, V](key: K, replyTo: ActorRef[Option[V]]) extends CacheMessage[K, V]
case class Remove[K, V](key: K) extends CacheMessage[K, V]
case class Clear[K, V]() extends CacheMessage[K, V]
case class GetStats[K, V](replyTo: ActorRef[CacheStats]) extends CacheMessage[K, V]
case class EvictExpired[K, V]() extends CacheMessage[K, V]
case class CacheEntry[V](
value: V,
createdAt: Instant,
expiresAt: Option[Instant],
lastAccessed: Instant,
accessCount: Long
) {
def isExpired: Boolean = expiresAt.exists(_.isBefore(Instant.now()))
def touch(): CacheEntry[V] = copy(lastAccessed = Instant.now(), accessCount = accessCount + 1)
}
case class CacheStats(
size: Int,
hitRate: Double,
missCount: Long,
hitCount: Long,
evictionCount: Long
)
sealed trait EvictionPolicy
case object LRU extends EvictionPolicy
case object LFU extends EvictionPolicy
case object FIFO extends EvictionPolicy
def apply[K, V](
maxSize: Int = 1000,
defaultTTL: Option[FiniteDuration] = None,
evictionPolicy: EvictionPolicy = LRU
): Behavior[CacheMessage[K, V]] = {
Behaviors.setup { context =>
// Schedule periodic cleanup
context.scheduleOnce(1.minute, context.self, EvictExpired())
def active(
cache: Map[K, CacheEntry[V]],
hitCount: Long,
missCount: Long,
evictionCount: Long
): Behavior[CacheMessage[K, V]] = {
Behaviors.receive { (context, message) =>
message match {
case Put(key, value, ttl) =>
val now = Instant.now()
val effectiveTTL = ttl.orElse(defaultTTL)
val expiresAt = effectiveTTL.map(now.plusMillis(_.toMillis))
val entry = CacheEntry(
value = value,
createdAt = now,
expiresAt = expiresAt,
lastAccessed = now,
accessCount = 1
)
val (newCache, newEvictionCount) = if (cache.size >= maxSize && !cache.contains(key)) {
val (evictedCache, evictedKey) = evictEntry(cache, evictionPolicy)
context.log.debug(s"Evicted key: $evictedKey due to cache size limit")
(evictedCache + (key -> entry), evictionCount + 1)
} else {
(cache + (key -> entry), evictionCount)
}
active(newCache, hitCount, missCount, newEvictionCount)
case Get(key, replyTo) =>
cache.get(key) match {
case Some(entry) if !entry.isExpired =>
replyTo ! Some(entry.value)
val updatedEntry = entry.touch()
active(cache + (key -> updatedEntry), hitCount + 1, missCount, evictionCount)
case Some(entry) if entry.isExpired =>
replyTo ! None
active(cache - key, hitCount, missCount + 1, evictionCount)
case None =>
replyTo ! None
active(cache, hitCount, missCount + 1, evictionCount)
}
case Remove(key) =>
active(cache - key, hitCount, missCount, evictionCount)
case Clear() =>
active(Map.empty, hitCount, missCount, evictionCount)
case GetStats(replyTo) =>
val totalRequests = hitCount + missCount
val hitRate = if (totalRequests > 0) hitCount.toDouble / totalRequests else 0.0
val stats = CacheStats(
size = cache.size,
hitRate = hitRate,
missCount = missCount,
hitCount = hitCount,
evictionCount = evictionCount
)
replyTo ! stats
Behaviors.same
case EvictExpired() =>
val now = Instant.now()
val (expiredKeys, validCache) = cache.partition { case (_, entry) => entry.isExpired }
if (expiredKeys.nonEmpty) {
context.log.debug(s"Evicted ${expiredKeys.size} expired entries")
}
// Schedule next cleanup
context.scheduleOnce(1.minute, context.self, EvictExpired())
active(validCache, hitCount, missCount, evictionCount + expiredKeys.size)
}
}
}
active(Map.empty, 0, 0, 0)
}
}
private def evictEntry[K, V](
cache: Map[K, CacheEntry[V]],
policy: EvictionPolicy
): (Map[K, CacheEntry[V]], K) = {
if (cache.isEmpty) {
throw new IllegalStateException("Cannot evict from empty cache")
}
val keyToEvict = policy match {
case LRU =>
cache.minBy { case (_, entry) => entry.lastAccessed }._1
case LFU =>
cache.minBy { case (_, entry) => entry.accessCount }._1
case FIFO =>
cache.minBy { case (_, entry) => entry.createdAt }._1
}
(cache - keyToEvict, keyToEvict)
}
}
Software Transactional Memory (STM) Patterns
// STM implementation for concurrent data structures
import scala.concurrent.stm._
import scala.concurrent.{Future, ExecutionContext}
import java.util.concurrent.atomic.AtomicReference
import scala.annotation.tailrec
// STM-based concurrent data structures
object STMCollections {
// STM-based concurrent map
class STMMap[K, V] {
private val data = Ref(Map.empty[K, V])
def put(key: K, value: V): Unit = {
atomic { implicit txn =>
data() = data() + (key -> value)
}
}
def get(key: K): Option[V] = {
atomic { implicit txn =>
data().get(key)
}
}
def remove(key: K): Option[V] = {
atomic { implicit txn =>
val current = data()
val result = current.get(key)
data() = current - key
result
}
}
def size: Int = {
atomic { implicit txn =>
data().size
}
}
def contains(key: K): Boolean = {
atomic { implicit txn =>
data().contains(key)
}
}
def putIfAbsent(key: K, value: V): Option[V] = {
atomic { implicit txn =>
val current = data()
current.get(key) match {
case Some(existingValue) => Some(existingValue)
case None =>
data() = current + (key -> value)
None
}
}
}
def replace(key: K, oldValue: V, newValue: V): Boolean = {
atomic { implicit txn =>
val current = data()
current.get(key) match {
case Some(value) if value == oldValue =>
data() = current + (key -> newValue)
true
case _ => false
}
}
}
def computeIfAbsent(key: K, mappingFunction: K => V): V = {
atomic { implicit txn =>
val current = data()
current.get(key) match {
case Some(value) => value
case None =>
val newValue = mappingFunction(key)
data() = current + (key -> newValue)
newValue
}
}
}
def merge(key: K, value: V, remappingFunction: (V, V) => V): V = {
atomic { implicit txn =>
val current = data()
val newValue = current.get(key) match {
case Some(existingValue) => remappingFunction(existingValue, value)
case None => value
}
data() = current + (key -> newValue)
newValue
}
}
def atomicUpdate[R](f: Map[K, V] => (Map[K, V], R)): R = {
atomic { implicit txn =>
val current = data()
val (newMap, result) = f(current)
data() = newMap
result
}
}
def snapshot: Map[K, V] = {
atomic { implicit txn =>
data()
}
}
}
// STM-based concurrent queue
class STMQueue[T] {
private val front = Ref(List.empty[T])
private val rear = Ref(List.empty[T])
private val sizeRef = Ref(0)
def enqueue(item: T): Unit = {
atomic { implicit txn =>
rear() = item :: rear()
sizeRef() = sizeRef() + 1
}
}
def dequeue(): Option[T] = {
atomic { implicit txn =>
(front(), rear()) match {
case (Nil, Nil) => None
case (x :: xs, _) =>
front() = xs
sizeRef() = sizeRef() - 1
Some(x)
case (Nil, _) =>
val reversed = rear().reverse
front() = reversed.tail
rear() = Nil
sizeRef() = sizeRef() - 1
Some(reversed.head)
}
}
}
def peek: Option[T] = {
atomic { implicit txn =>
(front(), rear()) match {
case (Nil, Nil) => None
case (x :: _, _) => Some(x)
case (Nil, _) => rear().reverse.headOption
}
}
}
def size: Int = sizeRef.single.get
def isEmpty: Boolean = size == 0
def clear(): Unit = {
atomic { implicit txn =>
front() = Nil
rear() = Nil
sizeRef() = 0
}
}
}
// STM-based priority queue
class STMPriorityQueue[T](implicit ordering: Ordering[T]) {
private val heap = Ref(Vector.empty[T])
def enqueue(item: T): Unit = {
atomic { implicit txn =>
val current = heap()
heap() = insertIntoHeap(current, item)
}
}
def dequeue(): Option[T] = {
atomic { implicit txn =>
val current = heap()
if (current.isEmpty) {
None
} else {
val result = current.head
heap() = removeFromHeap(current)
Some(result)
}
}
}
def peek: Option[T] = {
atomic { implicit txn =>
heap().headOption
}
}
def size: Int = {
atomic { implicit txn =>
heap().size
}
}
def isEmpty: Boolean = size == 0
private def insertIntoHeap(heap: Vector[T], item: T): Vector[T] = {
val newHeap = heap :+ item
bubbleUp(newHeap, newHeap.length - 1)
}
private def removeFromHeap(heap: Vector[T]): Vector[T] = {
if (heap.length <= 1) {
Vector.empty
} else {
val newHeap = heap.updated(0, heap.last).init
bubbleDown(newHeap, 0)
}
}
private def bubbleUp(heap: Vector[T], index: Int): Vector[T] = {
if (index == 0) heap
else {
val parentIndex = (index - 1) / 2
if (ordering.lt(heap(index), heap(parentIndex))) {
val swapped = heap.updated(index, heap(parentIndex)).updated(parentIndex, heap(index))
bubbleUp(swapped, parentIndex)
} else {
heap
}
}
}
private def bubbleDown(heap: Vector[T], index: Int): Vector[T] = {
val leftChild = 2 * index + 1
val rightChild = 2 * index + 2
if (leftChild >= heap.length) {
heap
} else {
val minChildIndex = if (rightChild >= heap.length || ordering.lt(heap(leftChild), heap(rightChild))) {
leftChild
} else {
rightChild
}
if (ordering.lt(heap(minChildIndex), heap(index))) {
val swapped = heap.updated(index, heap(minChildIndex)).updated(minChildIndex, heap(index))
bubbleDown(swapped, minChildIndex)
} else {
heap
}
}
}
}
// STM-based graph structure
class STMGraph[V, E] {
private val vertices = Ref(Set.empty[V])
private val edges = Ref(Map.empty[V, Map[V, E]])
def addVertex(vertex: V): Boolean = {
atomic { implicit txn =>
if (vertices().contains(vertex)) {
false
} else {
vertices() = vertices() + vertex
edges() = edges() + (vertex -> Map.empty[V, E])
true
}
}
}
def removeVertex(vertex: V): Boolean = {
atomic { implicit txn =>
if (!vertices().contains(vertex)) {
false
} else {
vertices() = vertices() - vertex
edges() = edges() - vertex
// Remove edges pointing to this vertex
edges() = edges().map { case (from, toMap) =>
from -> (toMap - vertex)
}
true
}
}
}
def addEdge(from: V, to: V, edge: E): Boolean = {
atomic { implicit txn =>
if (!vertices().contains(from) || !vertices().contains(to)) {
false
} else {
val currentEdges = edges()
val fromEdges = currentEdges.getOrElse(from, Map.empty)
edges() = currentEdges + (from -> (fromEdges + (to -> edge)))
true
}
}
}
def removeEdge(from: V, to: V): Boolean = {
atomic { implicit txn =>
val currentEdges = edges()
currentEdges.get(from) match {
case Some(fromEdges) if fromEdges.contains(to) =>
edges() = currentEdges + (from -> (fromEdges - to))
true
case _ => false
}
}
}
def getEdge(from: V, to: V): Option[E] = {
atomic { implicit txn =>
edges().get(from).flatMap(_.get(to))
}
}
def getNeighbors(vertex: V): Set[V] = {
atomic { implicit txn =>
edges().get(vertex).map(_.keySet).getOrElse(Set.empty)
}
}
def getVertices: Set[V] = {
atomic { implicit txn =>
vertices()
}
}
def shortestPath(from: V, to: V): Option[List[V]] = {
atomic { implicit txn =>
if (!vertices().contains(from) || !vertices().contains(to)) {
None
} else {
bfs(from, to)
}
}
}
private def bfs(start: V, target: V)(implicit txn: InTxn): Option[List[V]] = {
import scala.collection.mutable
val queue = mutable.Queue[(V, List[V])]()
val visited = mutable.Set[V]()
queue.enqueue((start, List(start)))
visited.add(start)
while (queue.nonEmpty) {
val (current, path) = queue.dequeue()
if (current == target) {
return Some(path)
}
val neighbors = edges().getOrElse(current, Map.empty).keySet
for (neighbor <- neighbors if !visited.contains(neighbor)) {
visited.add(neighbor)
queue.enqueue((neighbor, path :+ neighbor))
}
}
None
}
}
}
// STM-based bank account system for demonstration
object BankingSTM {
case class Account(id: String, balance: Ref[BigDecimal], transactions: Ref[List[Transaction]])
case class Transaction(
id: String,
accountId: String,
transactionType: TransactionType,
amount: BigDecimal,
timestamp: Instant,
description: String
)
sealed trait TransactionType
case object Deposit extends TransactionType
case object Withdrawal extends TransactionType
case object Transfer extends TransactionType
class Bank {
private val accounts = Ref(Map.empty[String, Account])
def createAccount(accountId: String, initialBalance: BigDecimal = BigDecimal(0)): Either[String, Account] = {
atomic { implicit txn =>
if (accounts().contains(accountId)) {
Left(s"Account $accountId already exists")
} else {
val account = Account(
id = accountId,
balance = Ref(initialBalance),
transactions = Ref(List.empty)
)
accounts() = accounts() + (accountId -> account)
Right(account)
}
}
}
def getAccount(accountId: String): Option[Account] = {
atomic { implicit txn =>
accounts().get(accountId)
}
}
def deposit(accountId: String, amount: BigDecimal, description: String = "Deposit"): Either[String, Transaction] = {
if (amount <= 0) {
Left("Deposit amount must be positive")
} else {
atomic { implicit txn =>
accounts().get(accountId) match {
case Some(account) =>
val transaction = Transaction(
id = java.util.UUID.randomUUID().toString,
accountId = accountId,
transactionType = Deposit,
amount = amount,
timestamp = Instant.now(),
description = description
)
account.balance() = account.balance() + amount
account.transactions() = transaction :: account.transactions()
Right(transaction)
case None =>
Left(s"Account $accountId not found")
}
}
}
}
def withdraw(accountId: String, amount: BigDecimal, description: String = "Withdrawal"): Either[String, Transaction] = {
if (amount <= 0) {
Left("Withdrawal amount must be positive")
} else {
atomic { implicit txn =>
accounts().get(accountId) match {
case Some(account) =>
if (account.balance() < amount) {
Left("Insufficient funds")
} else {
val transaction = Transaction(
id = java.util.UUID.randomUUID().toString,
accountId = accountId,
transactionType = Withdrawal,
amount = amount,
timestamp = Instant.now(),
description = description
)
account.balance() = account.balance() - amount
account.transactions() = transaction :: account.transactions()
Right(transaction)
}
case None =>
Left(s"Account $accountId not found")
}
}
}
}
def transfer(fromAccountId: String, toAccountId: String, amount: BigDecimal, description: String = "Transfer"): Either[String, (Transaction, Transaction)] = {
if (amount <= 0) {
Left("Transfer amount must be positive")
} else if (fromAccountId == toAccountId) {
Left("Cannot transfer to the same account")
} else {
atomic { implicit txn =>
(accounts().get(fromAccountId), accounts().get(toAccountId)) match {
case (Some(fromAccount), Some(toAccount)) =>
if (fromAccount.balance() < amount) {
Left("Insufficient funds")
} else {
val withdrawalTransaction = Transaction(
id = java.util.UUID.randomUUID().toString,
accountId = fromAccountId,
transactionType = Transfer,
amount = -amount,
timestamp = Instant.now(),
description = s"Transfer to $toAccountId: $description"
)
val depositTransaction = Transaction(
id = java.util.UUID.randomUUID().toString,
accountId = toAccountId,
transactionType = Transfer,
amount = amount,
timestamp = Instant.now(),
description = s"Transfer from $fromAccountId: $description"
)
fromAccount.balance() = fromAccount.balance() - amount
toAccount.balance() = toAccount.balance() + amount
fromAccount.transactions() = withdrawalTransaction :: fromAccount.transactions()
toAccount.transactions() = depositTransaction :: toAccount.transactions()
Right((withdrawalTransaction, depositTransaction))
}
case (None, _) =>
Left(s"Source account $fromAccountId not found")
case (_, None) =>
Left(s"Destination account $toAccountId not found")
}
}
}
}
def getBalance(accountId: String): Either[String, BigDecimal] = {
atomic { implicit txn =>
accounts().get(accountId) match {
case Some(account) => Right(account.balance())
case None => Left(s"Account $accountId not found")
}
}
}
def getTransactionHistory(accountId: String): Either[String, List[Transaction]] = {
atomic { implicit txn =>
accounts().get(accountId) match {
case Some(account) => Right(account.transactions().reverse)
case None => Left(s"Account $accountId not found")
}
}
}
def getTotalBalance: BigDecimal = {
atomic { implicit txn =>
accounts().values.map(_.balance()).sum
}
}
def getAllAccounts: Map[String, BigDecimal] = {
atomic { implicit txn =>
accounts().map { case (id, account) =>
id -> account.balance()
}
}
}
}
}
Parallel Collections and Lock-Free Programming
Advanced Parallel Collection Patterns
import scala.collection.parallel.CollectionConverters._
import scala.collection.parallel.ParSeq
import java.util.concurrent.{ForkJoinPool, ThreadLocalRandom}
import java.util.concurrent.atomic.{AtomicLong, AtomicReference, AtomicBoolean}
import scala.annotation.tailrec
// Advanced parallel collection operations
object ParallelCollectionPatterns {
// Custom parallel operations with fine-grained control
def parallelMapReduce[T, U, R](
collection: Seq[T],
parallelism: Int = Runtime.getRuntime.availableProcessors()
)(
mapFunction: T => U,
reduceFunction: (U, U) => U,
zeroValue: U
): R = {
val customPool = new ForkJoinPool(parallelism)
try {
val parCollection = collection.par
parCollection.tasksupport = new scala.collection.parallel.ForkJoinTaskSupport(customPool)
parCollection
.map(mapFunction)
.fold(zeroValue)(reduceFunction)
.asInstanceOf[R]
} finally {
customPool.shutdown()
}
}
// Parallel data processing pipeline
class ParallelDataPipeline[T] {
case class PipelineStage[A, B](
name: String,
processor: A => B,
parallelism: Int = Runtime.getRuntime.availableProcessors(),
batchSize: Int = 1000
)
def processBatched[U](
data: Seq[T],
stages: List[PipelineStage[_, _]]
): Seq[Any] = {
stages.foldLeft(data.asInstanceOf[Seq[Any]]) { (currentData, stage) =>
val stageProcessor = stage.processor.asInstanceOf[Any => Any]
currentData.grouped(stage.batchSize).toSeq.par.flatMap { batch =>
batch.par.map(stageProcessor)
}.seq
}
}
def processStreaming[U](
data: Iterator[T],
stages: List[PipelineStage[_, _]],
bufferSize: Int = 10000
): Iterator[Any] = {
data.grouped(bufferSize).flatMap { batch =>
processBatched(batch.toSeq, stages)
}
}
}
// Parallel aggregation operations
object ParallelAggregations {
case class AggregationResult[K, V](
groups: Map[K, V],
totalProcessed: Long,
processingTime: Long
)
def parallelGroupBy[T, K, V](
data: Seq[T],
keyExtractor: T => K,
valueExtractor: T => V,
combiner: (V, V) => V,
parallelism: Int = Runtime.getRuntime.availableProcessors()
): AggregationResult[K, V] = {
val startTime = System.nanoTime()
val customPool = new ForkJoinPool(parallelism)
try {
val parData = data.par
parData.tasksupport = new scala.collection.parallel.ForkJoinTaskSupport(customPool)
val groups = parData
.groupBy(keyExtractor)
.map { case (key, values) =>
key -> values.map(valueExtractor).reduce(combiner)
}
.seq
val endTime = System.nanoTime()
AggregationResult(
groups = groups,
totalProcessed = data.length,
processingTime = endTime - startTime
)
} finally {
customPool.shutdown()
}
}
def parallelStatistics[T](
data: Seq[T],
extractor: T => Double,
parallelism: Int = Runtime.getRuntime.availableProcessors()
): StatisticsResult = {
val customPool = new ForkJoinPool(parallelism)
try {
val parData = data.par
parData.tasksupport = new scala.collection.parallel.ForkJoinTaskSupport(customPool)
val values = parData.map(extractor)
val count = values.length
val sum = values.sum
val mean = sum / count
val sortedValues = values.seq.sorted
val variance = values.map(v => math.pow(v - mean, 2)).sum / count
val stdDev = math.sqrt(variance)
val median = if (count % 2 == 0) {
(sortedValues(count / 2 - 1) + sortedValues(count / 2)) / 2.0
} else {
sortedValues(count / 2)
}
StatisticsResult(
count = count,
sum = sum,
mean = mean,
median = median,
min = values.min,
max = values.max,
variance = variance,
standardDeviation = stdDev
)
} finally {
customPool.shutdown()
}
}
}
case class StatisticsResult(
count: Int,
sum: Double,
mean: Double,
median: Double,
min: Double,
max: Double,
variance: Double,
standardDeviation: Double
)
}
// Lock-free data structures
object LockFreeDataStructures {
// Lock-free stack using CAS operations
class LockFreeStack[T] {
private val head = new AtomicReference[Node[T]](null)
private case class Node[T](value: T, next: Node[T])
@tailrec
final def push(value: T): Unit = {
val currentHead = head.get()
val newNode = Node(value, currentHead)
if (!head.compareAndSet(currentHead, newNode)) {
push(value) // Retry on failure
}
}
@tailrec
final def pop(): Option[T] = {
val currentHead = head.get()
if (currentHead == null) {
None
} else {
if (head.compareAndSet(currentHead, currentHead.next)) {
Some(currentHead.value)
} else {
pop() // Retry on failure
}
}
}
def peek: Option[T] = {
val currentHead = head.get()
if (currentHead == null) None else Some(currentHead.value)
}
def isEmpty: Boolean = head.get() == null
def size: Int = {
@tailrec
def countNodes(node: Node[T], acc: Int): Int = {
if (node == null) acc
else countNodes(node.next, acc + 1)
}
countNodes(head.get(), 0)
}
}
// Lock-free queue using CAS operations
class LockFreeQueue[T] {
private case class Node[T](value: T, next: AtomicReference[Node[T]])
private val head = new AtomicReference[Node[T]](null)
private val tail = new AtomicReference[Node[T]](null)
private val size = new AtomicLong(0)
def enqueue(value: T): Unit = {
val newNode = Node(value, new AtomicReference[Node[T]](null))
while (true) {
val currentTail = tail.get()
if (currentTail == null) {
// Queue is empty
if (head.compareAndSet(null, newNode) && tail.compareAndSet(null, newNode)) {
size.incrementAndGet()
return
}
} else {
val tailNext = currentTail.next.get()
if (tailNext == null) {
// Try to link new node at the end
if (currentTail.next.compareAndSet(null, newNode)) {
// Move tail pointer
tail.compareAndSet(currentTail, newNode)
size.incrementAndGet()
return
}
} else {
// Help move tail pointer
tail.compareAndSet(currentTail, tailNext)
}
}
}
}
def dequeue(): Option[T] = {
while (true) {
val currentHead = head.get()
val currentTail = tail.get()
if (currentHead == null) {
return None // Queue is empty
}
val headNext = currentHead.next.get()
if (currentHead == currentTail) {
if (headNext == null) {
return None // Queue is empty
}
// Help move tail pointer
tail.compareAndSet(currentTail, headNext)
} else {
if (headNext == null) {
continue // Inconsistent state, retry
}
// Try to move head pointer
if (head.compareAndSet(currentHead, headNext)) {
size.decrementAndGet()
return Some(headNext.value)
}
}
}
}
def isEmpty: Boolean = size.get() == 0
def length: Long = size.get()
}
// Lock-free counter with overflow protection
class LockFreeCounter(initialValue: Long = 0) {
private val value = new AtomicLong(initialValue)
private val maxValue = Long.MaxValue - 1000 // Safety margin
@tailrec
final def increment(): Long = {
val current = value.get()
if (current >= maxValue) {
throw new ArithmeticException("Counter overflow")
}
val newValue = current + 1
if (value.compareAndSet(current, newValue)) {
newValue
} else {
increment() // Retry on failure
}
}
@tailrec
final def decrement(): Long = {
val current = value.get()
if (current <= Long.MinValue + 1000) {
throw new ArithmeticException("Counter underflow")
}
val newValue = current - 1
if (value.compareAndSet(current, newValue)) {
newValue
} else {
decrement() // Retry on failure
}
}
@tailrec
final def addAndGet(delta: Long): Long = {
val current = value.get()
val newValue = current + delta
// Check for overflow/underflow
if ((delta > 0 && newValue < current) || (delta < 0 && newValue > current)) {
throw new ArithmeticException("Counter overflow/underflow")
}
if (value.compareAndSet(current, newValue)) {
newValue
} else {
addAndGet(delta) // Retry on failure
}
}
def get(): Long = value.get()
def reset(): Long = value.getAndSet(0)
def compareAndSet(expected: Long, newValue: Long): Boolean = {
value.compareAndSet(expected, newValue)
}
}
// Lock-free bloom filter
class LockFreeBloomFilter[T](
expectedElements: Int,
falsePositiveRate: Double = 0.01
)(hash: T => Int) {
private val bitArraySize = optimalBitArraySize(expectedElements, falsePositiveRate)
private val hashFunctions = optimalHashFunctions(bitArraySize, expectedElements)
private val bitArray = Array.fill(bitArraySize)(new AtomicBoolean(false))
def add(element: T): Unit = {
val baseHash = hash(element)
(0 until hashFunctions).foreach { i =>
val bitIndex = math.abs((baseHash + i * 37) % bitArraySize)
bitArray(bitIndex).set(true)
}
}
def mightContain(element: T): Boolean = {
val baseHash = hash(element)
(0 until hashFunctions).forall { i =>
val bitIndex = math.abs((baseHash + i * 37) % bitArraySize)
bitArray(bitIndex).get()
}
}
def estimatedFalsePositiveRate: Double = {
val setBits = bitArray.count(_.get())
val ratio = setBits.toDouble / bitArraySize
math.pow(ratio, hashFunctions)
}
private def optimalBitArraySize(n: Int, p: Double): Int = {
math.ceil(-n * math.log(p) / (math.log(2) * math.log(2))).toInt
}
private def optimalHashFunctions(m: Int, n: Int): Int = {
math.max(1, math.round(m.toDouble / n * math.log(2)).toInt)
}
}
// Lock-free LRU cache
class LockFreeLRUCache[K, V](maxSize: Int) {
private case class Node(key: K, value: V, prev: AtomicReference[Node], next: AtomicReference[Node])
private val map = new scala.collection.concurrent.TrieMap[K, Node]()
private val head = new AtomicReference[Node](null)
private val tail = new AtomicReference[Node](null)
private val size = new AtomicLong(0)
def get(key: K): Option[V] = {
map.get(key).map { node =>
moveToHead(node)
node.value
}
}
def put(key: K, value: V): Unit = {
map.get(key) match {
case Some(existingNode) =>
// Update existing node
val newNode = Node(key, value, new AtomicReference(null), new AtomicReference(null))
if (map.replace(key, existingNode, newNode)) {
removeFromList(existingNode)
addToHead(newNode)
} else {
put(key, value) // Retry
}
case None =>
// Add new node
val newNode = Node(key, value, new AtomicReference(null), new AtomicReference(null))
map.put(key, newNode)
addToHead(newNode)
val currentSize = size.incrementAndGet()
if (currentSize > maxSize) {
evictTail()
}
}
}
def remove(key: K): Option[V] = {
map.remove(key).map { node =>
removeFromList(node)
size.decrementAndGet()
node.value
}
}
private def addToHead(node: Node): Unit = {
// Implementation would involve complex CAS operations
// This is a simplified version
node.next.set(head.get())
if (head.get() != null) {
head.get().prev.set(node)
}
head.set(node)
if (tail.get() == null) {
tail.set(node)
}
}
private def removeFromList(node: Node): Unit = {
// Implementation would involve complex CAS operations
// This is a simplified version
val prevNode = node.prev.get()
val nextNode = node.next.get()
if (prevNode != null) {
prevNode.next.set(nextNode)
} else {
head.set(nextNode)
}
if (nextNode != null) {
nextNode.prev.set(prevNode)
} else {
tail.set(prevNode)
}
}
private def moveToHead(node: Node): Unit = {
removeFromList(node)
addToHead(node)
}
private def evictTail(): Unit = {
val tailNode = tail.get()
if (tailNode != null) {
map.remove(tailNode.key)
removeFromList(tailNode)
size.decrementAndGet()
}
}
def currentSize: Long = size.get()
}
}
Conclusion
Advanced concurrency patterns in Scala provide powerful tools for building high-performance, scalable applications. Key concepts include:
Actor System Patterns:
- Hierarchical supervision for fault tolerance
- Message-driven architecture for loose coupling
- Workflow coordination with complex dependencies
- Resource isolation and failure containment
Software Transactional Memory (STM):
- Composable transactions for concurrent operations
- Deadlock-free synchronization mechanisms
- Memory consistency without explicit locking
- Atomic operations on complex data structures
Parallel Collections:
- Data parallelism with automatic work distribution
- Custom thread pools for fine-grained control
- Pipeline processing for streaming data
- Aggregation operations with efficient combining
Lock-Free Programming:
- Compare-and-swap (CAS) based algorithms
- Non-blocking data structures for high throughput
- Memory-efficient concurrent collections
- Elimination of thread contention and deadlocks
Best Practices:
- Choose appropriate concurrency model for the problem
- Measure performance impact of different approaches
- Consider memory allocation patterns in concurrent code
- Plan for failure scenarios and recovery mechanisms
- Monitor contention and blocking in production systems
Performance Considerations:
- Actor message passing overhead vs. shared memory
- STM retry costs under high contention
- Parallel collection overhead for small datasets
- Lock-free algorithm complexity vs. throughput gains
Design Principles:
- Immutability reduces synchronization needs
- Message passing eliminates shared mutable state
- Backpressure mechanisms prevent resource exhaustion
- Circuit breakers provide graceful degradation
Testing Strategies:
- Stress testing under concurrent load
- Property-based testing for concurrent invariants
- Chaos engineering for resilience validation
- Performance benchmarking across concurrency models
Modern concurrent applications benefit from combining these patterns strategically, using actors for coordination, STM for complex shared state, parallel collections for data processing, and lock-free structures for high-performance scenarios. The key is understanding the trade-offs and choosing the right tool for each specific concurrency challenge.
Comments
Be the first to comment on this lesson!