diff --git a/.github/workflows/on-pull-req.yaml b/.github/workflows/on-pull-req.yaml index 153f40b7..e6874174 100644 --- a/.github/workflows/on-pull-req.yaml +++ b/.github/workflows/on-pull-req.yaml @@ -3,7 +3,7 @@ on: pull_request: branches: [main] jobs: - prettier: + check: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/on-push-server.yaml b/.github/workflows/on-push-server.yaml deleted file mode 100644 index 1d740cf1..00000000 --- a/.github/workflows/on-push-server.yaml +++ /dev/null @@ -1,32 +0,0 @@ -name: On Push Server -on: - push: {branches: [server]} -jobs: - push-docker-image: - runs-on: ubuntu-latest - env: - DOCKER_CLI_EXPERIMENTAL: enabled - steps: - - uses: actions/checkout@v4 - - name: metadata - id: metadata - run: | - git fetch --all --tags - TAG=$(git describe --tags) - echo "tag=${TAG/-*}" >> $GITHUB_OUTPUT - - uses: docker/setup-qemu-action@v3 - - uses: docker/setup-buildx-action@v3 - with: - install: true - - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.repository_owner }} - password: ${{ secrets.GITHUB_TOKEN }} - - name: Build & push prem - run: >- - docker buildx build --push - --file Dockerfile - --tag ghcr.io/premai-io/prem-app:latest - --tag ghcr.io/premai-io/prem-app:${{ steps.metadata.outputs.tag }} - --platform linux/arm64,linux/amd64 . diff --git a/.github/workflows/on-workflow-dispatch-docker.yaml b/.github/workflows/on-workflow-dispatch-docker.yaml new file mode 100644 index 00000000..01d5002d --- /dev/null +++ b/.github/workflows/on-workflow-dispatch-docker.yaml @@ -0,0 +1,52 @@ +name: 🚀 Docker Image + +on: + workflow_dispatch: + inputs: + version: + description: 'Version tag for the Docker image (optional, will use latest Git tag if empty)' + required: false + type: string + tag_as_latest: + description: 'Also tag as latest?' + required: false + default: false + type: boolean + +jobs: + push-docker-image: + runs-on: ubuntu-latest + env: + DOCKER_CLI_EXPERIMENTAL: enabled + steps: + - uses: actions/checkout@v4 + - name: Determine tag + id: tag + run: | + if [ -z "${{ github.event.inputs.version }}" ]; then + git fetch --all --tags + TAG=$(git describe --tags `git rev-list --tags --max-count=1`) + else + TAG=${{ github.event.inputs.version }} + fi + echo "VERSION_TAG=$TAG" >> $GITHUB_ENV + - uses: docker/setup-qemu-action@v3 + - uses: docker/setup-buildx-action@v3 + with: + install: true + - uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.repository_owner }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Build & push + run: | + docker buildx build --push \ + --file Dockerfile \ + --tag ghcr.io/premai-io/prem-app:$VERSION_TAG \ + ${{ github.event.inputs.tag_as_latest == 'true' && '--tag ghcr.io/premai-io/prem-app:latest' || '' }} \ + --platform linux/arm64,linux/amd64 . + shell: /usr/bin/bash -e {0} + env: + DOCKER_CLI_EXPERIMENTAL: enabled + diff --git a/.github/workflows/on-workflow-dispatch-tauri.yaml b/.github/workflows/on-workflow-dispatch-tauri.yaml new file mode 100644 index 00000000..91cd07aa --- /dev/null +++ b/.github/workflows/on-workflow-dispatch-tauri.yaml @@ -0,0 +1,132 @@ +name: 🚀 Tauri Dekstop App + +on: + workflow_dispatch: + inputs: + branchName: + description: 'Branch Name you are releasing from' + required: true + version: + description: 'Version tag for the Github Release and the .dmg for MacOS' + required: true + release_as_draft: + description: 'Release as Draft' + required: false + default: true + type: boolean + +jobs: + publish-tauri: + permissions: write-all + strategy: + fail-fast: false + matrix: + platform: [macos-latest] + + runs-on: ${{ matrix.platform }} + + steps: + + - name: View branch name + run: | + echo "Branch name: ${{ github.event.inputs.branchName }}" + + - name: View version + run: | + echo "Tag: ${{ github.event.inputs.version }}" + + - name: Version as Number + id: next_version + run: | + tag=${{ github.event.inputs.version }} + echo "version=${tag:1}" >> $GITHUB_OUTPUT + + - name: Checkout code + uses: actions/checkout@v3 + + - name: Rust setup + uses: dtolnay/rust-toolchain@stable + + - name: install dependencies (ubuntu only) + if: matrix.platform == 'ubuntu-20.04' + run: | + sudo apt-get update + sudo apt-get install -y libgtk-3-dev libwebkit2gtk-4.0-dev libappindicator3-dev librsvg2-dev patchelf + + - name: Install missing Rust target for universal Mac build + if: matrix.platform == 'macos-latest' + run: rustup target add aarch64-apple-darwin + + - name: Rust cache + uses: swatinem/rust-cache@v2 + with: + workspaces: "./src-tauri -> target" + + - name: Sync node version and setup cache + uses: actions/setup-node@v3 + with: + node-version: "lts/*" + cache: "npm" + + - name: Install frontend dependencies + run: npm install + + + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: ^1.20.1 + id: go + + - name: Install dasel + run: | + go install github.com/tomwright/dasel/cmd/dasel@latest + + + - name: Update with latest branch + run: | + git config --local user.email "$(git log --format='%ae' HEAD^!)" + git config --local user.name "$(git log --format='%an' HEAD^!)" + git config pull.rebase true + git stash + git fetch origin ${{ github.event.inputs.branchName }} + git pull origin ${{ github.event.inputs.branchName }} + git stash pop || true + + - name: Increment version + run: | + dasel put string -f package.json ".version" "${{ steps.next_version.outputs.version }}" + dasel put string -f src-tauri/tauri.conf.json ".package.version" "${{ steps.next_version.outputs.version }}" + dasel put string -f src-tauri/Cargo.toml ".package.version" "${{ steps.next_version.outputs.version }}" + + - name: Build the app + uses: tauri-apps/tauri-action@dev + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + TAURI_PRIVATE_KEY: ${{ secrets.TAURI_PRIVATE_KEY }} + ENABLE_CODE_SIGNING: ${{ secrets.APPLE_CERTIFICATE }} + APPLE_CERTIFICATE: ${{ secrets.APPLE_CERTIFICATE }} + APPLE_CERTIFICATE_PASSWORD: ${{ secrets.APPLE_CERTIFICATE_PASSWORD }} + APPLE_SIGNING_IDENTITY: ${{ secrets.APPLE_SIGNING_IDENTITY }} + APPLE_ID: ${{ secrets.APPLE_ID }} + APPLE_PASSWORD: ${{ secrets.APPLE_PASSWORD }} + APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }} + with: + tagName: ${{ github.event.inputs.version }} + releaseName: "PremAI App ${{ github.event.inputs.version }}" + releaseBody: "See the assets to download and install this version." + releaseDraft: ${{ github.event.inputs.release_as_draft }} + includeDebug: true + updaterJsonKeepUniversal: true + args: ${{matrix.platform == 'ubuntu-20.04' && '--target x86_64-unknown-linux-gnu' || '--target universal-apple-darwin'}} + + + # Commit package.json, tauri.conf.json and Cargo.toml to master + - name: Commit & Push + continue-on-error: true + run: | + git add . + git commit -m "${{ github.event.inputs.version }}" + git push origin HEAD:${{ github.event.inputs.branchName }} + + diff --git a/docker-compose.gateway.yml b/docker-compose.gateway.yml new file mode 100644 index 00000000..6fedfd79 --- /dev/null +++ b/docker-compose.gateway.yml @@ -0,0 +1,114 @@ +version: '3.7' +services: + + premapp: + container_name: premapp + build: . + environment: + - VITE_DESTINATION=browser + - VITE_IS_PACKAGED=true + - VITE_PROXY_ENABLED=true + labels: + - "traefik.enable=true" + - "traefik.http.routers.premapp-http.rule=PathPrefix(`/`)" + - "traefik.http.routers.premapp-http.entrypoints=web" + - "traefik.http.services.premapp.loadbalancer.server.port=8080" + ports: + - "8085:8080" + restart: unless-stopped + + premd: + container_name: premd + image: ghcr.io/premai-io/premd:latest + volumes: + - /var/run/docker.sock:/var/run/docker.sock + environment: + - PREM_REGISTRY_URL=https://raw.githubusercontent.com/premAI-io/prem-registry/main/manifests.json + - PROXY_ENABLED=True + labels: + - "traefik.enable=true" + - "traefik.http.routers.premd.rule=PathPrefix(`/premd`)" + - "traefik.http.middlewares.premd-strip-prefix.stripprefix.prefixes=/premd" + - "traefik.http.routers.premd.middlewares=premd-strip-prefix" + ports: + - "8084:8000" + restart: unless-stopped + + + + traefik: + container_name: traefik + image: traefik:v2.4 + command: + - "--providers.docker=true" + - "--providers.docker.exposedbydefault=false" + - "--accesslog=true" + - "--ping" + - "--entrypoints.web.address=:80" + ports: + - "80:80" + - "8080:8080" + - "443:443" + volumes: + - /var/run/docker.sock:/var/run/docker.sock + - traefik-letsencrypt:/letsencrypt + depends_on: + - dnsd + restart: unless-stopped + + dnsd: + container_name: dnsd + image: ghcr.io/premai-io/dnsd:latest + labels: + - "traefik.enable=true" + - "traefik.http.routers.dnsd.rule=PathPrefix(`/dnsd`)" + - "traefik.http.middlewares.dnsd-strip-prefix.stripprefix.prefixes=/dnsd" + - "traefik.http.routers.dnsd.middlewares=dnsd-strip-prefix" + depends_on: + - dnsd-db-pg + - authd + environment: + PREM_GATEWAY_DNS_DB_USER: root + PREM_GATEWAY_DNS_DB_PASS: secret + PREM_GATEWAY_DNS_DB_NAME: dnsd-db + PREM_GATEWAY_DNS_DB_HOST: dnsd-db-pg + ports: + - "8082:8080" + restart: unless-stopped + + dnsd-db-pg: + container_name: dnsd-db-pg + image: postgres:14.7 + ports: + - "5432:5432" + environment: + POSTGRES_USER: root + POSTGRES_PASSWORD: secret + POSTGRES_DB: dnsd-db + volumes: + - dnsd-pg-data:/var/lib/postgresql/data + restart: unless-stopped + + authd: + container_name: authd + image: ghcr.io/premai-io/authd:latest + ports: + - "8081:8080" + restart: unless-stopped + + controllerd: + container_name: controllerd + image: ghcr.io/premai-io/controllerd:latest + ports: + - "8083:8080" + volumes: + - /var/run/docker.sock:/var/run/docker.sock + user: root + environment: + LETSENCRYPT_PROD: false + SERVICES: premd,premapp + restart: unless-stopped + +volumes: + dnsd-pg-data: + traefik-letsencrypt: \ No newline at end of file diff --git a/latest.json b/latest.json index 14086336..a31a3696 100644 --- a/latest.json +++ b/latest.json @@ -1,11 +1,19 @@ { - "version": "0.1.2", + "version": "0.2.0", "notes": "See the assets to download and install this version.", - "pub_date": "2023-10-26T17:02:17.290Z", + "pub_date": "2023-11-08T12:51:18.968Z", "platforms": { - "linux-x86_64": { - "signature": "dW50cnVzdGVkIGNvbW1lbnQ6IHNpZ25hdHVyZSBmcm9tIHRhdXJpIHNlY3JldCBrZXkKUlVTck91dEJzS1kvNnhxSDVuUGVZQ3diN0owNUd1YWd0UldlVmRTWjJsNWFWVHZMbFNxNDE1M2k2Qnd5aEMyNUJJZFlqbFc5VEdQaDlKMzFRejlMc2ZhM25Eb2ozbTJtWXdRPQp0cnVzdGVkIGNvbW1lbnQ6IHRpbWVzdGFtcDoxNjk4MzM5MjA0CWZpbGU6cHJlbV8wLjEuMl9hbWQ2NC5BcHBJbWFnZS50YXIuZ3oKWkt6RE5sc3F3Y0dSS1dGdGx2dzR6N3JrVzdkODliQ1l5QVkyTnJSbzZpdkNLSE1Xa08zeXNQaVNTTkpleWk1dkQvdHdLV2wvVmxaalovZHRqSXR3REE9PQo=", - "url": "https://github.com/premAI-io/prem-app/releases/download/v0.1.2/prem_0.1.2_amd64.AppImage.tar.gz" + "darwin-aarch64": { + "signature": "dW50cnVzdGVkIGNvbW1lbnQ6IHNpZ25hdHVyZSBmcm9tIHRhdXJpIHNlY3JldCBrZXkKUlVTck91dEJzS1kvNjJKRnU2c200a0Y2cy9TRnF6N3M4UmZrU0pRYnVCNjRQUlJLZUpacnpoekM1akl0K1l3VFBxT0hZand1SU1wY0dOb3NRUDEzMHdSeU00S1lLU2w0OUFRPQp0cnVzdGVkIGNvbW1lbnQ6IHRpbWVzdGFtcDoxNjk5NDQ3NzI0CWZpbGU6UHJlbS5hcHAudGFyLmd6CkdqTU8yNHRObmNublhWZGVVVlZZdXhYNTF5OG1TdE9BdEdYcWRyZnZGZU9sRWpOMkd2YldCWkRhdmtWeFZTNWNyMkFEVEFSSGUzOWhMeTRUQTFUOURBPT0K", + "url": "https://github.com/premAI-io/prem-app/releases/download/v0.2.0/Prem_universal.app.tar.gz" + }, + "darwin-x86_64": { + "signature": "dW50cnVzdGVkIGNvbW1lbnQ6IHNpZ25hdHVyZSBmcm9tIHRhdXJpIHNlY3JldCBrZXkKUlVTck91dEJzS1kvNjJKRnU2c200a0Y2cy9TRnF6N3M4UmZrU0pRYnVCNjRQUlJLZUpacnpoekM1akl0K1l3VFBxT0hZand1SU1wY0dOb3NRUDEzMHdSeU00S1lLU2w0OUFRPQp0cnVzdGVkIGNvbW1lbnQ6IHRpbWVzdGFtcDoxNjk5NDQ3NzI0CWZpbGU6UHJlbS5hcHAudGFyLmd6CkdqTU8yNHRObmNublhWZGVVVlZZdXhYNTF5OG1TdE9BdEdYcWRyZnZGZU9sRWpOMkd2YldCWkRhdmtWeFZTNWNyMkFEVEFSSGUzOWhMeTRUQTFUOURBPT0K", + "url": "https://github.com/premAI-io/prem-app/releases/download/v0.2.0/Prem_universal.app.tar.gz" + }, + "darwin-universal": { + "signature": "dW50cnVzdGVkIGNvbW1lbnQ6IHNpZ25hdHVyZSBmcm9tIHRhdXJpIHNlY3JldCBrZXkKUlVTck91dEJzS1kvNjJKRnU2c200a0Y2cy9TRnF6N3M4UmZrU0pRYnVCNjRQUlJLZUpacnpoekM1akl0K1l3VFBxT0hZand1SU1wY0dOb3NRUDEzMHdSeU00S1lLU2w0OUFRPQp0cnVzdGVkIGNvbW1lbnQ6IHRpbWVzdGFtcDoxNjk5NDQ3NzI0CWZpbGU6UHJlbS5hcHAudGFyLmd6CkdqTU8yNHRObmNublhWZGVVVlZZdXhYNTF5OG1TdE9BdEdYcWRyZnZGZU9sRWpOMkd2YldCWkRhdmtWeFZTNWNyMkFEVEFSSGUzOWhMeTRUQTFUOURBPT0K", + "url": "https://github.com/premAI-io/prem-app/releases/download/v0.2.0/Prem_universal.app.tar.gz" } } } \ No newline at end of file diff --git a/package-lock.json b/package-lock.json index bf4114d7..130aaf4d 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "prem-app", - "version": "0.1.2", + "version": "0.2.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "prem-app", - "version": "0.1.2", + "version": "0.2.0", "dependencies": { "@microsoft/fetch-event-source": "^2.0.1", "@tanstack/react-query": "^5.4.3", diff --git a/package.json b/package.json index 266216f0..1237fac8 100644 --- a/package.json +++ b/package.json @@ -74,5 +74,5 @@ "tauri": "tauri" }, "type": "module", - "version": "0.1.2" + "version": "0.2.0" } diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index e7f8eb13..b25a45c0 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -457,6 +457,30 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-deque" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" +dependencies = [ + "cfg-if", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" +dependencies = [ + "autocfg", + "cfg-if", + "crossbeam-utils", + "memoffset", + "scopeguard", +] + [[package]] name = "crossbeam-utils" version = "0.8.16" @@ -513,6 +537,16 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "ctrlc" +version = "3.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82e95fbd621905b854affdc67943b043a0fbb6ed7385fd5a25650d19a8a6cfdf" +dependencies = [ + "nix", + "windows-sys 0.48.0", +] + [[package]] name = "darling" version = "0.20.3" @@ -639,6 +673,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56ce8c6da7551ec6c462cbaf3bfbc75131ebbfa1c944aeaa9dab51ca1c5f0c3b" +[[package]] +name = "either" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" + [[package]] name = "embed-resource" version = "2.4.0" @@ -1975,6 +2015,15 @@ version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72ef4a56884ca558e5ddb05a1d1e7e1bfd9a68d9ed024c21704cc98872dae1bb" +[[package]] +name = "ntapi" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8a3895c6391c39d7fe7ebc444a87eb2991b2a0bc718fdabd071eec617fc68e4" +dependencies = [ + "winapi", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -2447,9 +2496,10 @@ checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" [[package]] name = "prem-app" -version = "0.1.2" +version = "0.2.0" dependencies = [ "chrono", + "ctrlc", "futures", "log", "pretty_env_logger", @@ -2458,8 +2508,10 @@ dependencies = [ "serde", "serde_json", "sys-info", + "sysinfo", "tauri", "tauri-build", + "tauri-plugin-store", "thiserror", "tokio", ] @@ -2647,6 +2699,26 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2ff9a1f06a88b01621b7ae906ef0211290d1c8a168a15542486a8f61c0833b9" +[[package]] +name = "rayon" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "redox_syscall" version = "0.2.16" @@ -3402,6 +3474,21 @@ dependencies = [ "libc", ] +[[package]] +name = "sysinfo" +version = "0.29.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a18d114d420ada3a891e6bc8e96a2023402203296a47cdd65083377dad18ba5" +dependencies = [ + "cfg-if", + "core-foundation-sys", + "libc", + "ntapi", + "once_cell", + "rayon", + "winapi", +] + [[package]] name = "system-configuration" version = "0.5.1" @@ -3641,6 +3728,18 @@ dependencies = [ "tauri-utils", ] +[[package]] +name = "tauri-plugin-store" +version = "0.0.0" +source = "git+https://github.com/tauri-apps/plugins-workspace?branch=v1#8d6045421a553330e9da8b9e1e4405d419c5ea88" +dependencies = [ + "log", + "serde", + "serde_json", + "tauri", + "thiserror", +] + [[package]] name = "tauri-runtime" version = "0.14.1" diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 677800c2..adbc8daf 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -7,12 +7,15 @@ [dependencies] chrono = "0.4.31" + ctrlc = "3.4.1" log = "0.4.20" pretty_env_logger = "0.5.0" sentry-tauri = "0.2" serde_json = "1.0" sys-info = "0.9.1" + sysinfo = "0.29.10" thiserror = "1.0.49" + tauri-plugin-store = { git = "https://github.com/tauri-apps/plugins-workspace", branch = "v1" } [dependencies.futures] default-features = false @@ -44,4 +47,4 @@ license = "" name = "prem-app" repository = "" - version = "0.1.2" + version = "0.2.0" diff --git a/src-tauri/petals/create_env.sh b/src-tauri/petals/create_env.sh index b9bc1738..f4b2be0f 100755 --- a/src-tauri/petals/create_env.sh +++ b/src-tauri/petals/create_env.sh @@ -1,6 +1,5 @@ #!/usr/bin/env bash prem_appdir="${PREM_APPDIR:-.}" -requirements="$(dirname "$0")/requirements.txt" conda_prefix="$prem_appdir/miniconda" if test ! -x "$conda_prefix/bin/mamba"; then @@ -27,7 +26,4 @@ echo "Using Conda at '$conda_prefix'" export PREM_PYTHON="$conda_prefix/envs/prem_app/bin/python" echo "Ensuring env 'prem_app' exists" -test -x "$PREM_PYTHON" || "$conda_prefix/bin/mamba" create -y -n prem_app python=3.11 - -echo "Installing requirements" -"$PREM_PYTHON" -m pip install -r "$requirements" +test -x "$PREM_PYTHON" || "$conda_prefix/bin/mamba" create -y -n prem_app python=3.11 \ No newline at end of file diff --git a/src-tauri/petals/run_swarm.sh b/src-tauri/petals/run_swarm.sh new file mode 100755 index 00000000..b44d183d --- /dev/null +++ b/src-tauri/petals/run_swarm.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash + +# Usage: ./run_swarm.sh +# +# Parameters: +# num_blocks: The number of blocks to use. +# public_name: The public name to use. +# model: The model to use. + +prem_python="${PREM_PYTHON:-python}" +requirements="$(dirname "$0")/requirements.txt" + +echo "Installing requirements" +"$prem_python" -m pip install -r "$requirements" + +echo "Running petals" +"$prem_python" -m petals.cli.run_server \ + --num_blocks "$1" \ + --public_name "$2" \ + "$3" \ No newline at end of file diff --git a/src-tauri/src/controller_binaries.rs b/src-tauri/src/controller_binaries.rs index 8c6eecc5..3c427273 100644 --- a/src-tauri/src/controller_binaries.rs +++ b/src-tauri/src/controller_binaries.rs @@ -1,18 +1,30 @@ // Prevents additional console window on Windows in release, DO NOT REMOVE!! #![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] -use crate::download::Downloader; -use crate::errors::{Context, Result}; -use crate::{Service, SharedState}; -use std::collections::HashMap; +use crate::{ + download::Downloader, + err, + errors::{Context, Result}, + logerr, + swarm::{create_environment, Config}, + utils, Registry, Service, SharedState, +}; + use std::path::PathBuf; use std::time::Duration; -use sys_info::mem_info; +use std::{collections::HashMap, sync::Arc}; use futures::future; + +use sys_info::mem_info; +use sysinfo::{Pid, ProcessExt, ProcessRefreshKind, RefreshKind, SystemExt}; + use tauri::{AppHandle, Runtime, State, Window}; +use tauri_plugin_store::StoreBuilder; + +use tokio::process::{Child, Command}; use tokio::time::interval; -use tokio::{fs, process::Command}; +use tokio::{fs, sync::Mutex}; #[tauri::command(async)] pub async fn download_service( @@ -49,8 +61,9 @@ pub async fn download_service( #[tauri::command(async)] pub async fn start_service( + handle: tauri::AppHandle, service_id: String, - state: State<'_, SharedState>, + state: State<'_, Arc>, app_handle: AppHandle, ) -> Result<()> { let service_dir = app_handle @@ -71,7 +84,7 @@ pub async fn start_service( // Check if service is already running let running_services_guard = state.running_services.lock().await; if running_services_guard.contains_key(&service_id) { - Err(format!("Service with `{service_id}` already exist"))? + err!("Service with `{service_id}` already exist") } drop(running_services_guard); @@ -88,6 +101,11 @@ pub async fn start_service( ) })? .as_str(); + let is_petals_model = services_guard + .get(&service_id) + .with_context(|| format!("service_id {} doesn't exist in registry", service_id))? + .petals + .unwrap_or_default(); log::info!("serve_command: {}", serve_command); let serve_command_vec: Vec<&str> = serve_command.split_whitespace().collect(); @@ -95,7 +113,7 @@ pub async fn start_service( let binary_path = PathBuf::from(&service_dir).join(&serve_command_vec[0]); log::info!("binary_path: {:?}", binary_path); if !binary_path.exists() { - Err(format!("invalid binary for `{service_id}`"))? + err!("invalid binary for `{service_id}`") } // Extract the arguments with different delimiters let args: Vec = serve_command_vec[1..] @@ -114,9 +132,20 @@ pub async fn start_service( }) .collect(); log::info!("args: {:?}", args); + + let config = Config::new(); + let mut env_vars = HashMap::new(); + env_vars.insert("PREM_APPDIR".to_string(), config.app_data_dir); + env_vars.insert("PREM_PYTHON".to_string(), config.python); + + if is_petals_model { + create_environment(handle); + } + let child = Command::new(&binary_path) .current_dir(service_dir) .args(args) + .envs(env_vars) .stdout(std::process::Stdio::from( log_file .try_clone() @@ -126,31 +155,42 @@ pub async fn start_service( .spawn() .map_err(|e| format!("Failed to spawn child process: {}", e))?; - // Check if the service is running calling /v1 endpoint every 500ms - let interval_duration = Duration::from_millis(500); - let mut interval = interval(interval_duration); - loop { - interval.tick().await; - let base_url = get_base_url(&services_guard[&service_id])?; - let url = format!("{}/v1", base_url); - let client = reqwest::Client::new(); - let res = client.get(&url).send().await; - match res { - Ok(response) => { - // If /v1 is not implemented by the service, it will return 400 Bad Request, consider it as success - if response.status().is_success() - || response.status() == reqwest::StatusCode::BAD_REQUEST - { - let mut running_services_guard = state.running_services.lock().await; - running_services_guard.insert(service_id.clone(), child); - log::info!("Service started: {}", service_id); - break; - } else { - log::error!("Service failed to start: {}", service_id); + let skip_service_check = services_guard + .get(&service_id) + .map(|service| service.skip_health_check.unwrap_or(false)) + .unwrap(); + + if skip_service_check { + let mut running_services_guard = state.running_services.lock().await; + running_services_guard.insert(service_id.clone(), child); + log::info!("Service started: {}", service_id); + } else { + // Check if the service is running calling /v1 endpoint every 500ms + let interval_duration = Duration::from_millis(500); + let mut interval = interval(interval_duration); + loop { + interval.tick().await; + let base_url = get_base_url(&services_guard[&service_id])?; + let url = format!("{}/v1", base_url); + let client = reqwest::Client::new(); + let res = client.get(&url).send().await; + match res { + Ok(response) => { + // If /v1 is not implemented by the service, it will return 400 Bad Request, consider it as success + if response.status().is_success() + || response.status() == reqwest::StatusCode::BAD_REQUEST + { + let mut running_services_guard = state.running_services.lock().await; + running_services_guard.insert(service_id.clone(), child); + log::info!("Service started: {}", service_id); + break; + } else { + log::error!("Service failed to start: {}", service_id); + } + } + Err(e) => { + log::error!("Failed to send request: {}", e); } - } - Err(e) => { - log::error!("Failed to send request: {}", e); } } } @@ -158,32 +198,78 @@ pub async fn start_service( } #[tauri::command(async)] -pub async fn stop_service(service_id: String, state: State<'_, SharedState>) -> Result<()> { - let mut running_services_guard = state.running_services.lock().await; - if let Some(mut child) = running_services_guard.remove(&service_id) { - match child.kill().await { - Ok(_) => { - log::info!("Child process killed: {}", service_id); - } - Err(e) => { - log::error!("Failed to kill child process: {}", e); - } - } +pub async fn stop_service(service_id: String, state: State<'_, Arc>) -> Result<()> { + let running_services = &state.running_services; + let services = &state.services; + _stop_service(service_id.as_str(), running_services, services).await +} + +async fn _stop_service( + service_id: &str, + running_services: &Mutex>, + services: &Mutex>, +) -> Result<()> { + log::info!("stopping service service_id = {service_id}"); + let mut running_services_guard = running_services.lock().await; + let Some(mut child) = running_services_guard.remove(service_id) else { + err!("Service not running") + }; + // kill the process gracefully using SIGTERM/SIGINT + let Some(pid) = child.id() else { + err!("Service couldn't be stopped: {}", service_id) + }; + log::info!("service pid = {pid}"); + let system = sysinfo::System::new_with_specifics( + RefreshKind::new().with_processes(ProcessRefreshKind::new()), + ); + let process = system + .process(Pid::from(pid as usize)) + .with_context(|| format!("Process pid({}) invalid", pid))?; + log::info!( + "terminating service: process_name({}) process_id({})", + process.name(), + process.pid() + ); + if !process + .kill_with(sysinfo::Signal::Term) + .with_context(|| format!("Couldn't send terminate signal to process(pid: {}).", pid))? + { + err!("Failed to kill the process"); } - let mut registry_lock = state.services.lock().await; - if let Some(service) = registry_lock.get_mut(&service_id) { + let mut registry_lock = services.lock().await; + if let Some(service) = registry_lock.get_mut(service_id) { service.running = Some(false); } + // wait for process to properly exit + if let Ok(_exit_code) = child.try_wait() { + log::info!("service stopped!"); + } else if let Ok(_exit_code) = child.wait().await { + log::info!("service stopped!"); + } Ok(()) } -pub fn stop_all_services(state: tauri::State<'_, SharedState>) { +pub fn stop_all_services(state: Arc) { + log::info!("Stopping all services"); tauri::async_runtime::block_on(async move { - let services = state.running_services.lock().await; - for service_id in services.keys() { - _ = stop_service(service_id.clone(), state.clone()).await; + let keys = state + .running_services + .lock() + .await + .keys() + .cloned() + .collect::>(); + for service_id in keys { + logerr!( + _stop_service( + service_id.as_str(), + &state.running_services, + &state.services + ) + .await + ); } - }) + }); } #[tauri::command(async)] @@ -213,7 +299,7 @@ pub async fn get_logs_for_service(service_id: String, app_handle: AppHandle) -> #[tauri::command] pub async fn get_services( - state: State<'_, SharedState>, + state: State<'_, Arc>, app_handle: AppHandle, ) -> Result> { let mut services = Vec::new(); @@ -254,7 +340,7 @@ pub async fn get_services( #[tauri::command(async)] pub async fn get_service_by_id( service_id: String, - state: State<'_, SharedState>, + state: State<'_, Arc>, app_handle: AppHandle, ) -> Result { let services = state.services.lock().await; @@ -265,12 +351,15 @@ pub async fn get_service_by_id( // Dynamic service state #[tauri::command(async)] -pub async fn get_running_services(state: State<'_, SharedState>) -> Result> { +pub async fn get_running_services(state: State<'_, Arc>) -> Result> { let services = state.running_services.lock().await; return Ok(services.keys().cloned().collect()); } -pub async fn is_service_running(service_id: &str, state: &State<'_, SharedState>) -> Result { +pub async fn is_service_running( + service_id: &str, + state: &State<'_, Arc>, +) -> Result { let running_services = state.running_services.lock().await; return Ok(running_services.contains_key(service_id)); } @@ -279,9 +368,9 @@ pub async fn is_service_downloaded(service: &Service, app_handle: &AppHandle) -> let service_dir = app_handle .path_resolver() .app_data_dir() - .expect("failed to resolve app data dir") + .with_context(|| "Failed to resolve app data dir")? .join("models") - .join(&service.id.as_ref().unwrap()); + .join(service.get_id_ref()?); let mut downloaded = true; for file in &service.weights_files.clone().unwrap() { let file_path = service_dir.join(file); @@ -304,11 +393,19 @@ pub async fn is_service_downloaded(service: &Service, app_handle: &AppHandle) -> .head(url) .send() .await - .map_err(|_| format!("Failed HEAD request: {}", url)) - .unwrap(); + .map_err(|_| format!("Failed HEAD request: {}", url))?; if response.status().is_success() { if let Some(remote_file_size) = response.headers().get("Content-Length") { - let parsed_size = remote_file_size.to_str().unwrap().parse::().unwrap(); + let parsed_size = remote_file_size + .to_str() + .with_context(|| "Header value remote_file_size not utf-8")? + .parse::() + .with_context(|| { + format!( + "Failed to parse header-value({:?}) as u64", + remote_file_size.to_str() + ) + })?; // println!("Content-Length: {:?}", parsed_size); if file_size_on_disk != parsed_size { downloaded = false; @@ -359,7 +456,7 @@ pub fn get_base_url(service: &Service) -> Result { pub async fn update_service_with_dynamic_state( service: &mut Service, - state: &State<'_, SharedState>, + state: &State<'_, Arc>, app_handle: &AppHandle, ) -> Result { let is_service_downloaded = is_service_downloaded(service, &app_handle).await?; @@ -367,7 +464,7 @@ pub async fn update_service_with_dynamic_state( let has_enough_total_memory = has_enough_total_memory(service).await?; let has_enough_storage = has_enough_storage().await?; let is_supported = is_supported().await?; - let is_service_running = is_service_running(&service.id.as_ref().unwrap(), &state).await?; + let is_service_running = is_service_running(service.get_id_ref()?, &state).await?; let base_url = get_base_url(service)?; service.downloaded = Some(is_service_downloaded); service.enough_memory = Some(has_enough_free_memory); @@ -383,9 +480,9 @@ pub async fn update_service_with_dynamic_state( pub async fn get_system_stats() -> Result> { Ok(HashMap::new()) } + #[tauri::command(async)] -pub async fn get_service_stats(service_id: String) -> Result> { - log::info!("service_id: {}", service_id); +pub async fn get_service_stats(_service_id: String) -> Result> { Ok(HashMap::new()) } #[tauri::command(async)] @@ -394,8 +491,148 @@ pub async fn get_gpu_stats() -> Result> { } #[tauri::command(async)] -pub async fn add_service(service: Service, state: State<'_, SharedState>) -> Result<()> { +pub async fn add_service(service: Service, state: State<'_, Arc>) -> Result<()> { let mut services_guard = state.services.lock().await; - services_guard.insert(service.id.clone().unwrap(), service.clone()); + services_guard.insert(service.get_id()?, service); + Ok(()) +} + +#[tauri::command(async)] +pub async fn add_registry( + registry: Registry, + app_handle: AppHandle, + state: State<'_, Arc>, +) -> Result<()> { + let store_path = app_handle + .path_resolver() + .app_data_dir() + .with_context(|| "Failed to resolve app data dir")? + .join("store.json"); + let mut store = StoreBuilder::new(app_handle, store_path).build(); + store.load().with_context(|| "Failed to load store")?; + if let Some(registries) = store.get("registries").cloned() { + match serde_json::from_value::>(registries) { + Ok(mut registries) => { + registries.push(registry); + store + .insert( + "registries".to_string(), + serde_json::to_value(®istries).unwrap(), + ) + .with_context(|| "Failed to insert into store")?; + utils::fetch_all_services_manifests(®istries, &state) + .await + .expect("failed to fetch services") + } + Err(e) => println!("Error unwrapping registries: {:?}", e), + } + } else { + let new_registry = [registry]; + store + .insert( + "registries".to_string(), + serde_json::to_value(&new_registry).unwrap(), + ) + .with_context(|| "Failed to insert into store")?; + utils::fetch_all_services_manifests(&new_registry, &state) + .await + .expect("failed to fetch services") + } + store.save().expect("failed to save store"); + Ok(()) +} + +#[tauri::command(async)] +pub async fn delete_registry( + registry: Registry, + app_handle: AppHandle, + state: State<'_, Arc>, +) -> Result<()> { + let store_path = app_handle + .path_resolver() + .app_data_dir() + .with_context(|| "Failed to resolve app data dir")? + .join("store.json"); + let mut store = StoreBuilder::new(app_handle, store_path).build(); + store.load().with_context(|| "Failed to load store")?; + if let Some(registries) = store.get("registries").cloned() { + let mut registries = serde_json::from_value::>(registries) + .with_context(|| "Failed to deserialize")?; + registries.retain(|r| r.url != registry.url); + store + .insert( + "registries".to_string(), + serde_json::to_value(registries).with_context(|| "Failed to serialize")?, + ) + .with_context(|| "Failed to insert into store")?; + store.save().with_context(|| "Failed to save store")?; + + // Reset services state and refetch all registries + let mut services_guard = state.services.lock().await; + services_guard.clear(); + drop(services_guard); + if let Some(registries) = store.get("registries").cloned() { + match serde_json::from_value::>(registries) { + Ok(registries) => utils::fetch_all_services_manifests(®istries, &state) + .await + .with_context(|| "Failed to fetch services")?, + Err(e) => log::error!("Error unwrapping registries: {:?}", e), + } + } else { + println!("No registries found"); + } + } + Ok(()) +} + +#[tauri::command(async)] +pub async fn fetch_registries(app_handle: AppHandle) -> Result> { + let store_path = app_handle + .path_resolver() + .app_data_dir() + .with_context(|| "Failed to resolve app data dir")? + .join("store.json"); + let mut store = StoreBuilder::new(app_handle, store_path).build(); + match store.load() { + Ok(_) => { + if let Some(registries) = store.get("registries").cloned() { + match serde_json::from_value::>(registries) { + Ok(registries) => Ok(registries), + Err(e) => { + log::error!("Error unwrapping registries: {:?}", e); + Ok(Vec::new()) + } + } + } else { + log::error!("No registries found"); + Ok(Vec::new()) + } + } + Err(e) => { + log::error!("Error loading store: {:?}", e); + Ok(Vec::new()) + } + } +} + +#[tauri::command(async)] +pub async fn reset_default_registry( + app_handle: AppHandle, + state: State<'_, Arc>, +) -> Result<()> { + let store_path = app_handle + .path_resolver() + .app_data_dir() + .with_context(|| "Failed to resolve app data dir")? + .join("store.json"); + let mut store = StoreBuilder::new(app_handle.clone(), store_path).build(); + store.load().with_context(|| "Failed to load store")?; + store + .delete("registries") + .with_context(|| "Failed to delete registries")?; + store.save().with_context(|| "Failed to save store")?; + add_registry(Registry::default(), app_handle.clone(), state) + .await + .with_context(|| "Failed to add default registry")?; Ok(()) } diff --git a/src-tauri/src/download.rs b/src-tauri/src/download.rs index ad8210f8..c42bfbd9 100644 --- a/src-tauri/src/download.rs +++ b/src-tauri/src/download.rs @@ -1,7 +1,8 @@ use crate::errors::{Context, Result}; use std::collections::HashMap; +use std::sync::Arc; -use crate::{utils, SharedState}; +use crate::{err, utils, SharedState}; use futures::StreamExt; use serde::Serialize; use tauri::{Manager, Runtime, Window}; @@ -144,7 +145,7 @@ impl Downloader { total_file_size: u64, size_on_disk: u64, ) -> Result<()> { - let state = self.window.state::(); + let state = self.window.state::>(); let mut downloading_files_guard = state.downloading_files.lock().await; if downloading_files_guard.contains(&output_path.as_ref().to_string()) { log::warn!("File already downloading: {}", output_path.as_ref()); @@ -175,11 +176,7 @@ impl Downloader { // Check the status for errors. if !res.status().is_success() { - Err(format!( - "GET Request: ({}): ({})", - res.status(), - url.as_ref() - ))? + err!("GET Request: ({}): ({})", res.status(), url.as_ref()); } // Prepare the destination directories @@ -211,7 +208,7 @@ impl Downloader { // Retrieve chunk. let mut chunk = match item { Ok(chunk) => chunk, - Err(e) => Err(format!("Error while downloading: {:?}", e))?, + Err(e) => err!("Error while downloading: {e:?}"), }; let chunk_size = chunk.len() as u64; diff --git a/src-tauri/src/errors.rs b/src-tauri/src/errors.rs index a657e7fe..0e37cd0e 100644 --- a/src-tauri/src/errors.rs +++ b/src-tauri/src/errors.rs @@ -40,3 +40,34 @@ impl serde::Serialize for Error { serializer.serialize_str(self.to_string().as_ref()) } } + +#[macro_export] +macro_rules! err { + ($($t:tt)*) => {{ + Err(format!($($t)*))? + }}; +} + +#[macro_export] +macro_rules! logerr { + ($ar:expr $(,)+ $($t:tt)+) => {{ + if let Err(e) = $ar { + log::error!("{e:?}"); + log::error!($($t)*); + } + }}; + ($ar:expr) => {{ + if let Err(e) = $ar { + log::error!("{e:?}"); + } + }}; +} + +#[macro_export] +macro_rules! logsome { + ($ar:expr, $($t:tt)*) => {{ + if ($ar).is_none() { + log::error!($($t)*); + } + }}; +} diff --git a/src-tauri/src/main.rs b/src-tauri/src/main.rs index 77864a7f..99250cbd 100644 --- a/src-tauri/src/main.rs +++ b/src-tauri/src/main.rs @@ -7,14 +7,19 @@ mod errors; mod swarm; mod utils; +use crate::controller_binaries::stop_all_services; + +use std::{collections::HashMap, env, ops::Deref, str, sync::Arc}; + use sentry_tauri::sentry; use serde::{Deserialize, Serialize}; -use std::{collections::HashMap, env, str}; use tauri::{ AboutMetadata, CustomMenuItem, Manager, Menu, MenuItem, RunEvent, Submenu, SystemTray, SystemTrayEvent, SystemTrayMenu, SystemTrayMenuItem, WindowEvent, }; -use tokio::{process::Child, sync::Mutex}; +use tauri_plugin_store::StoreBuilder; +use tokio::process::Child; +use tokio::sync::Mutex; #[derive(Debug, Default)] pub struct SharedState { @@ -24,6 +29,32 @@ pub struct SharedState { services: Mutex>, } +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct Registry { + url: String, +} + +impl Default for Registry { + fn default() -> Self { + // Determine the URL based on whether the app is in debug or release mode + let url = if cfg!(debug_assertions) { + // Debug mode URL + "https://raw.githubusercontent.com/premAI-io/prem-registry/dev/manifests.json" + } else { + // Release mode URL + "https://raw.githubusercontent.com/premAI-io/prem-registry/v1/manifests.json" + }; + Registry { + url: url.to_string(), + } + } +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct Store { + registries: Vec, +} + #[derive(Debug, Deserialize, Serialize, Clone)] pub struct Service { // Static state from registry manifest @@ -50,6 +81,9 @@ pub struct Service { running_port: Option, #[serde(rename = "serviceType")] service_type: Option, + petals: Option, + #[serde(rename = "skipHealthCheck")] + skip_health_check: Option, version: Option, #[serde(rename = "weightsDirectoryUrl")] weights_directory_url: Option, @@ -73,6 +107,22 @@ pub struct Service { supported: Option, } +impl Service { + fn get_id(&self) -> errors::Result { + use errors::Context; + self.id + .clone() + .with_context(|| format!("Service doesn't contain a valid id\n{:#?}", self)) + } + // ref to String is used as it's more generally coerce-able + fn get_id_ref(&self) -> errors::Result<&String> { + use errors::Context; + self.id + .as_ref() + .with_context(|| format!("Service doesn't contain a valid id\n{:#?}", self)) + } +} + #[derive(Debug, Deserialize, Serialize, Clone)] struct ModelInfo { #[serde(rename = "inferenceTime")] @@ -91,12 +141,11 @@ struct ModelInfo { } fn main() { - // Sentry let client = sentry::init(( "https://b98405fd0e4cc275b505645d293d23a5@o4506111848808448.ingest.sentry.io/4506111925223424", sentry::ClientOptions { release: sentry::release_name!(), - debug: true, + debug: false, // this outputs dsn to stdout on sentry init ..Default::default() }, )); @@ -105,8 +154,11 @@ fn main() { let _guard = sentry_tauri::minidump::init(&client); // Everything after here runs in only the app process + // TODO: consider directly pushing logs to sentry (sentry-sdk provides + // log integration) for release builds + // initialize logger - pretty_env_logger::formatted_timed_builder() + pretty_env_logger::formatted_builder() .format(|buf, record| { use std::io::Write; writeln!( @@ -164,11 +216,12 @@ fn main() { let system_tray = SystemTray::new().with_menu(tray_menu); - let state = SharedState::default(); - #[allow(unused_mut)] - let mut app = tauri::Builder::default() + let state = std::sync::Arc::new(SharedState::default()); + + let app = tauri::Builder::default() .plugin(sentry_tauri::plugin()) - .manage(state) + .plugin(tauri_plugin_store::Builder::default().build()) + .manage(state.clone()) .invoke_handler(tauri::generate_handler![ controller_binaries::start_service, controller_binaries::stop_service, @@ -182,17 +235,25 @@ fn main() { controller_binaries::get_service_stats, controller_binaries::get_gpu_stats, controller_binaries::add_service, + controller_binaries::add_registry, + controller_binaries::delete_registry, + controller_binaries::fetch_registries, + controller_binaries::reset_default_registry, swarm::is_swarm_supported, swarm::get_username, swarm::get_petals_models, - swarm::run_swarm_mode, + swarm::create_environment, + swarm::delete_environment, + swarm::run_swarm, swarm::stop_swarm_mode, swarm::is_swarm_mode_running ]) .menu(menu) .on_menu_event(|event| match event.menu_item_id() { "quit" => { - controller_binaries::stop_all_services(event.window().state()); + controller_binaries::stop_all_services( + event.window().state::>().deref().clone(), + ); event.window().close().unwrap(); } "close" => { @@ -204,17 +265,25 @@ fn main() { .on_system_tray_event(|app, event| match event { SystemTrayEvent::MenuItemClick { id, .. } => match id.as_str() { "hide" => { - let window = app.get_window("main").unwrap(); - window.hide().unwrap(); + let Some(window) = app.get_window("main") else { + log::error!("Couldn't get window from for label 'main'"); + return; + }; + logerr!(window.hide()); } "quit" => { - controller_binaries::stop_all_services(app.state()); + controller_binaries::stop_all_services( + app.state::>().deref().clone(), + ); app.exit(0); } "show" => { - let window = app.get_window("main").unwrap(); - window.set_focus().unwrap(); - window.show().unwrap(); + let Some(window) = app.get_window("main") else { + log::error!("Couldn't get window from for label 'main'"); + return; + }; + logerr!(window.set_focus()); + logerr!(window.show()); } _ => {} }, @@ -222,24 +291,80 @@ fn main() { }) .setup(|app| { tauri::async_runtime::block_on(async move { - utils::fetch_services_manifests( - "https://raw.githubusercontent.com/premAI-io/prem-registry/dev/manifests.json", - &app.state::(), - ) - .await - .expect("Failed to fetch and save services manifests"); + //Create a store with default registry if doesn't exist + let store_path = app + .path_resolver() + .app_data_dir() + .expect("failed to resolve app data dir") + .join("store.json"); + if !store_path.exists() { + let mut registries: Vec = Vec::new(); + registries.push(Registry::default()); + let mut default_store = HashMap::new(); + default_store.insert( + "registries".to_string(), + serde_json::to_value(registries).unwrap(), + ); + let store = StoreBuilder::new(app.handle(), store_path.clone()) + .defaults(default_store) + .build(); + store.save().expect("failed to save store"); + log::info!("Store created"); + } + // Fetch all registries + let mut store = StoreBuilder::new(app.handle(), store_path.clone()).build(); + store.load().expect("Failed to load store"); + if let Some(registries) = store.get("registries").cloned() { + match serde_json::from_value::>(registries) { + Ok(registries) => utils::fetch_all_services_manifests( + ®istries, + &app.state::>().clone(), + ) + .await + .expect("failed to fetch services"), + Err(e) => println!("Error unwrapping registries: {:?}", e), + } + } }); Ok(()) }) + .on_window_event(|ev| { + if matches!(ev.event(), WindowEvent::Destroyed) { + stop_all_services(ev.window().state::>().deref().clone()); + } + }) .build(tauri::generate_context!()) - .expect("error while running tauri application"); + .expect("Error while building tauri application"); + + { + let app_handle = app.handle(); + let s = state.clone(); + std::panic::set_hook(Box::new(move |_| { + stop_all_services(s.clone()); + app_handle.exit(-1); + })); + + let app_handle = app.handle(); + let s = state.clone(); + ctrlc::set_handler(move || { + stop_all_services(s.clone()); + app_handle.exit(-1); + }) + .expect("Error setting Ctrl-C handler"); + } app.run(|app_handle, e| match e { // Triggered when a window is trying to close RunEvent::WindowEvent { label, event, .. } => { match event { WindowEvent::CloseRequested { api, .. } => { - app_handle.get_window(&label).unwrap().hide().unwrap(); + logsome!( + app_handle.get_window(&label).map(|e| logerr!( + e.hide(), + "Failed to hide window with label({label:?})" + )), + "Failed to get app window with label({label:?})" + ); // use the exposed close api, and prevent the event loop to close api.prevent_close(); } @@ -247,5 +372,5 @@ fn main() { } } _ => {} - }) + }); } diff --git a/src-tauri/src/swarm.rs b/src-tauri/src/swarm.rs index c7d6340e..9cb1b184 100644 --- a/src-tauri/src/swarm.rs +++ b/src-tauri/src/swarm.rs @@ -1,6 +1,13 @@ use reqwest::get; use serde::Deserialize; -use std::{collections::HashMap, env, str}; +use std::{ + collections::HashMap, + env, + path::PathBuf, + str, + thread::{self, JoinHandle}, + time::Duration, +}; use tauri::api::process::Command; #[derive(Deserialize)] @@ -18,6 +25,36 @@ pub fn is_swarm_supported() -> bool { } } +pub struct Config { + pub app_data_dir: String, + pub python: String, +} + +impl Config { + pub fn new() -> Self { + let mut app_data_dir = + tauri::api::path::home_dir().expect("🙈 Failed to get app data directory"); + app_data_dir.push(".config/prem"); + let app_data_dir = app_data_dir + .to_str() + .expect("🙈 Failed to convert app data dir path to str") + .to_string(); + + let python = PathBuf::from(format!( + "{}/miniconda/envs/prem_app/bin/python", + app_data_dir + )) + .to_str() + .unwrap_or("python") + .to_string(); + + Config { + app_data_dir, + python, + } + } +} + #[tauri::command] pub fn get_username() -> String { let output = Command::new("whoami").output(); @@ -58,76 +95,81 @@ pub async fn get_petals_models() -> Result, String> { #[tauri::command] pub fn is_swarm_mode_running() -> bool { - let output_value = get_swarm_processes(); + let processes = get_swarm_processes(); - if !output_value.is_empty() { - println!( - "🏃‍♀️ Processeses running: {}", - output_value.replace("\n", " ") - ); + if processes.len() > 0 { + println!("🏃‍♀️ Processeses running: {:?}", processes); return true; } return false; } -pub fn create_environment(handle: tauri::AppHandle) -> String { - // Get the application data directory - let app_data_dir = tauri::api::path::home_dir().expect("🙈 Failed to get app data directory"); - let app_data_dir = app_data_dir.join(".config/prem"); - let app_data_dir = app_data_dir - .to_str() - .expect("🙈 Failed to convert app data dir path to str"); - - // Get create env path - let binding = handle - .path_resolver() - .resolve_resource("petals") - .expect("🙈 Failed to find `create_env.sh`"); - let petals_path = binding - .to_str() - .expect("🙈 Failed to convert petals path to str"); - - let python = format!("{app_data_dir}/miniconda/envs/prem_app/bin/python"); +#[tauri::command(async)] +pub fn create_environment(handle: tauri::AppHandle) { + println!("🐍 Creating the environment..."); + let petals_path = get_petals_path(handle); + let config = Config::new(); - // Set env variables let mut env = HashMap::new(); - env.insert("PREM_APPDIR".to_string(), app_data_dir.to_string()); - env.insert("PREM_PYTHON".to_string(), python.clone()); + env.insert("PREM_APPDIR".to_string(), config.app_data_dir); + env.insert("PREM_PYTHON".to_string(), config.python); - // Run the bash script let _ = Command::new("sh") .args([format!("{petals_path}/create_env.sh")]) .envs(env) .output() .expect("🙈 Failed to create env"); - python } #[tauri::command(async)] -pub fn run_swarm_mode( - handle: tauri::AppHandle, - num_blocks: i32, - model: String, - public_name: String, -) { - let python: String = create_environment(handle); - println!("🚀 Starting the Swarm..."); +pub fn delete_environment(handle: tauri::AppHandle) { + println!("❌ Deleting the environment..."); + let petals_path = get_petals_path(handle); + let config = Config::new(); + + let mut env = HashMap::new(); + env.insert("PREM_APPDIR".to_string(), config.app_data_dir); - let _ = Command::new(&python) - .args(&[ - "-m", - "petals.cli.run_server", - "--num_blocks", + let _ = Command::new("sh") + .args([format!("{petals_path}/delete_env.sh")]) + .envs(env.clone()) + .output() + .expect("🙈 Failed to delete env"); +} + +#[tauri::command(async)] +pub fn run_swarm(handle: tauri::AppHandle, num_blocks: i32, model: String, public_name: String) { + let petals_path = get_petals_path(handle.clone()); + let config = Config::new(); + + let mut env = HashMap::new(); + env.insert("PREM_PYTHON".to_string(), config.python); + + println!("🚀 Starting the Swarm..."); + let _ = Command::new("sh") + .args([ + format!("{petals_path}/run_swarm.sh").as_str(), &num_blocks.to_string(), - "--public_name", &public_name, &model, ]) + .envs(env) .spawn() .expect("🙈 Failed to run swarm"); } -pub fn get_swarm_processes() -> String { +fn get_petals_path(handle: tauri::AppHandle) -> String { + let binding = handle + .path_resolver() + .resolve_resource("petals") + .expect("🙈 Failed to find `create_env.sh`"); + let petals_path = binding + .to_str() + .expect("🙈 Failed to convert petals path to str"); + petals_path.to_string() +} + +pub fn get_swarm_processes() -> Vec { // Check if create_env.sh is running let output = Command::new("/usr/bin/pgrep") .args(&["-f", "create_env.sh|(mamba|conda).*create.*prem_app"]) @@ -141,12 +183,24 @@ pub fn get_swarm_processes() -> String { // If create_env.sh is running, return an empty string if !output_value.is_empty() { - return "".to_string(); + return vec![]; } + let config = Config::new(); + let python_path = PathBuf::from(config.python); + let prem_app_env = python_path + .parent() + .unwrap() + .parent() + .unwrap() + .to_str() + .unwrap(); + + let regex = format!("https://github.com/bigscience-workshop/petals|{prem_app_env}.*(petals.cli.run_server|multiprocessing.resource_tracker|from multiprocessing.spawn)"); + // If create_env.sh is not running, get the processes from petals let output = Command::new("/usr/bin/pgrep") - .args(&["-f", "https://github.com/bigscience-workshop/petals|petals.cli.run_server|multiprocessing.resource_tracker|from multiprocessing.spawn"]) + .args(&["-f", ®ex]) .output() .map_err(|e| { println!("🙈 Failed to execute command: {}", e); @@ -154,25 +208,53 @@ pub fn get_swarm_processes() -> String { }); let output_value = output.unwrap().stdout; - output_value + let processes: Vec = output_value + .replace("\n", " ") + .split(" ") + .collect::>() + .into_iter() + .filter_map(|s| s.parse::().ok()) + .collect(); + processes } #[tauri::command] pub fn stop_swarm_mode() { println!("🛑 Stopping the Swarm..."); - let processes = get_swarm_processes().replace("\n", " "); - println!("🛑 Stopping Processes: {}", processes); - let processes = processes.split(" ").collect::>(); + let processes = get_swarm_processes(); + println!("🛑 Stopping Processes: {:?}", processes); for process in processes { - println!("🛑 Stopping Process: {}", process); let _ = Command::new("kill") - .args(&[process.to_string()]) - .output() - .map_err(|e| { - println!("🙈 Failed to execute command: {}", e); - e - }); + .args(["-s", "SIGTERM", &process.to_string()]) + .spawn() + .expect("🙈 Failed to execute kill command with SIGTERM"); + + let handle: JoinHandle<_> = thread::spawn(move || { + thread::sleep(Duration::from_millis(50)); + match Command::new("ps") + .args(["-p", &process.to_string()]) + .output() + { + Ok(output) => match output.status.code() { + Some(0) => true, + _ => false, + }, + Err(e) => { + eprintln!("Error executing ps command: {}", e); + false + } + } + }); + if handle.join().unwrap() { + let _ = Command::new("kill") + .args(["-s", "SIGKILL", &process.to_string()]) + .output() + .expect("🙈 Failed to execute kill command with SIGKILL"); + println!("🛑 Stopping Process with SIGKILL: {}", process); + } else { + println!("🛑 Stopping Process with SIGTERM: {}", process); + } } println!("🛑 Stopped all the Swarm Processes."); } diff --git a/src-tauri/src/utils.rs b/src-tauri/src/utils.rs index 55fee848..2a10a545 100644 --- a/src-tauri/src/utils.rs +++ b/src-tauri/src/utils.rs @@ -1,51 +1,156 @@ -use crate::errors::Result; -use crate::{Service, SharedState}; +use crate::errors::{Context, Result}; +use crate::{err, Registry, Service, SharedState}; +use futures::future; use reqwest::get; use std::collections::HashMap; -use tauri::State; +use std::sync::Arc; -pub async fn fetch_services_manifests(url: &str, state: &State<'_, SharedState>) -> Result<()> { - let response = get(url).await.expect("Failed to fetch registry"); - let services = response.json::>().await.unwrap(); +pub async fn fetch_all_services_manifests( + registries: &[Registry], + state: &Arc, +) -> Result<()> { + let mut handlers = vec![]; + for registry in registries { + let handler = async move { + if let Err(err) = fetch_services_manifests(registry.url.as_str(), state).await { + println!( + "Failed to fetch {} and save services manifests: {}", + registry.url, err + ); + } + }; + handlers.push(handler); + } + future::join_all(handlers).await; + Ok(()) +} + +async fn fetch_services_manifests(url: &str, state: &Arc) -> Result<()> { + let response = get(url) + .await + .with_context(|| format!("Couldn't fetch the manifest from {url:?}"))?; + let services = response + .json::>() + .await + .with_context(|| "Failed to parse response to list of services")?; let mut services_guard = state.services.lock().await; - let service_ids = services_guard.keys().cloned().collect::>(); for service in services { - if !service_ids.contains(&service.id.clone().unwrap()) { - services_guard.insert(service.id.clone().unwrap(), service); + if !service + .get_id_ref() + .map(|id| services_guard.contains_key(id)) + .unwrap_or_default() + { + services_guard.insert(service.get_id()?, service); } } Ok(()) } -pub fn is_aarch64() -> bool { +pub const fn is_aarch64() -> bool { cfg!(target_arch = "aarch64") } -pub fn is_x86_64() -> bool { +pub const fn is_x86_64() -> bool { cfg!(target_arch = "x86_64") } +pub const fn is_macos() -> bool { + cfg!(target_os = "macos") +} + +pub const fn is_unix() -> bool { + cfg!(target_family = "unix") +} + pub fn get_binary_url(binaries_url: &HashMap>) -> Result { - let mut binary_url = "".to_string(); - if is_aarch64() { - binary_url = binaries_url - .get("aarch64-apple-darwin") - .unwrap() - .clone() - .unwrap_or_else(|| "Unsupported architecture".to_string()) - } else if is_x86_64() { - binary_url = binaries_url - .get("x86_64-apple-darwin") - .unwrap() - .clone() - .unwrap_or_else(|| "Unsupported architecture".to_string()) - } - if binary_url == "Unsupported architecture" { - binary_url = binaries_url - .get("universal-apple-darwin") - .unwrap() - .clone() - .unwrap_or_else(|| Err("Unsupported architecture").unwrap()) - } - Ok(binary_url) + if !is_unix() && !is_macos() { + err!("Unsupported OS") + } else if is_macos() { + if is_x86_64() { + binaries_url.get("x86_64-apple-darwin").cloned().flatten() + } else if is_aarch64() { + binaries_url.get("aarch64-apple-darwin").cloned().flatten() + } else { + err!("Unsupported architecture") + } + .or_else(|| { + binaries_url + .get("universal-apple-darwin") + .cloned() + .flatten() + }) + .with_context(|| "Service not supported on your platform") + } else { + err!("Unsupported platform") + } +} + +#[cfg(all(test, target_os = "macos"))] +mod tests { + use super::get_binary_url; + use std::collections::HashMap; + + #[test] + fn binary_url_test_universal_only() { + let url = get_binary_url(&HashMap::from_iter([( + "universal-apple-darwin".to_string(), + Some("randome_url.com".to_string()), + )])) + .unwrap(); + assert_eq!(url, "randome_url.com"); + } + + #[test] + fn binary_url_test_universal_only_option() { + let url = get_binary_url(&HashMap::from_iter([ + ( + "universal-apple-darwin".to_string(), + Some("randome_url.com".to_string()), + ), + ("aarch64-apple-darwin".to_string(), None), + ("x86_64-apple-darwin".to_string(), None), + ])) + .unwrap(); + assert_eq!(url, "randome_url.com"); + } + + #[test] + fn binary_url_test_with_non_universal() { + let url = get_binary_url(&HashMap::from_iter([ + ( + "universal-apple-darwin".to_string(), + Some("randome_url.com//universal".to_string()), + ), + ( + "aarch64-apple-darwin".to_string(), + Some("randome_url.com//not_universal".to_string()), + ), + ( + "x86_64-apple-darwin".to_string(), + Some("randome_url.com//not_universal".to_string()), + ), + ])) + .unwrap(); + assert_eq!(url, "randome_url.com//not_universal"); + } + + #[test] + #[should_panic] + fn binary_url_test_empty() { + let url = get_binary_url(&HashMap::from_iter([])).unwrap(); + assert_eq!(url, "randome_url.com"); + } + + #[test] + fn binary_url_all_none() { + let err = get_binary_url(&HashMap::from_iter([ + ("universal-apple-darwin".to_string(), None), + ("aarch64-apple-darwin".to_string(), None), + ("x86_64-apple-darwin".to_string(), None), + ])); + assert_eq!( + err.err().map(|err| err.to_string()).unwrap_or_default(), + "Service not supported on your platform" + ); + } } diff --git a/src-tauri/tauri.conf.json b/src-tauri/tauri.conf.json index ae5a8a57..af3b2080 100644 --- a/src-tauri/tauri.conf.json +++ b/src-tauri/tauri.conf.json @@ -8,7 +8,7 @@ }, "package": { "productName": "Prem", - "version": "0.1.2" + "version": "0.2.0" }, "tauri": { "allowlist": { @@ -27,9 +27,6 @@ }, "bundle": { "active": true, - "resources": [ - "petals/*" - ], "icon": [ "icons/32x32.png", "icons/128x128.png", @@ -38,6 +35,9 @@ "icons/icon.ico" ], "identifier": "io.premai.prem-app", + "resources": [ + "petals/*" + ], "targets": "all" }, "security": { diff --git a/src/App.tsx b/src/App.tsx index 93200849..945850dd 100644 --- a/src/App.tsx +++ b/src/App.tsx @@ -23,6 +23,7 @@ function App() { useSettingStore.getState().setIsIP(hostIsIP); useSettingStore.getState().setBackendUrl(getBackendUrl()); useSettingStore.getState().removeAllServiceAsDownloading(); + useSettingStore.getState().resetSwarm(); // Download progress if (isDesktopEnv()) { diff --git a/src/AppRouter.tsx b/src/AppRouter.tsx index c7da5ad5..45260e3b 100644 --- a/src/AppRouter.tsx +++ b/src/AppRouter.tsx @@ -18,7 +18,7 @@ const AppRouter = () => { } /> - } /> + } /> } /> } /> } /> diff --git a/src/assets/css/index.css b/src/assets/css/index.css index e0d70f1c..0159150f 100644 --- a/src/assets/css/index.css +++ b/src/assets/css/index.css @@ -2,6 +2,7 @@ @import url("./documentation.css"); @import url("./fonts.css"); @import url("./forms.css"); +@import url("./left-sidebar.css"); @import url("./markdown.css"); @import url("./misc.css"); @import url("./modal.css"); diff --git a/src/assets/css/left-sidebar.css b/src/assets/css/left-sidebar.css new file mode 100644 index 00000000..6bb3e558 --- /dev/null +++ b/src/assets/css/left-sidebar.css @@ -0,0 +1,28 @@ +/* Playground sidebar */ +.left-sidebar { + @apply max-md:fixed max-md:z-[11] max-md:w-full md:w-[259px] max-md:p-5 max-md:top-0 max-md:bottom-0 max-md:overflow-hidden flex flex-col md:relative md:pt-7 bg-grey-800 pt-[22px] pb-[10px] px-2; +} + +.left-sidebar li.active { + @apply max-md:bg-grey-900; +} + +.left-sidebar__logo { + @apply max-md:max-w-[100px] max-md:ml-3; +} + +.left-sidebar__search { + @apply max-md:h-[38px]; +} + +.left-sidebar li { + @apply max-md:px-2; +} + +.left-sidebar li a { + @apply max-md:text-[12px]; +} + +.left-sidebar li a span { + @apply whitespace-nowrap w-[180px] text-ellipsis overflow-hidden; +} diff --git a/src/assets/css/prem-chat.css b/src/assets/css/prem-chat.css index 9c0969b3..107e82a1 100644 --- a/src/assets/css/prem-chat.css +++ b/src/assets/css/prem-chat.css @@ -7,26 +7,6 @@ @apply bg-grey-800 flex-col px-4 flex h-screen w-[260px] max-md:absolute max-md:z-10 right-0 max-md:top-0; } -.prem-chat-sidebar .sidebar { - @apply max-md:fixed max-md:z-[11] max-md:w-full max-md:p-5 max-md:top-0 max-md:bottom-0 max-md:overflow-hidden max-md:bg-grey-800; -} - -.prem-chat-sidebar .sidebar li.active { - @apply max-md:bg-grey-900; -} - -.prem-chat-sidebar .sidebar__logo { - @apply max-md:max-w-[100px] max-md:ml-3; -} - -.prem-chat-sidebar .sidebar__search { - @apply max-md:h-[38px]; -} - -.prem-chat-sidebar .sidebar li a { - @apply max-md:text-[12px]; -} - .prem-chat { @apply mt-8 border-2 border-grey-400 rounded-md; } @@ -35,25 +15,29 @@ @apply bg-transparent cursor-pointer border border-grey-400 text-grey-200 w-full z-10 px-4 py-[6px] appearance-none; } -.prem-chat-input input { - @apply bg-grey-700 text-white md:text-sm text-[13px] outline-none rounded-[29px] h-[57px] mt-[14px] py-2 md:pl-[33px] md:pr-16 pr-12 pl-4 w-full; +.autosize-textarea { + @apply bg-grey-700 text-white text-sm outline-none rounded-[29px] h-14 py-4 md:pl-[33px] md:pr-16 pr-12 pl-6 w-full min-h-14 max-h-40; } -.prem-chat-input input:-webkit-autofill, -.prem-chat-input input:-webkit-autofill:hover, -.prem-chat-input input:-webkit-autofill:active, -.prem-chat-input input:-webkit-autofill:focus { +.autosize-textarea:-webkit-autofill, +.autosize-textarea:-webkit-autofill:hover, +.autosize-textarea:-webkit-autofill:active, +.autosize-textarea:-webkit-autofill:focus { -webkit-text-fill-color: #302f32; -webkit-box-shadow: 0 0 0 1000px #302f32 inset; transition: background-color 5000s ease-in-out 0s; } -.prem-chat-input img { +.autosize-textarea-container { + @apply relative mt-4; +} + +.autosize-textarea-container img { @apply bg-primary-light p-2 w-[35px] h-[35px] rounded-full stroke-white; } -.prem-chat-input button { - @apply absolute md:right-[21px] right-[12px] top-[25px]; +.autosize-textarea-container button { + @apply absolute md:right-[21px] right-[12px] bottom-4; } .prem-chat-bottom { @@ -96,7 +80,7 @@ @apply !text-white; } -.bot-reply > p { +.bot-reply .markdown-body { @apply p-4 w-full text-white; max-width: 660px; border-radius: 0 10px 10px 10px; @@ -113,19 +97,11 @@ @apply flex items-end mb-[40px]; } -.user-reply p { - @apply bg-grey-100 border py-4 px-[22px] ml-auto max-w-[660px] text-base text-grey-700; +.user-reply pre { + @apply bg-grey-100 border py-4 px-[22px] ml-auto max-w-[660px] font-sans text-sm text-grey-700 whitespace-pre-wrap; box-shadow: 10.4294px 10.4294px 26.0736px rgba(112, 124, 151, 0.05), 15.6442px 15.6442px 36.5031px rgba(112, 124, 151, 0.05); border-radius: 10px 10px 0 10px; border-color: rgba(112, 124, 151, 0.25); } - -.prem-chat-sidebar .sidebar li a span { - @apply whitespace-nowrap w-[180px] text-ellipsis overflow-hidden; -} - -.prem-chat-sidebar .sidebar li { - @apply max-md:px-2; -} diff --git a/src/assets/images/network.png b/src/assets/images/network.png new file mode 100644 index 00000000..f3310886 Binary files /dev/null and b/src/assets/images/network.png differ diff --git a/src/controller/abstractServiceController.ts b/src/controller/abstractServiceController.ts index 98b93e04..8c09d9f0 100644 --- a/src/controller/abstractServiceController.ts +++ b/src/controller/abstractServiceController.ts @@ -1,4 +1,5 @@ import type { Service } from "../modules/service/types"; +import type { Registry } from "../modules/settings/types"; import type { Interface } from "../shared/helpers/interfaces"; import type { DownloadArgs } from "./serviceController"; @@ -23,6 +24,10 @@ abstract class AbstractServiceController { abstract getGPUStats(): Promise>; abstract getInterfaces(): Promise; abstract addService(service: Service): Promise; + abstract addRegistry(registry: Registry): Promise; + abstract deleteRegistry(registry: Registry): Promise; + abstract fetchRegistries(): Promise; + abstract resetDefaultRegistry(): Promise; } export default AbstractServiceController; diff --git a/src/controller/binariesController.ts b/src/controller/binariesController.ts index bccffa7a..46dfd8de 100644 --- a/src/controller/binariesController.ts +++ b/src/controller/binariesController.ts @@ -1,4 +1,5 @@ import { invoke } from "@tauri-apps/api/tauri"; +import type { Registry } from "modules/settings/types"; import type { Service, ServiceBinary } from "../modules/service/types"; import type { Interface } from "../shared/helpers/interfaces"; @@ -13,10 +14,7 @@ class BinariesController extends AbstractServiceController { async restart(serviceId: string): Promise { await this.stop(serviceId); - const services: string[] = await invoke("get_running_services"); - while (!services.includes(serviceId)) { - await this.start(serviceId); - } + await this.start(serviceId); } async stop(serviceId: string): Promise { @@ -80,6 +78,22 @@ class BinariesController extends AbstractServiceController { async addService(service: Service): Promise { await invoke("add_service", { service }); } + + async addRegistry(registry: Registry): Promise { + await invoke("add_registry", { registry }); + } + + async deleteRegistry(registry: Registry): Promise { + await invoke("delete_registry", { registry }); + } + + async fetchRegistries(): Promise { + return await invoke("fetch_registries"); + } + + async resetDefaultRegistry(): Promise { + return await invoke("reset_default_registry"); + } } export default BinariesController; diff --git a/src/controller/dockerController.ts b/src/controller/dockerController.ts index e1b82fd1..8feb7786 100644 --- a/src/controller/dockerController.ts +++ b/src/controller/dockerController.ts @@ -1,3 +1,5 @@ +import type { Registry } from "modules/settings/types"; + import downloadServiceStream from "../modules/service/api/downloadServiceStream"; import type { Service, ServiceDocker } from "../modules/service/types"; import api from "../shared/api/v1"; @@ -31,6 +33,7 @@ class DockerController extends AbstractServiceController { afterSuccess?: () => void; }): Promise { console.log(`Downloading service ${serviceId}`); + useSettingStore.getState().setServiceDownloadProgress(serviceId, "docker", 0); await downloadServiceStream( serviceId, (error) => { @@ -62,7 +65,7 @@ class DockerController extends AbstractServiceController { } async getServices(): Promise { - const response = await api().get("v1/services"); + const response = await api().get("v1/services/"); return response.data; } @@ -76,23 +79,40 @@ class DockerController extends AbstractServiceController { } async getSystemStats(): Promise> { - const response = await api().get("v1/stats-all"); + const response = await api().get("v1/stats-all/"); return response.data; } async getGPUStats(): Promise> { - const response = await api().get("v1/gpu-stats-all"); + const response = await api().get("v1/gpu-stats-all/"); return response.data; } async getInterfaces(): Promise { - const response = await api().get("v1/interfaces"); + const response = await api().get("v1/interfaces/"); return response.data; } async addService(service: Service): Promise { await api().post("v1/services/", service); } + + async addRegistry(registry: Registry): Promise { + await api().post(`v1/registries/`, registry); + } + + async deleteRegistry(registry: Registry): Promise { + await api().delete(`v1/registries/?url=${registry.url}`); + } + + async fetchRegistries(): Promise { + const response = await api().get(`v1/registries/`); + return response.data; + } + + async resetDefaultRegistry(): Promise { + throw new Error("Method not implemented."); + } } export default DockerController; diff --git a/src/controller/serviceController.ts b/src/controller/serviceController.ts index 6989cd92..aaa3a9d6 100644 --- a/src/controller/serviceController.ts +++ b/src/controller/serviceController.ts @@ -1,4 +1,5 @@ import type { Service } from "../modules/service/types"; +import type { Registry } from "../modules/settings/types"; import type { Interface } from "../shared/helpers/interfaces"; import useSettingStore from "../shared/store/setting"; @@ -33,6 +34,10 @@ interface IServiceController { getGPUStats(serviceType: Service["serviceType"]): Promise; getInterfaces(serviceType: Service["serviceType"]): Promise; addService(service: Service, serviceType: Service["serviceType"]): Promise; + addRegistry(registry: Registry, serviceType: Service["serviceType"]): Promise; + deleteRegistry(registry: Registry, serviceType: Service["serviceType"]): Promise; + fetchRegistries(serviceType: Service["serviceType"]): Promise; + resetDefaultRegistry(serviceType: Service["serviceType"]): Promise; } class ServiceController implements IServiceController { @@ -104,7 +109,6 @@ class ServiceController implements IServiceController { if (serviceType === "docker") { useSettingStore.getState().addServiceAsDownloading(serviceId); await this.dockerController.download({ serviceId, afterSuccess }); - useSettingStore.getState().removeServiceAsDownloading(serviceId); } else if (serviceType === "binary") { useSettingStore.getState().addServiceAsDownloading(serviceId); await this.binariesController.download({ @@ -208,6 +212,40 @@ class ServiceController implements IServiceController { await this.binariesController.addService(service); } } + + async addRegistry(registry: Registry, serviceType: Service["serviceType"]): Promise { + if (serviceType === "docker") { + await this.dockerController.addRegistry(registry); + } else if (serviceType === "binary") { + await this.binariesController.addRegistry(registry); + } + } + + async deleteRegistry(registry: Registry, serviceType: Service["serviceType"]): Promise { + if (serviceType === "docker") { + await this.dockerController.deleteRegistry(registry); + } else if (serviceType === "binary") { + await this.binariesController.deleteRegistry(registry); + } + } + + async fetchRegistries(serviceType: Service["serviceType"]): Promise { + if (serviceType === "docker") { + return await this.dockerController.fetchRegistries(); + } else if (serviceType === "binary") { + return await this.binariesController.fetchRegistries(); + } else { + return []; + } + } + + async resetDefaultRegistry(serviceType: Service["serviceType"]): Promise { + if (serviceType === "docker") { + return await this.dockerController.resetDefaultRegistry(); + } else if (serviceType === "binary") { + return await this.binariesController.resetDefaultRegistry(); + } + } } export default ServiceController; diff --git a/src/modules/prem-audio/components/PremAudio.tsx b/src/modules/prem-audio/components/PremAudio.tsx index 21a703a0..002fe486 100644 --- a/src/modules/prem-audio/components/PremAudio.tsx +++ b/src/modules/prem-audio/components/PremAudio.tsx @@ -23,8 +23,8 @@ const PremAudio = () => { ); }; diff --git a/src/modules/prem-audio/components/PremAudioContainer.tsx b/src/modules/prem-audio/components/PremAudioContainer.tsx index 8d9b11c3..c2bf8e04 100644 --- a/src/modules/prem-audio/components/PremAudioContainer.tsx +++ b/src/modules/prem-audio/components/PremAudioContainer.tsx @@ -1,6 +1,4 @@ -import clsx from "clsx"; import { useState } from "react"; -import { useMediaQuery } from "usehooks-ts"; import type { PremAudioContainerProps } from "../types"; @@ -16,19 +14,18 @@ const PremAudioContainer = ({ historyId, }: PremAudioContainerProps) => { const [rightSidebar, setRightSidebar] = useState(false); - const responsiveMatches = useMediaQuery("(max-width: 767px)"); const [hamburgerMenuOpen, setHamburgerMenu] = useState(true); return (
-
- -
+
diff --git a/src/modules/prem-audio/components/PremAudioLeftSidebar.tsx b/src/modules/prem-audio/components/PremAudioLeftSidebar.tsx index 76329260..5fd4f2d1 100644 --- a/src/modules/prem-audio/components/PremAudioLeftSidebar.tsx +++ b/src/modules/prem-audio/components/PremAudioLeftSidebar.tsx @@ -1,6 +1,6 @@ import clsx from "clsx"; import { orderBy } from "lodash"; -import { Link, useNavigate, useParams } from "react-router-dom"; +import { Link, useNavigate } from "react-router-dom"; import DeleteIcon from "shared/components/DeleteIcon"; import LeftSidebar from "shared/components/LeftSidebar"; import NoPrompts from "shared/components/NoPrompts"; @@ -9,9 +9,14 @@ import type { HamburgerMenuProps } from "shared/types"; import { useMediaQuery } from "usehooks-ts"; import { shallow } from "zustand/shallow"; -const PremAudioLeftSidebar = ({ setHamburgerMenu }: HamburgerMenuProps) => { +const PremAudioLeftSidebar = ({ + hamburgerMenuOpen, + setHamburgerMenu, + serviceId, + serviceType, + historyId, +}: HamburgerMenuProps) => { const navigate = useNavigate(); - const { serviceId, serviceType, historyId } = useParams(); const responsiveMatches = useMediaQuery("(max-width: 767px)"); const { history, deleteHistory } = usePremAudioStore( (state) => ({ @@ -29,7 +34,12 @@ const PremAudioLeftSidebar = ({ setHamburgerMenu }: HamburgerMenuProps) => { }; return ( - + {history.length === 0 && }
    diff --git a/src/modules/prem-chat/components/Header.tsx b/src/modules/prem-chat/components/Header.tsx index c466cc24..f39c994c 100644 --- a/src/modules/prem-chat/components/Header.tsx +++ b/src/modules/prem-chat/components/Header.tsx @@ -1,52 +1,63 @@ import hamburgerMenu from "assets/images/hamburger-menu.svg"; import setting from "assets/images/setting.svg"; import WarningModal from "modules/service/components/WarningModal"; -import { useState } from "react"; +import { type ForwardedRef, forwardRef, useState } from "react"; import WarningIcon from "shared/components/WarningIcon"; import type { HeaderProps } from "shared/types"; -const Header = ({ setRightSidebar, hamburgerMenuOpen, setHamburgerMenu, title }: HeaderProps) => { - const [open, setIsOpen] = useState(false); +import PetalsBanner from "../../../shared/components/PetalsBanner"; - const closeModal = () => { - setIsOpen(false); - }; +const Header = forwardRef( + ( + { setRightSidebar, hamburgerMenuOpen, setHamburgerMenu, title, isPetals = false }: HeaderProps, + ref: ForwardedRef, + ) => { + const [open, setIsOpen] = useState(false); - const openModal = () => { - setIsOpen(true); - }; + const closeModal = () => { + setIsOpen(false); + }; - return ( - <> -
    -
    - -

    {title}

    -
    -
    - - + const openModal = () => { + setIsOpen(true); + }; + + return ( + <> + {isPetals ? : null} +
    +
    + +

    {title}

    +
    +
    + + +
    -
    - } - /> - - ); -}; + } + /> + + ); + }, +); export default Header; diff --git a/src/modules/prem-chat/components/InputBox.tsx b/src/modules/prem-chat/components/InputBox.tsx deleted file mode 100644 index e6dc59bc..00000000 --- a/src/modules/prem-chat/components/InputBox.tsx +++ /dev/null @@ -1,29 +0,0 @@ -import Send from "assets/images/send.svg"; -import type { ForwardedRef } from "react"; -import { forwardRef } from "react"; - -import type { InputBoxProps } from "../types"; - -const InputBox = forwardRef((props: InputBoxProps, ref: ForwardedRef) => { - const { question, setQuestion, disabled, placeholder } = props; - return ( -
    - setQuestion(e.target.value)} - disabled={disabled} - placeholder={placeholder} - ref={ref} - autoFocus - /> - -
    - ); -}); - -export default InputBox; diff --git a/src/modules/prem-chat/components/PremChat.tsx b/src/modules/prem-chat/components/PremChat.tsx index 83ec04b9..9859f489 100644 --- a/src/modules/prem-chat/components/PremChat.tsx +++ b/src/modules/prem-chat/components/PremChat.tsx @@ -7,8 +7,8 @@ import type { Service } from "../../service/types"; import PremChatContainer from "./PremChatContainer"; function PremChat() { - const { chatId, serviceId, serviceType } = useParams<{ - chatId: string; + const { historyId, serviceId, serviceType } = useParams<{ + historyId: string; serviceId: string; serviceType: Service["serviceType"]; }>(); @@ -21,10 +21,11 @@ function PremChat() { return ( ); } diff --git a/src/modules/prem-chat/components/PremChatContainer.tsx b/src/modules/prem-chat/components/PremChatContainer.tsx index 797d5813..6cda7803 100644 --- a/src/modules/prem-chat/components/PremChatContainer.tsx +++ b/src/modules/prem-chat/components/PremChatContainer.tsx @@ -1,31 +1,35 @@ -import clsx from "clsx"; +import Send from "assets/images/send.svg"; import { useEffect, useRef, useState } from "react"; import BotReply from "shared/components/BotReply"; import UserReply from "shared/components/UserReply"; import usePremChatStream from "shared/hooks/usePremChatStream"; import { useMediaQuery, useWindowSize } from "usehooks-ts"; +import Spinner from "../../../shared/components/Spinner"; +import useAutosizeTextArea from "../../../shared/hooks/useAutosizeTextarea"; import type { Message, PremChatContainerProps } from "../types"; import Header from "./Header"; -import InputBox from "./InputBox"; import PremChatSidebar from "./PremChatSidebar"; import RegenerateButton from "./RegenerateButton"; import RightSidebar from "./RightSidebar"; const PremChatContainer = ({ - chatId, + historyId, serviceId, serviceType, serviceName, + isPetals, }: PremChatContainerProps) => { const model = serviceId; const [rightSidebar, setRightSidebar] = useState(false); const [hamburgerMenuOpen, setHamburgerMenu] = useState(true); const chatMessageListRef = useRef(null); - const inputRef = useRef(null); + const textAreaRef = useRef(null); const { height } = useWindowSize(); const responsiveMatches = useMediaQuery("(min-width: 768px)"); + const headerRef = useRef(null); + const [headerVisibleHeight, setHeaderVisibleHeight] = useState(0); const { chatMessages, @@ -38,7 +42,25 @@ const PremChatContainer = ({ resetPromptTemplate, resetChatServiceUrl, abort, - } = usePremChatStream(serviceId, serviceType, chatId || null); + } = usePremChatStream(serviceId, serviceType, historyId || null); + + const handleScroll = () => { + if (headerRef.current) { + const rect = (headerRef.current as any).getBoundingClientRect(); + const windowHeight = window.innerHeight; + const visibleHeight = Math.min(rect.bottom, windowHeight) - Math.max(rect.top, 0); + setHeaderVisibleHeight(visibleHeight); + } + }; + + useEffect(() => { + window.addEventListener("scroll", handleScroll); + return () => { + window.removeEventListener("scroll", handleScroll); + }; + }, []); + + useAutosizeTextArea(textAreaRef.current, question); useEffect(() => { if (chatMessageListRef.current) { @@ -47,8 +69,8 @@ const PremChatContainer = ({ }, [chatMessages]); useEffect(() => { - if (!isLoading && inputRef.current) { - inputRef.current.focus(); + if (!isLoading && textAreaRef.current) { + textAreaRef.current.focus(); } }, [isLoading]); @@ -57,32 +79,48 @@ const PremChatContainer = ({ return () => { abort(); }; - }, [abort, chatId]); + }, [abort, historyId]); + + const handleKeyDown = (e: React.KeyboardEvent) => { + if (e.key === "Enter") { + e.preventDefault(); + onSubmit(e); + } + }; return (
    -
    -
    - -
    +
    +
    {chatMessages.map((message: Message, index: number) => ( @@ -97,26 +135,37 @@ const PremChatContainer = ({
    -
    +
    {chatMessages.length > 0 && !isLoading && !isError && ( -
    - -
    + )} -
    - + +
    +