Skip to content

Commit

Permalink
September collaborate (#193)
Browse files Browse the repository at this point in the history
* add get service state

* update get service state

* update get service state

* update get service state

* update get service state

* update get service state

* update get service state

* fix mpc sdk

* fix psi

* fix psi

* fix psi

* fix psi

* fix fitTransform

* fix psi

* fix psi

* fix psi

* fix psi

* fix psi

* fix psi

* fix psi

* fix psi

* fix pir

* fix pir

* fix pir

* fix pir

* fix pir

* fix pir

* fix pir

* fix pir

* fix pir

* fix pir

* fix pir

* fix pir

* fix pir

* fix pir

* fix pir

* fix pir

* fix pir

* fix pir

* fix pir

* fix pir

* fix pir

* fix psi

* fix pir
  • Loading branch information
likehabits authored Sep 28, 2023
1 parent ddc448c commit fabdd98
Show file tree
Hide file tree
Showing 47 changed files with 779 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,21 @@
import java.util.Map;

public enum ModelTypeEnum {
V_XGBOOST(2,"taskModel-v_xgboost",0,"hetero_xgb.ftl","hetero_xgb_infer.ftl"),
TRANSVERSE_LR(3,"HFL_logistic_regression",1,"homo_lr.ftl","homo_lr_infer.ftl"),
MPC_LR(4,"taskModel-mpc_lr",1,null,null),
HETERO_LR(5,"VFL_logistic_regression",0,"hetero_lr.ftl","hetero_lr_infer.ftl"),
CLASSIFICATION_BINARY(6,"taskModel-nn_classification",1,"homo_nn_binary.ftl","homo_nn_binary_infer.ftl"),
REGRESSION_BINARY(7,"taskModel-nn_regression",1,"homo_nn_binary.ftl","homo_nn_binary_infer.ftl"),
HFL_LINEAR_REGRESSION(8,"HFL_linear_regression",1,"homo_lr.ftl","homo_lr_infer.ftl"),
VFL_LINEAR_REGRESSION(9,"VFL_linear_regression",0,"hetero_lr.ftl","hetero_lr_infer.ftl"),
V_XGBOOST(2,"taskModel-v_xgboost",0,"hetero_xgb.ftl","hetero_xgb_infer.ftl","hetero_fitTransform.ftl"),
TRANSVERSE_LR(3,"HFL_logistic_regression",1,"homo_lr.ftl","homo_lr_infer.ftl","homo_fitTransform.ftl"),
MPC_LR(4,"taskModel-mpc_lr",1,null,null,null),
HETERO_LR(5,"VFL_logistic_regression",0,"hetero_lr.ftl","hetero_lr_infer.ftl","hetero_fitTransform.ftl"),
CLASSIFICATION_BINARY(6,"taskModel-nn_classification",1,"homo_nn_binary.ftl","homo_nn_binary_infer.ftl","homo_fitTransform.ftl"),
REGRESSION_BINARY(7,"taskModel-nn_regression",1,"homo_nn_binary.ftl","homo_nn_binary_infer.ftl","homo_fitTransform.ftl"),
HFL_LINEAR_REGRESSION(8,"HFL_linear_regression",1,"homo_lr.ftl","homo_lr_infer.ftl","homo_fitTransform.ftl"),
VFL_LINEAR_REGRESSION(9,"VFL_linear_regression",0,"hetero_lr.ftl","hetero_lr_infer.ftl","hetero_fitTransform.ftl"),
;
private Integer type;
private Integer trainType;
private String typeName;
private String modelFtlPath;
private String inferFtlPath;
private String fitTransformFtlPath;

public static Map<Integer, ModelTypeEnum> MODEL_TYPE_MAP=new HashMap(){
{
Expand All @@ -27,12 +28,13 @@ public enum ModelTypeEnum {
}
};

ModelTypeEnum(Integer type, String typeName, Integer trainType, String modelFtlPath, String inferFtlPath) {
ModelTypeEnum(Integer type, String typeName, Integer trainType, String modelFtlPath, String inferFtlPath,String fitTransformFtlPath) {
this.type = type;
this.typeName = typeName;
this.trainType = trainType;
this.modelFtlPath = modelFtlPath;
this.inferFtlPath = inferFtlPath;
this.fitTransformFtlPath = fitTransformFtlPath;
}

public String getTypeName() {
Expand Down Expand Up @@ -74,4 +76,12 @@ public String getInferFtlPath() {
public void setInferFtlPath(String inferFtlPath) {
this.inferFtlPath = inferFtlPath;
}

public String getFitTransformFtlPath() {
return fitTransformFtlPath;
}

public void setFitTransformFtlPath(String fitTransformFtlPath) {
this.fitTransformFtlPath = fitTransformFtlPath;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,14 @@ private void runComponentTask(Channel channel, TaskParam<TaskComponentParam> tas
taskContentParam.getFreemarkerMap().put("model",taskContentParam.getModelType().getTypeName());
}
String freemarkerContent;
if (StringUtils.isEmpty(taskContentParam.getTemplatesContent())){
freemarkerContent = FreemarkerTemplate.getInstance().generateTemplateStr(taskContentParam.getFreemarkerMap(),taskContentParam.isInfer()?taskContentParam.getModelType().getInferFtlPath():taskContentParam.getModelType().getModelFtlPath());
if (taskContentParam.isFitTransform()){
freemarkerContent = FreemarkerTemplate.getInstance().generateTemplateStr(taskContentParam.getFreemarkerMap(),taskContentParam.getModelType().getFitTransformFtlPath());
}else {
freemarkerContent = taskContentParam.isUntreated()?FreemarkerTemplate.getInstance().generateTemplateStrFreemarkerContent(taskContentParam.isInfer()?taskContentParam.getModelType().getInferFtlPath():taskContentParam.getModelType().getModelFtlPath(),taskContentParam.getTemplatesContent(),taskContentParam.getFreemarkerMap()):taskContentParam.getTemplatesContent();
if (StringUtils.isEmpty(taskContentParam.getTemplatesContent())){
freemarkerContent = FreemarkerTemplate.getInstance().generateTemplateStr(taskContentParam.getFreemarkerMap(),taskContentParam.isInfer()?taskContentParam.getModelType().getInferFtlPath():taskContentParam.getModelType().getModelFtlPath());
}else {
freemarkerContent = taskContentParam.isUntreated()?FreemarkerTemplate.getInstance().generateTemplateStrFreemarkerContent(taskContentParam.isInfer()?taskContentParam.getModelType().getInferFtlPath():taskContentParam.getModelType().getModelFtlPath(),taskContentParam.getTemplatesContent(),taskContentParam.getFreemarkerMap()):taskContentParam.getTemplatesContent();
}
}
log.info("start taskParam:{} - freemarkerContent:{}",taskParam,freemarkerContent);
Common.ParamValue componentParamsParamValue = Common.ParamValue.newBuilder().setValueString(ByteString.copyFrom(JSONObject.toJSONString(JSONObject.parseObject(freemarkerContent), SerializerFeature.WriteMapNullValue).getBytes(StandardCharsets.UTF_8))).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,13 @@ public void continuouslyObtainTaskStatus(Channel channel,Common.TaskContext task
}
param.setError(sb.toString());
isContinue = false;
}else {
long success = getNumberOfSuccessfulTasks(key,cacheService);
log.info("taskid:{} - requestId:{} - num:{} - success:{}",param.getTaskId(),param.getRequestId(),partyCount,success);
if (partyCount <= success){
isContinue = false;
}
}
}
long success = getNumberOfSuccessfulTasks(key,cacheService);
log.info("taskid:{} - requestId:{} - num:{} - success:{}",param.getTaskId(),param.getRequestId(),partyCount,success);
if (partyCount <= success){
isContinue = false;
}
}
Thread.sleep(1000L);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public void execute(Channel channel, TaskParam taskParam) {
taskParam.setSuccess(true);
}else {
taskParam.setSuccess(false);
taskParam.setError(response.getMsgInfo());
taskParam.setError(response.getMsgInfoBytes().toStringUtf8());
}
log.info("kill end response:{}",response.toString());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import primihub.rpc.Common;

import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;


public class AbstractPsiGRPCExecute extends AbstractGRPCExecuteFactory {
Expand Down Expand Up @@ -67,15 +69,22 @@ private void runPsi(Channel channel, TaskParam<TaskPSIParam> param){
paramsBuilder.putParamMap("sync_result_to_server",syncResultToServerParamValue).putParamMap("server_outputFullFilname",serverOutputFullFilnameParamValue);
}
Common.TaskContext taskBuild = assembleTaskContext(param);
Map<String, Common.Dataset> datasetMap = new HashMap<>();
datasetMap.put("SERVER",Common.Dataset.newBuilder().putData("SERVER", param.getTaskContentParam().getServerData()).build());
datasetMap.put("CLIENT",Common.Dataset.newBuilder().putData("CLIENT", param.getTaskContentParam().getClientData()).build());
String code = "";
if (param.getTaskContentParam().getPsiTag().equals(2)){
datasetMap.put("TEE_COMPUTE",Common.Dataset.newBuilder().putData("TEE_COMPUTE", param.getTaskContentParam().getTeeData()).build());
code = "psi";
}
Common.Task task= Common.Task.newBuilder()
.setType(Common.TaskType.PSI_TASK)
.setParams(paramsBuilder.build())
.setName("psiTask")
.setTaskInfo(taskBuild)
.setLanguage(Common.Language.PROTO)
.setCode(ByteString.copyFrom("".getBytes(StandardCharsets.UTF_8)))
.putPartyDatasets("SERVER",Common.Dataset.newBuilder().putData("SERVER", param.getTaskContentParam().getServerData()).build())
.putPartyDatasets("CLIENT",Common.Dataset.newBuilder().putData("CLIENT", param.getTaskContentParam().getClientData()).build())
.setCode(ByteString.copyFrom(code.getBytes(StandardCharsets.UTF_8)))
.putAllPartyDatasets(datasetMap)
.build();
log.info("grpc Common.Task : \n{}",task.toString());
PushTaskRequest request=PushTaskRequest.newBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ public class TaskComponentParam{

private boolean infer = false;

private boolean fitTransform = false;

public ModelTypeEnum getModelType() {
return modelType;
}
Expand Down Expand Up @@ -63,6 +65,14 @@ public void setInfer(boolean infer) {
this.infer = infer;
}

public boolean isFitTransform() {
return fitTransform;
}

public void setFitTransform(boolean fitTransform) {
this.fitTransform = fitTransform;
}

@Override
public String toString() {
return "ComponentTaskParam{" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.primihub.sdk.task.factory.AbstractPsiGRPCExecute;

import java.util.Arrays;
import java.util.List;

/**
* psi 隐私求交组装类
Expand All @@ -25,6 +26,7 @@ public class TaskPSIParam {
/**
* 0、ECDH
* 1、KKRT
* 2、TEE
* 默认0
*/
private Integer psiTag = 0;
Expand All @@ -49,6 +51,8 @@ public class TaskPSIParam {
*/
private String serverOutputFullFilname;

private String teeData;

public String getClientData() {
return clientData;
}
Expand Down Expand Up @@ -121,6 +125,14 @@ public void setServerOutputFullFilname(String serverOutputFullFilname) {
this.serverOutputFullFilname = serverOutputFullFilname;
}

public String getTeeData() {
return teeData;
}

public void setTeeData(String teeData) {
this.teeData = teeData;
}

@Override
public String toString() {
return "PsiTaskParam{" +
Expand Down
76 changes: 76 additions & 0 deletions primihub-sdk/src/main/resources/templates/hetero_fitTransform.ftl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
{
"roles": {
"host": "Bob",
"guest": [
"Charlie"
]
},
"common_params": {
"model": "FL_Preprocess",
"process": "fit_transform",
"FL_type": "V",
"task_name": "VFL_simpleimpute_fit_transform",
"task": "classification"
},
"role_params": {
"Bob": {
"data_set": "${label_dataset}",
"selected_column": null,
"id": "id",
"label": "y",
"preprocess_column": null,
"preprocess_dataset_id": "${new_label_dataset}",
"preprocess_dataset_path": "${new_label_dataset_path}",
"preprocess_module_path": "${new_label_dataset_path}.pkl",
"preprocess_module": {
"SimpleImputer_string": {
"column": null,
"missing_values": "np.nan",
"strategy": "${simpleImputerString}",
"fill_value": null,
"copy": true,
"add_indicator": false,
"keep_empty_features": false
},
"SimpleImputer_numeric": {
"column": null,
"missing_values": "np.nan",
"strategy": "${simpleImputerNumeric}",
"fill_value": null,
"copy": true,
"add_indicator": false,
"keep_empty_features": false
}
}
},
"Charlie": {
"data_set": "${guest_dataset}",
"selected_column": null,
"id": "id",
"preprocess_column": null,
"preprocess_dataset_id": "${new_guest_dataset}",
"preprocess_dataset_path": "${new_guest_dataset_path}",
"preprocess_module_path": "${new_guest_dataset_path}.pkl",
"preprocess_module": {
"SimpleImputer_string": {
"column": null,
"missing_values": "np.nan",
"strategy": "${simpleImputerString}",
"fill_value": null,
"copy": true,
"add_indicator": false,
"keep_empty_features": false
},
"SimpleImputer_numeric": {
"column": null,
"missing_values": "np.nan",
"strategy": "${simpleImputerNumeric}",
"fill_value": null,
"copy": true,
"add_indicator": false,
"keep_empty_features": false
}
}
}
}
}
57 changes: 57 additions & 0 deletions primihub-sdk/src/main/resources/templates/homo_fitTransform.ftl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
{
"roles": {
"server": "Alice",
"client": [
"Bob",
"Charlie"
]
},
"common_params": {
"model": "FL_Preprocess",
"process": "fit_transform",
"FL_type": "H",
"task_name": "HFL_simpleimpute_fit_transform",
"task": "classification",
"selected_column": null,
"id": "id",
"label": "y",
"preprocess_column": null,
"preprocess_module": {
"SimpleImputer_string": {
"column": null,
"missing_values": "np.nan",
"strategy": "${simpleImputerString}",
"fill_value": null,
"copy": true,
"add_indicator": false,
"keep_empty_features": false
},
"SimpleImputer_numeric": {
"column": null,
"missing_values": "np.nan",
"strategy": "${simpleImputerNumeric}",
"fill_value": null,
"copy": true,
"add_indicator": false,
"keep_empty_features": false
}
}
},
"role_params": {
"Bob": {
"data_set": "${label_dataset}",
"preprocess_dataset_id": "${new_label_dataset}",
"preprocess_dataset_path": "${new_label_dataset_path}",
"preprocess_module_path": "${new_label_dataset_path}.pkl"
},
"Charlie": {
"data_set": "${guest_dataset}",
"preprocess_dataset_id": "${new_guest_dataset}",
"preprocess_dataset_path": "${new_guest_dataset_path}",
"preprocess_module_path": "${new_guest_dataset_path}.pkl"
},
"Alice": {
"data_set": "${arbiter_dataset!""}"
}
}
}
2 changes: 1 addition & 1 deletion primihub-sdk/src/main/resources/templates/homo_lr.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
},
"Alice": {
"data_set": "${arbiter_dataset}",
"metric_path": "${indicatorFileName}"
"metric_path": "/data${indicatorFileName}"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
},
"Alice": {
"data_set": "${arbiter_dataset}",
"metric_path": "${indicatorFileName}"
"metric_path": "/data${indicatorFileName}"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.springframework.cloud.openfeign.EnableFeignClients;
import org.springframework.cloud.stream.annotation.EnableBinding;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.annotation.EnableScheduling;

@NacosPropertySources({
// @NacosPropertySource(dataId = "test", autoRefreshed = true),
Expand All @@ -22,6 +23,7 @@
@ServletComponentScan(basePackages = {"com.primihub.biz.filter"})
@EnableBinding({SingleTaskChannel.class})
@EnableFeignClients(basePackages = {"com.primihub"})
@EnableScheduling
public class PlatformApplication {

public static void main(String[] args) {
Expand Down
Loading

0 comments on commit fabdd98

Please sign in to comment.