Skip to content

Commit

Permalink
fixed bug where headers are written for every batch #2 #6
Browse files Browse the repository at this point in the history
  • Loading branch information
bbawj committed Aug 1, 2023
1 parent c40e240 commit fe89dae
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/file_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ impl FileProcessor {
Ok(())
}

pub async fn write_embedding_csv(&self, embeddings: Vec<EmbeddingRow>) -> Result<()> {
let mut wtr = csv::Writer::from_writer(vec![]);
pub async fn write_embedding_csv(&self, embeddings: Vec<EmbeddingRow>, with_header: bool) -> Result<()> {
let mut wtr = csv::WriterBuilder::new().has_headers(with_header).from_writer(vec![]);
for row in embeddings {
wtr.serialize(WrittenEmbeddingRow {
name: &row.name,
Expand Down
4 changes: 3 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ impl GenerateEmbeddingsCommand {
let num_records = modified_input.len();
debug!("Found {} records.", num_records);
let batch_size = (num_records as f64 / num_batches as f64).ceil() as usize;
let mut with_headers = true;

while num_processed < num_records {
let num_to_process = if batch == num_batches {
Expand Down Expand Up @@ -98,10 +99,11 @@ impl GenerateEmbeddingsCommand {
});

embedding_rows.append(&mut reusable_embeddings);
self.file_processor.write_embedding_csv(embedding_rows).await?;
self.file_processor.write_embedding_csv(embedding_rows, with_headers).await?;

num_processed += num_to_process;
batch += 1;
with_headers = false;
}

debug!("Saved embeddings to {}", EMBEDDING_FILE_PATH);
Expand Down

0 comments on commit fe89dae

Please sign in to comment.