Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkgs/core/swarmauri_core/crypto/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class JWAAlg(str, Enum):
RSA_OAEP = "RSA-OAEP"
RSA_OAEP_256 = "RSA-OAEP-256"
ECDH_ES = "ECDH-ES"
ECDH_ES_X25519_MLKEM768 = "ECDH-ES+X25519MLKEM768"
DIR = "dir"
A128GCM = "A128GCM"
A192GCM = "A192GCM"
Expand Down
1 change: 1 addition & 0 deletions pkgs/standards/swarmauri_crypto_jwe/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"swarmauri_core",
"swarmauri_base",
"cryptography>=41",
"pqcrypto>=0.3.1",
]
keywords = [
'swarmauri',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations

import base64
import binascii
import json
import os
import zlib
Expand All @@ -22,6 +23,8 @@
load_pem_private_key,
)

from pqcrypto.kem import kyber768

from swarmauri_core.crypto.types import JWAAlg


Expand Down Expand Up @@ -168,6 +171,38 @@ def _load_ecdh_recipient_public(jwk_or_pem: Any) -> Tuple[str, Any]:
raise TypeError("Unsupported recipient public key format for ECDH-ES.")


def _bytes_from_any(value: Any, *, allow_mapping_key: str | None = None) -> bytes:
if allow_mapping_key and isinstance(value, Mapping):
if allow_mapping_key not in value:
raise ValueError(
f"Mapping is missing required key '{allow_mapping_key}' for ML-KEM-768"
)
return _bytes_from_any(value[allow_mapping_key])
if isinstance(value, (bytes, bytearray)):
return bytes(value)
if isinstance(value, str):
decoders = (_b64u_dec, lambda s: base64.b64decode(s, validate=False))
for decoder in decoders:
try:
return decoder(value)
except (binascii.Error, ValueError):
continue
raise ValueError("Failed to decode ML-KEM key material from string.")
raise TypeError("Unsupported key material type; expected bytes or str.")


def _load_mlkem768_public(value: Any) -> bytes:
if value is None:
raise ValueError("ML-KEM-768 public key is required for hybrid encryption.")
return _bytes_from_any(value, allow_mapping_key="pub")


def _load_mlkem768_private(value: Any) -> bytes:
if value is None:
raise ValueError("ML-KEM-768 private key is required for hybrid decryption.")
return _bytes_from_any(value, allow_mapping_key="priv")


