Event-Driven AI Response Streaming with Axon Framework - A Practical Guide
Event-Driven AI Response Streaming with Axon Framework: A Practical Guide
Let’s dive into a practical implementation of AI response streaming using event-driven architecture with Axon Framework and Spring. This approach offers real-time updates while maintaining clean architectural patterns.
Core Components
Commands and Events
data class TestCommand(@TargetAggregateIdentifier val id: UUID, val message: String)
data class TestEvent(val id: UUID, val message: String)
data class TestAiResponseCommand(@TargetAggregateIdentifier val id: UUID, val message: String)
data class TestAiResponseEvent(val id: UUID, val message: String)
These simple data classes form our event-driven backbone. Each serves a specific purpose in our CQRS pattern.
The Shared Flux Component
@Component
class SharedFlux {
final val sink: Sinks.Many<Map<UUID, String>> = Sinks.many().multicast().onBackpressureBuffer()
val flux = sink.asFlux()
}
This component acts as our broadcast mechanism, allowing multiple subscribers to receive real-time updates from our AI stream.
The Implementation
1. Processing AI Responses
@Component
class Processor(
private val openAiChatModel: OpenAiChatModel,
private val sharedFlux: SharedFlux,
private val commandGateway: CommandGateway
) {
@EventHandler
fun on(event: TestEvent) {
val prompt = Prompt(UserMessage(event.message))
val finalResponse = StringBuilder()
openAiChatModel.stream(prompt)
.doOnNext { response ->
val content = response.result.output.content ?: ""
sharedFlux.sink.emitNext(mapOf(event.id to content), Sinks.EmitFailureHandler.FAIL_FAST)
finalResponse.append(content)
}
.doOnComplete {
commandGateway.send<TestAiResponseCommand>(
TestAiResponseCommand(event.id, finalResponse.toString())
)
}
.subscribe()
}
}
2. API Endpoints
@RestController
@RequestMapping("/api")
class SharedFluxController(
val queryGateway: QueryGateway,
private val sharedFlux: SharedFlux,
val commandGateway: CommandGateway,
) {
@GetMapping("/generateStream")
fun generateStream(message: String): Flux<Map<UUID, String>> {
val id = UUID.randomUUID()
commandGateway.send<TestCommand>(TestCommand(id, message))
return sharedFlux.flux
}
@GetMapping("/liveResponse")
fun getLiveResponse(): Flux<String> {
return sharedFlux.flux
.filter { it.containsKey(id) }
.map { it[id]!! }
}
}
Key Benefits
Separation of Concerns
- Commands handle user requests
- Events process AI responses
- Queries retrieve final results
- Shared Flux manages real-time updates
Real-Time Streaming
- Immediate updates as AI generates responses
- No polling required
- Efficient multicast to all subscribers
Event Sourcing
- Complete history of AI responses
- Ability to replay events
- Query final responses at any time
Scalability
- Decoupled components
- Event-driven architecture
- Ready for distributed systems
Practical Use Cases
This architecture is particularly useful for:
- Chat applications requiring real-time responses
- AI-powered content generation tools
- Interactive AI assistants
- Any system requiring real-time AI interaction with multiple clients
Implementation Notes
- Uses Spring AI’s OpenAiChatModel for AI integration
- Leverages Axon Framework for event sourcing and CQRS
- Implements both streaming and query-based endpoints
- Maintains state through event sourcing
The beauty of this approach lies in its simplicity and separation of concerns, while still providing robust functionality for real-time AI response streaming.
Complete Implementation
package com.versilite.demo.liveSharedAiResponse
import org.axonframework.commandhandling.CommandHandler
import org.axonframework.commandhandling.gateway.CommandGateway
import org.axonframework.eventhandling.EventHandler
import org.axonframework.eventsourcing.EventSourcingHandler
import org.axonframework.eventsourcing.eventstore.EventStore
import org.axonframework.modelling.command.*
import org.axonframework.queryhandling.QueryGateway
import org.axonframework.queryhandling.QueryHandler
import org.axonframework.queryhandling.QueryUpdateEmitter
import org.axonframework.spring.stereotype.Aggregate
import org.springframework.ai.chat.messages.UserMessage
import org.springframework.ai.chat.prompt.Prompt
import org.springframework.ai.openai.OpenAiChatModel
import org.springframework.ai.openai.OpenAiChatOptions
import org.springframework.ai.openai.api.OpenAiApi
import org.springframework.http.MediaType
import org.springframework.stereotype.Component
import org.springframework.web.bind.annotation.GetMapping
import org.springframework.web.bind.annotation.RequestMapping
import org.springframework.web.bind.annotation.RequestParam
import org.springframework.web.bind.annotation.RestController
import reactor.core.publisher.Flux
import reactor.core.publisher.Sinks
import java.util.*
@RestController
class TestController(
private val openAiChatModel: OpenAiChatModel
) {
@GetMapping("/ai/generateStream")
fun generateStream(
@RequestParam(
value = "message",
defaultValue = "Tell me a joke"
) message: String?
): Flux<String> {
val client = OpenAiChatOptions.builder()
.withModel(OpenAiApi.ChatModel.GPT_4_O_MINI)
.build()
val prompt = Prompt(UserMessage(message), client)
val fullResponse = StringBuilder()
return openAiChatModel.stream(prompt)
.doOnNext { response ->
fullResponse.append(response.result.output.content?: "")
}.doOnComplete {
println("Full response: $fullResponse")
}.map {
it.result.output.content?: ""
}
}
}
data class TestCommand(@TargetAggregateIdentifier val id: UUID, val message: String)
data class TestEvent(val id: UUID, val message: String)
data class TestAiResponseCommand(@TargetAggregateIdentifier val id: UUID, val message: String)
data class TestAiResponseEvent(val id: UUID, val message: String)
@Aggregate
class TestAggregate {
@AggregateIdentifier
private lateinit var id: UUID
constructor() // Required no-args constructor for Axon
@CommandHandler
@CreationPolicy(AggregateCreationPolicy.ALWAYS)
fun handle(command: TestCommand) {
println("Handling command")
AggregateLifecycle.apply(TestEvent(command.id, command.message))
}
@CommandHandler
fun handle(command: TestAiResponseCommand) {
println("Handling ai response command")
AggregateLifecycle.apply(TestAiResponseEvent(command.id, command.message))
}
@EventSourcingHandler
fun on(event: TestEvent) {
id = event.id
}
}
@Component
class Processor(
private val openAiChatModel: OpenAiChatModel,
private val sharedFlux: SharedFlux,
private val commandGateway: CommandGateway
) {
@EventHandler
fun on(event: TestEvent) {
val prompt = Prompt(UserMessage(event.message), OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_O_MINI).build())
val finalResponse = StringBuilder()
openAiChatModel.stream(prompt)
.doOnNext { response ->
val content = response.result.output.content ?: ""
val map = mapOf(event.id to content)
sharedFlux.sink.emitNext(map, Sinks.EmitFailureHandler.FAIL_FAST) // Emit each response chunk to the shared sink
finalResponse.append(content)
}
.doOnComplete {
sharedFlux.sink.tryEmitComplete() // Complete the stream after processing is done
println(finalResponse)
commandGateway.send<TestAiResponseCommand>(TestAiResponseCommand(event.id, finalResponse.toString()))
}
.subscribe()
}
}
data class TestQuery(val id: UUID)
data class TestReadModel(val message: String)
@Component
class SharedFlux {
final val sink: Sinks.Many<Map<UUID, String>> = Sinks.many().multicast().onBackpressureBuffer()
val flux = sink.asFlux() // Shared Flux that emits values to subscribers
}
@RestController
@RequestMapping("/api")
class SharedFluxController(
val queryGateway: QueryGateway,
private val sharedFlux: SharedFlux,
val commandGateway: CommandGateway,
) {
private val id: UUID = UUID.randomUUID()
@GetMapping("/generateStream", produces = [MediaType.TEXT_EVENT_STREAM_VALUE])
fun generateStream(
@RequestParam(
value = "message",
defaultValue = "Tell me a joke"
) message: String
): Flux<Map<UUID, String>> {
println("Our id in generateStream is $id")
commandGateway.send<TestCommand>(TestCommand(id, message))
return sharedFlux.flux // Return the shared flux for streaming to the client
}
@GetMapping("/finalResponse", produces = [MediaType.TEXT_EVENT_STREAM_VALUE])
fun getFinalResponse(
): Flux<TestReadModel> {
println("Our id in finalResponse is $id")
val query= queryGateway.subscriptionQuery(TestQuery(id), TestReadModel::class.java, TestReadModel::class.java)
return query.initialResult()
.concatWith(query.updates())
}
@GetMapping("/liveResponse", produces = [MediaType.TEXT_EVENT_STREAM_VALUE])
fun getLiveResponse(
): Flux<String> {
println("Our id in liveResponse is $id")
return sharedFlux.flux.filter { it.containsKey(id) }.map { it[id]!! }
}
}
@Component
class QueryHandler(
private val eventStore: EventStore,
private val queryUpdateEmitter: QueryUpdateEmitter
) {
@QueryHandler
fun handle(query: TestQuery): TestReadModel {
val events = eventStore
.readEvents(query.id.toString())
.asStream()
.filter { it.payload is TestAiResponseEvent }
.map {
it.payload as TestAiResponseEvent
}.toList()
if (events.isEmpty()) {
return TestReadModel("No response yet")
}
return TestReadModel(events.last().message)
}
@EventHandler
fun on(event: TestAiResponseEvent) {
queryUpdateEmitter.emit(
TestQuery::class.java,
{ query -> query.id == event.id },
TestReadModel(event.message)
)
}
}