Skip to content

Commit

Permalink
Optimized PQA keygen code
Browse files Browse the repository at this point in the history
  • Loading branch information
aabmets committed Feb 16, 2024
1 parent 3b349db commit 9e27ad8
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 18 deletions.
34 changes: 24 additions & 10 deletions quantcrypt/internal/pqa/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from enum import Enum
from abc import ABC, abstractmethod
from types import ModuleType
from typing import Literal
from typing import Literal, Type
from functools import lru_cache
from ..errors import InvalidArgsError
from .. import utils
Expand All @@ -26,13 +26,23 @@
__all__ = [
"PQAVariant",
"BasePQAParamSizes",
"BasePQAlgorithm",
"BasePQAlgorithm"
]


class PQAVariant(Enum):
CLEAN = "clean"
AVX2 = "avx2"
"""
Available binaries:
REF - Clean reference binaries for the x86_64 architecture.
OPT - Speed-optimized binaries for the x86_64 architecture.
ARM - Binaries for the aarch64 architecture.
"""
REF = "clean"
OPT = "avx2"
ARM = "aarch64"


class BasePQAParamSizes:
Expand Down Expand Up @@ -71,18 +81,18 @@ def _import(self, variant: PQAVariant) -> ModuleType:
def __init__(self, variant: PQAVariant = None):
# variant is None -> auto-select mode
try:
_var = variant or PQAVariant.AVX2
_var = variant or PQAVariant.OPT
self._lib = self._import(_var)
self.variant = _var
except ModuleNotFoundError as ex:
if variant is None:
try:
self._lib = self._import(PQAVariant.CLEAN)
self.variant = PQAVariant.CLEAN
self._lib = self._import(PQAVariant.REF)
self.variant = PQAVariant.REF
return
except ModuleNotFoundError: # pragma: no cover
pass
elif variant == PQAVariant.AVX2: # pragma: no cover
elif variant == PQAVariant.OPT: # pragma: no cover
raise ex
raise SystemExit( # pragma: no cover
"Quantcrypt Fatal Error:\n"
Expand All @@ -95,15 +105,19 @@ def _upper_name(self) -> str:
pattern='.[^A-Z]*'
)).upper()

def _keygen(self, algo_type: Literal["kem", "sign"]) -> tuple[bytes, bytes]:
def _keygen(
self,
algo_type: Literal["kem", "sign"],
error_cls: Type[errors.PQAError]
) -> tuple[bytes, bytes]:
ffi, params = FFI(), self.param_sizes
public_key = ffi.new(f"uint8_t [{params.pk_size}]")
secret_key = ffi.new(f"uint8_t [{params.sk_size}]")

name = f"_crypto_{algo_type}_keypair"
func = getattr(self._lib, self._namespace + name)
if func(public_key, secret_key) != 0: # pragma: no cover
return tuple()
raise error_cls

pk = ffi.buffer(public_key, params.pk_size)
sk = ffi.buffer(secret_key, params.sk_size)
Expand Down
5 changes: 1 addition & 4 deletions quantcrypt/internal/pqa/dss.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,7 @@ def keygen(self) -> tuple[bytes, bytes]:
library has failed to generate the keys for the current
DSS algorithm for any reason.
"""
result = self._keygen("sign")
if not result: # pragma: no cover
raise errors.DSSKeygenFailedError
return result
return self._keygen("sign", errors.DSSKeygenFailedError)

def sign(self, secret_key: bytes, message: bytes) -> bytes:
"""
Expand Down
5 changes: 1 addition & 4 deletions quantcrypt/internal/pqa/kem.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,7 @@ def keygen(self) -> tuple[bytes, bytes]:
library has failed to generate the keys for the current
KEM algorithm for any reason.
"""
result = self._keygen("kem")
if not result: # pragma: no cover
raise errors.KEMKeygenFailedError
return result
return self._keygen("kem", errors.KEMKeygenFailedError)

def encaps(self, public_key: bytes) -> tuple[bytes, bytes]:
"""
Expand Down

0 comments on commit 9e27ad8

Please sign in to comment.