diff --git a/src/java/src/main/java/triton/client/InferResult.java b/src/java/src/main/java/triton/client/InferResult.java index 91712079d..2ec3249c7 100644 --- a/src/java/src/main/java/triton/client/InferResult.java +++ b/src/java/src/main/java/triton/client/InferResult.java @@ -269,6 +269,50 @@ public double[] getOutputAsDouble(String output) { return (double[])getOutputImpl(out, double.class, ByteBuffer::getDouble); } + public float[] getOutputFp16AsFloat(String output) { + IOTensor out = this.response.getOutputByName(output); + if (out == null) { + return null; + } + Preconditions.checkArgument(out.getDatatype() == DataType.FP16, + "Could not get double[] from data of type %s on output %s.", out.getDatatype(), out.getName()); + return (float[])getOutputImpl(out, float.class, this::fromFP16); + } + + /** + * Half-precision floating-point in Java + * Ref: https://stackoverflow.com/a/6162687 + * + * @param byteBuffer fp16 bytes + * @return float32 + */ + public float fromFP16(ByteBuffer byteBuffer) { + short hbits = byteBuffer.getShort(); + int mant = hbits & 0x03ff; // 10 bits mantissa + int exp = hbits & 0x7c00; // 5 bits exponent + if( exp == 0x7c00 ) // NaN/Inf + exp = 0x3fc00; // -> NaN/Inf + else if( exp != 0 ) // normalized value + { + exp += 0x1c000; // exp - 15 + 127 + if( mant == 0 && exp > 0x1c400 ) // smooth transition + return Float.intBitsToFloat( ( hbits & 0x8000 ) << 16 + | exp << 13 | 0x3ff ); + } + else if( mant != 0 ) // && exp==0 -> subnormal + { + exp = 0x1c400; // make it normal + do { + mant <<= 1; // mantissa * 2 + exp -= 0x400; // decrease exp by 1 + } while( ( mant & 0x400 ) == 0 ); // while not normal + mant &= 0x3ff; // discard subnormal bit + } // else +/-0 -> +/-0 + return Float.intBitsToFloat( // combine all parts + ( hbits & 0x8000 ) << 16 // sign << ( 31 - 15 ) + | ( exp | mant ) << 13 ); // value << ( 23 - 10 ) + } + private Object getOutputImpl(IOTensor out, Class clazz, Function getter) { Index idx = this.nameToBinaryIdx.get(out.getName()); if (idx != null) { // Output in binary format. diff --git a/src/java/src/main/java/triton/client/InferenceServerClient.java b/src/java/src/main/java/triton/client/InferenceServerClient.java index 0fd46af00..ff5376fce 100644 --- a/src/java/src/main/java/triton/client/InferenceServerClient.java +++ b/src/java/src/main/java/triton/client/InferenceServerClient.java @@ -226,7 +226,7 @@ private static CloseableHttpAsyncClient createHttpClient(HttpConfig httpConfig) } public void setRetryCnt(int retryCnt) { - Preconditions.checkArgument(retryCnt > 0, "Invalid retryCount: %s", retryCnt); + Preconditions.checkArgument(retryCnt >= 0, "Invalid retryCount: %s", retryCnt); this.retryCnt = retryCnt; }