Skip to content

Commit

Permalink
update config
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad committed Dec 4, 2024
1 parent 4094fda commit 52b14e7
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
15 changes: 11 additions & 4 deletions aria/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from aria.inference import TransformerLM
from aria.tokenizer import Tokenizer
from aria.data.midi import MidiDict

torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
Expand Down Expand Up @@ -193,8 +194,8 @@ def sample_top_p(probs, p):

# TODO: Clean up a bit and get rid of footguns
def get_inst_prompt(
tokenizer,
midi_dict,
tokenizer: Tokenizer,
midi_dict: MidiDict,
truncate_len: int,
noise: bool,
):
Expand All @@ -217,12 +218,15 @@ def get_inst_prompt(
print("No notes found in prompt region")
prompt_seq = prompt_seq[: prompt_seq.index(tokenizer.bos_tok) + 1]

if tokenizer.dim_tok in prompt_seq:
prompt_seq.remove(tokenizer.dim_tok)

return prompt_seq


def get_pt_prompt(
tokenizer,
midi_dict,
tokenizer: Tokenizer,
midi_dict: MidiDict,
truncate_len: int,
):
prompt_seq = tokenizer.tokenize(midi_dict=midi_dict)
Expand All @@ -231,4 +235,7 @@ def get_pt_prompt(
trunc_time_ms=truncate_len * 1e3,
)

if tokenizer.dim_tok in prompt_seq:
prompt_seq.remove(tokenizer.dim_tok)

return prompt_seq
16 changes: 8 additions & 8 deletions config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,33 @@
"data": {
"tests": {
"max_programs":{
"run": true,
"run": false,
"args": {
"max": 12
}
},
"max_instruments":{
"run": true,
"run": false,
"args": {
"max": 7
}
},
"total_note_frequency":{
"run": true,
"run": false,
"args": {
"min_per_second": 1.5,
"max_per_second": 30
}
},
"note_frequency_per_instrument":{
"run": true,
"run": false,
"args": {
"min_per_second": 1.0,
"max_per_second": 25
}
},
"min_length":{
"run": true,
"run": false,
"args": {
"min_seconds": 30
}
Expand Down Expand Up @@ -78,7 +78,7 @@
}
},
"maestro_json": {
"run": false,
"run": true,
"args": {
"composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"],
"form_names": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"]
Expand All @@ -91,7 +91,7 @@
}
},
"abs_path": {
"run": false,
"run": true,
"args": {}
}
},
Expand Down Expand Up @@ -226,7 +226,7 @@
},
"drum_velocity": 60,
"velocity_quantization": {
"step": 15
"step": 10
},
"abs_time_step_ms": 5000,
"max_dur_ms": 5000,
Expand Down

0 comments on commit 52b14e7

Please sign in to comment.