-
Notifications
You must be signed in to change notification settings - Fork 93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add RNN example to perform DGA detection with LSTMs #240
base: master
Are you sure you want to change the base?
Conversation
cpp/lstm/dga_detection/lstm_test.cpp
Outdated
@@ -0,0 +1,123 @@ | |||
/** | |||
* @file lstm_dga_detection_train.cpp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The file name is not identical to this one.
What is the point of this one as well ? given that we have the train file above ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ack, I thought I removed it! Sorry about that. Fixed in 6169ab1.
test_labels = requests.get( | ||
"https://datasets.mlpack.org/mnist/t10k-labels-idx1-ubyte.gz") | ||
progress_bar("test_labels.gz", test_labels) | ||
ungzip("test_labels.gz", "test_labels.ubytes") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks 👍
This should wait for merge on those other 3 PRs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Second approval provided automatically after 24 hours. 👍
This PR adds an example that uses two RNNs to solve the DGA detection problem. In short it is a classification problem to determine whether or not a domain was generated by a domain generation algorithm.
The model uses two RNNs: one trained on benign domains, and one trained on malicious domains. For prediction, we compute the likelihood of a domain coming from each model, and then we take the more likely model as the class. This is the generalized likelihood ratio test or something very much like it. We used this strategy at Symantec and a more complicated version of it in a paper.
The network structure is simple: it uses
float
data instead ofdouble
data and consists of 50 LSTM units followed by a 39-element linear layer and log-softmax.I plan to use this as a demonstration of how to deploy a predictive model inside of a Docker container.
When the model is trained and run, here is the output:
The size of the resulting models is small:
And the prediction program works like below, where I write a domain name and then the prediction is printed (or an error if the domain name was invalid):
For the code to run correctly, the following PRs must first be merged in mlpack: