Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve test coverage #154

Open
wants to merge 9 commits into
base: dev
Choose a base branch
from
Open
5 changes: 3 additions & 2 deletions precise/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
- Interpretation of the network output to a confidence value
"""
from math import floor
from typing import Optional

import attr
import json
Expand Down Expand Up @@ -68,7 +69,7 @@ class ListenerParams:
use_delta = attr.ib() # type: bool
vectorizer = attr.ib() # type: int
threshold_config = attr.ib() # type: tuple
threshold_center = attr.ib() # type: float
threshold_center = attr.ib() # type: Optional[float]

@property
def buffer_samples(self):
Expand Down Expand Up @@ -140,7 +141,7 @@ class Vectorizer:
pr = ListenerParams(
buffer_t=1.5, window_t=0.1, hop_t=0.05, sample_rate=16000,
sample_depth=2, n_fft=512, n_filt=20, n_mfcc=13, use_delta=False,
threshold_config=((6, 4),), threshold_center=0.2, vectorizer=Vectorizer.mfccs
threshold_config=(), threshold_center=None, vectorizer=Vectorizer.mfccs
)

# Used to fill in old param files without new attributes
Expand Down
Empty file removed precise/pocketsphinx/__init__.py
Empty file.
Empty file.
22 changes: 17 additions & 5 deletions precise/threshold_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ class ThresholdDecoder:
activations using a series of averages and standard deviations to
calculate a cumulative probability distribution

Args:
mu_stds: tuple of pairs of (mean, standard deviation) that model the positive network output
center: proportion of activations that a threshold of 0.5 indicates. Pass as None to disable decoding
resolution: precision of cumulative sum estimation. Increases memory usage
min_z: Minimum z score to generate in distribution map
max_z: Maximum z score to generate in distribution map

Background:
We could simply take the output of the neural network as the confidence of a given
prediction, but this typically jumps quickly between 0.01 and 0.99 even in cases where
Expand All @@ -36,14 +43,17 @@ class ThresholdDecoder:
of 80% means that the network output is greater than roughly 80% of the dataset
"""
def __init__(self, mu_stds: Tuple[Tuple[float, float]], center=0.5, resolution=200, min_z=-4, max_z=4):
self.min_out = int(min(mu + min_z * std for mu, std in mu_stds))
self.max_out = int(max(mu + max_z * std for mu, std in mu_stds))
self.out_range = self.max_out - self.min_out
self.cd = np.cumsum(self._calc_pd(mu_stds, resolution))
self.min_out = self.max_out = self.out_range = 0
self.cd = np.array([])
self.center = center
if center is not None:
self.min_out = int(min([mu + min_z * std for mu, std in mu_stds]))
self.max_out = int(max([mu + max_z * std for mu, std in mu_stds]))
self.out_range = self.max_out - self.min_out
self.cd = np.cumsum(self._calc_pd(mu_stds, resolution))

def decode(self, raw_output: float) -> float:
if raw_output == 1.0 or raw_output == 0.0:
if self.center is None or raw_output == 1.0 or raw_output == 0.0:
return raw_output
if self.out_range == 0:
cp = int(raw_output > self.min_out)
Expand All @@ -57,6 +67,8 @@ def decode(self, raw_output: float) -> float:
return 0.5 + 0.5 * (cp - self.center) / (1 - self.center)

