ChatGPTのStreamモードをOkHttp-sse+Gson+Hiltを使って実装
/ 9 min read
Updated:Table of Contents
はじめに
最近ChatGPTが流行ってますね。ブラウザでChatGPTを使うと、レスポンスを一文字ずつ返してくれます。これをAPI経由で実現するためにはどうすればいいか調べてみると、APIにリクエストを送るときにstreamパラメータに対してtrueを設定することで、実現できることがわかりました。この一文字ずつ返すのはServer Sent Events(SSE)として送られているみたいです。
実装
SSEを扱いやすくするために、okhttp-sseを使用します。今回はリクエスト部分などの紹介は省き、SSEで通信する箇所をメインに紹介します。実装したコードは、こちらのリポジトリにアップしています。
実装イメージとしてはこのようなものを想定しています。通信を完了するとSnackBarを表示するようにもしてみました。
完成のイメージ
OkHttpをRepository層にDIするためにHiltのProvidesを使います。
@Module@InstallIn(SingletonComponent::class)class OkHttpModule {
// TODO OpenAIのトークンを貼る private val token = ""
@Singleton @Provides fun providesRequestBuilder(): Request.Builder { return Request.Builder() .url("<https://api.openai.com/v1/chat/completions>") .header("Accept", "application/json") .addHeader("Authorization", "Bearer $token") }
@Singleton @Provides fun providesOkHttpClient(): OkHttpClient { return OkHttpClient.Builder() .readTimeout(10, TimeUnit.MINUTES) .connectTimeout(10, TimeUnit.MINUTES) .build() }}Repository層
実際にOpenAIに対してデータ取得をするRepository層を作成します。okhttp-sseを使う上で特に重要な箇所は、EventSourceListenerクラスのonEventメソッドです。ここに送られてきた文字列(Json形式のときもあれば、ただの文字列のときもある)が入ってきます。
SSEの通信イベントをSSEEventとして定義し、StateFlow
sealed interface SSEEvent { object Empty : SSEEvent object Open : SSEEvent data class Event(val response: GPT35TurboResponse) : SSEEvent data class Failure(val e: Throwable, val response: Response?) : SSEEvent object Closed : SSEEvent}Repository層のコードです。onEventメソッドで送られてきたJsonレスポンスをGsonで変換します。ただし、ChatGPTが生成した文字列がすべて送られると[DONE]という文字列が送られてくるので、このときはJsonの変換を行いようにします。
createFactory()のnewEventSource()を実行したタイミングで通信を開始します。
interface OpenAiRepository { suspend fun postCompletions(gpt35Turbo: GPT35Turbo) val state: StateFlow<SSEEvent>}
@Singletonclass OpenAiRepositoryImpl @Inject constructor( private val requestBuilder: Request.Builder, private val client: OkHttpClient,) : OpenAiRepository {
private val _state = MutableStateFlow<SSEEvent>(SSEEvent.Empty) override val state = _state.asStateFlow()
private val gson: Gson = Gson() private var eventSource: EventSource? = null private val eventSourceListener = object : EventSourceListener() {
override fun onOpen(eventSource: EventSource, response: Response) { super.onOpen(eventSource, response) _state.value = SSEEvent.Open }
override fun onEvent(eventSource: EventSource, id: String?, type: String?, data: String) { super.onEvent(eventSource, id, type, data)
if (data != "[DONE]") { val response = gson.fromJson(data, GPT35TurboResponse::class.java) val message = response.choices[0].delta.content _state.value = SSEEvent.Event(response) } }
override fun onFailure(eventSource: EventSource, t: Throwable?, response: Response?) { super.onFailure(eventSource, t, response)
if (t != null) { _state.value = SSEEvent.Failure(t, response) } }
override fun onClosed(eventSource: EventSource) { super.onClosed(eventSource)
_state.value = SSEEvent.Closed } }
override suspend fun postCompletions(gpt35Turbo: GPT35Turbo) { val requestBody = gson.toJson(gpt35Turbo)
val request = requestBuilder .post(requestBody.toRequestBody("application/json; charset=UTF-8".toMediaTypeOrNull())) .build()
withContext(Dispatchers.IO) { eventSource = EventSources.createFactory(client) .newEventSource(request, eventSourceListener) } }}Repository層のHiltの設定です。interfaceを作成しているのでBindsを使います。
@Module@InstallIn(SingletonComponent::class)abstract class RepositoryModule() {
@Binds abstract fun bindOpenAiRepository( openAiRepositoryImpl: OpenAiRepositoryImpl ): OpenAiRepository}UseCase層
Repository層で取得したデータを加工するためにUseCaseを実装します。今回はUseCase層でRepositoryから送られてきたイベントに応じて加工後、StateFlow
Stateの実装はこのように行いました。ほぼSSEEventと一緒ですが、UseCase層では、Repository層から送られてきた、GPT35TurboResponse内にあるChatGPTから送られてきた文字列を取り出し公開するために、StateのEventをStringにしています。(実装後の反省ですが、細かく分けすぎて逆にわかりにくくなった気がするのでStateを改めて定義する必要はなかったかもです)
sealed interface State { object Empty: State object Open : State data class Event(val response: String) : State data class Failure(val e: Throwable, val response: Response?) : State object Closed : State}ほぼRepository層をラップしたものですが、SSEEvent.Eventのところで送られてきた文字列を取り出しています。
interface OpenAiUseCase { suspend fun postCompletions(gpt35Turbo: GPT35Turbo) fun cancelCompletions() val state: StateFlow<State>}
@Singletonclass OpenAiUseCaseImpl @Inject constructor( private val repository: OpenAiRepository) : OpenAiUseCase {
private val _state = MutableStateFlow<State>(State.Empty) override val state = _state.asStateFlow()
override suspend fun postCompletions(gpt35Turbo: GPT35Turbo) { repository.postCompletions(gpt35Turbo) repository.state.collect { event ->
when (event) { is SSEEvent.Empty -> _state.value = State.Empty is SSEEvent.Open -> _state.value = State.Open is SSEEvent.Event -> { val value = event.response.choices.first().delta.content ?: "" _state.value = State.Event(value) } is SSEEvent.Failure -> { _state.value = State.Failure(event.e, event.response) } is SSEEvent.Closed -> _state.value = State.Closed } } }}UseCase層のHiltの設定です。UseCase層もinterfaceを作成しているのでBindsを使います。
@Module@InstallIn(SingletonComponent::class)abstract class UseCaseModule {
@Binds abstract fun bindOpenAiUseCase( openAiUseCaseImpl: OpenAiUseCaseImpl ): OpenAiUseCase}ViewModel層
UseCaseからデータを受取ってView層に渡してあげます。
@HiltViewModelclass MainActivityViewModel @Inject constructor( private val useCase: OpenAiUseCase) : ViewModel() {
data class UiState( val generatedText: String = "", )
sealed interface UiEvent { data class ShowSnackBar(val message: String) : UiEvent object Empty : UiEvent }
private val _state = MutableStateFlow(UiState()) val state = _state.asStateFlow() private val _event = Channel<UiEvent>() val event: Flow<UiEvent> = _event.consumeAsFlow()
init { viewModelScope.launch { useCase.state.collect { state -> when (state) { is State.Event -> { _state.update { UiState(it.generatedText + state.response) } } is State.Closed -> { _event.send(UiEvent.ShowSnackBar("完了しました")) } else -> { } } } } }
fun start() { viewModelScope.launch { _state.update { UiState("") } val messages = listOf( Messages(Messages.Role.SYSTEM, "あなたは生粋の関西人です。"), Messages(Messages.Role.ASSISTANT, "大阪名物について"), Messages(Messages.Role.USER, "150文字以内かつ関西弁で紹介して下さい。"), ) val gpt35Turbo = GPT35Turbo(messages = messages) useCase.postCompletions(gpt35Turbo) } }}View層
Actvityに直接TextViewを作成しViewBindingで送ります。
@AndroidEntryPointclass MainActivity : AppCompatActivity() {
private val viewModel: MainActivityViewModel by viewModels() private lateinit var binding: ActivityMainBinding
override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState)
binding = ActivityMainBinding.inflate(layoutInflater) val view = binding.root setContentView(view)
binding.button.setOnClickListener { viewModel.start() }
lifecycleScope.launch { repeatOnLifecycle(Lifecycle.State.STARTED) { viewModel.state.collect { binding.textView.text = it.generatedText } } } lifecycleScope.launch { repeatOnLifecycle(Lifecycle.State.STARTED) { viewModel.event.collect { event -> when (event) { is MainActivityViewModel.UiEvent.ShowSnackBar -> { Snackbar.make(view, event.message, Snackbar.LENGTH_LONG).show() } else -> {} } } } } }}最後に
ChatGPTのStream APIを組み合わせて使ってみました。通常のRESTのAPIを使っても良いのですが、ChatGPTからの応答時間がかかるため、実際にアプリの機能として落としこむためには、ユーザへの見せ方を工夫する必要があります。しかし、このStreamを使った実装方法では、リアルタイムに結果をユーザへ反映できるため飽きさせない、かつ、わくわくする見せ方を可能とします。可能であればこの実装方法で機能に組込みたいですね。