diff --git a/tzrec/main.py b/tzrec/main.py index b176d4a..f722442 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -1010,6 +1010,7 @@ def predict( data_config: DataConfig = pipeline_config.data_config data_config.ClearField("label_fields") + data_config.ClearField("sample_weight_fields") data_config.drop_remainder = False # Build feature features = _create_features(list(pipeline_config.feature_configs), data_config)