Skip to content

Commit

Permalink
Refine returned model data types & add initial examples
Browse files Browse the repository at this point in the history
  • Loading branch information
tadayosi committed Sep 18, 2024
1 parent cc03c8d commit f7307ef
Show file tree
Hide file tree
Showing 18 changed files with 545 additions and 66 deletions.
28 changes: 28 additions & 0 deletions examples/mnist.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
///usr/bin/env jbang "$0" "$@" ; exit $?
//DEPS com.github.tadayosi.torchserve:torchserve-client:0.1-SNAPSHOT

import java.nio.file.Files;
import java.nio.file.Path;

import com.github.tadayosi.torchserve.client.impl.DefaultInference;
import com.github.tadayosi.torchserve.client.inference.invoker.ApiException;

public class mnist {

private static String MNIST_MODEL = "mnist_v2";

public static void main(String... args) throws Exception {
var zero = Files.readAllBytes(Path.of("src/test/resources/data/0.png"));
var one = Files.readAllBytes(Path.of("src/test/resources/data/1.png"));
try {
var inference = new DefaultInference();
var result0 = inference.predictions(MNIST_MODEL, zero);
System.out.println("Answer> " + result0);
var result1 = inference.predictions(MNIST_MODEL, one);
System.out.println("Answer> " + result1);
} catch (ApiException e) {
System.err.println(e.getResponseBody());
e.printStackTrace();
}
}
}
28 changes: 28 additions & 0 deletions examples/register_mnist.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
///usr/bin/env jbang "$0" "$@" ; exit $?
//DEPS com.github.tadayosi.torchserve:torchserve-client:0.1-SNAPSHOT

import com.github.tadayosi.torchserve.client.impl.DefaultManagement;
import com.github.tadayosi.torchserve.client.management.invoker.ApiException;
import com.github.tadayosi.torchserve.client.model.RegisterModelOptions;
import com.github.tadayosi.torchserve.client.model.SetAutoScaleOptions;

