Skip to content

Commit

Permalink
Fix/download (#35)
Browse files Browse the repository at this point in the history
* fix: copy bind from models index

* fix(download): settings

* fix: better regex

* fix: version check
  • Loading branch information
ido-pluto authored Sep 2, 2023
1 parent e396a8e commit dcdec37
Show file tree
Hide file tree
Showing 11 changed files with 86 additions and 45 deletions.
1 change: 1 addition & 0 deletions server/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions server/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
"ora": "^7.0.1",
"progress-stream": "^2.0.0",
"prompts": "^2.4.2",
"semver": "^7.5.4",
"sirv": "^2.0.2",
"uuid": "^9.0.0",
"wretch": "^2.6.0",
Expand Down
9 changes: 6 additions & 3 deletions server/src/cli/commands/install.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@ installCommand.description('Install any GGML/GGUF model')
download: model,
tag,
latest,
settings: {
bindClass: bind,
apiKey: key,
model: {
settings: {
bind,
key,
}

}
});
await installer.startDownload();
Expand Down
24 changes: 17 additions & 7 deletions server/src/cli/commands/postinstall/migration/migtation.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
import {$} from 'execa';
import {packageJSON} from '../../../../storage/config.js';
import semver from 'semver';

import v0313 from './versions/v0.3.13.js';
import v209 from './versions/v2.0.9.js';

const MIGRATIONS = [v0313, v209];

