diff --git a/examples/mnist.java b/examples/mnist.java new file mode 100755 index 0000000..bc66e7a --- /dev/null +++ b/examples/mnist.java @@ -0,0 +1,28 @@ +///usr/bin/env jbang "$0" "$@" ; exit $? +//DEPS com.github.tadayosi.torchserve:torchserve-client:0.1-SNAPSHOT + +import java.nio.file.Files; +import java.nio.file.Path; + +import com.github.tadayosi.torchserve.client.impl.DefaultInference; +import com.github.tadayosi.torchserve.client.inference.invoker.ApiException; + +public class mnist { + + private static String MNIST_MODEL = "mnist_v2"; + + public static void main(String... args) throws Exception { + var zero = Files.readAllBytes(Path.of("src/test/resources/data/0.png")); + var one = Files.readAllBytes(Path.of("src/test/resources/data/1.png")); + try { + var inference = new DefaultInference(); + var result0 = inference.predictions(MNIST_MODEL, zero); + System.out.println("Answer> " + result0); + var result1 = inference.predictions(MNIST_MODEL, one); + System.out.println("Answer> " + result1); + } catch (ApiException e) { + System.err.println(e.getResponseBody()); + e.printStackTrace(); + } + } +} diff --git a/examples/register_mnist.java b/examples/register_mnist.java new file mode 100755 index 0000000..6a2d932 --- /dev/null +++ b/examples/register_mnist.java @@ -0,0 +1,28 @@ +///usr/bin/env jbang "$0" "$@" ; exit $? +//DEPS com.github.tadayosi.torchserve:torchserve-client:0.1-SNAPSHOT + +import com.github.tadayosi.torchserve.client.impl.DefaultManagement; +import com.github.tadayosi.torchserve.client.management.invoker.ApiException; +import com.github.tadayosi.torchserve.client.model.RegisterModelOptions; +import com.github.tadayosi.torchserve.client.model.SetAutoScaleOptions; + +public class register_mnist { + + private static String MNIST_URL = "https://torchserve.pytorch.org/mar_files/mnist_v2.mar"; + private static String MNIST_MODEL = "mnist_v2"; + + public static void main(String... args) throws Exception { + try { + var management = new DefaultManagement(); + var response = management.registerModel(MNIST_URL, RegisterModelOptions.empty()); + System.out.println("registerModel> " + response.getStatus()); + response = management.setAutoScale(MNIST_MODEL, SetAutoScaleOptions.builder() + .minWorker(1) + .maxWorker(1) + .build()); + System.out.println("setAutoScale> " + response.getStatus()); + } catch (ApiException e) { + System.err.println(e.getResponseBody()); + } + } +} diff --git a/src/main/java/com/github/tadayosi/torchserve/client/Inference.java b/src/main/java/com/github/tadayosi/torchserve/client/Inference.java index 6436a92..33f03ee 100644 --- a/src/main/java/com/github/tadayosi/torchserve/client/Inference.java +++ b/src/main/java/com/github/tadayosi/torchserve/client/Inference.java @@ -1,5 +1,8 @@ package com.github.tadayosi.torchserve.client; +import com.github.tadayosi.torchserve.client.model.API; +import com.github.tadayosi.torchserve.client.model.Response; + /** * Inference API */ @@ -8,12 +11,12 @@ public interface Inference { /** * Get openapi description. */ - Object apiDescription() throws Exception; + API apiDescription() throws Exception; /** * Get TorchServe status. */ - Object ping() throws Exception; + Response ping() throws Exception; /** * Predictions entry point to get inference using default model version. diff --git a/src/main/java/com/github/tadayosi/torchserve/client/Management.java b/src/main/java/com/github/tadayosi/torchserve/client/Management.java index 64f3285..8fcb1a8 100644 --- a/src/main/java/com/github/tadayosi/torchserve/client/Management.java +++ b/src/main/java/com/github/tadayosi/torchserve/client/Management.java @@ -2,7 +2,11 @@ import java.util.List; +import com.github.tadayosi.torchserve.client.model.API; +import com.github.tadayosi.torchserve.client.model.ModelDetail; +import com.github.tadayosi.torchserve.client.model.ModelList; import com.github.tadayosi.torchserve.client.model.RegisterModelOptions; +import com.github.tadayosi.torchserve.client.model.Response; import com.github.tadayosi.torchserve.client.model.SetAutoScaleOptions; import com.github.tadayosi.torchserve.client.model.UnregisterModelOptions; @@ -14,52 +18,52 @@ public interface Management { /** * Register a new model in TorchServe. */ - Object registerModel(String url, RegisterModelOptions options) throws Exception; + Response registerModel(String url, RegisterModelOptions options) throws Exception; /** * Configure number of workers for a default version of a model. This is an asynchronous call by default. Caller need to call describeModel to check if the model workers has been changed. */ - Object setAutoScale(String modelName, SetAutoScaleOptions options) throws Exception; + Response setAutoScale(String modelName, SetAutoScaleOptions options) throws Exception; /** * Configure number of workers for a specified version of a model. This is an asynchronous call by default. Caller need to call describeModel to check if the model workers has been changed. */ - Object setAutoScale(String modelName, String modelVersion, SetAutoScaleOptions options) throws Exception; + Response setAutoScale(String modelName, String modelVersion, SetAutoScaleOptions options) throws Exception; /** * Provides detailed information about the default version of a model. */ - List describeModel(String modelName) throws Exception; + List describeModel(String modelName) throws Exception; /** * Provides detailed information about the specified version of a model.If "all" is specified as version, returns the details about all the versions of the model. */ - List describeModel(String modelName, String modelVersion) throws Exception; + List describeModel(String modelName, String modelVersion) throws Exception; /** * Unregister the default version of a model from TorchServe if it is the only version available. This is an asynchronous call by default. Caller can call listModels to confirm model is unregistered. */ - Object unregisterModel(String modelName, UnregisterModelOptions options) throws Exception; + Response unregisterModel(String modelName, UnregisterModelOptions options) throws Exception; /** * Unregister the specified version of a model from TorchServe. This is an asynchronous call by default. Caller can call listModels to confirm model is unregistered. */ - Object unregisterModel(String modelName, String modelVersion, UnregisterModelOptions options) throws Exception; + Response unregisterModel(String modelName, String modelVersion, UnregisterModelOptions options) throws Exception; /** * List registered models in TorchServe. */ - Object listModels(Integer limit, String nextPageToken) throws Exception; + ModelList listModels(Integer limit, String nextPageToken) throws Exception; /** * Set default version of a model. */ - Object setDefault(String modelName, String modelVersion) throws Exception; + Response setDefault(String modelName, String modelVersion) throws Exception; /** * Get openapi description. */ - Object apiDescription() throws Exception; + API apiDescription() throws Exception; /** * Not supported yet. diff --git a/src/main/java/com/github/tadayosi/torchserve/client/Metrics.java b/src/main/java/com/github/tadayosi/torchserve/client/Metrics.java index 20cf4d4..296e4f3 100644 --- a/src/main/java/com/github/tadayosi/torchserve/client/Metrics.java +++ b/src/main/java/com/github/tadayosi/torchserve/client/Metrics.java @@ -8,11 +8,11 @@ public interface Metrics { /** * Get TorchServe application metrics in prometheus format. */ - Object metrics() throws Exception; + String metrics() throws Exception; /** * Get TorchServe application metrics in prometheus format. */ - Object metrics(String name) throws Exception; + String metrics(String name) throws Exception; } diff --git a/src/main/java/com/github/tadayosi/torchserve/client/impl/DefaultInference.java b/src/main/java/com/github/tadayosi/torchserve/client/impl/DefaultInference.java index 9ec8b9e..69ecd94 100644 --- a/src/main/java/com/github/tadayosi/torchserve/client/impl/DefaultInference.java +++ b/src/main/java/com/github/tadayosi/torchserve/client/impl/DefaultInference.java @@ -3,6 +3,8 @@ import com.github.tadayosi.torchserve.client.Inference; import com.github.tadayosi.torchserve.client.inference.api.DefaultApi; import com.github.tadayosi.torchserve.client.inference.invoker.ApiClient; +import com.github.tadayosi.torchserve.client.model.API; +import com.github.tadayosi.torchserve.client.model.Response; public class DefaultInference implements Inference { @@ -18,13 +20,13 @@ public DefaultInference(int port) { } @Override - public Object apiDescription() throws Exception { - return api.apiDescription(); + public API apiDescription() throws Exception { + return API.from(api.apiDescription()); } @Override - public Object ping() throws Exception { - return api.ping(); + public Response ping() throws Exception { + return Response.from(api.ping()); } @Override diff --git a/src/main/java/com/github/tadayosi/torchserve/client/impl/DefaultManagement.java b/src/main/java/com/github/tadayosi/torchserve/client/impl/DefaultManagement.java index f296f68..560bc03 100644 --- a/src/main/java/com/github/tadayosi/torchserve/client/impl/DefaultManagement.java +++ b/src/main/java/com/github/tadayosi/torchserve/client/impl/DefaultManagement.java @@ -1,11 +1,17 @@ package com.github.tadayosi.torchserve.client.impl; import java.util.List; +import java.util.Map; import com.github.tadayosi.torchserve.client.Management; import com.github.tadayosi.torchserve.client.management.api.DefaultApi; import com.github.tadayosi.torchserve.client.management.invoker.ApiClient; +import com.github.tadayosi.torchserve.client.model.API; +import com.github.tadayosi.torchserve.client.model.Model; +import com.github.tadayosi.torchserve.client.model.ModelDetail; +import com.github.tadayosi.torchserve.client.model.ModelList; import com.github.tadayosi.torchserve.client.model.RegisterModelOptions; +import com.github.tadayosi.torchserve.client.model.Response; import com.github.tadayosi.torchserve.client.model.SetAutoScaleOptions; import com.github.tadayosi.torchserve.client.model.UnregisterModelOptions; @@ -23,8 +29,8 @@ public DefaultManagement(int port) { } @Override - public Object registerModel(String url, RegisterModelOptions options) throws Exception { - return api.registerModel(url, null, + public Response registerModel(String url, RegisterModelOptions options) throws Exception { + return Response.from(api.registerModel(url, null, options.getModelName(), options.getHandler(), options.getRuntime(), @@ -33,66 +39,67 @@ public Object registerModel(String url, RegisterModelOptions options) throws Exc options.getResponseTimeout(), options.getInitialWorkers(), options.getSynchronous(), - options.getS3SseKms()); + options.getS3SseKms())); } @Override - public Object setAutoScale(String modelName, SetAutoScaleOptions options) throws Exception { - return api.setAutoScale(modelName, + public Response setAutoScale(String modelName, SetAutoScaleOptions options) throws Exception { + return Response.from(api.setAutoScale(modelName, options.getMinWorker(), options.getMaxWorker(), options.getNumberGpu(), options.getSynchronous(), - options.getTimeout()); + options.getTimeout())); } @Override - public Object setAutoScale(String modelName, String modelVersion, SetAutoScaleOptions options) throws Exception { - return api.versionSetAutoScale(modelName, modelVersion, + public Response setAutoScale(String modelName, String modelVersion, SetAutoScaleOptions options) throws Exception { + return Response.from(api.versionSetAutoScale(modelName, modelVersion, options.getMinWorker(), options.getMaxWorker(), options.getNumberGpu(), options.getSynchronous(), - options.getTimeout()); + options.getTimeout())); } @Override - public List describeModel(String modelName) throws Exception { - return List.copyOf(api.describeModel(modelName)); + public List describeModel(String modelName) throws Exception { + return ModelDetail.from(api.describeModel(modelName)); } @Override - public List describeModel(String modelName, String modelVersion) throws Exception { - return List.copyOf(api.versionDescribeModel(modelName, modelVersion)); + public List describeModel(String modelName, String modelVersion) throws Exception { + return ModelDetail.from(api.versionDescribeModel(modelName, modelVersion)); } @Override - public Object unregisterModel(String modelName, UnregisterModelOptions options) throws Exception { - return api.unregisterModel(modelName, + public Response unregisterModel(String modelName, UnregisterModelOptions options) throws Exception { + return Response.from(api.unregisterModel(modelName, options.getSynchronous(), - options.getTimeout()); + options.getTimeout())); } @Override - public Object unregisterModel(String modelName, String modelVersion, UnregisterModelOptions options) throws Exception { - return api.versionUnregisterModel(modelName, modelVersion, + public Response unregisterModel(String modelName, String modelVersion, UnregisterModelOptions options) + throws Exception { + return Response.from(api.versionUnregisterModel(modelName, modelVersion, options.getSynchronous(), - options.getTimeout()); + options.getTimeout())); } @Override - public Object listModels(Integer limit, String nextPageToken) throws Exception { - return api.listModels(limit, nextPageToken); + public ModelList listModels(Integer limit, String nextPageToken) throws Exception { + return ModelList.from(api.listModels(limit, nextPageToken)); } @Override - public Object setDefault(String modelName, String modelVersion) throws Exception { - return api.setDefault(modelName, modelVersion); + public Response setDefault(String modelName, String modelVersion) throws Exception { + return Response.from(api.setDefault(modelName, modelVersion)); } @Override - public Object apiDescription() throws Exception { - return api.apiDescription(); + public API apiDescription() throws Exception { + return API.from(api.apiDescription()); } @Override diff --git a/src/main/java/com/github/tadayosi/torchserve/client/impl/DefaultMetrics.java b/src/main/java/com/github/tadayosi/torchserve/client/impl/DefaultMetrics.java index 4f211a9..050f821 100644 --- a/src/main/java/com/github/tadayosi/torchserve/client/impl/DefaultMetrics.java +++ b/src/main/java/com/github/tadayosi/torchserve/client/impl/DefaultMetrics.java @@ -18,12 +18,12 @@ public DefaultMetrics(int port) { } @Override - public Object metrics() throws Exception { + public String metrics() throws Exception { return metrics(null); } @Override - public Object metrics(String name) throws Exception { + public String metrics(String name) throws Exception { return api.metrics(name); } diff --git a/src/main/java/com/github/tadayosi/torchserve/client/model/API.java b/src/main/java/com/github/tadayosi/torchserve/client/model/API.java new file mode 100644 index 0000000..20b576e --- /dev/null +++ b/src/main/java/com/github/tadayosi/torchserve/client/model/API.java @@ -0,0 +1,56 @@ +package com.github.tadayosi.torchserve.client.model; + +import java.util.HashMap; +import java.util.Map; + +public class API { + + private String openapi = null; + private Map info = new HashMap<>(); + private Map paths = new HashMap<>(); + + public API() { + } + + @SuppressWarnings("unchecked") + public static API from(com.github.tadayosi.torchserve.client.inference.model.InlineResponse200 src) { + API api = new API(); + api.setOpenapi(src.getOpenapi()); + api.setInfo((Map) src.getInfo()); + api.setPaths((Map) src.getPaths()); + return api; + } + + @SuppressWarnings("unchecked") + public static API from(com.github.tadayosi.torchserve.client.management.model.InlineResponse200 src) { + API api = new API(); + api.setOpenapi(src.getOpenapi()); + api.setInfo((Map) src.getInfo()); + api.setPaths((Map) src.getPaths()); + return api; + } + + public String getOpenapi() { + return openapi; + } + + public void setOpenapi(String openapi) { + this.openapi = openapi; + } + + public Map getInfo() { + return info; + } + + public void setInfo(Map info) { + this.info = info; + } + + public Map getPaths() { + return paths; + } + + public void setPaths(Map paths) { + this.paths = paths; + } +} diff --git a/src/main/java/com/github/tadayosi/torchserve/client/model/JobQueueStatus.java b/src/main/java/com/github/tadayosi/torchserve/client/model/JobQueueStatus.java new file mode 100644 index 0000000..6160256 --- /dev/null +++ b/src/main/java/com/github/tadayosi/torchserve/client/model/JobQueueStatus.java @@ -0,0 +1,35 @@ +package com.github.tadayosi.torchserve.client.model; + +import com.github.tadayosi.torchserve.client.management.model.ModelsmodelNameJobQueueStatus; + +public class JobQueueStatus { + + private Integer remainingCapacity = null; + private Integer pendingRequests = null; + + public JobQueueStatus() { + } + + public static JobQueueStatus from(ModelsmodelNameJobQueueStatus src) { + JobQueueStatus status = new JobQueueStatus(); + status.setRemainingCapacity(src.getRemainingCapacity()); + status.setPendingRequests(src.getPendingRequests()); + return status; + } + + public Integer getRemainingCapacity() { + return remainingCapacity; + } + + public void setRemainingCapacity(Integer remainingCapacity) { + this.remainingCapacity = remainingCapacity; + } + + public Integer getPendingRequests() { + return pendingRequests; + } + + public void setPendingRequests(Integer pendingRequests) { + this.pendingRequests = pendingRequests; + } +} diff --git a/src/main/java/com/github/tadayosi/torchserve/client/model/Metrics.java b/src/main/java/com/github/tadayosi/torchserve/client/model/Metrics.java new file mode 100644 index 0000000..2a6f8cd --- /dev/null +++ b/src/main/java/com/github/tadayosi/torchserve/client/model/Metrics.java @@ -0,0 +1,14 @@ +package com.github.tadayosi.torchserve.client.model; + +import com.github.tadayosi.torchserve.client.management.model.ModelsmodelNameMetrics; + +public class Metrics { + + public Metrics() { + } + + public static Metrics from(ModelsmodelNameMetrics src) { + Metrics metrics = new Metrics(); + return metrics; + } +} diff --git a/src/main/java/com/github/tadayosi/torchserve/client/model/Model.java b/src/main/java/com/github/tadayosi/torchserve/client/model/Model.java new file mode 100644 index 0000000..795eb6e --- /dev/null +++ b/src/main/java/com/github/tadayosi/torchserve/client/model/Model.java @@ -0,0 +1,52 @@ +package com.github.tadayosi.torchserve.client.model; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class Model { + + private static final Logger LOG = LoggerFactory.getLogger(Model.class); + + private String modelName = null; + private String modelUrl = null; + + public Model() { + } + + public static Model fromMap(Object src) { + if (!(src instanceof Map)) { + LOG.error("Unexpected model data: {}", src); + return new Model(); + } + @SuppressWarnings("unchecked") + Map map = (Map) src; + Model model = new Model(); + model.setModelName(map.get("modelName")); + model.setModelUrl(map.get("modelUrl")); + return model; + } + + public static List fromMap(List src) { + return src.stream().map(Model::fromMap).collect(Collectors.toList()); + } + + public String getModelName() { + return modelName; + } + + public void setModelName(String modelName) { + this.modelName = modelName; + } + + public String getModelUrl() { + return modelUrl; + } + + public void setModelUrl(String modelUrl) { + this.modelUrl = modelUrl; + } +} diff --git a/src/main/java/com/github/tadayosi/torchserve/client/model/ModelDetail.java b/src/main/java/com/github/tadayosi/torchserve/client/model/ModelDetail.java new file mode 100644 index 0000000..b31659a --- /dev/null +++ b/src/main/java/com/github/tadayosi/torchserve/client/model/ModelDetail.java @@ -0,0 +1,113 @@ +package com.github.tadayosi.torchserve.client.model; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +public class ModelDetail extends Model { + + private String modelVersion = null; + private Integer minWorkers = null; + private Integer maxWorkers = null; + private Integer batchSize = null; + private Integer maxBatchDelay = null; + private String status = null; + private List workers = new ArrayList<>(); + private Metrics metrics = null; + private JobQueueStatus jobQueueStatus = null; + + public ModelDetail() { + } + + public static ModelDetail from(com.github.tadayosi.torchserve.client.management.model.InlineResponse2003 src) { + ModelDetail model = new ModelDetail(); + model.setModelName(src.getModelName()); + model.setModelVersion(src.getModelVersion()); + model.setModelUrl(src.getModelUrl()); + model.setMinWorkers(src.getMinWorkers()); + model.setMaxWorkers(src.getMaxWorkers()); + model.setBatchSize(src.getBatchSize()); + model.setMaxBatchDelay(src.getMaxBatchDelay()); + model.setStatus(src.getStatus()); + model.setWorkers(Worker.from(src.getWorkers())); + model.setMetrics(Metrics.from(src.getMetrics())); + model.setJobQueueStatus(JobQueueStatus.from(src.getJobQueueStatus())); + return model; + } + + public static List from(List src) { + return src.stream().map(ModelDetail::from).collect(Collectors.toList()); + } + + public String getModelVersion() { + return modelVersion; + } + + public void setModelVersion(String modelVersion) { + this.modelVersion = modelVersion; + } + + public Integer getMinWorkers() { + return minWorkers; + } + + public void setMinWorkers(Integer minWorkers) { + this.minWorkers = minWorkers; + } + + public Integer getMaxWorkers() { + return maxWorkers; + } + + public void setMaxWorkers(Integer maxWorkers) { + this.maxWorkers = maxWorkers; + } + + public Integer getBatchSize() { + return batchSize; + } + + public void setBatchSize(Integer batchSize) { + this.batchSize = batchSize; + } + + public Integer getMaxBatchDelay() { + return maxBatchDelay; + } + + public void setMaxBatchDelay(Integer maxBatchDelay) { + this.maxBatchDelay = maxBatchDelay; + } + + public String getStatus() { + return status; + } + + public void setStatus(String status) { + this.status = status; + } + + public List getWorkers() { + return workers; + } + + public void setWorkers(List workers) { + this.workers = workers; + } + + public Metrics getMetrics() { + return metrics; + } + + public void setMetrics(Metrics metrics) { + this.metrics = metrics; + } + + public JobQueueStatus getJobQueueStatus() { + return jobQueueStatus; + } + + public void setJobQueueStatus(JobQueueStatus jobQueueStatus) { + this.jobQueueStatus = jobQueueStatus; + } +} diff --git a/src/main/java/com/github/tadayosi/torchserve/client/model/ModelList.java b/src/main/java/com/github/tadayosi/torchserve/client/model/ModelList.java new file mode 100644 index 0000000..f37efc4 --- /dev/null +++ b/src/main/java/com/github/tadayosi/torchserve/client/model/ModelList.java @@ -0,0 +1,38 @@ +package com.github.tadayosi.torchserve.client.model; + +import java.util.ArrayList; +import java.util.List; + +import com.github.tadayosi.torchserve.client.management.model.InlineResponse2001; + +public class ModelList { + + private String nextPageToken = null; + private List models = new ArrayList<>(); + + public ModelList() { + } + + public static ModelList from(InlineResponse2001 inlineResponse2001) { + ModelList modelList = new ModelList(); + modelList.setNextPageToken(inlineResponse2001.getNextPageToken()); + modelList.setModels(Model.fromMap(inlineResponse2001.getModels())); + return modelList; + } + + public String getNextPageToken() { + return nextPageToken; + } + + public void setNextPageToken(String nextPageToken) { + this.nextPageToken = nextPageToken; + } + + public List getModels() { + return models; + } + + public void setModels(List models) { + this.models = models; + } +} diff --git a/src/main/java/com/github/tadayosi/torchserve/client/model/Response.java b/src/main/java/com/github/tadayosi/torchserve/client/model/Response.java new file mode 100644 index 0000000..20cf50b --- /dev/null +++ b/src/main/java/com/github/tadayosi/torchserve/client/model/Response.java @@ -0,0 +1,29 @@ +package com.github.tadayosi.torchserve.client.model; + +public class Response { + + private String status; + + public Response() { + } + + public static Response from(com.github.tadayosi.torchserve.client.inference.model.InlineResponse2001 src) { + Response response = new Response(); + response.setStatus(src.getStatus()); + return response; + } + + public static Response from(com.github.tadayosi.torchserve.client.management.model.InlineResponse2002 src) { + Response response = new Response(); + response.setStatus(src.getStatus()); + return response; + } + + public String getStatus() { + return status; + } + + public void setStatus(String status) { + this.status = status; + } +} diff --git a/src/main/java/com/github/tadayosi/torchserve/client/model/Worker.java b/src/main/java/com/github/tadayosi/torchserve/client/model/Worker.java new file mode 100644 index 0000000..e4674e8 --- /dev/null +++ b/src/main/java/com/github/tadayosi/torchserve/client/model/Worker.java @@ -0,0 +1,76 @@ +package com.github.tadayosi.torchserve.client.model; + +import java.util.List; +import java.util.stream.Collectors; + +import com.github.tadayosi.torchserve.client.management.model.ModelsmodelNameWorkers; + +public class Worker { + + private String id = null; + private String startTime = null; + private Boolean gpu = null; + private Status status = null; + + public Worker() { + } + + public static Worker from(ModelsmodelNameWorkers src) { + Worker worker = new Worker(); + worker.setId(src.getId()); + worker.setStartTime(src.getStartTime()); + worker.setGpu(src.isGpu()); + worker.setStatus(Status.from(src.getStatus())); + return worker; + } + + public static List from(List src) { + return src.stream().map(Worker::from).collect(Collectors.toList()); + } + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getStartTime() { + return startTime; + } + + public void setStartTime(String startTime) { + this.startTime = startTime; + } + + public Boolean getGpu() { + return gpu; + } + + public void setGpu(Boolean gpu) { + this.gpu = gpu; + } + + public Status getStatus() { + return status; + } + + public void setStatus(Status status) { + this.status = status; + } + + public enum Status { + READY, + LOADING, + UNLOADING; + + public static Status from(ModelsmodelNameWorkers.StatusEnum status) { + return switch (status) { + case READY -> READY; + case LOADING -> LOADING; + case UNLOADING -> UNLOADING; + }; + } + } +} diff --git a/src/test/java/com/github/tadayosi/torchserve/client/InferenceTest.java b/src/test/java/com/github/tadayosi/torchserve/client/InferenceTest.java index 14f32fd..9d49d19 100644 --- a/src/test/java/com/github/tadayosi/torchserve/client/InferenceTest.java +++ b/src/test/java/com/github/tadayosi/torchserve/client/InferenceTest.java @@ -5,7 +5,6 @@ import java.util.Map; import com.github.tadayosi.torchserve.client.impl.DefaultInference; -import com.github.tadayosi.torchserve.client.inference.model.InlineResponse2001; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.testcontainers.junit.jupiter.Testcontainers; @@ -19,6 +18,7 @@ public class InferenceTest extends TorchServeTestSupport { private static final String DEFAULT_MODEL = "squeezenet1_1"; + private static final String DEFAULT_MODEL_VERSION = "1.0"; private static final String TEST_DATA = "src/test/resources/data/kitten.jpg"; private Inference inference; @@ -36,7 +36,7 @@ public void testApiDescription() throws Exception { @Test public void testPing() throws Exception { - var response = (InlineResponse2001) inference.ping(); + var response = inference.ping(); assertEquals("Healthy", response.getStatus()); } @@ -49,9 +49,8 @@ public void testPredictions() throws Exception { @Test public void testPredictions_version() throws Exception { - var modelVersion = "1.0"; var body = Files.readAllBytes(Path.of(TEST_DATA)); - var response = inference.predictions(DEFAULT_MODEL, modelVersion, body); + var response = inference.predictions(DEFAULT_MODEL, DEFAULT_MODEL_VERSION, body); assertInstanceOf(Map.class, response); } diff --git a/src/test/java/com/github/tadayosi/torchserve/client/ManagementTest.java b/src/test/java/com/github/tadayosi/torchserve/client/ManagementTest.java index 4a4a808..dbe6933 100644 --- a/src/test/java/com/github/tadayosi/torchserve/client/ManagementTest.java +++ b/src/test/java/com/github/tadayosi/torchserve/client/ManagementTest.java @@ -2,15 +2,10 @@ import java.nio.file.Files; import java.nio.file.Path; -import java.util.Map; import com.github.tadayosi.torchserve.client.impl.DefaultInference; import com.github.tadayosi.torchserve.client.impl.DefaultManagement; import com.github.tadayosi.torchserve.client.management.invoker.ApiException; -import com.github.tadayosi.torchserve.client.management.model.InlineResponse200; -import com.github.tadayosi.torchserve.client.management.model.InlineResponse2001; -import com.github.tadayosi.torchserve.client.management.model.InlineResponse2002; -import com.github.tadayosi.torchserve.client.management.model.InlineResponse2003; import com.github.tadayosi.torchserve.client.model.RegisterModelOptions; import com.github.tadayosi.torchserve.client.model.SetAutoScaleOptions; import com.github.tadayosi.torchserve.client.model.UnregisterModelOptions; @@ -48,7 +43,7 @@ public void testRegisterModel() throws Exception { var url = "https://torchserve.pytorch.org/mar_files/mnist_v2.mar"; try { var response = management.registerModel(url, RegisterModelOptions.empty()); - assertTrue(((InlineResponse2002) response).getStatus().contains("registered")); + assertTrue(response.getStatus().contains("registered")); } catch (ApiException e) { e.printStackTrace(); fail(e.getResponseBody()); @@ -71,13 +66,13 @@ public void registerModel() throws Exception { @Test public void testUnregisterModel() throws Exception { var response = management.unregisterModel(ADDED_MODEL, UnregisterModelOptions.empty()); - assertTrue(((InlineResponse2002) response).getStatus().contains("unregistered")); + assertTrue(response.getStatus().contains("unregistered")); } @Test public void testUnregisterModel_version() throws Exception { var response = management.unregisterModel(ADDED_MODEL, ADDED_MODEL_VERSION, UnregisterModelOptions.empty()); - assertTrue(((InlineResponse2002) response).getStatus().contains("unregistered")); + assertTrue(response.getStatus().contains("unregistered")); } @Nested @@ -94,7 +89,7 @@ public void testSetAutoScale() throws Exception { SetAutoScaleOptions.builder() .minWorker(1) .build()); - assertTrue(((InlineResponse2002) response1).getStatus().contains("Processing worker updates")); + assertTrue(response1.getStatus().contains("Processing worker updates")); // Testing inference with MNIST V2 var inference = new DefaultInference(torchServe.getMappedPort(8080)); @@ -110,7 +105,7 @@ public void testSetAutoScale_version() throws Exception { SetAutoScaleOptions.builder() .minWorker(1) .build()); - assertTrue(((InlineResponse2002) response1).getStatus().contains("Processing worker updates")); + assertTrue(response1.getStatus().contains("Processing worker updates")); // Testing inference with MNIST V2 var inference = new DefaultInference(torchServe.getMappedPort(8080)); @@ -126,15 +121,15 @@ public void testSetAutoScale_version() throws Exception { public void testDescribeModel() throws Exception { var response = management.describeModel(DEFAULT_MODEL); assertEquals(1, response.size()); - assertEquals("squeezenet1_1", ((InlineResponse2003) response.get(0)).getModelName()); + assertEquals("squeezenet1_1", response.get(0).getModelName()); } @Test public void testDescribeModel_version() throws Exception { var response = management.describeModel(DEFAULT_MODEL, DEFAULT_MODEL_VERSION); assertEquals(1, response.size()); - assertEquals("squeezenet1_1", ((InlineResponse2003) response.get(0)).getModelName()); - assertEquals("1.0", ((InlineResponse2003) response.get(0)).getModelVersion()); + assertEquals("squeezenet1_1", response.get(0).getModelName()); + assertEquals("1.0", response.get(0).getModelVersion()); } @Test @@ -142,21 +137,21 @@ public void testListModels() throws Exception { int limit = 10; String nextPageToken = null; var response = management.listModels(limit, nextPageToken); - var models = ((InlineResponse2001) response).getModels(); + var models = response.getModels(); assertFalse(models.isEmpty()); - assertEquals(DEFAULT_MODEL, ((Map) models.get(0)).get("modelName")); + assertEquals(DEFAULT_MODEL, models.get(0).getModelName()); } @Test public void testSetDefault() throws Exception { var response = management.setDefault(DEFAULT_MODEL, DEFAULT_MODEL_VERSION); - assertTrue(((InlineResponse2002) response).getStatus().contains("Default vesion succsesfully updated")); + assertTrue(response.getStatus().contains("Default vesion succsesfully updated")); } @Test public void testApiDescription() throws Exception { var response = management.apiDescription(); - assertEquals("TorchServe APIs", ((Map) ((InlineResponse200) response).getInfo()).get("title")); + assertEquals("TorchServe APIs", response.getInfo().get("title")); } @Test