Skip to content

Commit

Permalink
Merge branch 'add_tensornet_tool_pip_package' into 'master'
Browse files Browse the repository at this point in the history
add merge embedding

See merge request deep-learning/tensornet!14
  • Loading branch information
jiangxinglei committed Jul 22, 2024
2 parents c036438 + e86bc40 commit ebbb751
Show file tree
Hide file tree
Showing 11 changed files with 374 additions and 25 deletions.
61 changes: 61 additions & 0 deletions .github/workflows/tensornet-tools.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
name: Build Tensornet Tools

on:
push:
tags:
- '*tool'
pull_request:

jobs:
tn_tools_build:
runs-on: ubuntu-latest
steps:
- name: checkout repository
uses: actions/checkout@v4

- uses: mamba-org/setup-micromamba@v1
with:
micromamba-version: '1.5.8-0'
environment-file: config/tn_build.yaml
init-shell: bash
cache-downloads: true
post-cleanup: 'none'
- name: Run custom command in micromamba environment
run: |
python setup-tn-tools.py bdist_wheel
twine check dist/*
shell: micromamba-shell {0}

- name: Store wheels
uses: actions/upload-artifact@v4
with:
path: dist/
retention-days: 7

publish-to-pypi:
name: Upload to pypi
if: github.repository == 'Qihoo360/tensornet' && startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes
needs:
- tn_tools_build
runs-on: ubuntu-latest
permissions:
id-token: write # mandatory for pypi trusted publishing
contents: write # mandatory for create a github release
steps:
- name: Download wheels
uses: actions/download-artifact@v4
with:
merge-multiple: true
path: dist/
- name: Publish wheels to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
# push v*.rc* tags to test.pypi.org
repository-url: ${{ contains(github.ref, '.rc') && 'https://test.pypi.org/legacy/' || 'https://upload.pypi.org/legacy/' }}
print-hash: true
- name: Create a draft release
uses: softprops/action-gh-release@v2
with:
draft: true
prerelease: ${{ contains(github.ref, '.rc') }}
generate_release_notes: true
18 changes: 13 additions & 5 deletions manager
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ bump_dev_version() {

bump_release_version() {
hash bumpversion >/dev/null || die "cannot find bumpversion command"
local mode=${1:-prod} release_part=''
local mode=${1:-prod} release_part='' tag_name_option='' tool_deploy=false
[[ $# -gt 1 && $2 == 'tool' ]] && tool_deploy=true
release_part=$(bumpversion --allow-dirty build --dry-run --list | awk -vFS=. '/^current_version=/ && $NF ~ "^[a-z]" { print $NF }')

case "$mode" in
Expand All @@ -160,11 +161,18 @@ bump_release_version() {
case "${release_part-}" in
(dev*)
local release_version=''
release_version=$(bumpversion --allow-dirty build --dry-run --list | grep '^current_version=' | cut -s -d - -f 1 | cut -s -d = -f 2)
bumpversion --commit --tag --new-version "$release_version" release # prod release, skip rc
release_version=$(bumpversion --allow-dirty release --dry-run --list | grep '^current_version=' | cut -s -d = -f 2 | sed -E 's/\.[^0-9]+([0-9]*)$//')
[[ $tool_deploy == true ]] && tag_name_option="--tag-name ${release_version}-tool"
bumpversion --commit --tag ${tag_name_option} --new-version "$release_version" release # prod release, skip rc
;;
(rc*)
[[ $tool_deploy == true ]] && tag_name_option="--tag-name $(bumpversion --allow-dirty release --dry-run --list | grep '^new_version=' | cut -s -d = -f 2)-tool"
bumpversion --commit --tag ${tag_name_option} release
;;
(''|post*)
[[ $tool_deploy == true ]] && tag_name_option="--tag-name $(bumpversion --allow-dirty build --dry-run --list | grep '^new_version=' | cut -s -d = -f 2)-tool"
bumpversion --commit --tag ${tag_name_option} build
;;
(rc*) bumpversion --commit --tag release ;;
(''|post*) bumpversion --commit --tag build ;;
(*) die "Unknown release part ($release_part) for bump_release_version" ;;
esac
;;
Expand Down
28 changes: 28 additions & 0 deletions setup-tn-tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os
from setuptools import setup, find_packages
from importlib.machinery import SourceFileLoader

# use importlib to avoid import so file
_version = SourceFileLoader('version', 'tensornet/version.py').load_module()
version = _version.VERSION


setup(
name='qihoo-tensornet-tools',
version=version,
description='tools for tensornet',
long_description='multi tools for tensornet. E.g. merge/resize sparse or dense table, include external embeddings',
long_description_content_type='text/markdown',
author='jiangxinglei',
author_email='[email protected]',
url='https://github.com/Qihoo360/tensornet',
packages=["tensornet-tools"],
package_data={'tensornet-tools': ['bin/*', 'config/*', 'python/*']},
python_requires='>=3.7',
classifiers=[
'Development Status :: 5 - Production/Stable',
'Intended Audience :: Developers',
'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python :: 3.7'
],
)
File renamed without changes.
94 changes: 94 additions & 0 deletions tensornet-tools/bin/tn_tools.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#!/usr/bin/env bash

WORKSPACE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PYTHON_DIR=${WORKSPACE_DIR}/../python
TMP_PACKAGE_DIR=$(mktemp -d)
TN_TOOL_ENV_NAME=tn_tool
TN_TOOL_TGZ=${TMP_PACKAGE_DIR}/${TN_TOOL_ENV_NAME}.tar.gz

: ${SPARK_HOME:=/opt/spark3}

_die() {
local err=$? err_fmt=
(( err )) && err_fmt=" (err=$err)" || err=1
printf >&2 "[ERROR]$err_fmt %s\n" "$*"
exit $err
}

_check_spark_env(){

if [[ ! -d ${SPARK_HOME} ]] || [[ ! -e ${SPARK_HOME}/bin/spark-submit ]];then
_die "no valid spark path, should export valid SPARK_HOME"
fi

spark_major_version=$($SPARK_HOME/bin/spark-submit --version 2>&1 | grep version | awk -F"version" '{print $2}' | head -1 | sed 's/ //g' | awk -F. '{print $1}')

if [[ -z ${spark_major_version} ]] || [[ $spark_major_version -lt 3 ]];then
_die "invalid spark version. should be >= 3"
fi

}

_prepare_mamba_env(){
if ! type micromamba >/dev/null 2>&1;then
HTTPS_PROXY=${PROXY_URL:=${HTTPS_PROXY}} "${SHELL}" <(curl -L micro.mamba.pm/install.sh)
fi
_mamba_source
[[ -z ${NEXUS3_HEADER} ]] || {
${MAMBA_EXE} config set --file "${MAMBA_ROOT_PREFIX}/.mambarc" channel_alias "${NEXUS3_HEADER}/conda"
}
micromamba create -y -f "${WORKSPACE_DIR}/config/tn_tool_env.yaml"
micromamba activate ${TN_TOOL_ENV_NAME}
TN_TOOL_ENV_DIR=$(micromamba env list | grep "${TN_TOOL_ENV_NAME}" | awk '{print $NF}')
conda-pack --prefix ${TN_TOOL_ENV_DIR} -o ${TN_TOOL_TGZ}
}

start_merge_sparse(){

${SPARK_HOME}/bin/spark-submit --executor-memory 8g --driver-memory 10g --py-files ${PYTHON_DIR}/utils.py ${PYTHON_DIR}/merge_sparse.py "$@"

}

start_resize_sparse(){
_prepare_mamba_env

${SPARK_HOME}/bin/spark-submit --conf spark.executor.memory=10g --conf spark.archives=file://${TN_TOOL_TGZ}#envs --conf spark.pyspark.driver.python=${TN_TOOL_ENV_DIR}/bin/python --conf spark.pyspark.python=envs/bin/python --py-files ${PYTHON_DIR}/utils.py ${PYTHON_DIR}/resize_sparse.py "$@"
}

start_resize_sparse(){
_prepare_mamba_env

${SPARK_HOME}/bin/spark-submit --conf spark.executor.memory=10g --conf spark.archives=file://${TN_TOOL_TGZ}#envs --conf spark.pyspark.driver.python=${TN_TOOL_ENV_DIR}/bin/python --conf spark.pyspark.python=envs/bin/python --py-files ${PYTHON_DIR}/utils.py ${PYTHON_DIR}/resize_dense.py "$@"
}

_check_spark_env

case "${1-}" in
(merge-sparse)
shift 1
start_merge_sparse "$@"
;;
(resize-sparse)
shift 1
start_resize_sparse "$@"
;;
(resize-dense)
shift 1
start_resize_dense "$@"
;;
(''|help)
cmd=$(basename -- "$0")
cat <<-END
Usage:
$cmd [help] - Print this help.
$cmd merge-sparse [-i/--input input_path] [-o/--output output_path] [-f/--format file_format] [-n/--number number] [-b/--bracket] - merge all tensornet generated sparse file into one hdfs directory.
$cmd resize-sparse [-i/--input input_path] [-o/--output output_path] [-f/--format file_format] [-n/--number number] - change current sparse parallelism to another size.
$cmd resize-dense [-i/--input input_path] [-o/--output output_path] [-n/--number number] - change current dense parallelism to another size.
END
;;
(*) die Unknown command "$1" ;;
esac

Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
name: tn_build
name: tn_tool
channels:
- conda-forge
dependencies:
- python=3.8
- nomkl
- openssl>=3
- hdfs3
- pyarrow==12.0.1
54 changes: 54 additions & 0 deletions tensornet-tools/python/merge_extra_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#coding=utf-8
import sys
import argparse
import os
from pyspark import SparkContext, SparkConf
from pyspark.sql import *
from pyspark.sql.functions import col, udf, lit
from pyspark.sql import functions as F
from pyspark.sql.types import *
from utils import *


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input", type=str, help="sparse table input path")
parser.add_argument("-o", "--output", type=str, help="merged file output path")
parser.add_argument("-f", "--format", type=str, help="input file format, 'txt' or 'bin'")
parser.add_argument("-e", "--extra", type=str, help="extra embedding file path")
args = parser.parse_args()
return args


def main(args):
spark = SparkSession.builder \
.appName("[spark][merge extra embedding]") \
.master('yarn') \
.enableHiveSupport() \
.getOrCreate()

sc = spark.sparkContext
output_bc_value = sc.broadcast(args.output)
format_bc_value = sc.broadcast(args.format)
path_info = SparseTablePathInfo(args.input)
source_rank_num = path_info.total_rank_num
handle_names = path_info.handles
sparse_table_parent = path_info.sparse_parent
handle_names_bc_value = sc.broadcast(handle_names)
number_bc_value = sc.broadcast(source_rank_num)
get_sign_partition_key_udf = udf(get_sign_partition_key, IntegerType())

dims_df = load_sparse_table_to_df(sc, args.input, args.format).withColumn('par_key', get_sign_partition_key_udf(col('sign'), lit(source_rank_num)))

extra_data_rdd = sc.textFile(args.extra).map(lambda x: (x.split(',')[0], x.split(',')[1].split(':')[0], get_sign_partition_key(x.split(',')[1].split(':')[0], source_rank_num), x.split(',')[1].split(':')[1], get_sign_partition_key(x.split(',')[1].split(':')[1], source_rank_num))).map(lambda x: ((x[0], x[2]), x))

distinct_key_list = extra_data_rdd.keys().distinct().collect()
repartition_num = len(distinct_key_list)

dims_df.unionAll(extra_data_rdd.map(lambda x: (distinct_key_list.index(x[0]), x[1])).partitionBy(repartition_num).mapPartitions(lambda p: get_weight_for_extra_embedding(p, source_rank_num, sparse_table_parent)).toDF(sparse_df_schema)).rdd.map(lambda row: (row[7], row)).partitionBy(source_rank_num * BLOCK_NUM)\
.foreachPartition(lambda p: resize_partition(p, output_bc_value, format_bc_value, number_bc_value, handle_names_bc_value))


if __name__ == '__main__':
args = parse_args()
main(args)
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit ebbb751

Please sign in to comment.