Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RNN example to perform DGA detection with LSTMs #240

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

rcurtin
Copy link
Member

@rcurtin rcurtin commented Jan 9, 2025

This PR adds an example that uses two RNNs to solve the DGA detection problem. In short it is a classification problem to determine whether or not a domain was generated by a domain generation algorithm.

The model uses two RNNs: one trained on benign domains, and one trained on malicious domains. For prediction, we compute the likelihood of a domain coming from each model, and then we take the more likely model as the class. This is the generalized likelihood ratio test or something very much like it. We used this strategy at Symantec and a more complicated version of it in a paper.

The network structure is simple: it uses float data instead of double data and consists of 50 LSTM units followed by a 39-element linear layer and log-softmax.

I plan to use this as a demonstration of how to deploy a predictive model inside of a Docker container.

When the model is trained and run, here is the output:

$ ./lstm_dga_detection_train dga_domains.csv 
File 'dga_domains.csv' has 80000 benign domains with a maximum length of 64, and 80000 malicious domains with a maximum length of 36.
Epoch 1/5
72000/72000 [====================================================================================================] 100% - 381.067s/epoch; 5ms/step; loss: 28.8829
Epoch 2/5
72000/72000 [====================================================================================================] 100% - 457.332s/epoch; 6ms/step; loss: 26.0513
Epoch 3/5
72000/72000 [====================================================================================================] 100% - 508.525s/epoch; 7ms/step; loss: 25.8233
Epoch 4/5
72000/72000 [====================================================================================================] 100% - 535.384s/epoch; 7ms/step; loss: 25.6701
Epoch 5/5
72000/72000 [====================================================================================================] 100% - 759.608s/epoch; 10ms/step; loss: 25.4866
Epoch 1/5
72000/72000 [====================================================================================================] 100% - 759.118s/epoch; 10ms/step; loss: 58.4937
Epoch 2/5
72000/72000 [====================================================================================================] 100% - 729.818s/epoch; 10ms/step; loss: 49.2438
Epoch 3/5
72000/72000 [====================================================================================================] 100% - 843.702s/epoch; 11ms/step; loss: 49.1188
Epoch 4/5
72000/72000 [====================================================================================================] 100% - 932.335s/epoch; 12ms/step; loss: 49.368
Epoch 5/5
72000/72000 [====================================================================================================] 100% - 903.086s/epoch; 12ms/step; loss: 49.4524
Model performance:
  Training accuracy: 141157 of 144000 correct (98.0257%).
  Test accuracy:     15683 of 16000 correct (98.0187%).

The size of the resulting models is small:

$ ls -lh *.bin
-rw-rw-r-- 1 ryan ryan 80K Jan  8 23:21 lstm_dga_detector_benign.bin
-rw-rw-r-- 1 ryan ryan 80K Jan  8 23:21 lstm_dga_detector_malicious.bin

And the prediction program works like below, where I write a domain name and then the prediction is printed (or an error if the domain name was invalid):

$ ./lstm_dga_detection_predict lstm_dga_detector_benign.bin lstm_dga_detector_malicious.bin 
www.mlpack.org
benign
asd98udvsa908usad98uf234.org
malicious
this IS a domain with invalid characters
Domain 'this IS a domain with invalid characters' has invalid character ' '!
$

For the code to run correctly, the following PRs must first be merged in mlpack:

Copy link

github-actions bot commented Jan 9, 2025

Binder 👈 Launch a binder notebook on branch rcurtin/examples/dga-detection

@@ -0,0 +1,123 @@
/**
* @file lstm_dga_detection_train.cpp
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The file name is not identical to this one.

What is the point of this one as well ? given that we have the train file above ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack, I thought I removed it! Sorry about that. Fixed in 6169ab1.

test_labels = requests.get(
"https://datasets.mlpack.org/mnist/t10k-labels-idx1-ubyte.gz")
progress_bar("test_labels.gz", test_labels)
ungzip("test_labels.gz", "test_labels.ubytes")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks 👍

@rcurtin
Copy link
Member Author

rcurtin commented Jan 9, 2025

This should wait for merge on those other 3 PRs.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second approval provided automatically after 24 hours. 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants