Skip to content

Commit

Permalink
code clean
Browse files Browse the repository at this point in the history
  • Loading branch information
kavicastelo committed Dec 24, 2023
1 parent 446f588 commit 13e2172
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
@Component
public class AIModel {
public String Message() { return "Server Get The Response From Here"; }

public String AirHumidity() {
return "Server Get The Response From Here air Humidity";
}

public static void main(String[] args) {

AIModel aiModel = new AIModel();
Expand All @@ -19,4 +24,8 @@ public double predict(double[] features) {
// Replace this with your actual prediction logic
return 0.0;
}

public double predictAirHumidity(double[] features) {
return predict(features);
}
}

This file was deleted.

1 change: 1 addition & 0 deletions src/main/java/com/api/air_quality/python/AIModelPython.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ def predict(self, features):
gateway = JavaGateway()
msgObjectFromJavaApp = gateway.entry_point
print(msgObjectFromJavaApp.Message())
print(msgObjectFromJavaApp.AirHumidity())
60 changes: 54 additions & 6 deletions src/main/java/com/api/air_quality/python/AirHumidityModelPython.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,67 @@
from pyspark.sql import SparkSession
from joblib import load
from py4j.java_gateway import JavaGateway
import numpy as np
from pyspark.sql.functions import udf
from pyspark.sql.types import BooleanType


class AirHumidityModelPython:
def __init__(self):
def __init__(self, spark):
# Connect to the Py4J gateway server
self.gateway = JavaGateway()

# Retrieve the Java instance of the model
self.java_model = self.gateway.entry_point.getAirHumidityModel()
self.java_model = self.gateway.entry_point

# Load the PMML model or any other necessary initialization logic
self.model = load("../../../../../../../AI_Models/airHumidity_model.joblib")

# Define the UDF for point_inside_polygon
self.point_inside_polygon_udf = udf(self.point_inside_polygon, BooleanType())

# Example Spark DataFrame
self.df = spark.createDataFrame([(1.0, 2.0)], ["lat", "long"])

def point_inside_polygon(self, lat, long, polygon):
# Implement your point_inside_polygon logic here
# Return True if the point is inside the polygon, else False
pass

def predict_air_humidity(self, features):
# Implement your prediction logic here
return self.java_model.predictAirHumidity(features)
try:
# Explicitly convert NumPy array to Python list
features_list = [float(val) for val in np.array(features)]

# Perform prediction using the loaded model
prediction = self.model.predict([features_list])

# Example usage of UDF for point_inside_polygon
# Replace lat and long with actual features from the model
inside_polygon_result = self.df.where(self.point_inside_polygon_udf('lat', 'long', features_list))

# Pass the prediction and result to the Java side
return self.java_model.receivePrediction(prediction, inside_polygon_result)
except Exception as e:
# Handle any errors during prediction
return str(e)


def main():
# Initialize Spark session
spark = SparkSession.builder.appName("AirHumidityModelApp").getOrCreate()

if __name__ == "__main__":
# Create an instance of the Python class
air_humidity_model = AirHumidityModelPython()
air_humidity_model = AirHumidityModelPython(spark)

# Example prediction
features = [1.0, 2.0, 4.0, 5.0, 7.0, 9.0, 10.0]
result = air_humidity_model.predict_air_humidity(features)
print(result)

# Stop the Spark session
spark.stop()


if __name__ == "__main__":
main()

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
public class PythonIntegrationService {

private final GatewayServer gatewayServer;
private final AIModel aiModel;

@Autowired
@Lazy
public PythonIntegrationService(GatewayServer gatewayServer) {
this.gatewayServer = gatewayServer;
this.aiModel = new AIModel();
}

@Bean
Expand All @@ -42,4 +44,9 @@ public void stopPythonGateway() {
gatewayServer.shutdown();
}
}

public double predictAirHumidity(double[] features) {
// Call the predict method on the Java model
return aiModel.predictAirHumidity(features);
}
}

0 comments on commit 13e2172

Please sign in to comment.