Reactive scalable chat on Kotlin + Spring + WebSockets

Content

  1. Project configuration





    1. Logger





    2. Domain





    3. Mapper





  2. Configuring Spring Security





  3. Web Sockets Configuration





  4. Solution architecture





  5. Implementation





    1. Integration with Redis





    2. Service implementation





  6. Conclusion





Foreword

In this tutorial, we will consider creating a scalable application, connecting and communicating with which occurs via web sockets. Let's consider and courageously overcome the problem of transferring messages between instances using a message broker. Redis will be used as the broker's message.





Project configuration

Let's start with the most important one, the configuration of the logger!

, prototype bean, , .





@Configuration
class LoggingConfig {

    @Bean
    @Scope("prototype")
    fun logger(injectionPoint: InjectionPoint): Logger {
        return LoggerFactory.getLogger(
                injectionPoint.methodParameter?.containingClass
                        ?: injectionPoint.field?.declaringClass
        )
    }
}
      
      



, , .





@Component
class ChatWebSocketHandlerService(
    private val logger: Logger
) 
      
      



, .





data class Chat(
    val chatId: UUID,
    val chatMembers: List<ChatMember>,
    @JsonSerialize(using = LocalDateTimeSerializer::class)
    @JsonDeserialize(using = LocalDateTimeDeserializer::class)
    val createdDate: LocalDateTime,
    var lastMessage: CommonMessage?
)
      
      



ChatMember . - deletedChat. - userId.





data class ChatMember(
        val userId: UUID,
        var fullName: String,
        var avatar: String,
        var deletedChat: Boolean
)
      
      



. @JsonTypeInfo , - JSON @type , .





@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY)
open class CommonMessage(
    val messageId: UUID,
    val chatId: UUID,
    val sender: ChatMember,
    @field:JsonSerialize(using = LocalDateTimeSerializer::class) @field:JsonDeserialize(using = LocalDateTimeDeserializer::class)
    val messageDate: LocalDateTime,
    var seen: Boolean
)
      
      



TextMessage -





class TextMessage(
    messageId: UUID,
    chatId: UUID,
    sender: ChatMember,
    var content: String,
    messageDate: LocalDateTime,
    seen: Boolean
) : CommonMessage(messageId, chatId, sender, messageDate, messageType, seen)
      
      



ObjectMapper





registerSubtypes -, JSON. ,





@Configuration
class ObjectMapperConfig {

    @Bean
    fun objectMapper(): ObjectMapper = ObjectMapper()
        .registerModule(JavaTimeModule())
        .registerModule(Jdk8Module())
        .registerModule(ParameterNamesModule())
        .registerModule(KotlinModule())
        .disable(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS)
        .apply {
            registerSubtypes(
                NamedType(NewMessageEvent::class.java, "NewMessageEvent"),
                NamedType(MarkMessageAsRead::class.java, "MarkMessageAsRead"),
                NamedType(TextMessage::class.java, "TextMessage"),
                NamedType(ImageMessage::class.java, "ImageMessage")
            )
        }
}
      
      



Spring Security

ReactiveAuthenticationManager SecurityContextRepository. JWT, JwtAuthenticationManager :





@Component
class JwtAuthenticationManager(val jwtUtil: JwtUtil) : ReactiveAuthenticationManager {

    override fun authenticate(authentication: Authentication): Mono<Authentication> {
        val token = authentication.credentials.toString()
        val validateToken = jwtUtil.validateToken(token)
        var username: String?
        try {
            username = jwtUtil.extractUsername(token)
        } catch (e: Exception) {
            username = null
            println(e)
        }
        return if (username != null && validateToken) {
            val claims = jwtUtil.getClaimsFromToken(token)
            val role: List<String> = claims["roles"] as List<String>
            val authorities = role.stream()
                    .map { role: String? -> SimpleGrantedAuthority(role) }
                    .collect(Collectors.toList())
            val authenticationToken = UsernamePasswordAuthenticationToken(
                    username,
                    null,
                    authorities
            )
            authenticationToken.details = claims
            Mono.just(authenticationToken)
        } else {
            Mono.empty()
        }
    }
}
      
      