def encode(self, threshold: float) -> float:
if self.center is None:
return threshold
threshold = 0.5 * threshold / self.center
if threshold < 0.5:
cp = threshold * self.center * 2
Expand Down
21 changes: 15 additions & 6 deletions precise/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,24 @@ def chunk_audio(audio: np.ndarray, chunk_size: int) -> Generator[np.ndarray, Non
yield audio[i - chunk_size:i]


def float_audio_to_int(audio: np.ndarray) -> np.ndarray:
"""Converts [-1.0, 1.0] -> [-32768, 32767]"""
return (audio.astype(np.float32, order='C') * (0x7FFF + 0.5) - 0.5).astype('<i2')


def int_audio_to_float(int_audio: np.ndarray) -> np.ndarray:
"""Converts [-32768, 32767] -> [-1.0, 1.0]"""
return (int_audio + 0.5) / (0x7FFF + 0.5)


def buffer_to_audio(buffer: bytes) -> np.ndarray:
"""Convert a raw mono audio byte string to numpy array of floats"""
return np.fromstring(buffer, dtype='<i2').astype(np.float32, order='C') / 32768.0
return int_audio_to_float(np.frombuffer(buffer, dtype='<i2'))


def audio_to_buffer(audio: np.ndarray) -> bytes:
"""Convert a numpy array of floats to raw mono audio"""
return (audio * 32768).astype('<i2').tostring()
return float_audio_to_int(audio).tostring()


def load_audio(file: Any) -> np.ndarray:
Expand All @@ -61,15 +71,14 @@ def load_audio(file: Any) -> np.ndarray:
if wav.rate != pr.sample_rate:
raise InvalidAudio('Unsupported sample rate: ' + str(wav.rate))

data = np.squeeze(wav.data)
return data.astype(np.float32) / float(np.iinfo(data.dtype).max)
return int_audio_to_float(np.squeeze(wav.data))


def save_audio(filename: str, audio: np.ndarray):
"""Save loaded audio to file using the configured audio parameters"""
import wavio
save_audio = (audio * np.iinfo(np.int16).max).astype(np.int16)
wavio.write(filename, save_audio, pr.sample_rate, sampwidth=pr.sample_depth, scale='none')
int_audio = float_audio_to_int(audio)
wavio.write(filename, int_audio, pr.sample_rate, sampwidth=pr.sample_depth, scale='none')


def play_audio(filename: str):
Expand Down
22 changes: 17 additions & 5 deletions runner/precise_runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import time
from subprocess import PIPE, Popen
from threading import Thread, Event
from threading import Thread, Event, current_thread


class Engine(object):
Expand Down Expand Up @@ -82,6 +82,7 @@ def __init__(self, s=b'', chop_samples=-1):
self.buffer = s
self.write_event = Event()
self.chop_samples = chop_samples
self.eof = False

def __len__(self):
return len(self.buffer)
Expand All @@ -95,14 +96,18 @@ def read(self, n=-1, timeout=None):
return_time = 1e10 if timeout is None else (
timeout + time.time()
)
while len(self.buffer) < n:
while len(self.buffer) < n and not self.eof:
self.write_event.clear()
if not self.write_event.wait(return_time - time.time()):
return b''
chunk = self.buffer[:n]
self.buffer = self.buffer[n:]
return chunk

def close(self):
self.write_event.set()
self.eof = True

def write(self, s):
self.buffer += s
self.write_event.set()
Expand Down Expand Up @@ -210,10 +215,12 @@ def start(self):
def stop(self):
"""Stop listening and close stream"""
if self.thread:
self.running = False
if isinstance(self.stream, ReadWriteStream):
self.stream.write(b'\0' * self.chunk_size)
self.thread.join()
self.stream.close()
else:
self.running = False
if current_thread() is not self.thread:
self.thread.join()
self.thread = None

self.engine.stop()
Expand All @@ -234,6 +241,11 @@ def _handle_predictions(self):
while self.running:
chunk = self.stream.read(self.chunk_size)

if len(chunk) < self.chunk_size: # EOF
self.stop()
self.running = False
return

if self.is_paused:
continue

Expand Down
3 changes: 3 additions & 0 deletions test/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ COPY requirements/test.txt mycroft-precise/requirements/
RUN pip install -r mycroft-precise/requirements/test.txt
COPY requirements/prod.txt mycroft-precise/requirements/
RUN pip install -r mycroft-precise/requirements/prod.txt
RUN ls
RUN pwd
RUN pip install runner
COPY . mycroft-precise

# Clone the devops repository, which contiains helper scripts for some continuous
Expand Down
38 changes: 34 additions & 4 deletions test/scripts/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,51 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil

import pytest

from precise.scripts.train import TrainScript
from test.scripts.test_train import DummyTrainFolder
from test.scripts.test_utils.temp_folder import TempFolder
from test.scripts.test_utils.dummy_train_folder import DummyTrainFolder


@pytest.fixture()
def train_folder():
folder = DummyTrainFolder(10)
folder = DummyTrainFolder()
folder.generate_default()
try:
yield folder
finally:
folder.cleanup()


@pytest.fixture()
def train_script(train_folder):
return TrainScript.create(model=train_folder.model, folder=train_folder.root, epochs=1)
def temp_folder():
folder = TempFolder()
try:
yield folder
finally:
folder.cleanup()


@pytest.fixture(scope='session')
def _trained_model():
"""Session wide model that gets trained once"""
folder = DummyTrainFolder()
folder.generate_default()
script = TrainScript.create(model=folder.model, folder=folder.root, epochs=100)
script.run()
try:
yield folder.model
finally:
folder.cleanup()


@pytest.fixture()
def trained_model(_trained_model, temp_folder):
"""Copy of session wide model"""
model = temp_folder.path('trained_model.net')
shutil.copy(_trained_model, model)
shutil.copy(_trained_model + '.params', model + '.params')
return model
22 changes: 5 additions & 17 deletions test/scripts/test_add_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from precise.scripts.add_noise import AddNoiseScript

from test.scripts.dummy_audio_folder import DummyAudioFolder


class DummyNoiseFolder(DummyAudioFolder):
def __init__(self, count=10):
super().__init__(count)
self.source = self.subdir('source')
self.noise = self.subdir('noise')
self.output = self.subdir('output')

self.generate_samples(self.subdir('source', 'wake-word'), 'ww-{}.wav', 1.0, self.rand(0, 2))
self.generate_samples(self.subdir('source', 'not-wake-word'), 'nww-{}.wav', 0.0, self.rand(0, 2))
self.generate_samples(self.noise, 'noise-{}.wav', 0.5, self.rand(10, 20))
from test.scripts.test_utils.dummy_noise_folder import DummyNoiseFolder


class TestAddNoise:
def get_base_data(self, count):
folders = DummyNoiseFolder(count)
folders = DummyNoiseFolder()
folders.generate_default(count)
base_args = dict(
folder=folders.source, noise_folder=folders.noise,
output_folder=folders.output
Expand All @@ -42,10 +30,10 @@ def test_run_basic(self):
folders, base_args = self.get_base_data(10)
script = AddNoiseScript.create(inflation_factor=1, **base_args)
script.run()
assert folders.count_files(folders.output) == 20
assert folders.count_files(folders.output) == 40

def test_run_basic_2(self):
folders, base_args = self.get_base_data(10)
script = AddNoiseScript.create(inflation_factor=2, **base_args)
script.run()
assert folders.count_files(folders.output) == 40
assert folders.count_files(folders.output) == 80
14 changes: 7 additions & 7 deletions test/scripts/test_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,35 +18,35 @@
from precise.scripts.calc_threshold import CalcThresholdScript
from precise.scripts.eval import EvalScript
from precise.scripts.graph import GraphScript
from test.scripts.test_utils.dummy_train_folder import DummyTrainFolder


def read_content(filename):
with open(filename) as f:
return f.read()


def test_combined(train_folder, train_script):
def test_combined(train_folder: DummyTrainFolder, trained_model: str):
"""Test a "normal" development cycle, train, evaluate and calc threshold.
"""
train_script.run()
params_file = train_folder.model + '.params'
assert isfile(train_folder.model)
params_file = trained_model + '.params'
assert isfile(trained_model)
assert isfile(params_file)

EvalScript.create(folder=train_folder.root,
models=[train_folder.model]).run()
models=[trained_model]).run()

# Ensure that the graph script generates a numpy savez file
out_file = train_folder.path('outputs.npz')
graph_script = GraphScript.create(folder=train_folder.root,
models=[train_folder.model],
models=[trained_model],
output_file=out_file)
graph_script.run()
assert isfile(out_file)

# Esure the params are updated after threshold is calculated
params_before = read_content(params_file)
CalcThresholdScript.create(folder=train_folder.root,
model=train_folder.model,
model=trained_model,
input_file=out_file).run()
assert params_before != read_content(params_file)
10 changes: 5 additions & 5 deletions test/scripts/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from os.path import isfile

from precise.scripts.convert import ConvertScript
from test.scripts.test_utils.temp_folder import TempFolder


def test_convert(train_folder, train_script):
train_script.run()

ConvertScript.create(model=train_folder.model, out=train_folder.model + '.pb').run()
assert isfile(train_folder.model + '.pb')
def test_convert(temp_folder: TempFolder, trained_model: str):
pb_model = temp_folder.path('model.pb')
ConvertScript.create(model=trained_model, out=pb_model).run()
assert isfile(pb_model)
9 changes: 5 additions & 4 deletions test/scripts/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from precise.scripts.engine import EngineScript
from runner.precise_runner import ReadWriteStream

from test.scripts.test_utils.dummy_train_folder import DummyTrainFolder


class FakeStdin:
def __init__(self, data: bytes):
Expand All @@ -35,18 +37,17 @@ def __init__(self):
self.buffer = ReadWriteStream()


def test_engine(train_folder, train_script):
def test_engine(train_folder: DummyTrainFolder, trained_model: str):
"""
Test t hat the output format of the engina matches a decimal form in the
Test t hat the output format of the engine matches a decimal form in the
range 0.0 - 1.0.
"""
train_script.run()
with open(glob.glob(join(train_folder.root, 'wake-word', '*.wav'))[0], 'rb') as f:
data = f.read()
try:
sys.stdin = FakeStdin(data)
sys.stdout = FakeStdout()
EngineScript.create(model_name=train_folder.model).run()
EngineScript.create(model_name=trained_model).run()
assert re.match(rb'[01]\.[0-9]+', sys.stdout.buffer.buffer)
finally:
sys.stdin = sys.__stdin__
Expand Down
Loading