Skip to content

Commit

Permalink
add client-local module
Browse files Browse the repository at this point in the history
  • Loading branch information
cmodi-meta committed Nov 5, 2024
1 parent 5de5a03 commit 9108f7c
Show file tree
Hide file tree
Showing 5 changed files with 267 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// File generated from our OpenAPI spec by Stainless.

package com.llama_stack_client.api.client.local

import com.llama_stack_client.api.core.RequestOptions
import com.llama_stack_client.api.models.InferenceChatCompletionParams
import com.llama_stack_client.api.models.InferenceChatCompletionResponse
import com.llama_stack_client.api.models.InferenceCompletionParams
import com.llama_stack_client.api.models.InferenceCompletionResponse
import com.llama_stack_client.api.services.blocking.inference.EmbeddingService
import com.llama_stack_client.api.services.blocking.InferenceService
import org.pytorch.executorch.LlamaModule
import org.pytorch.executorch.LlamaCallback

class InferenceServiceLocalImpl
constructor(
private val clientOptions: LocalClientOptions,

) : InferenceService, LlamaCallback {
override fun onResult(p0: String?) {
TODO("Not yet implemented")
}
override fun onStats(p0: Float) {
TODO("Not yet implemented")
}

override fun embeddings(): EmbeddingService {
val mModule: LlamaModule;
TODO("Not yet implemented")
}

override fun chatCompletion(
params: InferenceChatCompletionParams,
requestOptions: RequestOptions
): InferenceChatCompletionResponse {
val mModule = clientOptions.llamaModule

mModule.generate("what is the capital of France", 64, this , false)

// val request =
// HttpRequest.builder()
// .method(HttpMethod.POST)
// .addPathSegments("inference", "chat_completion")
// .putAllQueryParams(clientOptions.queryParams)
// .putAllQueryParams(params.getQueryParams())
// .putAllHeaders(clientOptions.headers)
// .putAllHeaders(params.getHeaders())
// .body(json(clientOptions.jsonMapper, params.getBody()))
// .build()
// return clientOptions.httpClient.execute(request, requestOptions).let { response ->
// response
// .use { chatCompletionHandler.handle(it) }
// .apply {
// if (requestOptions.responseValidation ?: clientOptions.responseValidation) {
// validate()
// }
// }
// }
}

// private val completionHandler: Handler<InferenceCompletionResponse> =
// jsonHandler<InferenceCompletionResponse>(clientOptions.jsonMapper)
// .withErrorHandler(errorHandler)

override fun completion(
params: InferenceCompletionParams,
requestOptions: RequestOptions
): InferenceCompletionResponse {
TODO("IMPLEMENT ET LOGIC HERE")
// val request =
// HttpRequest.builder()
// .method(HttpMethod.POST)
// .addPathSegments("inference", "completion")
// .putAllQueryParams(clientOptions.queryParams)
// .putAllQueryParams(params.getQueryParams())
// .putAllHeaders(clientOptions.headers)
// .putAllHeaders(params.getHeaders())
// .body(json(clientOptions.jsonMapper, params.getBody()))
// .build()
// return clientOptions.httpClient.execute(request, requestOptions).let { response ->
// response
// .use { completionHandler.handle(it) }
// .apply {
// if (requestOptions.responseValidation ?: clientOptions.responseValidation) {
// validate()
// }
// }
// }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// File generated from our OpenAPI spec by Stainless.

package com.llama_stack_client.api.client.local

import com.llama_stack_client.api.client.LlamaStackClientClient
import com.llama_stack_client.api.client.LlamaStackClientClientAsync
import com.llama_stack_client.api.models.*
import com.llama_stack_client.api.services.blocking.*

class LlamaStackClientClientLocalImpl
constructor(
private val clientOptions: LocalClientOptions,
) : LlamaStackClientClient {

private val inference: InferenceService by lazy { InferenceServiceLocalImpl(clientOptions) }

override fun inference(): InferenceService = inference

override fun async(): LlamaStackClientClientAsync {
TODO("Not yet implemented")
}

override fun telemetry(): TelemetryService {
TODO("Not yet implemented")
}

override fun agents(): AgentService {
TODO("Not yet implemented")
}

override fun datasets(): DatasetService {
TODO("Not yet implemented")
}

override fun evaluate(): EvaluateService {
TODO("Not yet implemented")
}

override fun evaluations(): EvaluationService {
TODO("Not yet implemented")
}

override fun safety(): SafetyService {
TODO("Not yet implemented")
}

override fun memory(): MemoryService {
TODO("Not yet implemented")
}

override fun postTraining(): PostTrainingService {
TODO("Not yet implemented")
}

override fun rewardScoring(): RewardScoringService {
TODO("Not yet implemented")
}

override fun syntheticDataGeneration(): SyntheticDataGenerationService {
TODO("Not yet implemented")
}

override fun batchInference(): BatchInferenceService {
TODO("Not yet implemented")
}

override fun models(): ModelService {
TODO("Not yet implemented")
}

override fun memoryBanks(): MemoryBankService {
TODO("Not yet implemented")
}

override fun shields(): ShieldService {
TODO("Not yet implemented")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package com.llama_stack_client.api.client.local

import com.llama_stack_client.api.client.LlamaStackClientClient
import com.llama_stack_client.api.client.local.LlamaStackClientClientLocalImpl
import com.llama_stack_client.api.client.local.LocalClientOptions

class LlamaStackClientLocalClient private constructor() {

companion object {
fun builder() = Builder()
}

class Builder {

private var clientOptions: LocalClientOptions.Builder = LocalClientOptions.builder()

private var modelPath: String? = null
private var tokenizerPath: String? = null

fun modelPath(modelPath: String) = apply {
this.modelPath = modelPath
}

fun tokenizerPath(tokenizerPath: String) = apply {
this.tokenizerPath = tokenizerPath
}

fun fromEnv() = apply { clientOptions.fromEnv() }

fun build(): LlamaStackClientClient {


return LlamaStackClientClientLocalImpl(
clientOptions
.modelPath("MODEL_PATH")
.tokenizerPath("TOKENIZER_PATH")
.temperature(0.0F)
.build()
)

}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// File generated from our OpenAPI spec by Stainless.

package com.llama_stack_client.api.client.local

import org.checkerframework.checker.units.qual.Temperature
import org.pytorch.executorch.LlamaModule

class LocalClientOptions
private constructor(
val modelPath: String,
val tokenizerPath: String,
val temperature : Float,
val llamaModule : LlamaModule
) {

companion object {
fun builder() = Builder()
}

class Builder {
private var modelPath: String? = null
private var tokenizerPath: String? = null
private var temperature: Float = 0.0F
private var llamaModule : LlamaModule? = null

fun modelPath(modelPath: String) = apply {
this.modelPath = modelPath
}

fun tokenizerPath(tokenizerPath: String) = apply {
this.tokenizerPath = tokenizerPath
}

fun temperature(temperature: Float) = apply {
this.temperature = temperature
}

fun fromEnv() = apply {}

fun build(): LocalClientOptions {
checkNotNull(modelPath) { "`modelPath` is required but not set" }
checkNotNull(tokenizerPath) { "`tokenizerPath` is required but not set" }

this.llamaModule = LlamaModule(1, modelPath, tokenizerPath, temperature);
checkNotNull(llamaModule) {"`temperature` is required but not set"}

return LocalClientOptions(
modelPath!!,
tokenizerPath!!,
temperature!!,
llamaModule!!
)
}
}
}
1 change: 1 addition & 0 deletions settings.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ include("llama-stack-client-kotlin")
include("llama-stack-client-kotlin-client-okhttp")
include("llama-stack-client-kotlin-core")
include("llama-stack-client-kotlin-example")
include("llama-stack-client-kotlin-client-local")

0 comments on commit 9108f7c

Please sign in to comment.