diff --git a/src/main/java/com/api/air_quality/model/ai_models/AIModel.java b/src/main/java/com/api/air_quality/model/ai_models/AIModel.java index 675d90c5..9b8c5074 100644 --- a/src/main/java/com/api/air_quality/model/ai_models/AIModel.java +++ b/src/main/java/com/api/air_quality/model/ai_models/AIModel.java @@ -6,6 +6,11 @@ @Component public class AIModel { public String Message() { return "Server Get The Response From Here"; } + + public String AirHumidity() { + return "Server Get The Response From Here air Humidity"; + } + public static void main(String[] args) { AIModel aiModel = new AIModel(); @@ -19,4 +24,8 @@ public double predict(double[] features) { // Replace this with your actual prediction logic return 0.0; } + + public double predictAirHumidity(double[] features) { + return predict(features); + } } diff --git a/src/main/java/com/api/air_quality/model/ai_models/AirHumidityModel.java b/src/main/java/com/api/air_quality/model/ai_models/AirHumidityModel.java deleted file mode 100644 index 132f1c4e..00000000 --- a/src/main/java/com/api/air_quality/model/ai_models/AirHumidityModel.java +++ /dev/null @@ -1,42 +0,0 @@ -package com.api.air_quality.model.ai_models; - -import org.springframework.stereotype.Component; -import org.jpmml.model.PMMLUtil; - -import javax.xml.transform.stream.StreamSource; - -import java.io.File; - -@Component -public class AirHumidityModel { - - private final org.dmg.pmml.PMML pmml; - - public AirHumidityModel() throws Exception { - // Replace with the actual path to your PMML file - String pmmlPath = "../../../../../../../../AI_Models/airHumidity_model.joblib"; - this.pmml = loadPMML(pmmlPath); - } - - public double predict(double[] features) { - // Implement your prediction logic using the loaded PMML model - // Replace the following line with your actual prediction code - return 0.0; - } - - private org.dmg.pmml.PMML loadPMML(String pmmlPath) throws Exception { - File file = new File(pmmlPath); - StreamSource source = new StreamSource(file); -// return PMMLUtil.unmarshal(source.getInputStream()); - try { - return PMMLUtil.unmarshal(source.getInputStream()); - } catch (Exception e) { - e.printStackTrace(); - return null; - } - } - - public double predictAirHumidity(double[] features) { - return predict(features); - } -} diff --git a/src/main/java/com/api/air_quality/python/AIModelPython.py b/src/main/java/com/api/air_quality/python/AIModelPython.py index 06fdc762..82a2b153 100644 --- a/src/main/java/com/api/air_quality/python/AIModelPython.py +++ b/src/main/java/com/api/air_quality/python/AIModelPython.py @@ -14,3 +14,4 @@ def predict(self, features): gateway = JavaGateway() msgObjectFromJavaApp = gateway.entry_point print(msgObjectFromJavaApp.Message()) + print(msgObjectFromJavaApp.AirHumidity()) diff --git a/src/main/java/com/api/air_quality/python/AirHumidityModelPython.py b/src/main/java/com/api/air_quality/python/AirHumidityModelPython.py index 50150f26..4c88ffdd 100644 --- a/src/main/java/com/api/air_quality/python/AirHumidityModelPython.py +++ b/src/main/java/com/api/air_quality/python/AirHumidityModelPython.py @@ -1,19 +1,67 @@ +from pyspark.sql import SparkSession +from joblib import load from py4j.java_gateway import JavaGateway +import numpy as np +from pyspark.sql.functions import udf +from pyspark.sql.types import BooleanType class AirHumidityModelPython: - def __init__(self): + def __init__(self, spark): # Connect to the Py4J gateway server self.gateway = JavaGateway() # Retrieve the Java instance of the model - self.java_model = self.gateway.entry_point.getAirHumidityModel() + self.java_model = self.gateway.entry_point + + # Load the PMML model or any other necessary initialization logic + self.model = load("../../../../../../../AI_Models/airHumidity_model.joblib") + + # Define the UDF for point_inside_polygon + self.point_inside_polygon_udf = udf(self.point_inside_polygon, BooleanType()) + + # Example Spark DataFrame + self.df = spark.createDataFrame([(1.0, 2.0)], ["lat", "long"]) + + def point_inside_polygon(self, lat, long, polygon): + # Implement your point_inside_polygon logic here + # Return True if the point is inside the polygon, else False + pass def predict_air_humidity(self, features): - # Implement your prediction logic here - return self.java_model.predictAirHumidity(features) + try: + # Explicitly convert NumPy array to Python list + features_list = [float(val) for val in np.array(features)] + # Perform prediction using the loaded model + prediction = self.model.predict([features_list]) + + # Example usage of UDF for point_inside_polygon + # Replace lat and long with actual features from the model + inside_polygon_result = self.df.where(self.point_inside_polygon_udf('lat', 'long', features_list)) + + # Pass the prediction and result to the Java side + return self.java_model.receivePrediction(prediction, inside_polygon_result) + except Exception as e: + # Handle any errors during prediction + return str(e) + + +def main(): + # Initialize Spark session + spark = SparkSession.builder.appName("AirHumidityModelApp").getOrCreate() -if __name__ == "__main__": # Create an instance of the Python class - air_humidity_model = AirHumidityModelPython() + air_humidity_model = AirHumidityModelPython(spark) + + # Example prediction + features = [1.0, 2.0, 4.0, 5.0, 7.0, 9.0, 10.0] + result = air_humidity_model.predict_air_humidity(features) + print(result) + + # Stop the Spark session + spark.stop() + + +if __name__ == "__main__": + main() diff --git a/src/main/java/com/api/air_quality/service/ai_services/AirHumidityModelService.java b/src/main/java/com/api/air_quality/service/ai_services/AirHumidityModelService.java deleted file mode 100644 index e2d2da78..00000000 --- a/src/main/java/com/api/air_quality/service/ai_services/AirHumidityModelService.java +++ /dev/null @@ -1,25 +0,0 @@ -package com.api.air_quality.service.ai_services; - -import com.api.air_quality.model.ai_models.AirHumidityModel; -import jakarta.annotation.PostConstruct; -import jakarta.annotation.PreDestroy; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.stereotype.Service; -import py4j.GatewayServer; - -@Service -public class AirHumidityModelService { - private final PythonIntegrationService pythonIntegrationService; - private final AirHumidityModel airHumidityModel; - - @Autowired - public AirHumidityModelService(PythonIntegrationService pythonIntegrationService, AirHumidityModel airHumidityModel) { - this.pythonIntegrationService = pythonIntegrationService; - this.airHumidityModel = airHumidityModel; - } - - public double predictAirHumidity(double[] features) { - // Call the predict method on the Java model - return airHumidityModel.predictAirHumidity(features); - } -} diff --git a/src/main/java/com/api/air_quality/service/ai_services/PythonIntegrationService.java b/src/main/java/com/api/air_quality/service/ai_services/PythonIntegrationService.java index ebc2c571..a06e3433 100644 --- a/src/main/java/com/api/air_quality/service/ai_services/PythonIntegrationService.java +++ b/src/main/java/com/api/air_quality/service/ai_services/PythonIntegrationService.java @@ -12,11 +12,13 @@ public class PythonIntegrationService { private final GatewayServer gatewayServer; + private final AIModel aiModel; @Autowired @Lazy public PythonIntegrationService(GatewayServer gatewayServer) { this.gatewayServer = gatewayServer; + this.aiModel = new AIModel(); } @Bean @@ -42,4 +44,9 @@ public void stopPythonGateway() { gatewayServer.shutdown(); } } + + public double predictAirHumidity(double[] features) { + // Call the predict method on the Java model + return aiModel.predictAirHumidity(features); + } }