diff --git a/aria/sample.py b/aria/sample.py index f39217b..397255c 100644 --- a/aria/sample.py +++ b/aria/sample.py @@ -9,6 +9,7 @@ from aria.inference import TransformerLM from aria.tokenizer import Tokenizer +from aria.data.midi import MidiDict torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.triton.unique_kernel_names = True @@ -193,8 +194,8 @@ def sample_top_p(probs, p): # TODO: Clean up a bit and get rid of footguns def get_inst_prompt( - tokenizer, - midi_dict, + tokenizer: Tokenizer, + midi_dict: MidiDict, truncate_len: int, noise: bool, ): @@ -217,12 +218,15 @@ def get_inst_prompt( print("No notes found in prompt region") prompt_seq = prompt_seq[: prompt_seq.index(tokenizer.bos_tok) + 1] + if tokenizer.dim_tok in prompt_seq: + prompt_seq.remove(tokenizer.dim_tok) + return prompt_seq def get_pt_prompt( - tokenizer, - midi_dict, + tokenizer: Tokenizer, + midi_dict: MidiDict, truncate_len: int, ): prompt_seq = tokenizer.tokenize(midi_dict=midi_dict) @@ -231,4 +235,7 @@ def get_pt_prompt( trunc_time_ms=truncate_len * 1e3, ) + if tokenizer.dim_tok in prompt_seq: + prompt_seq.remove(tokenizer.dim_tok) + return prompt_seq diff --git a/config/config.json b/config/config.json index 2d01e70..f94d753 100644 --- a/config/config.json +++ b/config/config.json @@ -2,33 +2,33 @@ "data": { "tests": { "max_programs":{ - "run": true, + "run": false, "args": { "max": 12 } }, "max_instruments":{ - "run": true, + "run": false, "args": { "max": 7 } }, "total_note_frequency":{ - "run": true, + "run": false, "args": { "min_per_second": 1.5, "max_per_second": 30 } }, "note_frequency_per_instrument":{ - "run": true, + "run": false, "args": { "min_per_second": 1.0, "max_per_second": 25 } }, "min_length":{ - "run": true, + "run": false, "args": { "min_seconds": 30 } @@ -78,7 +78,7 @@ } }, "maestro_json": { - "run": false, + "run": true, "args": { "composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"], "form_names": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"] @@ -91,7 +91,7 @@ } }, "abs_path": { - "run": false, + "run": true, "args": {} } }, @@ -226,7 +226,7 @@ }, "drum_velocity": 60, "velocity_quantization": { - "step": 15 + "step": 10 }, "abs_time_step_ms": 5000, "max_dur_ms": 5000,