public class register_mnist {

private static String MNIST_URL = "https://torchserve.pytorch.org/mar_files/mnist_v2.mar";
private static String MNIST_MODEL = "mnist_v2";

public static void main(String... args) throws Exception {
try {
var management = new DefaultManagement();
var response = management.registerModel(MNIST_URL, RegisterModelOptions.empty());
System.out.println("registerModel> " + response.getStatus());
response = management.setAutoScale(MNIST_MODEL, SetAutoScaleOptions.builder()
.minWorker(1)
.maxWorker(1)
.build());
System.out.println("setAutoScale> " + response.getStatus());
} catch (ApiException e) {
System.err.println(e.getResponseBody());
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package com.github.tadayosi.torchserve.client;

import com.github.tadayosi.torchserve.client.model.API;
import com.github.tadayosi.torchserve.client.model.Response;

/**
* Inference API
*/
Expand All @@ -8,12 +11,12 @@ public interface Inference {
/**
* Get openapi description.
*/
Object apiDescription() throws Exception;
API apiDescription() throws Exception;

/**
* Get TorchServe status.
*/
Object ping() throws Exception;
Response ping() throws Exception;

/**
* Predictions entry point to get inference using default model version.
Expand Down
24 changes: 14 additions & 10 deletions src/main/java/com/github/tadayosi/torchserve/client/Management.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

import java.util.List;

import com.github.tadayosi.torchserve.client.model.API;
import com.github.tadayosi.torchserve.client.model.ModelDetail;
import com.github.tadayosi.torchserve.client.model.ModelList;
import com.github.tadayosi.torchserve.client.model.RegisterModelOptions;
import com.github.tadayosi.torchserve.client.model.Response;
import com.github.tadayosi.torchserve.client.model.SetAutoScaleOptions;
import com.github.tadayosi.torchserve.client.model.UnregisterModelOptions;

Expand All @@ -14,52 +18,52 @@ public interface Management {
/**
* Register a new model in TorchServe.
*/
Object registerModel(String url, RegisterModelOptions options) throws Exception;
Response registerModel(String url, RegisterModelOptions options) throws Exception;

/**
* Configure number of workers for a default version of a model. This is an asynchronous call by default. Caller need to call describeModel to check if the model workers has been changed.
*/
Object setAutoScale(String modelName, SetAutoScaleOptions options) throws Exception;
Response setAutoScale(String modelName, SetAutoScaleOptions options) throws Exception;

/**
* Configure number of workers for a specified version of a model. This is an asynchronous call by default. Caller need to call describeModel to check if the model workers has been changed.
*/
Object setAutoScale(String modelName, String modelVersion, SetAutoScaleOptions options) throws Exception;
Response setAutoScale(String modelName, String modelVersion, SetAutoScaleOptions options) throws Exception;

/**
* Provides detailed information about the default version of a model.
*/
List<Object> describeModel(String modelName) throws Exception;
List<ModelDetail> describeModel(String modelName) throws Exception;

/**
* Provides detailed information about the specified version of a model.If "all" is specified as version, returns the details about all the versions of the model.
*/
List<Object> describeModel(String modelName, String modelVersion) throws Exception;
List<ModelDetail> describeModel(String modelName, String modelVersion) throws Exception;

/**
* Unregister the default version of a model from TorchServe if it is the only version available. This is an asynchronous call by default. Caller can call listModels to confirm model is unregistered.
*/
Object unregisterModel(String modelName, UnregisterModelOptions options) throws Exception;
Response unregisterModel(String modelName, UnregisterModelOptions options) throws Exception;

/**
* Unregister the specified version of a model from TorchServe. This is an asynchronous call by default. Caller can call listModels to confirm model is unregistered.
*/
Object unregisterModel(String modelName, String modelVersion, UnregisterModelOptions options) throws Exception;
Response unregisterModel(String modelName, String modelVersion, UnregisterModelOptions options) throws Exception;

/**
* List registered models in TorchServe.
*/
Object listModels(Integer limit, String nextPageToken) throws Exception;
ModelList listModels(Integer limit, String nextPageToken) throws Exception;

/**
* Set default version of a model.
*/
Object setDefault(String modelName, String modelVersion) throws Exception;
Response setDefault(String modelName, String modelVersion) throws Exception;

/**
* Get openapi description.
*/
Object apiDescription() throws Exception;
API apiDescription() throws Exception;

/**
* Not supported yet.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ public interface Metrics {
/**
* Get TorchServe application metrics in prometheus format.
*/
Object metrics() throws Exception;
String metrics() throws Exception;

/**
* Get TorchServe application metrics in prometheus format.
*/
Object metrics(String name) throws Exception;
String metrics(String name) throws Exception;

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import com.github.tadayosi.torchserve.client.Inference;
import com.github.tadayosi.torchserve.client.inference.api.DefaultApi;
import com.github.tadayosi.torchserve.client.inference.invoker.ApiClient;
import com.github.tadayosi.torchserve.client.model.API;
import com.github.tadayosi.torchserve.client.model.Response;

public class DefaultInference implements Inference {

Expand All @@ -18,13 +20,13 @@ public DefaultInference(int port) {
}

@Override
public Object apiDescription() throws Exception {
return api.apiDescription();
public API apiDescription() throws Exception {
return API.from(api.apiDescription());
}

@Override
public Object ping() throws Exception {
return api.ping();
public Response ping() throws Exception {
return Response.from(api.ping());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
package com.github.tadayosi.torchserve.client.impl;

import java.util.List;
import java.util.Map;

import com.github.tadayosi.torchserve.client.Management;
import com.github.tadayosi.torchserve.client.management.api.DefaultApi;
import com.github.tadayosi.torchserve.client.management.invoker.ApiClient;
import com.github.tadayosi.torchserve.client.model.API;
import com.github.tadayosi.torchserve.client.model.Model;
import com.github.tadayosi.torchserve.client.model.ModelDetail;
import com.github.tadayosi.torchserve.client.model.ModelList;
import com.github.tadayosi.torchserve.client.model.RegisterModelOptions;
import com.github.tadayosi.torchserve.client.model.Response;
import com.github.tadayosi.torchserve.client.model.SetAutoScaleOptions;
import com.github.tadayosi.torchserve.client.model.UnregisterModelOptions;

Expand All @@ -23,8 +29,8 @@ public DefaultManagement(int port) {
}

@Override
public Object registerModel(String url, RegisterModelOptions options) throws Exception {
return api.registerModel(url, null,
public Response registerModel(String url, RegisterModelOptions options) throws Exception {
return Response.from(api.registerModel(url, null,
options.getModelName(),
options.getHandler(),
options.getRuntime(),
Expand All @@ -33,66 +39,67 @@ public Object registerModel(String url, RegisterModelOptions options) throws Exc
options.getResponseTimeout(),
options.getInitialWorkers(),
options.getSynchronous(),
options.getS3SseKms());
options.getS3SseKms()));
}

@Override
public Object setAutoScale(String modelName, SetAutoScaleOptions options) throws Exception {
return api.setAutoScale(modelName,
public Response setAutoScale(String modelName, SetAutoScaleOptions options) throws Exception {
return Response.from(api.setAutoScale(modelName,
options.getMinWorker(),
options.getMaxWorker(),
options.getNumberGpu(),
options.getSynchronous(),
options.getTimeout());
options.getTimeout()));
}

@Override
public Object setAutoScale(String modelName, String modelVersion, SetAutoScaleOptions options) throws Exception {
return api.versionSetAutoScale(modelName, modelVersion,
public Response setAutoScale(String modelName, String modelVersion, SetAutoScaleOptions options) throws Exception {
return Response.from(api.versionSetAutoScale(modelName, modelVersion,
options.getMinWorker(),
options.getMaxWorker(),
options.getNumberGpu(),
options.getSynchronous(),
options.getTimeout());
options.getTimeout()));
}

@Override
public List<Object> describeModel(String modelName) throws Exception {
return List.copyOf(api.describeModel(modelName));
public List<ModelDetail> describeModel(String modelName) throws Exception {
return ModelDetail.from(api.describeModel(modelName));
}

@Override
public List<Object> describeModel(String modelName, String modelVersion) throws Exception {
return List.copyOf(api.versionDescribeModel(modelName, modelVersion));
public List<ModelDetail> describeModel(String modelName, String modelVersion) throws Exception {
return ModelDetail.from(api.versionDescribeModel(modelName, modelVersion));
}

@Override
public Object unregisterModel(String modelName, UnregisterModelOptions options) throws Exception {
return api.unregisterModel(modelName,
public Response unregisterModel(String modelName, UnregisterModelOptions options) throws Exception {
return Response.from(api.unregisterModel(modelName,
options.getSynchronous(),
options.getTimeout());
options.getTimeout()));
}

@Override
public Object unregisterModel(String modelName, String modelVersion, UnregisterModelOptions options) throws Exception {
return api.versionUnregisterModel(modelName, modelVersion,
public Response unregisterModel(String modelName, String modelVersion, UnregisterModelOptions options)
throws Exception {
return Response.from(api.versionUnregisterModel(modelName, modelVersion,
options.getSynchronous(),
options.getTimeout());
options.getTimeout()));
}

@Override
public Object listModels(Integer limit, String nextPageToken) throws Exception {
return api.listModels(limit, nextPageToken);
public ModelList listModels(Integer limit, String nextPageToken) throws Exception {
return ModelList.from(api.listModels(limit, nextPageToken));
}

@Override
public Object setDefault(String modelName, String modelVersion) throws Exception {
return api.setDefault(modelName, modelVersion);
public Response setDefault(String modelName, String modelVersion) throws Exception {
return Response.from(api.setDefault(modelName, modelVersion));
}

@Override
public Object apiDescription() throws Exception {
return api.apiDescription();
public API apiDescription() throws Exception {
return API.from(api.apiDescription());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ public DefaultMetrics(int port) {
}

@Override
public Object metrics() throws Exception {
public String metrics() throws Exception {
return metrics(null);
}

@Override
public Object metrics(String name) throws Exception {
public String metrics(String name) throws Exception {
return api.metrics(name);
}

Expand Down
56 changes: 56 additions & 0 deletions src/main/java/com/github/tadayosi/torchserve/client/model/API.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package com.github.tadayosi.torchserve.client.model;

import java.util.HashMap;
import java.util.Map;

public class API {

private String openapi = null;
private Map<String, String> info = new HashMap<>();
private Map<String, Object> paths = new HashMap<>();

public API() {
}

@SuppressWarnings("unchecked")
public static API from(com.github.tadayosi.torchserve.client.inference.model.InlineResponse200 src) {
API api = new API();
api.setOpenapi(src.getOpenapi());
api.setInfo((Map<String, String>) src.getInfo());
api.setPaths((Map<String, Object>) src.getPaths());
return api;
}

@SuppressWarnings("unchecked")
public static API from(com.github.tadayosi.torchserve.client.management.model.InlineResponse200 src) {
API api = new API();
api.setOpenapi(src.getOpenapi());
api.setInfo((Map<String, String>) src.getInfo());
api.setPaths((Map<String, Object>) src.getPaths());
return api;
}

public String getOpenapi() {
return openapi;
}

public void setOpenapi(String openapi) {
this.openapi = openapi;
}

public Map<String, String> getInfo() {
return info;
}

public void setInfo(Map<String, String> info) {
this.info = info;
}

public Map<String, Object> getPaths() {
return paths;
}

public void setPaths(Map<String, Object> paths) {
this.paths = paths;
}
}
Loading

0 comments on commit f7307ef

Please sign in to comment.