Skip to content

Commit 1100a66

Browse files
authored
chore: combine block building and verification functions (#43)
1 parent af44308 commit 1100a66

File tree

1 file changed

+40
-48
lines changed

1 file changed

+40
-48
lines changed

src/sha256.nr

Lines changed: 40 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -142,24 +142,14 @@ pub(crate) fn process_full_blocks<let N: u32>(
142142

143143
for i in 0..num_blocks {
144144
let msg_start = BLOCK_SIZE * i;
145-
let new_msg_block =
146-
// Safety: separate verification function
147-
unsafe { build_msg_block(msg, message_size, msg_start) };
148-
149-
// Verify the block we are compressing was appropriately constructed
150-
verify_msg_block(msg, message_size, new_msg_block, msg_start);
145+
let new_msg_block = build_msg_block(msg, message_size, msg_start);
151146

152147
blocks[i] = new_msg_block;
153148
states[i + 1] = sha256_compression(new_msg_block, states[i]);
154149
}
155150
// If message_size/BLOCK_SIZE == N/BLOCK_SIZE, and there is a remainder, we need to process the last block.
156151
if N % BLOCK_SIZE != 0 {
157-
let new_msg_block =
158-
// Safety: separate verification function
159-
unsafe { build_msg_block(msg, message_size, BLOCK_SIZE * num_blocks) };
160-
161-
// Verify the block we are compressing was appropriately constructed
162-
verify_msg_block(msg, message_size, new_msg_block, BLOCK_SIZE * num_blocks);
152+
let new_msg_block = build_msg_block(msg, message_size, BLOCK_SIZE * num_blocks);
163153

164154
blocks[num_blocks] = new_msg_block;
165155
}
@@ -171,7 +161,7 @@ pub(crate) fn process_full_blocks<let N: u32>(
171161
}
172162

173163
// Take `BLOCK_SIZE` number of bytes from `msg` starting at `msg_start` and pack them into a `MSG_BLOCK`.
174-
pub(crate) unconstrained fn build_msg_block<let N: u32>(
164+
pub(crate) unconstrained fn build_msg_block_helper<let N: u32>(
175165
msg: [u8; N],
176166
message_size: u32,
177167
msg_start: u32,
@@ -213,47 +203,49 @@ pub(crate) unconstrained fn build_msg_block<let N: u32>(
213203
msg_block
214204
}
215205

216-
// Verify the block we are compressing was appropriately constructed by `build_msg_block`
217-
// and matches the input data.
206+
// Build a message block from the input message starting at `msg_start`.
207+
//
218208
// If `message_size` is less than `msg_start` then this is called with the old non-empty block;
219209
// in that case we can skip verification, ie. no need to check that everything is zero.
220-
fn verify_msg_block<let N: u32>(
221-
msg: [u8; N],
222-
message_size: u32,
223-
msg_block: MSG_BLOCK,
224-
msg_start: u32,
225-
) {
226-
let mut msg_end = msg_start + BLOCK_SIZE;
227-
if msg_end > N {
228-
msg_end = N;
229-
}
230-
// We might have to go beyond the input to pad the fields.
231-
if msg_end % INT_SIZE != 0 {
232-
msg_end = msg_end + INT_SIZE - msg_end % INT_SIZE;
233-
}
210+
fn build_msg_block<let N: u32>(msg: [u8; N], message_size: u32, msg_start: u32) -> MSG_BLOCK {
211+
let msg_block =
212+
// Safety: We constrain the block below by reconstructing each `u32` word from the input bytes.
213+
unsafe { build_msg_block_helper(msg, message_size, msg_start) };
214+
215+
if !is_unconstrained() {
216+
let mut msg_end = msg_start + BLOCK_SIZE;
217+
if msg_end > N {
218+
msg_end = N;
219+
}
220+
// We might have to go beyond the input to pad the fields.
221+
if msg_end % INT_SIZE != 0 {
222+
msg_end = msg_end + INT_SIZE - msg_end % INT_SIZE;
223+
}
234224

235-
// Reconstructed packed item.
236-
let mut msg_item: u32 = 0;
237-
238-
// Inclusive at the end so that we can compare the last item.
239-
let mut i: u32 = 0;
240-
for k in msg_start..=msg_end {
241-
if k % INT_SIZE == 0 {
242-
// If we consumed some input we can compare against the block.
243-
if (msg_start < message_size) & (k > msg_start) {
244-
println(f"i is {i}");
245-
assert_eq(msg_block[i], msg_item as u32);
246-
i = i + 1;
247-
msg_item = 0;
225+
// Reconstructed packed item.
226+
let mut msg_item: u32 = 0;
227+
228+
// Inclusive at the end so that we can compare the last item.
229+
let mut i: u32 = 0;
230+
for k in msg_start..=msg_end {
231+
if k % INT_SIZE == 0 {
232+
// If we consumed some input we can compare against the block.
233+
if (msg_start < message_size) & (k > msg_start) {
234+
assert_eq(msg_block[i], msg_item as u32);
235+
i = i + 1;
236+
msg_item = 0;
237+
}
238+
}
239+
// Shift the accumulator
240+
msg_item = msg_item << 8;
241+
// If we have input to consume, add it at the rightmost position.
242+
if k < message_size & k < msg_end {
243+
msg_item = msg_item + msg[k] as u32;
248244
}
249-
}
250-
// Shift the accumulator
251-
msg_item = msg_item << 8;
252-
// If we have input to consume, add it at the rightmost position.
253-
if k < message_size & k < msg_end {
254-
msg_item = msg_item + msg[k] as u32;
255245
}
256246
}
247+
248+
msg_block
257249
}
258250

259251
// Verify that a region of ints in the message block are (partially) zeroed,

0 commit comments

Comments
 (0)