Skip to content

Commit

Permalink
for test function added
Browse files Browse the repository at this point in the history
  • Loading branch information
Mayank2184 committed Jul 5, 2024
1 parent 449aaa4 commit 45b35db
Showing 1 changed file with 84 additions and 2 deletions.
86 changes: 84 additions & 2 deletions src/freesas/test/test_dnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,42 @@
import unittest
import logging
import os
import io
import numpy as np
from .utilstest import get_datafile
from ..resources import resource_filename
from ..sasio import load_scattering_data
from ..dnn import preprocess
from ..dnn import *

logger = logging.getLogger(__name__)

class TestDNN(unittest.TestCase):

def test_activation_functions(self):
"""
Test for the activation functions
"""
x = np.array([-1, 0, 1])

# Test tanh
expected_tanh = np.tanh(x)
self.assertTrue(np.allclose(tanh(x), expected_tanh), msg="tanh function failed")

# Test relu
expected_relu = np.maximum(0, x)
self.assertTrue(np.allclose(relu(x), expected_relu), msg="relu function failed")

# Test sigmoid
expected_sigmoid = 1 / (1 + np.exp(-x))
self.assertTrue(np.allclose(sigmoid(x), expected_sigmoid), msg="sigmoid function failed")

# Test linear
expected_linear = x
self.assertTrue(np.allclose(linear(x), expected_linear), msg="linear function failed")

logger.info("test_activation_functions ran successfully")


def test_preprocess(self):
"""
Test for the preprocessing function
Expand All @@ -48,11 +76,65 @@ def test_preprocess(self):
Iprep = preprocess(q, I)
self.assertEqual(Iprep.max(), 1, msg="range 0-1")
self.assertEqual(Iprep.shape, (1024,), msg="size 1024")



def test_forward_propagation(self):
"""
Test for the forward_propagation function
"""
try:
X = np.random.rand(1, 10)
params = [np.random.rand(10, 20), np.random.rand(20), np.random.rand(20, 10), np.random.rand(10)]
activations = [np.tanh, np.tanh]
output = forward_propagation(X, params, activations)
self.assertEqual(output.shape, (1, 10))
logger.info("test_forward_propogation ran successfully")
except Exception as e:
logger.error(f"test_forward_propagation failed: {e}")
raise


def test_DenseLayer(self):
"""
Test for the DenseLayer class
"""
try :
weights = np.random.rand(10, 20)
bias = np.random.rand(20)
layer = DenseLayer(weights, bias, 'tanh')
self.assertEqual(layer.input_size, 10)
self.assertEqual(layer.output_size, 20)
output = layer.forward(np.random.rand(1, 10))
self.assertEqual(output.shape, (1, 20))
logger.info("test_DenseLayer ran successfully")
except Exception as e:
logger.error(f"test_DenseLayer failed: {e}")
raise



def test_DNN(self):
"""
Test for the DNN class
"""
try:
layers = [DenseLayer(np.random.rand(10, 20), np.random.rand(20), 'tanh'),
DenseLayer(np.random.rand(20, 10), np.random.rand(10), 'tanh')]
dnn = DNN(*layers)
output = dnn.infer(np.random.rand(1, 10))
self.assertEqual(output.shape, (1, 10))
logger.info("test_DNN ran successfully")
except Exception as e:
logger.error(f"test_DNN failed: {e}")
raise


def suite():
loader = unittest.defaultTestLoader.loadTestsFromTestCase
test_suite = unittest.TestSuite()
test_suite.addTest(TestDNN("test_preprocess"))
test_suite.addTest(loader(TestDNN))
# test_suite.addTest(TestDNN("test_preprocess"))
return test_suite


Expand Down

0 comments on commit 45b35db

Please sign in to comment.