, , seucirty context, claims details ( 25).





SecurityContextRepository. :





  1. Authorization: Bearer ${JWT_TOKEN}





  2. GET ?access_token=${JWT_TOKEN}





@Component
class SecurityContextRepository(val authenticationManager: ReactiveAuthenticationManager) : ServerSecurityContextRepository {
    override fun save(exchange: ServerWebExchange, context: SecurityContext): Mono<Void> {
        return Mono.error { IllegalStateException("Save method not supported") }
    }

    override fun load(exchange: ServerWebExchange): Mono<SecurityContext> {
        val authHeader = exchange.request
            .headers
            .getFirst(HttpHeaders.AUTHORIZATION)

        val accessToken: String = if (authHeader != null && authHeader.startsWith("Bearer ")) {
            authHeader.substring(7)

        } else exchange.request
            .queryParams
            .getFirst("access_token") ?: return Mono.empty()

        val auth = UsernamePasswordAuthenticationToken(accessToken, accessToken)
        return authenticationManager
            .authenticate(auth)
            .map { authentication: Authentication -> SecurityContextImpl(authentication) }
    }
}
      
      



Spring Security.





@EnableWebFluxSecurity
@EnableReactiveMethodSecurity
class SecurityConfig(
    val reactiveAuthenticationManager: ReactiveAuthenticationManager,
    val securityContextRepository: SecurityContextRepository
) {

    @Bean
    fun securityWebFilterChain(httpSecurity: ServerHttpSecurity): SecurityWebFilterChain {
        return httpSecurity
            .exceptionHandling()
            .authenticationEntryPoint { swe: ServerWebExchange, e: AuthenticationException ->
                Mono.fromRunnable { swe.response.statusCode = HttpStatus.UNAUTHORIZED }
            }
            .accessDeniedHandler { swe: ServerWebExchange, e: AccessDeniedException ->
                Mono.fromRunnable { swe.response.statusCode = HttpStatus.FORBIDDEN }
            }
            .and()
            .csrf().disable()
            .cors().disable()
            .formLogin().disable()
            .httpBasic().disable()
            .authenticationManager(reactiveAuthenticationManager)
            .securityContextRepository(securityContextRepository)
            .authorizeExchange()
            .pathMatchers("/actuator/**").permitAll()
            .pathMatchers(HttpMethod.GET, "/ws/**").hasAuthority("ROLE_USER")
            .anyExchange().authenticated()
            .and()
            .build()
    }
}
      
      



: /ws , ROLE_USER.





Security , .





-

. , :





  1. , - uri, - . WebSocketHandler.





  2. cors.





@Configuration
class ReactiveWebSocketConfig {

    @Bean
    fun webSocketHandlerMapping(chatWebSocketHandler: ChatWebSocketHandler): HandlerMapping {
        val map: MutableMap<String, WebSocketHandler> = HashMap()
        map["/ws/chat"] = chatWebSocketHandler

        val handlerMapping = SimpleUrlHandlerMapping()
        handlerMapping.setCorsConfigurations(Collections.singletonMap("*", CorsConfiguration().applyPermitDefaultValues()))
        handlerMapping.order = 1
        handlerMapping.urlMap = map
        return handlerMapping
    }

    @Bean
    fun handlerAdapter(): WebSocketHandlerAdapter {
        return WebSocketHandlerAdapter()
    }
}
      
      



uri /ws/chat chatWebSocketHandler, , . WebSocketHandler, handle(session: WebSocketSession): Mono<Void>





@Component
class ChatWebSocketHandler : WebSocketHandler {
    override fun handle(session: WebSocketSession): Mono<Void> {
        TODO("Not yet implemented")
    }

}

      
      



.





. - , , , . , . Message Broker, . , , .





, User 1 User 2 . User 1 Chat-Instance-0, User 2 Chat-Instance-1. , User 1 Chat-Instance-0 ( ), Message broker, . Chat-Instance-1 , User 2, .





ChatWebSocketHandler





userId => session, , userId. userId : MutableMap<UUID, LinkedList<WebSocketSession>>.





