Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/cryptography/hazmat/asn1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
GeneralizedTime,
Implicit,
PrintableString,
Size,
UtcTime,
decode_der,
encode_der,
Expand All @@ -20,6 +21,7 @@
"GeneralizedTime",
"Implicit",
"PrintableString",
"Size",
"UtcTime",
"decode_der",
"encode_der",
Expand Down
12 changes: 11 additions & 1 deletion src/cryptography/hazmat/asn1/asn1.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def _extract_annotation(
) -> declarative_asn1.Annotation:
default = None
encoding = None
size = None
for raw_annotation in metadata:
if isinstance(raw_annotation, Default):
if default is not None:
Expand All @@ -79,10 +80,18 @@ def _extract_annotation(
f"'{field_name}'"
)
encoding = raw_annotation
elif isinstance(raw_annotation, declarative_asn1.Size):
if size is not None:
raise TypeError(
f"multiple SIZE annotations found in field '{field_name}'"
)
size = raw_annotation
else:
raise TypeError(f"unsupported annotation: {raw_annotation}")

return declarative_asn1.Annotation(default=default, encoding=encoding)
return declarative_asn1.Annotation(
default=default, encoding=encoding, size=size
)


def _normalize_field_type(
Expand Down Expand Up @@ -217,6 +226,7 @@ class Default(typing.Generic[U]):

Explicit = declarative_asn1.Encoding.Explicit
Implicit = declarative_asn1.Encoding.Implicit
Size = declarative_asn1.Size

PrintableString = declarative_asn1.PrintableString
UtcTime = declarative_asn1.UtcTime
Expand Down
10 changes: 10 additions & 0 deletions src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ class Type:
class Annotation:
default: typing.Any | None
encoding: Encoding | None
size: Size | None
def __new__(
cls,
default: typing.Any | None = None,
encoding: Encoding | None = None,
size: Size | None = None,
) -> Annotation: ...
def is_empty(self) -> bool: ...

Expand All @@ -36,6 +38,14 @@ class Encoding:
Implicit: typing.ClassVar[type]
Explicit: typing.ClassVar[type]

class Size:
min: int
max: int | None

def __new__(cls, min: int, max: int | None) -> Size: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess if you want no min you just do min=0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup

@staticmethod
def exact(n: int) -> Size: ...

class AnnotatedType:
inner: Type
annotation: Annotation
Expand Down
13 changes: 13 additions & 0 deletions src/rust/src/declarative_asn1/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,19 @@ pub(crate) fn decode_annotated_type<'a>(
let val = decode_annotated_type(py, d, inner_ann_type)?;
list.append(val)?;
}
if let Some(size) = &ann_type.annotation.get().size {
let list_len = list.len();
let min = size.get().min;
let max = size.get().max.unwrap_or(usize::MAX);
if list_len < min || list_len > max {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is like, 2% easier to read as if !(min..=max).contains(&list_len) since thats it unavoidable that it's an inclusive range

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

return Err(CryptographyError::Py(
pyo3::exceptions::PyValueError::new_err(format!(
"SEQUENCE OF has size {0}, expected size in [{1}, {2}]",
list_len, min, max
)),
));
}
}
Ok(list.into_any())
})?
}
Expand Down
8 changes: 8 additions & 0 deletions src/rust/src/declarative_asn1/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ impl asn1::Asn1Writable for AnnotatedTypeObject<'_> {
value: e,
})
.collect();

