Skip to content

Commit

Permalink
Move main fun to kt file
Browse files Browse the repository at this point in the history
  • Loading branch information
madroidmaq committed Aug 10, 2023
1 parent ea35cf0 commit da65855
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 57 deletions.
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ tasks.register<Copy>("createJarIfNeeded") {

tasks.register<JavaExec>("executeMain") {
dependsOn("createJarIfNeeded")
mainClass = "Llama2"
mainClass = "Llama2Kt"
val checkPoint = project.findProperty("cp")?.toString()
args = listOf(checkPoint)
classpath = sourceSets.main.get().runtimeClasspath
Expand Down
105 changes: 49 additions & 56 deletions src/main/kotlin/Llama2.kt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import java.io.IOException
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.nio.FloatBuffer
Expand Down Expand Up @@ -443,13 +442,6 @@ class Llama2(private val checkpoint: String) : Model {
matmul(s.logits, s.x, w.wcls, 0, dim, p.vocabSize)
}

// ----------------------------------------------------------------------------
// utilities
private fun timeInMs(): Long {
// return time in milliseconds, for benchmarking the model speed
return System.nanoTime() / 1000000
}

private var rngSeed: Long = 0
private fun randomU32(): Int {
// xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
Expand Down Expand Up @@ -540,61 +532,62 @@ class Llama2(private val checkpoint: String) : Model {
pos++
}
}
}

companion object {

// ----------------------------------------------------------------------------
@Throws(IOException::class)
@JvmStatic
fun main(args: Array<String>) {
fun main(args: Array<String>) {

// poor man's C argparse
var checkPoint: String? = null // e.g. out/model.bin
var temperature = 0.0f // 0.9f; // e.g. 1.0, or 0.0
var steps = 256 // max number of steps to run for, 0: use seq_len
var prompt: String? = null // prompt string
// poor man's C argparse
var checkPoint: String? = null // e.g. out/model.bin
var temperature = 0.0f // 0.9f; // e.g. 1.0, or 0.0
var steps = 256 // max number of steps to run for, 0: use seq_len
var prompt: String? = null // prompt string

// 'checkpoint' is necessary arg
if (args.isEmpty()) {
println("Usage: java -jar Llama2.jar <checkpoint_file> [temperature] [steps] [prompt]\n")
System.exit(1)
}
if (args.isNotEmpty()) {
checkPoint = args[0]
}
if (args.size >= 2) {
// optional temperature. 0.0 = (deterministic) argmax sampling. 1.0 = baseline
temperature = args[1].toFloat()
}
if (args.size >= 3) {
steps = args[2].toInt()
}
if (args.size >= 4) {
prompt = args[3]
}
// 'checkpoint' is necessary arg
if (args.isEmpty()) {
println("Usage: java -jar Llama2.jar <checkpoint_file> [temperature] [steps] [prompt]\n")
System.exit(1)
}
if (args.isNotEmpty()) {
checkPoint = args[0]
}
if (args.size >= 2) {
// optional temperature. 0.0 = (deterministic) argmax sampling. 1.0 = baseline
temperature = args[1].toFloat()
}
if (args.size >= 3) {
steps = args[2].toInt()
}
if (args.size >= 4) {
prompt = args[3]
}

val model = Llama2(checkPoint!!)
val model = Llama2(checkPoint!!)


val tokenize = Tokenizer.from(model.config.vocabSize)
val tokenize = Tokenizer.from(model.config.vocabSize)

// process the prompt, if any
val promptTokens: IntArray? = if (prompt != null) {
tokenize.encode(prompt)
} else {
null
}
// process the prompt, if any
val promptTokens: IntArray? = if (prompt != null) {
tokenize.encode(prompt)
} else {
null
}

val start = model.timeInMs()
model.generate(promptTokens, steps, temperature) { next ->
// following BOS token (1), sentencepiece decoder strips any leading whitespace (see PR#89)
val tokenStr = tokenize.decode(next)
System.out.printf("%s", tokenStr)
System.out.flush()
}
// report achieved tok/s
val end = model.timeInMs()
System.out.printf("\nachieved tok/s: %f\n", (steps) / (end - start).toDouble() * 1000)
}
val start = timeInMs()
model.generate(promptTokens, steps, temperature) { next ->
// following BOS token (1), sentencepiece decoder strips any leading whitespace (see PR#89)
val tokenStr = tokenize.decode(next)
System.out.printf("%s", tokenStr)
System.out.flush()
}
// report achieved tok/s
val end = timeInMs()
System.out.printf("\nachieved tok/s: %f\n", (steps) / (end - start).toDouble() * 1000)
}

// ----------------------------------------------------------------------------
// utilities
private fun timeInMs(): Long {
// return time in milliseconds, for benchmarking the model speed
return System.nanoTime() / 1000000
}

0 comments on commit da65855

Please sign in to comment.