Security in Scala Applications: Authentication, Authorization, and Best Practices
Security is paramount in enterprise Scala applications. This comprehensive lesson covers authentication systems, authorization patterns, JWT tokens, OAuth integration, cryptography, input validation, and security best practices for building robust, secure applications that protect against common vulnerabilities and threats.
Authentication Systems
JWT (JSON Web Token) Authentication
// JWTAuthentication.scala - JWT-based authentication system
package com.example.security.jwt
import io.jsonwebtoken._
import io.jsonwebtoken.security.Keys
import java.security.Key
import java.time.{Instant, Duration}
import java.util.{Date, UUID}
import scala.util.{Try, Success, Failure}
import cats.effect.IO
import cats.implicits._
case class User(
id: UUID,
username: String,
email: String,
roles: Set[String],
isActive: Boolean = true,
createdAt: Instant = Instant.now()
) {
def hasRole(role: String): Boolean = roles.contains(role)
def hasAnyRole(requiredRoles: Set[String]): Boolean = requiredRoles.intersect(roles).nonEmpty
}
case class JWTClaims(
userId: UUID,
username: String,
email: String,
roles: Set[String],
issuedAt: Instant,
expiresAt: Instant,
jti: String = UUID.randomUUID().toString
)
case class RefreshToken(
token: String,
userId: UUID,
expiresAt: Instant,
isRevoked: Boolean = false
)
sealed trait AuthenticationError extends Exception
case object InvalidCredentials extends AuthenticationError
case object TokenExpired extends AuthenticationError
case object TokenInvalid extends AuthenticationError
case object UserNotFound extends AuthenticationError
case object UserNotActive extends AuthenticationError
case object InsufficientPermissions extends AuthenticationError
trait TokenStorage {
def storeRefreshToken(token: RefreshToken): IO[Unit]
def getRefreshToken(token: String): IO[Option[RefreshToken]]
def revokeRefreshToken(token: String): IO[Unit]
def revokeAllUserTokens(userId: UUID): IO[Unit]
def addToBlacklist(jti: String): IO[Unit]
def isBlacklisted(jti: String): IO[Boolean]
}
class InMemoryTokenStorage extends TokenStorage {
private val refreshTokens = scala.collection.concurrent.TrieMap[String, RefreshToken]()
private val blacklistedTokens = scala.collection.concurrent.TrieMap[String, Instant]()
def storeRefreshToken(token: RefreshToken): IO[Unit] = IO {
refreshTokens.put(token.token, token)
()
}
def getRefreshToken(token: String): IO[Option[RefreshToken]] = IO {
refreshTokens.get(token).filter(!_.isRevoked)
}
def revokeRefreshToken(token: String): IO[Unit] = IO {
refreshTokens.updateWith(token)(_.map(_.copy(isRevoked = true)))
()
}
def revokeAllUserTokens(userId: UUID): IO[Unit] = IO {
refreshTokens.transform { (_, token) =>
if (token.userId == userId) token.copy(isRevoked = true)
else token
}
()
}
def addToBlacklist(jti: String): IO[Unit] = IO {
blacklistedTokens.put(jti, Instant.now())
()
}
def isBlacklisted(jti: String): IO[Boolean] = IO {
blacklistedTokens.contains(jti)
}
}
class JWTService(
secretKey: String,
issuer: String = "scalatut-app",
accessTokenExpiry: Duration = Duration.ofHours(1),
refreshTokenExpiry: Duration = Duration.ofDays(7),
tokenStorage: TokenStorage
) {
private val key: Key = Keys.hmacShaKeyFor(secretKey.getBytes)
def generateTokens(user: User): IO[(String, String)] = {
val now = Instant.now()
val accessTokenExpiry = now.plus(this.accessTokenExpiry)
val refreshTokenExpiry = now.plus(this.refreshTokenExpiry)
val jti = UUID.randomUUID().toString
val claims = JWTClaims(
userId = user.id,
username = user.username,
email = user.email,
roles = user.roles,
issuedAt = now,
expiresAt = accessTokenExpiry,
jti = jti
)
val accessToken = createAccessToken(claims)
val refreshToken = createRefreshToken(user.id, refreshTokenExpiry)
tokenStorage.storeRefreshToken(RefreshToken(refreshToken, user.id, refreshTokenExpiry))
.map(_ => (accessToken, refreshToken))
}
private def createAccessToken(claims: JWTClaims): String = {
Jwts.builder()
.setSubject(claims.userId.toString)
.setIssuer(issuer)
.setIssuedAt(Date.from(claims.issuedAt))
.setExpiration(Date.from(claims.expiresAt))
.setId(claims.jti)
.claim("username", claims.username)
.claim("email", claims.email)
.claim("roles", claims.roles.mkString(","))
.signWith(key, SignatureAlgorithm.HS256)
.compact()
}
private def createRefreshToken(userId: UUID, expiresAt: Instant): String = {
Jwts.builder()
.setSubject(userId.toString)
.setIssuer(issuer)
.setIssuedAt(Date.from(Instant.now()))
.setExpiration(Date.from(expiresAt))
.setId(UUID.randomUUID().toString)
.claim("type", "refresh")
.signWith(key, SignatureAlgorithm.HS256)
.compact()
}
def validateAccessToken(token: String): IO[Either[AuthenticationError, JWTClaims]] = {
Try {
val claims = Jwts.parserBuilder()
.setSigningKey(key)
.build()
.parseClaimsJws(token)
.getBody
val userId = UUID.fromString(claims.getSubject)
val username = claims.get("username", classOf[String])
val email = claims.get("email", classOf[String])
val roles = claims.get("roles", classOf[String]).split(",").toSet.filter(_.nonEmpty)
val issuedAt = claims.getIssuedAt.toInstant
val expiresAt = claims.getExpiration.toInstant
val jti = claims.getId
JWTClaims(userId, username, email, roles, issuedAt, expiresAt, jti)
} match {
case Success(claims) =>
tokenStorage.isBlacklisted(claims.jti).map { isBlacklisted =>
if (isBlacklisted) Left(TokenInvalid)
else if (claims.expiresAt.isBefore(Instant.now())) Left(TokenExpired)
else Right(claims)
}
case Failure(_: ExpiredJwtException) =>
IO.pure(Left(TokenExpired))
case Failure(_) =>
IO.pure(Left(TokenInvalid))
}
}
def refreshAccessToken(refreshToken: String): IO[Either[AuthenticationError, (String, String)]] = {
for {
tokenOpt <- tokenStorage.getRefreshToken(refreshToken)
result <- tokenOpt match {
case Some(token) if !token.isRevoked && token.expiresAt.isAfter(Instant.now()) =>
// Create new tokens
val user = User(
id = token.userId,
username = "", // Would need to fetch from user service
email = "",
roles = Set.empty
)
// In real implementation, fetch user details
generateTokens(user).map(Right(_))
case Some(_) =>
IO.pure(Left(TokenExpired))
case None =>
IO.pure(Left(TokenInvalid))
}
} yield result
}
def revokeToken(jti: String): IO[Unit] = {
tokenStorage.addToBlacklist(jti)
}
def revokeAllUserTokens(userId: UUID): IO[Unit] = {
tokenStorage.revokeAllUserTokens(userId)
}
}
// Password hashing and verification
object PasswordSecurity {
import org.mindrot.jbcrypt.BCrypt
def hashPassword(password: String): String = {
BCrypt.hashpw(password, BCrypt.gensalt(12))
}
def verifyPassword(password: String, hashedPassword: String): Boolean = {
Try(BCrypt.checkpw(password, hashedPassword)).getOrElse(false)
}
def isStrongPassword(password: String): Boolean = {
val minLength = 8
val hasUppercase = password.exists(_.isUpper)
val hasLowercase = password.exists(_.isLower)
val hasDigit = password.exists(_.isDigit)
val hasSpecialChar = password.exists(c => !c.isLetterOrDigit)
password.length >= minLength &&
hasUppercase &&
hasLowercase &&
hasDigit &&
hasSpecialChar
}
def generateSecurePassword(length: Int = 16): String = {
val uppercase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
val lowercase = "abcdefghijklmnopqrstuvwxyz"
val digits = "0123456789"
val special = "!@#$%^&*()_+-=[]{}|;:,.<>?"
val allChars = uppercase + lowercase + digits + special
val random = new scala.util.Random()
// Ensure at least one character from each category
val password = new StringBuilder()
password.append(uppercase(random.nextInt(uppercase.length)))
password.append(lowercase(random.nextInt(lowercase.length)))
password.append(digits(random.nextInt(digits.length)))
password.append(special(random.nextInt(special.length)))
// Fill the rest randomly
for (_ <- 4 until length) {
password.append(allChars(random.nextInt(allChars.length)))
}
// Shuffle the password
password.toString().toCharArray.shuffle.mkString
}
}
// Two-Factor Authentication (2FA)
class TwoFactorAuthService {
import javax.crypto.Mac
import javax.crypto.spec.SecretKeySpec
import java.nio.ByteBuffer
import java.security.SecureRandom
import java.util.Base64
private val TOTP_WINDOW = 30 // 30 seconds
private val TOTP_DIGITS = 6
def generateSecret(): String = {
val random = new SecureRandom()
val secret = new Array[Byte](20) // 160 bits
random.nextBytes(secret)
Base64.getEncoder.encodeToString(secret)
}
def generateTOTP(secret: String, timeStep: Long = System.currentTimeMillis() / 1000 / TOTP_WINDOW): String = {
val secretBytes = Base64.getDecoder.decode(secret)
val timeBytes = ByteBuffer.allocate(8).putLong(timeStep).array()
val mac = Mac.getInstance("HmacSHA1")
mac.init(new SecretKeySpec(secretBytes, "HmacSHA1"))
val hash = mac.doFinal(timeBytes)
val offset = hash(hash.length - 1) & 0x0F
val truncatedHash = ByteBuffer.wrap(hash).getInt(offset) & 0x7FFFFFFF
val otp = truncatedHash % math.pow(10, TOTP_DIGITS).toInt
f"$otp%0${TOTP_DIGITS}d"
}
def verifyTOTP(secret: String, userToken: String, allowedWindow: Int = 1): Boolean = {
val currentTimeStep = System.currentTimeMillis() / 1000 / TOTP_WINDOW
// Check current time step and adjacent ones for clock drift tolerance
for (i <- -allowedWindow to allowedWindow) {
val timeStep = currentTimeStep + i
val generatedToken = generateTOTP(secret, timeStep)
if (generatedToken == userToken) {
return true
}
}
false
}
def generateBackupCodes(count: Int = 10): List[String] = {
val random = new SecureRandom()
(1 to count).map { _ =>
val code = new Array[Byte](8)
random.nextBytes(code)
code.map("%02x".format(_)).mkString.toUpperCase
}.toList
}
def generateQRCodeURI(secret: String, accountName: String, issuer: String): String = {
val encodedSecret = secret.replace("=", "")
s"otpauth://totp/$issuer:$accountName?secret=$encodedSecret&issuer=$issuer"
}
}
OAuth 2.0 Integration
// OAuth2Integration.scala - OAuth 2.0 provider integration
package com.example.security.oauth
import cats.effect.IO
import org.http4s._
import org.http4s.client.Client
import org.http4s.circe.CirceEntityCodec._
import io.circe.generic.auto._
import io.circe.syntax._
import java.net.URI
import java.security.SecureRandom
import java.util.Base64
import scala.concurrent.duration._
case class OAuth2Config(
clientId: String,
clientSecret: String,
authorizationUrl: String,
tokenUrl: String,
userInfoUrl: String,
redirectUri: String,
scopes: List[String] = List("openid", "profile", "email")
)
case class OAuth2State(
state: String,
codeVerifier: String,
codeChallenge: String,
createdAt: Long = System.currentTimeMillis()
) {
def isExpired(maxAge: Duration = 10.minutes): Boolean = {
System.currentTimeMillis() - createdAt > maxAge.toMillis
}
}
case class OAuth2TokenResponse(
access_token: String,
token_type: String,
expires_in: Option[Int],
refresh_token: Option[String],
scope: Option[String],
id_token: Option[String]
)
case class OAuth2UserInfo(
sub: String,
name: Option[String],
email: Option[String],
picture: Option[String],
email_verified: Option[Boolean]
)
case class OAuth2AuthorizeRequest(
response_type: String = "code",
client_id: String,
redirect_uri: String,
scope: String,
state: String,
code_challenge: String,
code_challenge_method: String = "S256"
)
trait OAuth2StateStorage {
def store(state: OAuth2State): IO[Unit]
def get(stateParam: String): IO[Option[OAuth2State]]
def remove(stateParam: String): IO[Unit]
}
class InMemoryOAuth2StateStorage extends OAuth2StateStorage {
private val states = scala.collection.concurrent.TrieMap[String, OAuth2State]()
def store(state: OAuth2State): IO[Unit] = IO {
states.put(state.state, state)
()
}
def get(stateParam: String): IO[Option[OAuth2State]] = IO {
states.get(stateParam).filter(!_.isExpired())
}
def remove(stateParam: String): IO[Unit] = IO {
states.remove(stateParam)
()
}
}
class OAuth2Service(
config: OAuth2Config,
stateStorage: OAuth2StateStorage,
httpClient: Client[IO]
) {
private val random = new SecureRandom()
def generateAuthorizationUrl(): IO[(String, String)] = {
val state = generateRandomString(32)
val codeVerifier = generateCodeVerifier()
val codeChallenge = generateCodeChallenge(codeVerifier)
val oauth2State = OAuth2State(state, codeVerifier, codeChallenge)
val authRequest = OAuth2AuthorizeRequest(
client_id = config.clientId,
redirect_uri = config.redirectUri,
scope = config.scopes.mkString(" "),
state = state,
code_challenge = codeChallenge
)
val params = Map(
"response_type" -> authRequest.response_type,
"client_id" -> authRequest.client_id,
"redirect_uri" -> authRequest.redirect_uri,
"scope" -> authRequest.scope,
"state" -> authRequest.state,
"code_challenge" -> authRequest.code_challenge,
"code_challenge_method" -> authRequest.code_challenge_method
)
val queryString = params.map { case (k, v) => s"$k=${java.net.URLEncoder.encode(v, "UTF-8")}" }.mkString("&")
val authUrl = s"${config.authorizationUrl}?$queryString"
stateStorage.store(oauth2State).map(_ => (authUrl, state))
}
def exchangeCodeForToken(code: String, state: String): IO[Either[String, OAuth2TokenResponse]] = {
for {
stateOpt <- stateStorage.get(state)
result <- stateOpt match {
case Some(oauth2State) if !oauth2State.isExpired() =>
val tokenRequest = UrlForm(
"grant_type" -> "authorization_code",
"client_id" -> config.clientId,
"client_secret" -> config.clientSecret,
"code" -> code,
"redirect_uri" -> config.redirectUri,
"code_verifier" -> oauth2State.codeVerifier
)
val request = Request[IO](
method = Method.POST,
uri = Uri.unsafeFromString(config.tokenUrl),
headers = Headers("Content-Type" -> "application/x-www-form-urlencoded")
).withEntity(tokenRequest)
httpClient.expect[OAuth2TokenResponse](request)
.map(Right(_))
.handleErrorWith(error => IO.pure(Left(s"Token exchange failed: ${error.getMessage}")))
.flatTap(_ => stateStorage.remove(state))
case Some(_) =>
IO.pure(Left("OAuth2 state expired"))
case None =>
IO.pure(Left("Invalid OAuth2 state"))
}
} yield result
}
def getUserInfo(accessToken: String): IO[Either[String, OAuth2UserInfo]] = {
val request = Request[IO](
method = Method.GET,
uri = Uri.unsafeFromString(config.userInfoUrl),
headers = Headers("Authorization" -> s"Bearer $accessToken")
)
httpClient.expect[OAuth2UserInfo](request)
.map(Right(_))
.handleErrorWith(error => IO.pure(Left(s"Failed to fetch user info: ${error.getMessage}")))
}
def refreshToken(refreshToken: String): IO[Either[String, OAuth2TokenResponse]] = {
val tokenRequest = UrlForm(
"grant_type" -> "refresh_token",
"client_id" -> config.clientId,
"client_secret" -> config.clientSecret,
"refresh_token" -> refreshToken
)
val request = Request[IO](
method = Method.POST,
uri = Uri.unsafeFromString(config.tokenUrl),
headers = Headers("Content-Type" -> "application/x-www-form-urlencoded")
).withEntity(tokenRequest)
httpClient.expect[OAuth2TokenResponse](request)
.map(Right(_))
.handleErrorWith(error => IO.pure(Left(s"Token refresh failed: ${error.getMessage}")))
}
private def generateRandomString(length: Int): String = {
val bytes = new Array[Byte](length)
random.nextBytes(bytes)
Base64.getUrlEncoder.withoutPadding().encodeToString(bytes)
}
private def generateCodeVerifier(): String = {
generateRandomString(32)
}
private def generateCodeChallenge(verifier: String): String = {
import java.security.MessageDigest
val digest = MessageDigest.getInstance("SHA-256")
val hash = digest.digest(verifier.getBytes("UTF-8"))
Base64.getUrlEncoder.withoutPadding().encodeToString(hash)
}
}
// OAuth2 middleware for HTTP4S
class OAuth2Middleware(oauth2Service: OAuth2Service) {
def authenticate: AuthMiddleware[IO, User] = { authUser =>
Kleisli { request =>
extractToken(request) match {
case Some(token) =>
oauth2Service.getUserInfo(token).flatMap {
case Right(userInfo) =>
val user = User(
id = java.util.UUID.fromString(userInfo.sub),
username = userInfo.name.getOrElse(""),
email = userInfo.email.getOrElse(""),
roles = Set("user") // Default role
)
authUser(user).map(_.some)
case Left(_) =>
IO.pure(None)
}
case None =>
IO.pure(None)
}
}
}
private def extractToken(request: Request[IO]): Option[String] = {
request.headers.get(org.http4s.headers.Authorization) match {
case Some(authHeader) =>
val authValue = authHeader.head.value
if (authValue.startsWith("Bearer ")) {
Some(authValue.substring(7))
} else None
case None => None
}
}
}
Authorization and Access Control
Role-Based Access Control (RBAC)
// AuthorizationSystem.scala - Comprehensive authorization framework
package com.example.security.authorization
import cats.effect.IO
import cats.implicits._
import java.util.UUID
import scala.annotation.tailrec
// Domain models for authorization
case class Permission(
id: String,
name: String,
description: String,
resource: String,
action: String
) {
def matches(resource: String, action: String): Boolean = {
this.resource == resource && this.action == action
}
}
case class Role(
id: String,
name: String,
description: String,
permissions: Set[Permission],
inheritsFrom: Set[String] = Set.empty
) {
def hasPermission(permission: Permission): Boolean = permissions.contains(permission)
def hasPermission(resource: String, action: String): Boolean =
permissions.exists(_.matches(resource, action))
}
case class UserPrincipal(
userId: UUID,
username: String,
roles: Set[Role],
directPermissions: Set[Permission] = Set.empty,
attributes: Map[String, String] = Map.empty
) {
def hasRole(roleName: String): Boolean = roles.exists(_.name == roleName)
def hasPermission(permission: Permission): Boolean = {
directPermissions.contains(permission) ||
roles.exists(_.hasPermission(permission))
}
def hasPermission(resource: String, action: String): Boolean = {
directPermissions.exists(_.matches(resource, action)) ||
roles.exists(_.hasPermission(resource, action))
}
def getAllPermissions: Set[Permission] = {
directPermissions ++ roles.flatMap(_.permissions)
}
}
// Attribute-Based Access Control (ABAC)
case class AccessRequest(
principal: UserPrincipal,
resource: String,
action: String,
context: Map[String, Any] = Map.empty
)
sealed trait PolicyEffect
case object Allow extends PolicyEffect
case object Deny extends PolicyEffect
case class Policy(
id: String,
name: String,
effect: PolicyEffect,
condition: AccessRequest => Boolean,
priority: Int = 0
)
trait PolicyRepository {
def findApplicablePolicies(request: AccessRequest): IO[List[Policy]]
}
class InMemoryPolicyRepository extends PolicyRepository {
private val policies = scala.collection.mutable.ListBuffer[Policy]()
def addPolicy(policy: Policy): Unit = {
policies += policy
}
def findApplicablePolicies(request: AccessRequest): IO[List[Policy]] = IO {
policies.toList.sortBy(-_.priority)
}
}
// Authorization service
class AuthorizationService(
policyRepository: PolicyRepository,
defaultDeny: Boolean = true
) {
def authorize(request: AccessRequest): IO[Boolean] = {
for {
policies <- policyRepository.findApplicablePolicies(request)
result <- evaluatePolicies(request, policies)
} yield result
}
private def evaluatePolicies(request: AccessRequest, policies: List[Policy]): IO[Boolean] = IO {
// First check for explicit deny policies
val denyPolicies = policies.filter(_.effect == Deny)
val allowPolicies = policies.filter(_.effect == Allow)
// If any deny policy matches, access is denied
val isDenied = denyPolicies.exists(_.condition(request))
if (isDenied) {
false
} else {
// Check if any allow policy matches
val isAllowed = allowPolicies.exists(_.condition(request))
if (defaultDeny) {
isAllowed
} else {
!isDenied || isAllowed
}
}
}
def authorizeWithDetails(request: AccessRequest): IO[AuthorizationResult] = {
for {
policies <- policyRepository.findApplicablePolicies(request)
result <- evaluatePoliciesWithDetails(request, policies)
} yield result
}
private def evaluatePoliciesWithDetails(request: AccessRequest, policies: List[Policy]): IO[AuthorizationResult] = IO {
val matchedPolicies = policies.filter(_.condition(request))
val denyPolicies = matchedPolicies.filter(_.effect == Deny)
val allowPolicies = matchedPolicies.filter(_.effect == Allow)
if (denyPolicies.nonEmpty) {
AuthorizationResult(
allowed = false,
reason = s"Access denied by policy: ${denyPolicies.head.name}",
matchedPolicies = denyPolicies.map(_.id)
)
} else if (allowPolicies.nonEmpty) {
AuthorizationResult(
allowed = true,
reason = s"Access allowed by policy: ${allowPolicies.head.name}",
matchedPolicies = allowPolicies.map(_.id)
)
} else {
AuthorizationResult(
allowed = !defaultDeny,
reason = if (defaultDeny) "No matching allow policy found" else "No matching deny policy found",
matchedPolicies = List.empty
)
}
}
}
case class AuthorizationResult(
allowed: Boolean,
reason: String,
matchedPolicies: List[String]
)
// Common authorization policies
object StandardPolicies {
def ownershipPolicy(resourceOwnerField: String = "owner"): Policy = Policy(
id = "ownership-policy",
name = "Resource Ownership Policy",
effect = Allow,
condition = { request =>
request.context.get(resourceOwnerField).contains(request.principal.userId.toString)
},
priority = 100
)
def adminPolicy(): Policy = Policy(
id = "admin-policy",
name = "Administrator Policy",
effect = Allow,
condition = { request =>
request.principal.hasRole("admin")
},
priority = 1000
)
def timeBasedPolicy(allowedHours: Range): Policy = Policy(
id = "time-based-policy",
name = "Time-Based Access Policy",
effect = Allow,
condition = { request =>
val currentHour = java.time.LocalTime.now().getHour
allowedHours.contains(currentHour)
},
priority = 50
)
def ipWhitelistPolicy(allowedIPs: Set[String]): Policy = Policy(
id = "ip-whitelist-policy",
name = "IP Whitelist Policy",
effect = Allow,
condition = { request =>
request.context.get("clientIP") match {
case Some(ip: String) => allowedIPs.contains(ip)
case _ => false
}
},
priority = 75
)
def rateLimitPolicy(maxRequests: Int, windowMinutes: Int): Policy = {
val requestCounts = scala.collection.concurrent.TrieMap[String, (Long, Int)]()
Policy(
id = "rate-limit-policy",
name = "Rate Limiting Policy",
effect = Deny,
condition = { request =>
val userId = request.principal.userId.toString
val now = System.currentTimeMillis()
val windowStart = now - (windowMinutes * 60 * 1000)
val (lastReset, count) = requestCounts.getOrElse(userId, (now, 0))
if (lastReset < windowStart) {
// Reset the window
requestCounts.put(userId, (now, 1))
false // Don't deny
} else {
val newCount = count + 1
requestCounts.put(userId, (lastReset, newCount))
newCount > maxRequests // Deny if over limit
}
},
priority = 200
)
}
def departmentPolicy(requiredDepartment: String): Policy = Policy(
id = s"department-$requiredDepartment-policy",
name = s"Department $requiredDepartment Policy",
effect = Allow,
condition = { request =>
request.principal.attributes.get("department").contains(requiredDepartment)
},
priority = 60
)
}
// Method-level authorization annotations and macros
import scala.annotation.StaticAnnotation
case class RequiresPermission(resource: String, action: String) extends StaticAnnotation
case class RequiresRole(role: String) extends StaticAnnotation
case class RequiresAnyRole(roles: String*) extends StaticAnnotation
case class AllowOwner(ownerField: String = "owner") extends StaticAnnotation
// Authorization middleware for HTTP4S
import org.http4s._
import org.http4s.server.AuthMiddleware
class AuthorizationMiddleware(authService: AuthorizationService) {
def authorize(resource: String, action: String): AuthMiddleware[IO, UserPrincipal] = { authUser =>
Kleisli { request =>
val accessRequest = AccessRequest(
principal = authUser,
resource = resource,
action = action,
context = extractContext(request)
)
authService.authorize(accessRequest).flatMap { allowed =>
if (allowed) {
IO.pure(Some(authUser))
} else {
IO.pure(None)
}
}
}
}
def conditionalAuthorize(condition: (UserPrincipal, Request[IO]) => IO[Boolean]): AuthMiddleware[IO, UserPrincipal] = { authUser =>
Kleisli { request =>
condition(authUser, request).flatMap { allowed =>
if (allowed) {
IO.pure(Some(authUser))
} else {
IO.pure(None)
}
}
}
}
private def extractContext(request: Request[IO]): Map[String, Any] = {
Map(
"method" -> request.method.name,
"uri" -> request.uri.toString(),
"userAgent" -> request.headers.get(org.http4s.headers.`User-Agent`).map(_.head.value),
"clientIP" -> request.headers.get("X-Forwarded-For").map(_.head.value)
.orElse(request.headers.get("X-Real-IP").map(_.head.value))
.getOrElse("unknown")
).collect { case (k, v) if v != null => k -> v }
}
}
// Resource-based authorization helper
class ResourceAuthorization(authService: AuthorizationService) {
def checkAccess[T](
principal: UserPrincipal,
resource: String,
action: String,
resourceData: T,
context: Map[String, Any] = Map.empty
): IO[Boolean] = {
val enrichedContext = context ++ extractResourceContext(resourceData)
val request = AccessRequest(principal, resource, action, enrichedContext)
authService.authorize(request)
}
def filterAccessible[T](
principal: UserPrincipal,
resources: List[T],
action: String,
resourceType: String
)(implicit extractId: T => String): IO[List[T]] = {
resources.filterA { resource =>
val context = Map("resourceId" -> extractId(resource))
checkAccess(principal, resourceType, action, resource, context)
}
}
private def extractResourceContext[T](resource: T): Map[String, Any] = {
// Use reflection to extract common fields
try {
val fields = resource.getClass.getDeclaredFields
fields.flatMap { field =>
field.setAccessible(true)
val name = field.getName
val value = field.get(resource)
if (value != null) Some(name -> value) else None
}.toMap
} catch {
case _: Exception => Map.empty
}
}
}
Input Validation and Security
Comprehensive Input Validation Framework
// InputValidation.scala - Security-focused input validation
package com.example.security.validation
import cats.data.{NonEmptyList, Validated, ValidatedNel}
import cats.implicits._
import scala.util.matching.Regex
import java.net.{URL, MalformedURLException}
import java.time.{LocalDate, LocalDateTime, format.DateTimeParseException}
import java.util.UUID
sealed trait ValidationError {
def message: String
}
case class FieldError(field: String, message: String) extends ValidationError
case class SecurityError(message: String) extends ValidationError
case class FormatError(field: String, expectedFormat: String, message: String) extends ValidationError
type ValidationResult[A] = ValidatedNel[ValidationError, A]
// Security-focused validators
object SecurityValidators {
// SQL Injection protection
def sqlSafeString(field: String, value: String): ValidationResult[String] = {
val sqlKeywords = Set(
"SELECT", "INSERT", "UPDATE", "DELETE", "DROP", "CREATE", "ALTER",
"UNION", "OR", "AND", "WHERE", "EXEC", "EXECUTE", "SP_", "XP_"
)
val suspiciousPatterns = List(
"--".r, // SQL comments
"/\\*.*\\*/".r, // SQL block comments
"'.*'".r, // String literals
";".r, // Statement separators
"\\b(UNION|SELECT|INSERT|UPDATE|DELETE|DROP)\\b".r
)
val upperValue = value.toUpperCase
val hasSqlKeywords = sqlKeywords.exists(upperValue.contains)
val hasSuspiciousPatterns = suspiciousPatterns.exists(_.findFirstIn(value).isDefined)
if (hasSqlKeywords || hasSuspiciousPatterns) {
SecurityError(s"Field $field contains potentially dangerous SQL content").invalidNel
} else {
value.validNel
}
}
// XSS protection
def xssSafeString(field: String, value: String): ValidationResult[String] = {
val xssPatterns = List(
"<script.*?>".r,
"javascript:".r,
"on\\w+\\s*=".r, // Event handlers like onclick=
"<iframe.*?>".r,
"<object.*?>".r,
"<embed.*?>".r,
"<form.*?>".r
)
val hasXssPatterns = xssPatterns.exists(_.findFirstIn(value.toLowerCase).isDefined)
if (hasXssPatterns) {
SecurityError(s"Field $field contains potentially dangerous content").invalidNel
} else {
value.validNel
}
}
// Path traversal protection
def pathSafeString(field: String, value: String): ValidationResult[String] = {
val dangerousPatterns = List(
"../",
"..\\",
"/etc/",
"/proc/",
"/sys/",
"C:\\",
"D:\\"
)
val hasDangerousPath = dangerousPatterns.exists(value.contains)
if (hasDangerousPath) {
SecurityError(s"Field $field contains potentially dangerous path").invalidNel
} else {
value.validNel
}
}
// HTML sanitization
def sanitizeHtml(field: String, value: String): ValidationResult[String] = {
// Basic HTML sanitization - in production, use a proper library like jsoup
val allowedTags = Set("p", "br", "strong", "em", "u", "ol", "ul", "li")
val tagPattern = "<(/?)([a-zA-Z]+)[^>]*>".r
val sanitized = tagPattern.replaceAllIn(value, { m =>
val isClosing = m.group(1) == "/"
val tagName = m.group(2).toLowerCase
if (allowedTags.contains(tagName)) {
if (isClosing) s"</$tagName>" else s"<$tagName>"
} else {
"" // Remove disallowed tags
}
})
sanitized.validNel
}
// File upload validation
def validateFileUpload(
field: String,
filename: String,
contentType: String,
maxSize: Long,
allowedExtensions: Set[String],
allowedMimeTypes: Set[String]
): ValidationResult[String] = {
val extension = filename.split("\\.").lastOption.map(_.toLowerCase).getOrElse("")
val validations = List(
if (allowedExtensions.contains(extension)) ().validNel
else FieldError(field, s"File extension .$extension is not allowed").invalidNel,
if (allowedMimeTypes.contains(contentType.toLowerCase)) ().validNel
else FieldError(field, s"Content type $contentType is not allowed").invalidNel,
pathSafeString(field, filename).map(_ => ()),
// Check for double extensions (potential bypass attempt)
if (filename.count(_ == '.') <= 1) ().validNel
else SecurityError(s"File $filename has suspicious multiple extensions").invalidNel
)
validations.sequence.map(_ => filename)
}
}
// Standard field validators
object FieldValidators {
def nonEmpty(field: String, value: String): ValidationResult[String] = {
if (value.trim.nonEmpty) value.validNel
else FieldError(field, "cannot be empty").invalidNel
}
def minLength(field: String, value: String, min: Int): ValidationResult[String] = {
if (value.length >= min) value.validNel
else FieldError(field, s"must be at least $min characters long").invalidNel
}
def maxLength(field: String, value: String, max: Int): ValidationResult[String] = {
if (value.length <= max) value.validNel
else FieldError(field, s"must be at most $max characters long").invalidNel
}
def matches(field: String, value: String, pattern: Regex, errorMessage: String): ValidationResult[String] = {
if (pattern.matches(value)) value.validNel
else FieldError(field, errorMessage).invalidNel
}
def email(field: String, value: String): ValidationResult[String] = {
val emailPattern = """^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$""".r
matches(field, value, emailPattern, "must be a valid email address")
}
def url(field: String, value: String): ValidationResult[String] = {
try {
new URL(value)
value.validNel
} catch {
case _: MalformedURLException =>
FieldError(field, "must be a valid URL").invalidNel
}
}
def uuid(field: String, value: String): ValidationResult[UUID] = {
try {
UUID.fromString(value).validNel
} catch {
case _: IllegalArgumentException =>
FieldError(field, "must be a valid UUID").invalidNel
}
}
def positiveInt(field: String, value: Int): ValidationResult[Int] = {
if (value > 0) value.validNel
else FieldError(field, "must be positive").invalidNel
}
def range[T: Ordering](field: String, value: T, min: T, max: T): ValidationResult[T] = {
val ord = implicitly[Ordering[T]]
if (ord.gteq(value, min) && ord.lteq(value, max)) value.validNel
else FieldError(field, s"must be between $min and $max").invalidNel
}
def oneOf[T](field: String, value: T, allowed: Set[T]): ValidationResult[T] = {
if (allowed.contains(value)) value.validNel
else FieldError(field, s"must be one of: ${allowed.mkString(", ")}").invalidNel
}
def pastDate(field: String, value: LocalDate): ValidationResult[LocalDate] = {
if (value.isBefore(LocalDate.now())) value.validNel
else FieldError(field, "must be in the past").invalidNel
}
def futureDate(field: String, value: LocalDate): ValidationResult[LocalDate] = {
if (value.isAfter(LocalDate.now())) value.validNel
else FieldError(field, "must be in the future").invalidNel
}
}
// Validation builder for complex objects
class ValidationBuilder[T] {
private var validators: List[T => ValidationResult[Unit]] = List.empty
def field[F](extract: T => F)(validate: (String, F) => ValidationResult[F]): ValidationBuilder[T] = {
val validator: T => ValidationResult[Unit] = { obj =>
validate("field", extract(obj)).map(_ => ())
}
validators = validator :: validators
this
}
def custom(validate: T => ValidationResult[Unit]): ValidationBuilder[T] = {
validators = validate :: validators
this
}
def build: T => ValidationResult[T] = { obj =>
validators.map(_(obj)).sequence.map(_ => obj)
}
}
object ValidationBuilder {
def apply[T]: ValidationBuilder[T] = new ValidationBuilder[T]
}
// Example usage with case classes
case class UserRegistration(
username: String,
email: String,
password: String,
confirmPassword: String,
age: Int,
website: Option[String],
bio: String
)
object UserRegistrationValidator {
def validate(registration: UserRegistration): ValidationResult[UserRegistration] = {
(
validateUsername(registration.username),
validateEmail(registration.email),
validatePassword(registration.password),
validatePasswordMatch(registration.password, registration.confirmPassword),
validateAge(registration.age),
validateWebsite(registration.website),
validateBio(registration.bio)
).mapN { (username, email, password, _, age, website, bio) =>
registration.copy(
username = username,
email = email,
password = password,
age = age,
website = website,
bio = bio
)
}
}
private def validateUsername(username: String): ValidationResult[String] = {
(
FieldValidators.nonEmpty("username", username),
FieldValidators.minLength("username", username, 3),
FieldValidators.maxLength("username", username, 30),
FieldValidators.matches("username", username, "^[a-zA-Z0-9_]+$".r, "can only contain letters, numbers, and underscores"),
SecurityValidators.sqlSafeString("username", username),
SecurityValidators.xssSafeString("username", username)
).mapN((_, _, _, _, _, _) => username)
}
private def validateEmail(email: String): ValidationResult[String] = {
(
FieldValidators.nonEmpty("email", email),
FieldValidators.email("email", email),
FieldValidators.maxLength("email", email, 100),
SecurityValidators.sqlSafeString("email", email)
).mapN((_, _, _, _) => email)
}
private def validatePassword(password: String): ValidationResult[String] = {
(
FieldValidators.nonEmpty("password", password),
FieldValidators.minLength("password", password, 8),
FieldValidators.maxLength("password", password, 128),
validatePasswordStrength(password)
).mapN((_, _, _, _) => password)
}
private def validatePasswordStrength(password: String): ValidationResult[String] = {
val hasUpper = password.exists(_.isUpper)
val hasLower = password.exists(_.isLower)
val hasDigit = password.exists(_.isDigit)
val hasSpecial = password.exists(c => !c.isLetterOrDigit)
if (hasUpper && hasLower && hasDigit && hasSpecial) {
password.validNel
} else {
FieldError("password", "must contain uppercase, lowercase, digit, and special character").invalidNel
}
}
private def validatePasswordMatch(password: String, confirmPassword: String): ValidationResult[Unit] = {
if (password == confirmPassword) ().validNel
else FieldError("confirmPassword", "passwords do not match").invalidNel
}
private def validateAge(age: Int): ValidationResult[Int] = {
FieldValidators.range("age", age, 13, 120)
}
private def validateWebsite(website: Option[String]): ValidationResult[Option[String]] = {
website match {
case Some(url) =>
(
FieldValidators.url("website", url),
SecurityValidators.xssSafeString("website", url)
).mapN((_, _) => Some(url))
case None => None.validNel
}
}
private def validateBio(bio: String): ValidationResult[String] = {
(
FieldValidators.maxLength("bio", bio, 500),
SecurityValidators.xssSafeString("bio", bio),
SecurityValidators.sqlSafeString("bio", bio)
).mapN((_, _, _) => bio)
}
}
Conclusion
Security in Scala applications requires a comprehensive approach covering multiple layers of protection. Key concepts include:
Authentication Systems:
- JWT token-based authentication
- Password hashing and verification
- Two-factor authentication (2FA)
- OAuth 2.0 integration
- Session management and token storage
Authorization Frameworks:
- Role-Based Access Control (RBAC)
- Attribute-Based Access Control (ABAC)
- Policy-based authorization
- Resource ownership patterns
- Method-level security annotations
Input Validation and Security:
- SQL injection prevention
- Cross-site scripting (XSS) protection
- Path traversal prevention
- File upload validation
- HTML sanitization
Security Best Practices:
- Principle of least privilege
- Defense in depth
- Secure by default
- Input validation at boundaries
- Output encoding and sanitization
Cryptographic Operations:
- Secure password hashing (bcrypt)
- Token generation and validation
- HMAC-based authentication
- Time-based one-time passwords (TOTP)
HTTP Security:
- HTTPS enforcement
- Security headers
- CORS configuration
- Request rate limiting
- Circuit breaker patterns
Monitoring and Auditing:
- Authentication event logging
- Authorization decision tracking
- Security metric collection
- Intrusion detection
- Compliance reporting
Infrastructure Security:
- Environment-based configuration
- Secret management
- Database security
- API security
- Network security
Implementing comprehensive security measures ensures that Scala applications can withstand common attacks and protect sensitive data while maintaining usability and performance. Security should be considered from the design phase through deployment and ongoing maintenance.
Comments
Be the first to comment on this lesson!