-
Notifications
You must be signed in to change notification settings - Fork 2
/
chepy_ml.py
129 lines (106 loc) · 4.25 KB
/
chepy_ml.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""Code generated mostly with ChatGPT"""
import chepy.core
import logging
import lazy_import
try:
# TODO 🔥 move to lazy resources to speed up
# import torch.nn as nn
torch = lazy_import.lazy_module("torch")
# import torch
np = lazy_import.lazy_module("numpy")
# import numpy as np
import json
import pkg_resources
except ImportError:
logging.warning("Could not import pytorch or numpy. Use pip install torch numpy")
# Define the model architecture that matches the one used for training
class EncoderClassifier(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(EncoderClassifier, self).__init__()
self.embedding = torch.nn.Embedding(input_size, hidden_size)
self.fc = torch.nn.Linear(hidden_size, num_classes)
def forward(self, x):
x = self.embedding(x)
x = torch.mean(x, dim=1) # Average pooling over sequence length
x = self.fc(x)
return x
# Define a function to load the saved model
def load_model(model_filename, input_size, hidden_size, num_classes):
model = EncoderClassifier(input_size, hidden_size, num_classes)
model.load_state_dict(torch.load(model_filename, map_location=torch.device("cpu")))
model.eval()
return model
# Define a function to make predictions
def predict_encoding(model, input_string, top_k=3):
# Encode the input string
encoded_string = [char for char in input_string]
padded_input = torch.tensor(encoded_string, dtype=torch.long).unsqueeze(0)
# Make a prediction
with torch.no_grad():
output = model(padded_input)
probabilities = torch.nn.functional.softmax(output, dim=1)[0].numpy()
# Get the top k predicted labels
top_k_labels = np.argsort(-probabilities)[:top_k]
top_k_probabilities = probabilities[top_k_labels]
return top_k_labels, top_k_probabilities
class Chepy_ML(chepy.core.ChepyCore):
"""This plugin helps run various ML models against the state"""
def ml_magic(self, depth: int = 3, verbose=False):
"""Automatically try to decode the state based on detected encoding. Will break on first exception
Args:
depth (int, optional): Number of iterations. Defaults to 3.
verbose (bool, optional): Include detection weights. Defaults to False.
Returns:
Chepy: The Chepy object.
"""
hold = []
for _ in range(depth):
try:
data = self.state
detect_methods = self.ml_detect().o
self.state = data
method = next(iter(detect_methods))
out = getattr(self, method)().o
hold.append({"method": method, "detected": detect_methods, "out": out})
self.state = out
except Exception:
break
if verbose:
self.state = hold
else:
self.state = [h["out"] for h in hold]
return self
# @chepy.core.ChepyDecorators.call_stack
def ml_detect(self, num_top_labels: int = 5):
"""Detect encoding type of the state
Args:
num_top_labels (int, optional): Number of labels to return. Defaults to 5.
Returns:
ChepyPlugin: The Chepy object.
"""
# Load the trained model
model_filename = pkg_resources.resource_filename(
__name__, "data/ml_detect_encoding.pth"
)
input_size = 1024
hidden_size = 64
with open(
pkg_resources.resource_filename(__name__, "data/ml_labels.json"), "r"
) as f:
class_labels = json.loads(f.read())
num_classes = len(
class_labels
) # Update with the actual number of encoding types
loaded_model = load_model(model_filename, input_size, hidden_size, num_classes)
top_labels, top_probabilities = predict_encoding(
loaded_model, self._convert_to_bytes(), num_top_labels
)
# Display the top predicted labels and their probabilities
res = {
class_labels[str(label_idx)]: round(probability, 5)
for _, (label_idx, probability) in enumerate(
zip(top_labels, top_probabilities)
)
}
self.state = res
return self