def _concat_kdf(
z: bytes,
enc: JWAAlg,
Expand Down Expand Up @@ -312,6 +347,54 @@ async def encrypt_compact(
)
cek = _concat_kdf(z, enc, hashes.SHA256(), apu_b, apv_b)
protected["epk"] = epk_header
elif alg == JWAAlg.ECDH_ES_X25519_MLKEM768:
x25519_info = key.get("x25519")
if x25519_info is None:
raise ValueError(
"Hybrid alg requires 'x25519' entry containing recipient public key."
)
crv, rpk = _load_ecdh_recipient_public(x25519_info)
if crv != "X25519":
raise ValueError(
"Hybrid alg requires an X25519 recipient public key for the classical component."
)
esk = x25519.X25519PrivateKey.generate()
epk = esk.public_key()
z_classical = esk.exchange(rpk) # type: ignore[arg-type]
epk_header = _x25519_jwk_from_public_key(epk)

mlkem_pub = _load_mlkem768_public(
key.get("mlkem768")
or key.get("mlkem768_pub")
or key.get("pqc")
or key.get("mlkem")
)
pqc_ciphertext, pqc_secret = kyber768.encapsulate(mlkem_pub)

apu_b = None
apv_b = None
if "apu" in (header_extra or {}):
apu_b = (
_b64u_dec(header_extra["apu"])
if isinstance(header_extra["apu"], str)
else header_extra["apu"]
)
if "apv" in (header_extra or {}):
apv_b = (
_b64u_dec(header_extra["apv"])
if isinstance(header_extra["apv"], str)
else header_extra["apv"]
)

cek = _concat_kdf(
z_classical + pqc_secret, enc, hashes.SHA256(), apu_b, apv_b
)
protected["epk"] = epk_header
protected["pqc"] = {
"kty": "ML-KEM",
"kem": "ML-KEM-768",
"ct": _b64u(pqc_ciphertext),
}
else:
raise ValueError(f"Unsupported alg '{alg.value}'")

Expand Down Expand Up @@ -343,6 +426,7 @@ async def decrypt_compact(
rsa_private_pem: Optional[Union[str, bytes]] = None,
rsa_private_password: Optional[Union[str, bytes]] = None,
ecdh_private_key: Optional[Any] = None,
mlkem_private_key: Optional[Any] = None,
expected_algs: Optional[Iterable[JWAAlg]] = None,
expected_encs: Optional[Iterable[JWAAlg]] = None,
aad: Optional[Union[bytes, str]] = None,
Expand All @@ -355,6 +439,8 @@ async def decrypt_compact(
RSA-OAEP algorithms.
rsa_private_password (Union[str, bytes]): Password for the RSA key.
ecdh_private_key (Any): Private key for ECDH-ES.
mlkem_private_key (Any): Private key for ML-KEM-768 when using the hybrid
algorithm.
expected_algs (Iterable[JWAAlg]): Allowed algorithm values.
expected_encs (Iterable[JWAAlg]): Allowed encryption values.
aad (Union[bytes, str]): Additional authenticated data.
Expand Down Expand Up @@ -439,6 +525,39 @@ async def decrypt_compact(
apu_b = _b64u_dec(header["apu"]) if "apu" in header else None
apv_b = _b64u_dec(header["apv"]) if "apv" in header else None
cek = _concat_kdf(z, enc, hashes.SHA256(), apu_b, apv_b)
elif alg == JWAAlg.ECDH_ES_X25519_MLKEM768:
if not isinstance(ecdh_private_key, x25519.X25519PrivateKey):
raise TypeError(
"Hybrid alg requires an X25519PrivateKey for the classical component."
)
if mlkem_private_key is None:
raise ValueError(
"mlkem_private_key is required for hybrid ECDH-ES+X25519MLKEM768 decryption."
)
epk = header.get("epk")
if not (isinstance(epk, Mapping) and epk.get("kty") == "OKP"):
raise ValueError("Missing/invalid 'epk' for hybrid ECDH-ES header.")
if epk.get("crv") != "X25519":
raise ValueError("Hybrid 'epk' must declare crv='X25519'.")
z_classical = ecdh_private_key.exchange(
x25519.X25519PublicKey.from_public_bytes(_b64u_dec(epk["x"]))
)

pqc_info = header.get("pqc")
if not isinstance(pqc_info, Mapping):
raise ValueError("Missing 'pqc' object in hybrid header.")
ct_b64 = pqc_info.get("ct")
if not isinstance(ct_b64, str):
raise ValueError("Hybrid header 'pqc.ct' must be a base64url string.")
pqc_ciphertext = _b64u_dec(ct_b64)
mlkem_sk = _load_mlkem768_private(mlkem_private_key)
pqc_secret = kyber768.decapsulate(pqc_ciphertext, mlkem_sk)

apu_b = _b64u_dec(header["apu"]) if "apu" in header else None
apv_b = _b64u_dec(header["apv"]) if "apv" in header else None
cek = _concat_kdf(
z_classical + pqc_secret, enc, hashes.SHA256(), apu_b, apv_b
)
else:
raise ValueError(f"Unsupported alg '{alg.value}'")

Expand Down
65 changes: 65 additions & 0 deletions pkgs/standards/swarmauri_crypto_jwe/tests/test_hybrid_alg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import asyncio
import base64
import json

import pytest
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import x25519
from pqcrypto.kem import kyber768

from swarmauri_core.crypto.types import JWAAlg
from swarmauri_crypto_jwe import JweCrypto


def _b64u(data: bytes) -> str:
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")


@pytest.mark.unit
@pytest.mark.test
def test_ecdh_es_x25519_mlkem768_round_trip() -> None:
crypto = JweCrypto()

recipient_x_priv = x25519.X25519PrivateKey.generate()
recipient_x_pub = recipient_x_priv.public_key()
recipient_x_pub_bytes = recipient_x_pub.public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
)
mlkem_pub, mlkem_priv = kyber768.generate_keypair()

jwe = asyncio.run(
crypto.encrypt_compact(
payload=b"hybrid",
alg=JWAAlg.ECDH_ES_X25519_MLKEM768,
enc=JWAAlg.A256GCM,
key={
"x25519": {
"kty": "OKP",
"crv": "X25519",
"x": _b64u(recipient_x_pub_bytes),
},
"mlkem768": base64.b64encode(mlkem_pub).decode("ascii"),
},
)
)

protected_b64 = jwe.split(".")[0]
padding = "=" * ((4 - len(protected_b64) % 4) % 4)
protected = json.loads(base64.urlsafe_b64decode(protected_b64 + padding))

assert protected["alg"] == JWAAlg.ECDH_ES_X25519_MLKEM768.value
assert protected["enc"] == JWAAlg.A256GCM.value
assert protected["epk"]["crv"] == "X25519"
assert protected["pqc"]["kem"] == "ML-KEM-768"
assert isinstance(protected["pqc"]["ct"], str)

result = asyncio.run(
crypto.decrypt_compact(
jwe,
ecdh_private_key=recipient_x_priv,
mlkem_private_key=mlkem_priv,
)
)

assert result.plaintext == b"hybrid"
Loading