Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Iris: Track token usage of iris requests #9455

Open
wants to merge 20 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
a5dfdbb
Define LLM token usage model
alexjoham Sep 28, 2024
d0bdae3
Update table, save data recieved from Pyris Exercise chat pipeline
alexjoham Oct 11, 2024
2a08cb2
Implement competency generation tracking, update enum
alexjoham Oct 11, 2024
f85cf46
Add comments to LLMTokenUsageService
alexjoham Oct 11, 2024
65fb259
Fix server test failures by checking if tokens received
alexjoham Oct 12, 2024
188ff22
Update database for cost tracking and trace_id functionality
alexjoham Oct 12, 2024
be85a3b
Update database, add information to competency gen, change traceId calc
alexjoham Oct 12, 2024
e974d59
Implement server Integration tests for token tracking and saving
alexjoham Oct 13, 2024
6337162
Update code based on code-rabbit feedback, fix tests
alexjoham Oct 14, 2024
84a60dc
minor comment changes, remove tokens from frontend
alexjoham Oct 14, 2024
5b0ab48
Merge branch 'develop' into feature/track-usage-of-iris-requests
alexjoham Oct 14, 2024
62dad8b
Fix github test fails
alexjoham Oct 14, 2024
897d643
Change servicetype to type String to prevent failures
alexjoham Oct 14, 2024
1d10860
Change servicetype to type String to prevent failures
alexjoham Oct 14, 2024
8b27861
Merge remote-tracking branch 'origin/feature/track-usage-of-iris-requ…
alexjoham Oct 14, 2024
86294c1
Fix test failure by removing @SpyBean
alexjoham Oct 15, 2024
56b20e7
Update database to safe only IDs, fix competency Integration Test user
alexjoham Oct 15, 2024
8a29c82
Implement builder pattern based on feedback
alexjoham Oct 16, 2024
abbd28f
Update database migration with foreign keys and on delete null
alexjoham Oct 16, 2024
8d34428
Rework database, update saveLLMTokens method
alexjoham Oct 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package de.tum.cit.aet.artemis.core.domain;

/**
* Enum representing different types of LLM (Large Language Model) services used in the system.
*/
Comment on lines +3 to +5
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick (assertive)

Enhance JavaDoc with details about enum constants.

The JavaDoc comment provides a good overview of the enum's purpose. However, consider adding more details about the enum constants to improve documentation.

Here's a suggested enhancement:

/**
 * Enum representing different types of LLM (Large Language Model) services used in the system.
 * <p>
 * IRIS: Represents the IRIS LLM service.
 * ATHENA: Represents the ATHENA LLM service.
 */

public enum LLMServiceType {
/** Athena service for preliminary feedback */
ATHENA_PRELIMINARY_FEEDBACK,
/** Athena service for feedback suggestions */
ATHENA_FEEDBACK_SUGGESTION,
/** Iris service for code feedback */
IRIS_CODE_FEEDBACK,
/** Iris service for course chat messages */
IRIS_CHAT_COURSE_MESSAGE,
/** Iris service for exercise chat messages */
IRIS_CHAT_EXERCISE_MESSAGE,
/** Iris service for interaction suggestions */
IRIS_INTERACTION_SUGGESTION,
/** Iris service for lecture chat messages */
IRIS_CHAT_LECTURE_MESSAGE,
/** Iris service for competency generation */
IRIS_COMPETENCY_GENERATION,
/** Iris service for citation pipeline */
IRIS_CITATION_PIPELINE,
/** Iris service for lecture retrieval pipeline */
IRIS_LECTURE_RETRIEVAL_PIPELINE,
/** Iris service for lecture ingestion */
IRIS_LECTURE_INGESTION,
/** Default value when the service type is not set */
NOT_SET
}
177 changes: 177 additions & 0 deletions src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package de.tum.cit.aet.artemis.core.domain;

import java.time.ZonedDateTime;

import jakarta.annotation.Nullable;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.EnumType;
import jakarta.persistence.Enumerated;
import jakarta.persistence.JoinColumn;
import jakarta.persistence.ManyToOne;
import jakarta.persistence.Table;

import org.hibernate.annotations.Cache;
import org.hibernate.annotations.CacheConcurrencyStrategy;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;

import de.tum.cit.aet.artemis.exercise.domain.Exercise;
import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage;

