Skip to content

Commit

Permalink
ENH add support for different feature names in HFL (primihub#777)
Browse files Browse the repository at this point in the history
* ENH add support for different feature names in HFL

* use get for dictionary
  • Loading branch information
xuefeng-xu authored May 28, 2024
1 parent daea4dd commit 97b40f5
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 15 deletions.
4 changes: 3 additions & 1 deletion python/primihub/FL/linear_regression/hfl_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def train(self):
task_info=self.task_info)

# load dataset
selected_column = self.common_params['selected_column']
selected_column = self.common_params.get('selected_column')
if selected_column is None:
selected_column = self.role_params.get('selected_column')
id = self.common_params['id']
x = read_data(data_info=self.role_params['data'],
selected_column=selected_column,
Expand Down
4 changes: 3 additions & 1 deletion python/primihub/FL/logistic_regression/hfl_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def train(self):
task_info=self.task_info)

# load dataset
selected_column = self.common_params['selected_column']
selected_column = self.common_params.get('selected_column')
if selected_column is None:
selected_column = self.role_params.get('selected_column')
id = self.common_params['id']
x = read_data(data_info=self.role_params['data'],
selected_column=selected_column,
Expand Down
4 changes: 3 additions & 1 deletion python/primihub/FL/neural_network/hfl_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def train(self):
task_info=self.task_info)

# load dataset
selected_column = self.common_params['selected_column']
selected_column = self.common_params.get('selected_column')
if selected_column is None:
selected_column = self.role_params.get('selected_column')
id = self.common_params['id']
x = read_data(data_info=self.role_params['data'],
selected_column=selected_column,
Expand Down
17 changes: 5 additions & 12 deletions python/primihub/FL/preprocessing/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def fit_transform(self):

# load dataset
if FL_type == "H":
selected_column = self.common_params["selected_column"]
selected_column = self.common_params.get('selected_column')
if selected_column is None:
selected_column = self.role_params.get('selected_column')
id = self.common_params["id"]
else:
selected_column = self.role_params["selected_column"]
Expand Down Expand Up @@ -91,6 +93,8 @@ def fit_transform(self):
# preprocessing
if FL_type == "H":
module_params = self.common_params.get("preprocess_module")
if module_params is None:
module_params = self.role_params.get("preprocess_module")
else:
module_params = self.role_params.get("preprocess_module")
if module_params is None:
Expand Down Expand Up @@ -130,21 +134,10 @@ def fit_transform(self):
elif "Scaler" in module_name:
column = data[column].select_dtypes(include=num_type).columns

if role == "client":
channel.send("column", column)
column = channel.recv("column")
if role == "server":
client_column = channel.recv_all("column")
column = list(set(chain.from_iterable(client_column)))
channel.send_all("column", column)

if isinstance(column, pd.Index):
column = column.tolist()
if column:
logger.info(f"column: {column}, # {len(column)}")
else:
logger.info(f"column is empty, {module_name} is skipped")
continue

module = select_module(module_name, params, FL_type, role, channel)

Expand Down

0 comments on commit 97b40f5

Please sign in to comment.