if let Some(size) = &annotated_type.annotation.get().size {
let min = size.get().min;
let max = size.get().max.unwrap_or(usize::MAX);
if values.len() < min || values.len() > max {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

return Err(asn1::WriteError::AllocationError);
}
}
write_value(writer, &asn1::SequenceOfWriter::new(values), encoding)
}
Type::Option(cls) => {
Expand Down
38 changes: 35 additions & 3 deletions src/rust/src/declarative_asn1/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,22 +74,29 @@ pub struct Annotation {
pub(crate) default: Option<pyo3::Py<pyo3::types::PyAny>>,
#[pyo3(get)]
pub(crate) encoding: Option<pyo3::Py<Encoding>>,
#[pyo3(get)]
pub(crate) size: Option<pyo3::Py<Size>>,
}

#[pyo3::pymethods]
impl Annotation {
#[new]
#[pyo3(signature = (default = None, encoding = None))]
#[pyo3(signature = (default = None, encoding = None, size = None))]
fn new(
default: Option<pyo3::Py<pyo3::types::PyAny>>,
encoding: Option<pyo3::Py<Encoding>>,
size: Option<pyo3::Py<Size>>,
) -> Self {
Self { default, encoding }
Self {
default,
encoding,
size,
}
}

#[pyo3(signature = ())]
fn is_empty(&self) -> bool {
self.default.is_none() && self.encoding.is_none()
self.default.is_none() && self.encoding.is_none() && self.size.is_none()
}
}

Expand All @@ -99,6 +106,28 @@ pub enum Encoding {
Explicit(u32),
}

#[pyo3::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.asn1")]
pub struct Size {
pub min: usize,
pub max: Option<usize>,
}

#[pyo3::pymethods]
impl Size {
#[new]
fn new(min: usize, max: Option<usize>) -> Self {
Size { min, max }
}

#[staticmethod]
fn exact(n: usize) -> Self {
Size {
min: n,
max: Some(n),
}
}
}

#[derive(pyo3::FromPyObject)]
#[pyo3::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.asn1")]
pub struct PrintableString {
Expand Down Expand Up @@ -263,6 +292,7 @@ fn non_root_type_to_annotated<'p>(
annotation: Annotation {
default: None,
encoding: None,
size: None,
}
.into_pyobject(py)?
.unbind(),
Expand Down Expand Up @@ -328,6 +358,7 @@ mod tests {
annotation: Annotation {
default: None,
encoding: None,
size: None,
}
.into_pyobject(py)
.unwrap()
Expand All @@ -342,6 +373,7 @@ mod tests {
annotation: Annotation {
default: None,
encoding: None,
size: None,
}
.into_pyobject(py)
.unwrap()
Expand Down
2 changes: 1 addition & 1 deletion src/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ mod _rust {
#[pymodule_export]
use crate::declarative_asn1::types::{
non_root_python_to_rust, AnnotatedType, Annotation, Encoding, GeneralizedTime,
PrintableString, Type, UtcTime,
PrintableString, Size, Type, UtcTime,
};
}

Expand Down
12 changes: 12 additions & 0 deletions tests/hazmat/asn1/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,18 @@ def test_fail_multiple_explicit_annotations(self) -> None:
class Example:
invalid: Annotated[int, asn1.Explicit(0), asn1.Explicit(1)]

def test_fail_multiple_size_annotations(self) -> None:
with pytest.raises(
TypeError,
match="multiple SIZE annotations found in field 'invalid'",
):

@asn1.sequence
class Example:
invalid: Annotated[
int, asn1.Size(min=1, max=2), asn1.Size(min=1, max=2)
]

def test_fail_optional_with_default_field(self) -> None:
with pytest.raises(
TypeError,
Expand Down
105 changes: 105 additions & 0 deletions tests/hazmat/asn1/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,111 @@ class Example:
]
)

def test_ok_sequenceof_size_restriction(self) -> None:
@asn1.sequence
@_comparable_dataclass
class Example:
a: Annotated[typing.List[int], asn1.Size(min=1, max=4)]

assert_roundtrips(
[
(
Example(a=[1, 2, 3, 4]),
b"\x30\x0e\x30\x0c\x02\x01\x01\x02\x01\x02\x02\x01\x03\x02\x01\x04",
)
]
)

def test_ok_sequenceof_size_restriction_no_max(self) -> None:
@asn1.sequence
@_comparable_dataclass
class Example:
a: Annotated[typing.List[int], asn1.Size(min=1, max=None)]

assert_roundtrips(
[
(
Example(a=[1, 2, 3, 4]),
b"\x30\x0e\x30\x0c\x02\x01\x01\x02\x01\x02\x02\x01\x03\x02\x01\x04",
)
]
)

def test_ok_sequenceof_size_restriction_exact(self) -> None:
@asn1.sequence
@_comparable_dataclass
class Example:
a: Annotated[typing.List[int], asn1.Size.exact(4)]

assert_roundtrips(
[
(
Example(a=[1, 2, 3, 4]),
b"\x30\x0e\x30\x0c\x02\x01\x01\x02\x01\x02\x02\x01\x03\x02\x01\x04",
)
]
)

def test_fail_sequenceof_size_too_big(self) -> None:
@asn1.sequence
@_comparable_dataclass
class Example:
a: Annotated[typing.List[int], asn1.Size(min=1, max=2)]

with pytest.raises(
ValueError,
match=re.escape("SEQUENCE OF has size 4, expected size in [1, 2]"),
):
asn1.decode_der(
Example,
b"\x30\x0e\x30\x0c\x02\x01\x01\x02\x01\x02\x02\x01\x03\x02\x01\x04",
)

with pytest.raises(
ValueError,
):
asn1.encode_der(Example(a=[1, 2, 3, 4]))

def test_fail_sequenceof_size_too_small(self) -> None:
@asn1.sequence
@_comparable_dataclass
class Example:
a: Annotated[typing.List[int], asn1.Size(min=5, max=6)]

with pytest.raises(
ValueError,
match=re.escape("SEQUENCE OF has size 4, expected size in [5, 6]"),
):
asn1.decode_der(
Example,
b"\x30\x0e\x30\x0c\x02\x01\x01\x02\x01\x02\x02\x01\x03\x02\x01\x04",
)

with pytest.raises(
ValueError,
):
asn1.encode_der(Example(a=[1, 2, 3, 4]))

def test_fail_sequenceof_size_not_exact(self) -> None:
@asn1.sequence
@_comparable_dataclass
class Example:
a: Annotated[typing.List[int], asn1.Size.exact(5)]

with pytest.raises(
ValueError,
match=re.escape("SEQUENCE OF has size 4, expected size in [5, 5]"),
):
asn1.decode_der(
Example,
b"\x30\x0e\x30\x0c\x02\x01\x01\x02\x01\x02\x02\x01\x03\x02\x01\x04",
)

with pytest.raises(
ValueError,
):
asn1.encode_der(Example(a=[1, 2, 3, 4]))

def test_ok_sequence_with_optionals(self) -> None:
@asn1.sequence
@_comparable_dataclass
Expand Down