session.receive, doFinally.





getReceiverStream - , . payload WebSocketEvent, event'a .





getSenderStream ,





@Component
class ChatWebSocketHandler(
    val objectMapper: ObjectMapper,
    val logger: Logger,
    val chatService: ChatService,
    val objectStringConverter: ObjectStringConverter,
    val sinkWrapper: SinkWrapper
) : WebSocketHandler {

    private val userIdToSession: MutableMap<UUID, LinkedList<WebSocketSession>> = ConcurrentHashMap()

    override fun handle(session: WebSocketSession): Mono<Void> {
        return ReactiveSecurityContextHolder.getContext()
            .flatMap { ctx ->
                val userId = UUID.fromString((ctx.authentication.details as Claims)["id"].toString())
                val sender = getSenderStream(session, userId)
                val receiver = getReceiverStream(session, userId)

                return@flatMap Mono.zip(sender, receiver).then()
            }
    }

    private fun getReceiverStream(session: WebSocketSession, userId: UUID): Mono<Void> {
        return session.receive()
            .filter { it.type == WebSocketMessage.Type.TEXT }
            .map(WebSocketMessage::getPayloadAsText)
            .flatMap {
                objectStringConverter.stringToObject(it, WebSocketEvent::class.java)
            }
            .flatMap { convertedEvent ->
                when (convertedEvent) {
                    is NewMessageEvent -> chatService.handleNewMessageEvent(userId, convertedEvent)
                    is MarkMessageAsRead -> chatService.markPreviousMessagesAsRead(convertedEvent.messageId)
                    else -> Mono.error(RuntimeException())
                }
            }
            .onErrorContinue { t, _ -> logger.error("Error occurred with receiver stream", t) }
            .doOnSubscribe {
                val userSession = userIdToSession[userId]
                if (userSession == null) {
                    val newUserSessions = LinkedList<WebSocketSession>()
                    userIdToSession[userId] = newUserSessions
                }
                userIdToSession[userId]?.add(session)
            }
            .doFinally {
                val userSessions = userIdToSession[userId]
                userSessions?.remove(session)
            }
            .then()
    }

    private fun getSenderStream(session: WebSocketSession, userId: UUID): Mono<Void> {
        val sendMessage = sinkWrapper.sinks.asFlux()
            .filter { sendTo -> sendTo.userId == userId }
            .map { sendTo -> objectMapper.writeValueAsString(sendTo.event) }
            .map { stringObject -> session.textMessage(stringObject) }
            .doOnError { logger.error("", it) }
        return session.send(sendMessage)
    }
}
      
      



websocket , . reactora 3.4 Sinks.Many. SinkWrapper.





@Component
class SinkWrapper {
    val sinks: Sinks.Many<SendTo> = Sinks.many().multicast().onBackpressureBuffer()
}
      
      



, , , getSenderStream.





Redis

Redis PUB/SUB , .





, :





  1. RedisChatMessageListener -





  2. RedisChatMessagePublisher -





  3. RedisConfig -





  4. RedisListenerStarter -





:





RedisConfig ,





@Configuration
class RedisConfig {

    @Bean
    fun reactiveRedisConnectionFactory(redisProperties: RedisProperties): ReactiveRedisConnectionFactory {
        val redisStandaloneConfiguration = RedisStandaloneConfiguration(redisProperties.host, redisProperties.port)
        redisStandaloneConfiguration.setPassword(redisProperties.password)
        return LettuceConnectionFactory(redisStandaloneConfiguration)
    }

    @Bean
    fun template(reactiveRedisConnectionFactory: ReactiveRedisConnectionFactory): ReactiveStringRedisTemplate {
        return ReactiveStringRedisTemplate(reactiveRedisConnectionFactory)
    }
}
      
      



RedisChatMessageListener

( ). ( 13) sendMessage, , .