@Entity
@Table(name = "llm_token_usage")
@Cache(usage = CacheConcurrencyStrategy.NONSTRICT_READ_WRITE)
@JsonInclude(JsonInclude.Include.NON_EMPTY)
public class LLMTokenUsage extends DomainObject {
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick (assertive)

Consider adding JavaDoc comments for the class

Adding a class-level JavaDoc comment can improve code readability and help other developers understand the purpose and usage of this entity.


@Column(name = "service")
@Enumerated(EnumType.STRING)
private LLMServiceType serviceType;

@Column(name = "model")
private String model;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider Using an Enum for model Field

Similar to serviceType, the model field could benefit from being an enum to ensure only valid model names are used.

Apply these changes:

  1. Define an enum for LLMModel:

    package de.tum.cit.aet.artemis.core.domain;
    
    public enum LLMModel {
        GPT_3_5_TURBO,
        GPT_4,
        // Add other models as needed
    }
  2. Update the LLMTokenUsage class:

    + import jakarta.persistence.EnumType;
    + import jakarta.persistence.Enumerated;
    
      @Column(name = "model")
    + @Enumerated(EnumType.STRING)
    - private String model;
    + private LLMModel model;
  3. Update the getter and setter methods:

    - public String getModel() {
    + public LLMModel getModel() {
          return model;
      }
    
    - public void setModel(String model) {
    + public void setModel(LLMModel model) {
          this.model = model;
      }


@Column(name = "num_input_tokens")
private int numInputTokens;

@Column(name = "cost_per_million_input_tokens")
private float costPerMillionInputTokens;

@Column(name = "num_output_tokens")
private int numOutputTokens;

@Column(name = "cost_per_million_output_tokens")
private float costPerMillionOutputTokens;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Use BigDecimal Instead of float for Monetary Values

Using float for monetary values can lead to precision errors due to floating-point arithmetic limitations. Consider using BigDecimal to accurately represent monetary values for costPerMillionInputTokens and costPerMillionOutputTokens.

Apply these changes:

  1. Update the field declarations:

    - private float costPerMillionInputTokens;
    + private BigDecimal costPerMillionInputTokens;
    
    - private float costPerMillionOutputTokens;
    + private BigDecimal costPerMillionOutputTokens;
  2. Update the getter and setter methods:

    - public float getCostPerMillionInputTokens() {
    + public BigDecimal getCostPerMillionInputTokens() {
          return costPerMillionInputTokens;
      }
    
    - public void setCostPerMillionInputTokens(float costPerMillionInputTokens) {
    + public void setCostPerMillionInputTokens(BigDecimal costPerMillionInputTokens) {
          this.costPerMillionInputTokens = costPerMillionInputTokens;
      }
    
    - public float getCostPerMillionOutputTokens() {
    + public BigDecimal getCostPerMillionOutputTokens() {
          return costPerMillionOutputTokens;
      }
    
    - public void setCostPerMillionOutputTokens(float costPerMillionOutputTokens) {
    + public void setCostPerMillionOutputTokens(BigDecimal costPerMillionOutputTokens) {
          this.costPerMillionOutputTokens = costPerMillionOutputTokens;
      }
  3. Update any code that initializes or manipulates these fields to use BigDecimal.

Committable suggestion was skipped due to low confidence.


@Nullable
@ManyToOne
@JsonIgnore
@JoinColumn(name = "course_id")
private Course course;

@Nullable
@ManyToOne
@JsonIgnore
@JoinColumn(name = "exercise_id")
private Exercise exercise;

@Column(name = "user_id")
private long userId;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider Mapping userId as a @ManyToOne Relationship

Mapping the userId field as a @ManyToOne relationship to the User entity can enhance data integrity and simplify user-related queries.

Apply these changes:

  1. Update the field declaration:

    - @Column(name = "user_id")
    - private long userId;
    + @ManyToOne
    + @JoinColumn(name = "user_id")
    + private User user;
  2. Update the getter and setter methods:

    - public long getUserId() {
    -     return userId;
    - }
    
    - public void setUserId(long userId) {
    -     this.userId = userId;
    - }
    
    + public User getUser() {
    +     return user;
    + }
    
    + public void setUser(User user) {
    +     this.user = user;
    + }
  3. Ensure that you import the necessary classes:

    import jakarta.persistence.JoinColumn;
    import jakarta.persistence.ManyToOne;
    import de.tum.cit.aet.artemis.core.domain.User;


@Column(name = "time")
private ZonedDateTime time = ZonedDateTime.now();
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid using reserved SQL keywords for column names

The column name time may conflict with reserved keywords in some SQL databases. To prevent potential issues, consider renaming the column to a non-reserved word like timestamp or escaping it using quotes in the @Column annotation.

Apply this change:

-    @Column(name = "time")
+    @Column(name = "\"time\"")
     private ZonedDateTime time = ZonedDateTime.now();

Or rename the column:

-    @Column(name = "time")
+    @Column(name = "timestamp")
     private ZonedDateTime time = ZonedDateTime.now();

Ensure that any database scripts or queries are updated to reflect this change.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@Column(name = "time")
private ZonedDateTime time = ZonedDateTime.now();
@Column(name = "\"time\"")
private ZonedDateTime time = ZonedDateTime.now();


@Column(name = "trace_id")
private String traceId;

@Nullable
@ManyToOne
@JsonIgnore
@JoinColumn(name = "iris_message_id")
private IrisMessage irisMessage;

public LLMServiceType getServiceType() {
return serviceType;
}

public void setServiceType(LLMServiceType serviceType) {
this.serviceType = serviceType;
}

public String getModel() {
return model;
}

public void setModel(String model) {
this.model = model;
}

public float getCostPerMillionInputTokens() {
return costPerMillionInputTokens;
}

public void setCostPerMillionInputTokens(float costPerMillionInputToken) {
this.costPerMillionInputTokens = costPerMillionInputToken;
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick (assertive)

Synchronize Parameter Names in Setter Methods

The parameter name in setCostPerMillionInputTokens method is singular (costPerMillionInputToken), whereas the field name is plural (costPerMillionInputTokens). For consistency, the parameter name should match the field name.

Apply this diff:

- public void setCostPerMillionInputTokens(float costPerMillionInputToken) {
-     this.costPerMillionInputTokens = costPerMillionInputToken;
+ public void setCostPerMillionInputTokens(float costPerMillionInputTokens) {
+     this.costPerMillionInputTokens = costPerMillionInputTokens;
 }

Similarly, update the setCostPerMillionOutputTokens method:

- public void setCostPerMillionOutputTokens(float costPerMillionOutputToken) {
-     this.costPerMillionOutputTokens = costPerMillionOutputToken;
+ public void setCostPerMillionOutputTokens(float costPerMillionOutputTokens) {
+     this.costPerMillionOutputTokens = costPerMillionOutputTokens;
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
public void setCostPerMillionInputTokens(float costPerMillionInputToken) {
this.costPerMillionInputTokens = costPerMillionInputToken;
}
public void setCostPerMillionInputTokens(float costPerMillionInputTokens) {
this.costPerMillionInputTokens = costPerMillionInputTokens;
}


public float getCostPerMillionOutputTokens() {
return costPerMillionOutputTokens;
}

public void setCostPerMillionOutputTokens(float costPerMillionOutputToken) {
this.costPerMillionOutputTokens = costPerMillionOutputToken;
}

public int getNumInputTokens() {
return numInputTokens;
}

public void setNumInputTokens(int numInputTokens) {
this.numInputTokens = numInputTokens;
}

public int getNumOutputTokens() {
return numOutputTokens;
}

public void setNumOutputTokens(int numOutputTokens) {
this.numOutputTokens = numOutputTokens;
}

public Course getCourse() {
return course;
}

public void setCourse(Course course) {
this.course = course;
}

public Exercise getExercise() {
return exercise;
}

public void setExercise(Exercise exercise) {
this.exercise = exercise;
}

public long getUserId() {
return userId;
}

public void setUserId(long userId) {
this.userId = userId;
}

public ZonedDateTime getTime() {
return time;
}

public void setTime(ZonedDateTime time) {
this.time = time;
}

public String getTraceId() {
return traceId;
}

public void setTraceId(String traceId) {
this.traceId = traceId;
}

public IrisMessage getIrisMessage() {
return irisMessage;
}

public void setIrisMessage(IrisMessage message) {
this.irisMessage = message;
}

@Override
public String toString() {
return "LLMTokenUsage{" + "serviceType=" + serviceType + ", model=" + model + ", numInputTokens=" + numInputTokens + ", costPerMillionInputTokens="
+ costPerMillionInputTokens + ", numOutputTokens=" + numOutputTokens + ", costPerMillionOutputTokens=" + costPerMillionOutputTokens + ", course=" + course
+ ", exercise=" + exercise + ", userId=" + userId + ", timestamp=" + time + ", traceId=" + traceId + ", irisMessage=" + irisMessage + '}';
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick (assertive)

Inconsistent field label in toString method

In the toString method, the field time is labeled as timestamp, which is inconsistent with the field's name. For clarity and consistency, consider updating the label to time to match the field.

Apply this diff to correct the label:

-                + ", exercise=" + exercise + ", userId=" + userId + ", timestamp=" + time + ", traceId=" + traceId + ", irisMessage=" + irisMessage + '}';
+                + ", exercise=" + exercise + ", userId=" + userId + ", time=" + time + ", traceId=" + traceId + ", irisMessage=" + irisMessage + '}';
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@Override
public String toString() {
return "LLMTokenUsage{" + "serviceType=" + serviceType + ", model=" + model + ", numInputTokens=" + numInputTokens + ", costPerMillionInputTokens="
+ costPerMillionInputTokens + ", numOutputTokens=" + numOutputTokens + ", costPerMillionOutputTokens=" + costPerMillionOutputTokens + ", course=" + course
+ ", exercise=" + exercise + ", userId=" + userId + ", timestamp=" + time + ", traceId=" + traceId + ", irisMessage=" + irisMessage + '}';
}
@Override
public String toString() {
return "LLMTokenUsage{" + "serviceType=" + serviceType + ", model=" + model + ", numInputTokens=" + numInputTokens + ", costPerMillionInputTokens="
+ costPerMillionInputTokens + ", numOutputTokens=" + numOutputTokens + ", costPerMillionOutputTokens=" + costPerMillionOutputTokens + ", course=" + course
+ ", exercise=" + exercise + ", userId=" + userId + ", time=" + time + ", traceId=" + traceId + ", irisMessage=" + irisMessage + '}';
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package de.tum.cit.aet.artemis.core.repository;

import static de.tum.cit.aet.artemis.core.config.Constants.PROFILE_IRIS;

import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Repository;

import de.tum.cit.aet.artemis.core.domain.LLMTokenUsage;
import de.tum.cit.aet.artemis.core.repository.base.ArtemisJpaRepository;

@Repository
@Profile(PROFILE_IRIS)
public interface LLMTokenUsageRepository extends ArtemisJpaRepository<LLMTokenUsage, Long> {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package de.tum.cit.aet.artemis.core.service;

import static de.tum.cit.aet.artemis.core.config.Constants.PROFILE_IRIS;

import java.util.ArrayList;
import java.util.List;

import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Service;

import de.tum.cit.aet.artemis.core.domain.Course;
import de.tum.cit.aet.artemis.core.domain.LLMTokenUsage;
import de.tum.cit.aet.artemis.core.domain.User;
import de.tum.cit.aet.artemis.core.repository.LLMTokenUsageRepository;
import de.tum.cit.aet.artemis.exercise.domain.Exercise;
import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage;
import de.tum.cit.aet.artemis.iris.service.pyris.dto.data.PyrisLLMCostDTO;
import de.tum.cit.aet.artemis.iris.service.pyris.job.PyrisJob;

/**
* Service for managing the LLMTokenUsage by all LLMs in Artemis
*/
@Service
@Profile(PROFILE_IRIS)
public class LLMTokenUsageService {

private final LLMTokenUsageRepository llmTokenUsageRepository;

public LLMTokenUsageService(LLMTokenUsageRepository llmTokenUsageRepository) {
this.llmTokenUsageRepository = llmTokenUsageRepository;
}

/**
* method saves the token usage to the database with a link to the IrisMessage
* messages of the same job are grouped together by saving the job id as a trace id
*
* @param job used to create a unique traceId to group multiple LLM calls
* @param message IrisMessage to map the usage to an IrisMessage
* @param exercise to map the token cost to an exercise
* @param user to map the token cost to a user
* @param course to map the token to a course
* @param tokens token cost list of type PyrisLLMCostDTO
* @return list of the saved data
*/
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick (assertive)

Enhance JavaDoc comments for clarity and consistency

While the JavaDoc provides useful information, improving the grammar and formatting would enhance readability and maintainability. Starting parameter descriptions with uppercase letters and ending sentences with periods can make the documentation more professional.

Apply this diff to improve the JavaDoc comments:

 /**
- * method saves the token usage to the database with a link to the IrisMessage
- * messages of the same job are grouped together by saving the job id as a trace id
+ * Saves the token usage to the database with a link to the IrisMessage.
+ * Messages of the same job are grouped together by saving the job ID as a trace ID.
  *
- * @param job      used to create a unique traceId to group multiple LLM calls
- * @param message  IrisMessage to map the usage to an IrisMessage
- * @param exercise to map the token cost to an exercise
- * @param user     to map the token cost to a user
- * @param course   to map the token to a course
- * @param tokens   token cost list of type PyrisLLMCostDTO
+ * @param job      Used to create a unique traceId to group multiple LLM calls.
+ * @param message  IrisMessage to map the usage to an IrisMessage.
+ * @param exercise To map the token cost to an Exercise.
+ * @param user     To map the token cost to a User.
+ * @param course   To map the token to a Course.
+ * @param tokens   Token cost list of type PyrisLLMCostDTO.
  * @return List of the saved data.
  */
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
* method saves the token usage to the database with a link to the IrisMessage
* messages of the same job are grouped together by saving the job id as a trace id
*
* @param job used to create a unique traceId to group multiple LLM calls
* @param message IrisMessage to map the usage to an IrisMessage
* @param exercise to map the token cost to an exercise
* @param user to map the token cost to a user
* @param course to map the token to a course
* @param tokens token cost list of type PyrisLLMCostDTO
* @return list of the saved data
*/
/**
* Saves the token usage to the database with a link to the IrisMessage.
* Messages of the same job are grouped together by saving the job ID as a trace ID.
*
* @param job Used to create a unique traceId to group multiple LLM calls.
* @param message IrisMessage to map the usage to an IrisMessage.
* @param exercise To map the token cost to an Exercise.
* @param user To map the token cost to a User.
* @param course To map the token to a Course.
* @param tokens Token cost list of type PyrisLLMCostDTO.
* @return List of the saved data.
*/

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick (assertive)

Improve JavaDoc comments for clarity and adherence to standards

The JavaDoc comments can be enhanced for better readability and to conform with JavaDoc standards:

  • Start descriptions with a capital letter and end with a period.
  • Use complete sentences for parameter descriptions.
  • Ensure consistency and correctness in the documentation.

Apply this diff to improve the JavaDoc:

 /**
- * method saves the token usage to the database with a link to the IrisMessage
- * messages of the same job are grouped together by saving the job id as a trace id
+ * Saves the token usage to the database with a link to the IrisMessage.
+ * Messages of the same job are grouped together by saving the job ID as a trace ID.
  *
- * @param job      used to create a unique traceId to group multiple LLM calls
- * @param message  IrisMessage to map the usage to an IrisMessage
- * @param exercise to map the token cost to an exercise
- * @param user     to map the token cost to a user
- * @param course   to map the token to a course
- * @param tokens   token cost list of type PyrisLLMCostDTO
- * @return list of the saved data
+ * @param job      Used to create a unique trace ID to group multiple LLM calls.
+ * @param message  IrisMessage to map the usage to an IrisMessage.
+ * @param exercise To map the token cost to an Exercise.
+ * @param user     To map the token cost to a User.
+ * @param course   To map the token to a Course.
+ * @param tokens   Token cost list of type PyrisLLMCostDTO.
+ * @return List of the saved data.
  */
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
/**
* Service for managing the LLMTokenUsage by all LLMs in Artemis
*/
@Service
@Profile(PROFILE_IRIS)
public class LLMTokenUsageService {
private final LLMTokenUsageRepository llmTokenUsageRepository;
public LLMTokenUsageService(LLMTokenUsageRepository llmTokenUsageRepository) {
this.llmTokenUsageRepository = llmTokenUsageRepository;
}
/**
* method saves the token usage to the database with a link to the IrisMessage
* messages of the same job are grouped together by saving the job id as a trace id
*
* @param job used to create a unique traceId to group multiple LLM calls
* @param message IrisMessage to map the usage to an IrisMessage
* @param exercise to map the token cost to an exercise
* @param user to map the token cost to a user
* @param course to map the token to a course
* @param tokens token cost list of type PyrisLLMCostDTO
* @return list of the saved data
*/
/**
* Service for managing the LLMTokenUsage by all LLMs in Artemis
*/
@Service
@Profile(PROFILE_IRIS)
public class LLMTokenUsageService {
private final LLMTokenUsageRepository llmTokenUsageRepository;
public LLMTokenUsageService(LLMTokenUsageRepository llmTokenUsageRepository) {
this.llmTokenUsageRepository = llmTokenUsageRepository;
}
/**
* Saves the token usage to the database with a link to the IrisMessage.
* Messages of the same job are grouped together by saving the job ID as a trace ID.
*
* @param job Used to create a unique trace ID to group multiple LLM calls.
* @param message IrisMessage to map the usage to an IrisMessage.
* @param exercise To map the token cost to an Exercise.
* @param user To map the token cost to a User.
* @param course To map the token to a Course.
* @param tokens Token cost list of type PyrisLLMCostDTO.
* @return List of the saved data.
*/

public List<LLMTokenUsage> saveIrisTokenUsage(PyrisJob job, IrisMessage message, Exercise exercise, User user, Course course, List<PyrisLLMCostDTO> tokens) {
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
List<LLMTokenUsage> tokenUsages = new ArrayList<>();

for (PyrisLLMCostDTO cost : tokens) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add null check for tokens parameter to prevent NullPointerException

In the saveIrisTokenUsage method, the tokens parameter is not checked for null before use. If tokens is null, this will result in a NullPointerException when attempting to iterate over it. Consider adding a null check to handle this case gracefully.

Apply this diff to add a null check:

 public List<LLMTokenUsage> saveIrisTokenUsage(PyrisJob job, IrisMessage message, Exercise exercise, User user, Course course, List<PyrisLLMCostDTO> tokens) {
+    if (tokens == null || tokens.isEmpty()) {
+        return new ArrayList<>();
+    }
     List<LLMTokenUsage> tokenUsages = new ArrayList<>();
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
public List<LLMTokenUsage> saveIrisTokenUsage(PyrisJob job, IrisMessage message, Exercise exercise, User user, Course course, List<PyrisLLMCostDTO> tokens) {
List<LLMTokenUsage> tokenUsages = new ArrayList<>();
for (PyrisLLMCostDTO cost : tokens) {
public List<LLMTokenUsage> saveIrisTokenUsage(PyrisJob job, IrisMessage message, Exercise exercise, User user, Course course, List<PyrisLLMCostDTO> tokens) {
if (tokens == null || tokens.isEmpty()) {
return new ArrayList<>();
}
List<LLMTokenUsage> tokenUsages = new ArrayList<>();
for (PyrisLLMCostDTO cost : tokens) {

LLMTokenUsage llmTokenUsage = new LLMTokenUsage();
if (message != null) {
llmTokenUsage.setIrisMessage(message);
llmTokenUsage.setTime(message.getSentAt());
}
llmTokenUsage.setServiceType(cost.pipeline());
llmTokenUsage.setExercise(exercise);
if (user != null) {
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
llmTokenUsage.setUserId(user.getId());
}
llmTokenUsage.setCourse(course);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add null check before setting course to prevent potential NullPointerException

The course parameter may be null in some cases, as seen in the overloaded methods where course could be omitted. To ensure robustness, add a null check before setting the course in llmTokenUsage.

Apply this diff to add the null check:

 llmTokenUsage.setServiceType(cost.pipeline());
-llmTokenUsage.setCourse(course);
+if (course != null) {
+    llmTokenUsage.setCourse(course);
+}
 llmTokenUsage.setNumInputTokens(cost.numInputTokens());
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
llmTokenUsage.setCourse(course);
llmTokenUsage.setServiceType(cost.pipeline());
if (course != null) {
llmTokenUsage.setCourse(course);
}
llmTokenUsage.setNumInputTokens(cost.numInputTokens());

llmTokenUsage.setNumInputTokens(cost.numInputTokens());
llmTokenUsage.setCostPerMillionInputTokens(cost.costPerInputToken());
llmTokenUsage.setNumOutputTokens(cost.numOutputTokens());
llmTokenUsage.setCostPerMillionOutputTokens(cost.costPerOutputToken());
llmTokenUsage.setModel(cost.modelInfo());
llmTokenUsage.setTraceId(job.jobId());
tokenUsages.add(llmTokenUsage);
}
llmTokenUsageRepository.saveAll(tokenUsages);
return tokenUsages;
}
alexjoham marked this conversation as resolved.
Show resolved Hide resolved

/**
* Overloaded method to save token usage without message and exercise.
*
* @param job used to create a unique traceId to group multiple LLM calls
* @param user to map the token cost to a user
* @param course to map the token to a course
* @param tokens token cost list of type PyrisLLMCostDTO
* @return list of the saved data
*/
public List<LLMTokenUsage> saveIrisTokenUsage(PyrisJob job, User user, Course course, List<PyrisLLMCostDTO> tokens) {
return saveIrisTokenUsage(job, null, null, user, course, tokens);
}

/**
* Overloaded method to save token usage without exercise.
*
* @param job used to create a unique traceId to group multiple LLM calls
* @param message IrisMessage to map the usage to an IrisMessage
* @param user to map the token cost to a user
* @param course to map the token to a course
* @param tokens token cost list of type PyrisLLMCostDTO
* @return list of the saved data
*/
public List<LLMTokenUsage> saveIrisTokenUsage(PyrisJob job, IrisMessage message, User user, Course course, List<PyrisLLMCostDTO> tokens) {
return saveIrisTokenUsage(job, message, null, user, course, tokens);
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Refactor overloaded methods to reduce code duplication

The overloaded saveIrisTokenUsage methods introduce code duplication by passing null for omitted parameters. This can be refactored to improve readability and maintainability. Consider using optional parameters or redesigning the method signatures to reduce duplication.

For example, you could create a single method that accepts a builder or use method chaining to handle optional parameters more elegantly.


/**
* Overloaded method to save token usage without message, exercise and user.
*
* @param job used to create a unique traceId to group multiple LLM calls
* @param course to map the token to a course
* @param tokens token cost list of type PyrisLLMCostDTO
* @return list of the saved data
*/
public List<LLMTokenUsage> saveIrisTokenUsage(PyrisJob job, Course course, List<PyrisLLMCostDTO> tokens) {
return saveIrisTokenUsage(job, null, null, null, course, tokens);
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Refactor Overloaded Methods to Reduce Code Duplication

The overloaded saveIrisTokenUsage methods introduce code duplication by passing null parameters to the primary method. Consider refactoring to handle optional parameters more elegantly, perhaps by using a builder pattern or optional parameters to reduce duplication and improve maintainability.

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage;
import de.tum.cit.aet.artemis.iris.service.IrisRateLimitService;
import de.tum.cit.aet.artemis.iris.service.pyris.dto.data.PyrisLLMCostDTO;
import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageDTO;

/**
Expand All @@ -21,7 +22,7 @@
*/
@JsonInclude(JsonInclude.Include.NON_EMPTY)
public record IrisChatWebsocketDTO(IrisWebsocketMessageType type, IrisMessage message, IrisRateLimitService.IrisRateLimitInformation rateLimitInfo, List<PyrisStageDTO> stages,
List<String> suggestions) {
List<String> suggestions, List<PyrisLLMCostDTO> tokens) {

/**
* Creates a new IrisWebsocketDTO instance with the given parameters
Expand All @@ -31,8 +32,9 @@ public record IrisChatWebsocketDTO(IrisWebsocketMessageType type, IrisMessage me
* @param rateLimitInfo the rate limit information
* @param stages the stages of the Pyris pipeline
*/
public IrisChatWebsocketDTO(@Nullable IrisMessage message, IrisRateLimitService.IrisRateLimitInformation rateLimitInfo, List<PyrisStageDTO> stages, List<String> suggestions) {
this(determineType(message), message, rateLimitInfo, stages, suggestions);
public IrisChatWebsocketDTO(@Nullable IrisMessage message, IrisRateLimitService.IrisRateLimitInformation rateLimitInfo, List<PyrisStageDTO> stages, List<String> suggestions,
List<PyrisLLMCostDTO> tokens) {
this(determineType(message), message, rateLimitInfo, stages, suggestions, tokens);
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
}

/**
Expand Down
Loading
Loading