export async function runMigrations() {
const {stdout: oldVersion} = await $`npm -g info catai version`;
const fromVersion = Number(oldVersion.trim()[0] || '-');
const toVersion = Number(packageJSON.version.trim()[0]);

for (let i = fromVersion; i < toVersion; i++) {
try {
const migration = await import(`./versions/v${i}.js`);
await migration.default();
} catch {}
const fromVersion = oldVersion.trim();
const toVersion = packageJSON.version.trim();

if (!fromVersion) return;

for (const migration of MIGRATIONS) {
if (semver.lte(toVersion, migration.version)) continue;

if (semver.gte(migration.version, fromVersion)) {
console.log(`CatAI Migrated to v${migration.version}`);
await migration.migration();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import AppDb from '../../../../../storage/app-db.js';
const EXECUTABLE_DIR = path.join(ENV_CONFIG.CATAI_DIR!, 'executable');
const MIN_SIZE_BYTES = 10485760 * 2; // 20mb

export default async function migrationV0() {
async function migration() {
const models = await fs.readdir(ENV_CONFIG.MODEL_DIR!);

for (const model of models) {
Expand All @@ -21,13 +21,18 @@ export default async function migrationV0() {
},
bindClass: 'node-llama',
createDate: stat.birthtime.getTime(),
compatibleCatAIVersionRange: ['0.3.0', '0.3.12'],
compatibleCatAIVersionRange: ['0.3.0', '0.3.13'],
version: 0,
settings: {},
defaultSettings: {},
};
} as any;
}

await AppDb.saveDB();
await fs.remove(EXECUTABLE_DIR);
}

export default {
version: '0.3.13',
migration
};
15 changes: 15 additions & 0 deletions server/src/cli/commands/postinstall/migration/versions/v2.0.9.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import AppDB from '../../../../../storage/app-db.js';

async function migration() {
for (const [, value] of Object.entries(AppDB.db.models)) {
value.settings ??= {} as any;
value.settings.bind ??= (value as any).bindClass;
value.defaultSettings.bind ??= (value as any).bindClass;
}
await AppDB.saveDB();
}

export default {
version: '2.0.9',
migration
};
19 changes: 11 additions & 8 deletions server/src/manage-models/about-models/fetch-models/fetch-models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import fs from 'fs-extra';
import {pathToFileURL} from 'url';
import findBestModelBinding from '../best-model-binding.js';
import ConnectChunksProgress from './connect-chunks-progress.js';
import objectAssignDeep from 'object-assign-deep';

export type DetailedDownloadInfo = {
files: {
Expand All @@ -21,7 +22,7 @@ export type FetchOptions = {
download: string | string[] | DetailedDownloadInfo
tag?: string;
latest?: boolean;
settings?: Partial<ModelSettings<any>>
model?: Partial<ModelSettings<any>>
}

type RemoteFetchModels = {
Expand Down Expand Up @@ -70,7 +71,7 @@ export default class FetchModels {
if (!foundModel)
return this._setDetailedLocalModel();

const {download: modelDownloadDetails, ...settings} = models[foundModel!];
const {download: modelDownloadDetails, ...model} = models[foundModel!];

const branch = this.options.latest ? modelDownloadDetails.branch : modelDownloadDetails.commit;
const downloadLinks = Object.fromEntries(
Expand All @@ -81,8 +82,7 @@ export default class FetchModels {
);

this.options.tag = foundModel;
this.options.settings = {...settings, ...this.options.settings};

this.options.model = objectAssignDeep(this.options.model ?? {}, model);
this._downloadFiles = downloadLinks;
}

Expand Down Expand Up @@ -127,12 +127,15 @@ export default class FetchModels {
downloadedFiles[type] = savePath;
}

const settings = this.options.model?.settings ?? {};
settings.bind ??= findBestModelBinding(downloadedFiles);

AppDb.db.models[this.options.tag!] = {
...this.options.settings,
...this.options.model,
downloadedFiles,
defaultSettings: this.options.settings?.settings ?? {},
settings,
defaultSettings: settings,
createDate: Date.now(),
bindClass: this.options.settings?.bindClass ?? findBestModelBinding(downloadedFiles),
};
await AppDb.saveDB();
}
Expand Down Expand Up @@ -167,6 +170,6 @@ export default class FetchModels {
}

private static _findModelTag(modelPath: string) {
return modelPath.split(/\/|\\/).pop()!;
return modelPath.split(/[\/\\]/).pop()!;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ import chalk from 'chalk';
import * as os from 'os';
import FetchModels, {DEFAULT_VERSION} from './fetch-models/fetch-models.js';
import AppDb, {ModelSettings} from '../../storage/app-db.js';
import {calculateVersion} from '../../utils/check-for-update.js';
import {packageJSON} from '../../storage/config.js';
import semver from 'semver';

const GB_IN_BYTES = 1024 * 1024 * 1024;

Expand All @@ -29,21 +29,22 @@ class ModelCompatibilityChecker {
private static readonly availableMemory: number = os.freemem() / GB_IN_BYTES;

public static checkModelCompatibility({hardwareCompatibility, compatibleCatAIVersionRange}: ModelSettings<any>): Compatibility {
if(!compatibleCatAIVersionRange?.[0]){
if (!compatibleCatAIVersionRange?.[0]) {
return {
compatibility: '?',
note: 'Model unknown'
};
}

if(calculateVersion(compatibleCatAIVersionRange[0]) > calculateVersion(packageJSON.version)) {

if (semver.gt(compatibleCatAIVersionRange[0], packageJSON.version)) {
return {
compatibility: '❌',
note: `requires at least CatAI version ${chalk.cyan(compatibleCatAIVersionRange[0])}`
};
}

if(compatibleCatAIVersionRange[1] && calculateVersion(compatibleCatAIVersionRange[1]) < calculateVersion(packageJSON.version)) {
if (compatibleCatAIVersionRange[1] && semver.lt(compatibleCatAIVersionRange[1], packageJSON.version)) {
return {
compatibility: '❌',
note: `requires CatAI version ${chalk.cyan(compatibleCatAIVersionRange[1])} or lower`
Expand Down
16 changes: 10 additions & 6 deletions server/src/manage-models/bind-class/bind-class.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function getActiveModelDetails() {
if (!modelDetails)
throw new Error('No active model');

if(!modelDetails.bindClass)
if(!modelDetails.settings.bind)
throw new Error('No bind class');

return modelDetails;
Expand All @@ -21,8 +21,10 @@ function getActiveModelDetails() {
export function getCacheBindClass(){
const modelDetails = getActiveModelDetails();

if(cachedBinds[modelDetails.bindClass])
return cachedBinds[modelDetails.bindClass];
const bind = modelDetails.settings.bind;

if(cachedBinds[bind])
return cachedBinds[bind];

return null;
}
Expand All @@ -33,11 +35,13 @@ export default async function createChat(){
if(cachedBindClass)
return cachedBindClass.createChat();

const bindClass = ALL_BINDS.find(x => x.shortName === modelDetails.bindClass);
const bind = modelDetails.settings.bind;

const bindClass = ALL_BINDS.find(x => x.shortName === bind);
if (!bindClass)
throw new Error(`Bind class ${modelDetails.bindClass} not found`);
throw new Error(`Bind class ${bind} not found`);

const bindClassInstance = cachedBinds[modelDetails.bindClass] ??= new bindClass(modelDetails);
const bindClassInstance = cachedBinds[bind] ??= new bindClass(modelDetails);
await bindClassInstance.initialize();

return bindClassInstance.createChat();
Expand Down
11 changes: 7 additions & 4 deletions server/src/storage/app-db.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ import fs from 'fs-extra';
import path from 'path';
import ENV_CONFIG from './config.js';

export type ModelInnerSettings<T> = T &{
bind: string;
key?: string;
}

export type ModelSettings<T> = {
bindClass: string;
apiKey?: string;
downloadedFiles: {
[fileId: string]: string
},
Expand All @@ -15,8 +18,8 @@ export type ModelSettings<T> = {
"cpuCors": number,
"compressions": number
}
settings?: T;
defaultSettings: T,
settings: ModelInnerSettings<T>;
defaultSettings: ModelInnerSettings<T>,
createDate: number
};

Expand Down
15 changes: 5 additions & 10 deletions server/src/utils/check-for-update.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
import wretch from "wretch";
import {packageJSON} from "../storage/config.js";
import chalk from "chalk";


export function calculateVersion(version: string) {
const [major, minor, patch] = version.split('.').map(x => parseInt(x));
return major * 10000 + minor * 100 + patch;
}
import wretch from 'wretch';
import {packageJSON} from '../storage/config.js';
import chalk from 'chalk';
import semver from 'semver';

async function checkForUpdate() {
const npmPackage: any = await wretch(`https://registry.npmjs.com/${packageJSON.name}`).get().json();
const latestVersion = npmPackage['dist-tags'].latest;

if (calculateVersion(packageJSON.version) >= calculateVersion(latestVersion))
if (semver.gte(packageJSON.version, latestVersion))
return;

console.log(`\n${chalk.green('New version available!')}, some models may not work in older versions`);
Expand Down

0 comments on commit dcdec37

Please sign in to comment.