@Component
class RedisChatMessageListener(
    private val logger: Logger,
    private val reactiveStringRedisTemplate: ReactiveStringRedisTemplate,
    private val objectStringConverter: ObjectStringConverter,
    private val chatService: ChatService
) {

    fun subscribeOnCommonMessageTopic(): Mono<Void> {
        return reactiveStringRedisTemplate.listenTo(PatternTopic(CommonMessage::class.java.name))
            .map { message -> message.message }
            .doOnNext { logger.info("Receive new message: $it") }
            .flatMap { objectStringConverter.stringToObject(it, CommonMessage::class.java) }
            .flatMap { message ->
                when (message) {
                    is TextMessage -> chatService.sendMessage(message)
                    is ImageMessage -> chatService.sendMessage(message)
                    else -> Mono.error(RuntimeException())
                }
            }
            .then()
    }
}
      
      



RedisChatMessagePublisher

CommonMessage . .





@Component
class RedisChatMessagePublisher(
    val logger: Logger,
    val reactiveStringRedisTemplate: ReactiveStringRedisTemplate,
    val objectStringConverter: ObjectStringConverter
) {
    fun broadcastMessage(commonMessage: CommonMessage): Mono<Void> {
        return objectStringConverter.objectToString(commonMessage)
            .flatMap {
                logger.info("Broadcast message $it to channel ${CommonMessage::class.java.name}")
                reactiveStringRedisTemplate.convertAndSend(CommonMessage::class.java.name, it)
            }
            .then()
    }
}
      
      



RedisListenerStarter

RedisChatMessageListener. - subscribeOnCommonMessageTopic





@Component
class RedisListenerStarter(
    val logger: Logger,
    val redisChatMessageListener: RedisChatMessageListener
) {

    @Bean
    fun newMessageEventChannelListenerStarter(): ApplicationRunner {
        return ApplicationRunner { args: ApplicationArguments ->
            redisChatMessageListener.subscribeOnCommonMessageTopic()
                .doOnSubscribe { logger.info("Start NewMessageEvent channel listener") }
                .onErrorContinue { throwable, _ -> logger.error("Error occurred while listening NewMessageEvent channel", throwable) }
                .subscribe()
        }
    }
}
      
      



, chatRepository. , , .





handleNewMessageEvent WebSocketHandler userId NewMessageEvent - . , .





@Service
class DefaultChatService(
    val logger: Logger,
    val sinkWrapper: SinkWrapper,
    val chatRepository: ChatRepository,
    val redisChatPublisher: RedisChatMessagePublisher
) : ChatService {

    override fun handleNewMessageEvent(senderId: UUID, newMessageEvent: NewMessageEvent): Mono<Void> {
        logger.info("Receive NewMessageEvent from $senderId: $newMessageEvent")
        return chatRepository.findById(newMessageEvent.chatId)
            .filter { it.chatMembers.map(ChatMember::userId).contains(senderId) }
            .flatMap { chat ->
                val textMessage = TextMessage(UUID.randomUUID(), chat.chatId, chat.chatMembers.first { it.userId == senderId }, newMessageEvent.content, LocalDateTime.now(), false)
                chat.lastMessage = textMessage
                return@flatMap Mono.zip(chatRepository.save(chat), Mono.just(textMessage))
            }
            .flatMap { broadcastMessage(it.t2) }
    }

    /**
     * Broadcast the message between instances
     */
    override fun broadcastMessage(commonMessage: CommonMessage): Mono<Void> {
        return redisChatPublisher.broadcastMessage(commonMessage)
    }

    /**
     * Send the message to all of chatMembers of message chat direct
     */
    override fun sendMessage(message: CommonMessage): Mono<Void> {
        return chatRepository.findById(message.chatId)
            .map { it.chatMembers }
            .flatMapMany { Flux.fromIterable(it) }
            .flatMap { member -> sendEventToUserId(member.userId, ChatMessageEvent(message.chatId, message)) }
            .then()
    }

    override fun sendEventToUserId(userId: UUID, webSocketEvent: WebSocketEvent): Mono<Void> {
        return Mono.fromCallable { sinkWrapper.sinks.emitNext(SendTo(userId, webSocketEvent), Sinks.EmitFailureHandler.FAIL_FAST) }
            .then()
    }
}
      
      



. , , WebSocketEvent , event => handler. , , , .





GitHub








All Articles