From 9108f7ca8f240da3339155ace09ae669eaa07884 Mon Sep 17 00:00:00 2001 From: Chirag Modi Date: Mon, 4 Nov 2024 12:57:46 -0800 Subject: [PATCH] add client-local module --- .../client/local/InferenceServiceLocalImpl.kt | 90 +++++++++++++++++++ .../local/LlamaStackClientClientLocalImpl.kt | 78 ++++++++++++++++ .../local/LlamaStackClientLocalClient.kt | 43 +++++++++ .../api/client/local/LocalClientOptions.kt | 55 ++++++++++++ settings.gradle.kts | 1 + 5 files changed, 267 insertions(+) create mode 100644 llama-stack-client-kotlin-client-local/src/main/kotlin/com/llama_stack_client/api/client/local/InferenceServiceLocalImpl.kt create mode 100644 llama-stack-client-kotlin-client-local/src/main/kotlin/com/llama_stack_client/api/client/local/LlamaStackClientClientLocalImpl.kt create mode 100644 llama-stack-client-kotlin-client-local/src/main/kotlin/com/llama_stack_client/api/client/local/LlamaStackClientLocalClient.kt create mode 100644 llama-stack-client-kotlin-client-local/src/main/kotlin/com/llama_stack_client/api/client/local/LocalClientOptions.kt diff --git a/llama-stack-client-kotlin-client-local/src/main/kotlin/com/llama_stack_client/api/client/local/InferenceServiceLocalImpl.kt b/llama-stack-client-kotlin-client-local/src/main/kotlin/com/llama_stack_client/api/client/local/InferenceServiceLocalImpl.kt new file mode 100644 index 0000000..4ea1850 --- /dev/null +++ b/llama-stack-client-kotlin-client-local/src/main/kotlin/com/llama_stack_client/api/client/local/InferenceServiceLocalImpl.kt @@ -0,0 +1,90 @@ +// File generated from our OpenAPI spec by Stainless. + +package com.llama_stack_client.api.client.local + +import com.llama_stack_client.api.core.RequestOptions +import com.llama_stack_client.api.models.InferenceChatCompletionParams +import com.llama_stack_client.api.models.InferenceChatCompletionResponse +import com.llama_stack_client.api.models.InferenceCompletionParams +import com.llama_stack_client.api.models.InferenceCompletionResponse +import com.llama_stack_client.api.services.blocking.inference.EmbeddingService +import com.llama_stack_client.api.services.blocking.InferenceService +import org.pytorch.executorch.LlamaModule +import org.pytorch.executorch.LlamaCallback + +class InferenceServiceLocalImpl +constructor( + private val clientOptions: LocalClientOptions, + +) : InferenceService, LlamaCallback { + override fun onResult(p0: String?) { + TODO("Not yet implemented") + } + override fun onStats(p0: Float) { + TODO("Not yet implemented") + } + + override fun embeddings(): EmbeddingService { + val mModule: LlamaModule; + TODO("Not yet implemented") + } + + override fun chatCompletion( + params: InferenceChatCompletionParams, + requestOptions: RequestOptions + ): InferenceChatCompletionResponse { + val mModule = clientOptions.llamaModule + + mModule.generate("what is the capital of France", 64, this , false) + +// val request = +// HttpRequest.builder() +// .method(HttpMethod.POST) +// .addPathSegments("inference", "chat_completion") +// .putAllQueryParams(clientOptions.queryParams) +// .putAllQueryParams(params.getQueryParams()) +// .putAllHeaders(clientOptions.headers) +// .putAllHeaders(params.getHeaders()) +// .body(json(clientOptions.jsonMapper, params.getBody())) +// .build() +// return clientOptions.httpClient.execute(request, requestOptions).let { response -> +// response +// .use { chatCompletionHandler.handle(it) } +// .apply { +// if (requestOptions.responseValidation ?: clientOptions.responseValidation) { +// validate() +// } +// } +// } + } + +// private val completionHandler: Handler = +// jsonHandler(clientOptions.jsonMapper) +// .withErrorHandler(errorHandler) + + override fun completion( + params: InferenceCompletionParams, + requestOptions: RequestOptions + ): InferenceCompletionResponse { + TODO("IMPLEMENT ET LOGIC HERE") +// val request = +// HttpRequest.builder() +// .method(HttpMethod.POST) +// .addPathSegments("inference", "completion") +// .putAllQueryParams(clientOptions.queryParams) +// .putAllQueryParams(params.getQueryParams()) +// .putAllHeaders(clientOptions.headers) +// .putAllHeaders(params.getHeaders()) +// .body(json(clientOptions.jsonMapper, params.getBody())) +// .build() +// return clientOptions.httpClient.execute(request, requestOptions).let { response -> +// response +// .use { completionHandler.handle(it) } +// .apply { +// if (requestOptions.responseValidation ?: clientOptions.responseValidation) { +// validate() +// } +// } +// } + } +} diff --git a/llama-stack-client-kotlin-client-local/src/main/kotlin/com/llama_stack_client/api/client/local/LlamaStackClientClientLocalImpl.kt b/llama-stack-client-kotlin-client-local/src/main/kotlin/com/llama_stack_client/api/client/local/LlamaStackClientClientLocalImpl.kt new file mode 100644 index 0000000..a28e634 --- /dev/null +++ b/llama-stack-client-kotlin-client-local/src/main/kotlin/com/llama_stack_client/api/client/local/LlamaStackClientClientLocalImpl.kt @@ -0,0 +1,78 @@ +// File generated from our OpenAPI spec by Stainless. + +package com.llama_stack_client.api.client.local + +import com.llama_stack_client.api.client.LlamaStackClientClient +import com.llama_stack_client.api.client.LlamaStackClientClientAsync +import com.llama_stack_client.api.models.* +import com.llama_stack_client.api.services.blocking.* + +class LlamaStackClientClientLocalImpl +constructor( + private val clientOptions: LocalClientOptions, +) : LlamaStackClientClient { + + private val inference: InferenceService by lazy { InferenceServiceLocalImpl(clientOptions) } + + override fun inference(): InferenceService = inference + + override fun async(): LlamaStackClientClientAsync { + TODO("Not yet implemented") + } + + override fun telemetry(): TelemetryService { + TODO("Not yet implemented") + } + + override fun agents(): AgentService { + TODO("Not yet implemented") + } + + override fun datasets(): DatasetService { + TODO("Not yet implemented") + } + + override fun evaluate(): EvaluateService { + TODO("Not yet implemented") + } + + override fun evaluations(): EvaluationService { + TODO("Not yet implemented") + } + + override fun safety(): SafetyService { + TODO("Not yet implemented") + } + + override fun memory(): MemoryService { + TODO("Not yet implemented") + } + + override fun postTraining(): PostTrainingService { + TODO("Not yet implemented") + } + + override fun rewardScoring(): RewardScoringService { + TODO("Not yet implemented") + } + + override fun syntheticDataGeneration(): SyntheticDataGenerationService { + TODO("Not yet implemented") + } + + override fun batchInference(): BatchInferenceService { + TODO("Not yet implemented") + } + + override fun models(): ModelService { + TODO("Not yet implemented") + } + + override fun memoryBanks(): MemoryBankService { + TODO("Not yet implemented") + } + + override fun shields(): ShieldService { + TODO("Not yet implemented") + } +} diff --git a/llama-stack-client-kotlin-client-local/src/main/kotlin/com/llama_stack_client/api/client/local/LlamaStackClientLocalClient.kt b/llama-stack-client-kotlin-client-local/src/main/kotlin/com/llama_stack_client/api/client/local/LlamaStackClientLocalClient.kt new file mode 100644 index 0000000..86e3e36 --- /dev/null +++ b/llama-stack-client-kotlin-client-local/src/main/kotlin/com/llama_stack_client/api/client/local/LlamaStackClientLocalClient.kt @@ -0,0 +1,43 @@ +package com.llama_stack_client.api.client.local + +import com.llama_stack_client.api.client.LlamaStackClientClient +import com.llama_stack_client.api.client.local.LlamaStackClientClientLocalImpl +import com.llama_stack_client.api.client.local.LocalClientOptions + +class LlamaStackClientLocalClient private constructor() { + + companion object { + fun builder() = Builder() + } + + class Builder { + + private var clientOptions: LocalClientOptions.Builder = LocalClientOptions.builder() + + private var modelPath: String? = null + private var tokenizerPath: String? = null + + fun modelPath(modelPath: String) = apply { + this.modelPath = modelPath + } + + fun tokenizerPath(tokenizerPath: String) = apply { + this.tokenizerPath = tokenizerPath + } + + fun fromEnv() = apply { clientOptions.fromEnv() } + + fun build(): LlamaStackClientClient { + + + return LlamaStackClientClientLocalImpl( + clientOptions + .modelPath("MODEL_PATH") + .tokenizerPath("TOKENIZER_PATH") + .temperature(0.0F) + .build() + ) + + } + } +} diff --git a/llama-stack-client-kotlin-client-local/src/main/kotlin/com/llama_stack_client/api/client/local/LocalClientOptions.kt b/llama-stack-client-kotlin-client-local/src/main/kotlin/com/llama_stack_client/api/client/local/LocalClientOptions.kt new file mode 100644 index 0000000..a0fd788 --- /dev/null +++ b/llama-stack-client-kotlin-client-local/src/main/kotlin/com/llama_stack_client/api/client/local/LocalClientOptions.kt @@ -0,0 +1,55 @@ +// File generated from our OpenAPI spec by Stainless. + +package com.llama_stack_client.api.client.local + +import org.checkerframework.checker.units.qual.Temperature +import org.pytorch.executorch.LlamaModule + +class LocalClientOptions +private constructor( + val modelPath: String, + val tokenizerPath: String, + val temperature : Float, + val llamaModule : LlamaModule +) { + + companion object { + fun builder() = Builder() + } + + class Builder { + private var modelPath: String? = null + private var tokenizerPath: String? = null + private var temperature: Float = 0.0F + private var llamaModule : LlamaModule? = null + + fun modelPath(modelPath: String) = apply { + this.modelPath = modelPath + } + + fun tokenizerPath(tokenizerPath: String) = apply { + this.tokenizerPath = tokenizerPath + } + + fun temperature(temperature: Float) = apply { + this.temperature = temperature + } + + fun fromEnv() = apply {} + + fun build(): LocalClientOptions { + checkNotNull(modelPath) { "`modelPath` is required but not set" } + checkNotNull(tokenizerPath) { "`tokenizerPath` is required but not set" } + + this.llamaModule = LlamaModule(1, modelPath, tokenizerPath, temperature); + checkNotNull(llamaModule) {"`temperature` is required but not set"} + + return LocalClientOptions( + modelPath!!, + tokenizerPath!!, + temperature!!, + llamaModule!! + ) + } + } +} diff --git a/settings.gradle.kts b/settings.gradle.kts index 8f96c3a..265caa5 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -4,3 +4,4 @@ include("llama-stack-client-kotlin") include("llama-stack-client-kotlin-client-okhttp") include("llama-stack-client-kotlin-core") include("llama-stack-client-kotlin-example") +include("llama-stack-client-kotlin-client-local")