Skip to content

Commit

Permalink
feat: pricing page support multi groups #487
Browse files Browse the repository at this point in the history
  • Loading branch information
Calcium-Ion committed Sep 22, 2024
1 parent c6ff785 commit ed972ee
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 40 deletions.
11 changes: 2 additions & 9 deletions controller/pricing.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,11 @@ import (
)

func GetPricing(c *gin.Context) {
userId := c.GetInt("id")
// if no login, get default group ratio
groupRatio := common.GetGroupRatio("default")
group, err := model.CacheGetUserGroup(userId)
if err == nil {
groupRatio = common.GetGroupRatio(group)
}
pricing := model.GetPricing(group)
pricing := model.GetPricing()
c.JSON(200, gin.H{
"success": true,
"data": pricing,
"group_ratio": groupRatio,
"group_ratio": common.GroupRatio,
})
}

Expand Down
6 changes: 6 additions & 0 deletions model/ability.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ func GetEnabledModels() []string {
return models
}

func GetAllEnableAbilities() []Ability {
var abilities []Ability
DB.Find(&abilities, "enabled = ?", true)
return abilities
}

func getPriority(group string, model string, retry int) (int, error) {
groupCol := "`group`"
trueVal := "1"
Expand Down
48 changes: 27 additions & 21 deletions model/pricing.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@ import (
)

type Pricing struct {
Available bool `json:"available"`
ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"`
OwnerBy string `json:"owner_by"`
CompletionRatio float64 `json:"completion_ratio"`
EnableGroup []string `json:"enable_group,omitempty"`
EnableGroup []string `json:"enable_groups,omitempty"`
}

var (
Expand All @@ -23,40 +22,47 @@ var (
updatePricingLock sync.Mutex
)

func GetPricing(group string) []Pricing {
func GetPricing() []Pricing {
updatePricingLock.Lock()
defer updatePricingLock.Unlock()

if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
updatePricing()
}
if group != "" {
userPricingMap := make([]Pricing, 0)
models := GetGroupModels(group)
for _, pricing := range pricingMap {
if !common.StringsContains(models, pricing.ModelName) {
pricing.Available = false
}
userPricingMap = append(userPricingMap, pricing)
}
return userPricingMap
}
//if group != "" {
// userPricingMap := make([]Pricing, 0)
// models := GetGroupModels(group)
// for _, pricing := range pricingMap {
// if !common.StringsContains(models, pricing.ModelName) {
// pricing.Available = false
// }
// userPricingMap = append(userPricingMap, pricing)
// }
// return userPricingMap
//}
return pricingMap
}

func updatePricing() {
//modelRatios := common.GetModelRatios()
enabledModels := GetEnabledModels()
allModels := make(map[string]int)
for i, model := range enabledModels {
allModels[model] = i
enableAbilities := GetAllEnableAbilities()
modelGroupsMap := make(map[string][]string)
for _, ability := range enableAbilities {
groups := modelGroupsMap[ability.Model]
if groups == nil {
groups = make([]string, 0)
}
if !common.StringsContains(groups, ability.Group) {
groups = append(groups, ability.Group)
}
modelGroupsMap[ability.Model] = groups
}

pricingMap = make([]Pricing, 0)
for model, _ := range allModels {
for model, groups := range modelGroupsMap {
pricing := Pricing{
Available: true,
ModelName: model,
ModelName: model,
EnableGroup: groups,
}
modelPrice, findPrice := common.GetModelPrice(model, false)
if findPrice {
Expand Down
2 changes: 1 addition & 1 deletion web/src/components/HeaderBar.js
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ let buttons = [
text: '首页',
itemKey: 'home',
to: '/',
icon: <IconHomeStroked />,
// icon: <IconHomeStroked />,
},
// {
// text: '模型价格',
Expand Down
60 changes: 51 additions & 9 deletions web/src/components/ModelPricing.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import React, { useContext, useEffect, useRef, useMemo, useState } from 'react';
import { API, copy, showError, showSuccess } from '../helpers';
import { API, copy, showError, showInfo, showSuccess } from '../helpers';

import {
Banner,
Expand Down Expand Up @@ -87,6 +87,7 @@ const ModelPricing = () => {
const [selectedRowKeys, setSelectedRowKeys] = useState([]);
const [modalImageUrl, setModalImageUrl] = useState('');
const [isModalOpenurl, setIsModalOpenurl] = useState(false);
const [selectedGroup, setSelectedGroup] = useState('default');

const rowSelection = useMemo(
() => ({
Expand Down Expand Up @@ -120,7 +121,8 @@ const ModelPricing = () => {
title: '可用性',
dataIndex: 'available',
render: (text, record, index) => {
return renderAvailable(text);
// if record.enable_groups contains selectedGroup, then available is true
return renderAvailable(record.enable_groups.includes(selectedGroup));
},
sorter: (a, b) => a.available - b.available,
},
Expand Down Expand Up @@ -166,6 +168,43 @@ const ModelPricing = () => {
},
sorter: (a, b) => a.quota_type - b.quota_type,
},
{
title: '可用分组',
dataIndex: 'enable_groups',
render: (text, record, index) => {
// enable_groups is a string array
return (
<Space>
{text.map((group) => {
if (group === selectedGroup) {
return (
<Tag
color='blue'
size='large'
prefixIcon={<IconVerify />}
>
{group}
</Tag>
);
} else {
return (
<Tag
color='blue'
size='large'
onClick={() => {
setSelectedGroup(group);
showInfo('当前查看的分组为:' + group + ',倍率为:' + groupRatio[group]);
}}
>
{group}
</Tag>
);
}
})}
</Space>
);
},
},
{
title: () => (
<span style={{'display':'flex','alignItems':'center'}}>
Expand Down Expand Up @@ -201,6 +240,8 @@ const ModelPricing = () => {
<Text>模型:{record.quota_type === 0 ? text : '无'}</Text>
<br />
<Text>补全:{record.quota_type === 0 ? completionRatio : '无'}</Text>
<br />
<Text>分组:{groupRatio[selectedGroup]}</Text>
</>
);
return <div>{content}</div>;
Expand All @@ -213,11 +254,11 @@ const ModelPricing = () => {
let content = text;
if (record.quota_type === 0) {
// 这里的 *2 是因为 1倍率=0.002刀,请勿删除
let inputRatioPrice = record.model_ratio * 2 * record.group_ratio;
let inputRatioPrice = record.model_ratio * 2 * groupRatio[selectedGroup];
let completionRatioPrice =
record.model_ratio *
record.completion_ratio * 2 *
record.group_ratio;
groupRatio[selectedGroup];
content = (
<>
<Text>提示 ${inputRatioPrice} / 1M tokens</Text>
Expand All @@ -226,7 +267,7 @@ const ModelPricing = () => {
</>
);
} else {
let price = parseFloat(text) * record.group_ratio;
let price = parseFloat(text) * groupRatio[selectedGroup];
content = <>模型价格:${price}</>;
}
return <div>{content}</div>;
Expand All @@ -237,12 +278,12 @@ const ModelPricing = () => {
const [models, setModels] = useState([]);
const [loading, setLoading] = useState(true);
const [userState, userDispatch] = useContext(UserContext);
const [groupRatio, setGroupRatio] = useState(1);
const [groupRatio, setGroupRatio] = useState({});

const setModelsFormat = (models, groupRatio) => {
for (let i = 0; i < models.length; i++) {
models[i].key = models[i].model_name;
models[i].group_ratio = groupRatio;
models[i].group_ratio = groupRatio[models[i].model_name];
}
// sort by quota_type
models.sort((a, b) => {
Expand Down Expand Up @@ -275,6 +316,7 @@ const ModelPricing = () => {
const { success, message, data, group_ratio } = res.data;
if (success) {
setGroupRatio(group_ratio);
setSelectedGroup(userState.user ? userState.user.group : 'default')
setModelsFormat(data, group_ratio);
} else {
showError(message);
Expand Down Expand Up @@ -307,14 +349,14 @@ const ModelPricing = () => {
type="success"
fullMode={false}
closeIcon="null"
description={`您的分组为${userState.user.group},分组倍率为:${groupRatio}`}
description={`您的默认分组为${userState.user.group},分组倍率为:${groupRatio[userState.user.group]}`}
/>
) : (
<Banner
type='warning'
fullMode={false}
closeIcon="null"
description={`您还未登陆,显示的价格为默认分组倍率: ${groupRatio}`}
description={`您还未登陆,显示的价格为默认分组倍率: ${groupRatio['default']}`}
/>
)}
<br/>
Expand Down

0 comments on commit ed972ee

Please sign in to comment.