Skip to content

Commit e4b2063

Browse files
authored
Adding support for Reals (#50)
This PR would add support for `Real` values. In more detail: - It adds the type `real`, with the helper method `real_sort` - It adds the method `con_to_real` for converting `Int` values to reals - It adds the method `rdiv` which corresponds to division for reals (i.e. `/`) - It adds support for parsing real valued solutions (see below) ## On Parsing `Real` Solutions As it turns out, parsing the satisfying assignment for solutions to real values is not as straightforward. In many cases, the SMT solver will represent a solution as a fraction, e.g. if it assigns the value `2.5` to a variable, when asked for the solution, it will return it as the expression `/ 5.0 2.0`. For my project, I am fine with the loss of precision when evaluating this expression to an `f32` or `f64`. The `get_f32` and `get_f64` methods in this PR will evaluate this expression and return the result. However, if this is not desired, the next section discusses possible alternatives. I have written a small test that should check all possible variations of solutions that can be returned by the SMT solver: ```rust #[test] fn test_real_numbers() { let mut ctx = ContextBuilder::new() .solver("z3") .solver_args(["-smt2", "-in"]) .build() .unwrap(); let x = ctx.declare_const("x", ctx.real_sort()).unwrap(); // x == 2.0 ctx.assert(ctx.eq(x, ctx.decimal(2.0))).unwrap(); assert_eq!(ctx.check().unwrap(), Response::Sat); let solution = ctx.get_value(vec![x]).unwrap(); let sol = ctx.get_f64(solution[0].1).unwrap(); // Z3 returns `2.0` assert!(sol == 2.0, "Expected solution to be 2.5, got {}", sol); let y = ctx.declare_const("y", ctx.real_sort()).unwrap(); // y == -2.0 ctx.assert(ctx.eq(y, ctx.decimal(-2.0))).unwrap(); assert_eq!(ctx.check().unwrap(), Response::Sat); let solution = ctx.get_value(vec![y]).unwrap(); let sol = ctx.get_f64(solution[0].1).unwrap(); // Z3 returns `- 2.0` assert!(sol == -2.0, "Expected solution to be 2.5, got {}", sol); let z = ctx.declare_const("z", ctx.real_sort()).unwrap(); // z == 2.5 / 1.0 ctx.assert(ctx.eq(z, ctx.rdiv(ctx.decimal(2.5), ctx.decimal(1.0)))) .unwrap(); assert_eq!(ctx.check().unwrap(), Response::Sat); let solution = ctx.get_value(vec![z]).unwrap(); let sol = ctx.get_f64(solution[0].1).unwrap(); // Z3 returns `(/ 5.0 2.0)` assert!(sol == 2.5, "Expected solution to be 2.5, got {}", sol); let a = ctx.declare_const("a", ctx.real_sort()).unwrap(); // a == 2.5 / -1.0 ctx.assert(ctx.eq(a, ctx.rdiv(ctx.decimal(2.5), ctx.decimal(-1.0)))) .unwrap(); assert_eq!(ctx.check().unwrap(), Response::Sat); let solution = ctx.get_value(vec![a]).unwrap(); let sol = ctx.get_f64(solution[0].1).unwrap(); // Z3 returns `(- (/ 5.0 2.0))` assert!(sol == -2.5, "Expected solution to be -2.5, got {}", sol); } ``` ### Possible Alternatives If the loss of precision that occurs when parsing the solution is not desired, I see two alternatives: - Explicitly have functions like `get_numerator` and `get_denominator` - Defining / importing a fractional type (possibly hidden behind a feature flag) Let me know your thoughts!
1 parent 3dd756e commit e4b2063

File tree

4 files changed

+177
-14
lines changed

4 files changed

+177
-14
lines changed

src/context.rs

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ impl Context {
704704
/// `Context`. Failure to do so is safe, but may trigger a panic or return
705705
/// invalid data.
706706
pub fn get_u8(&self, expr: SExpr) -> Option<u8> {
707-
self.arena.get_t(expr)
707+
self.arena.get_i(expr)
708708
}
709709

710710
/// Get the data for the given s-expression as an `u16`.
@@ -716,7 +716,7 @@ impl Context {
716716
/// `Context`. Failure to do so is safe, but may trigger a panic or return
717717
/// invalid data.
718718
pub fn get_u16(&self, expr: SExpr) -> Option<u16> {
719-
self.arena.get_t(expr)
719+
self.arena.get_i(expr)
720720
}
721721

722722
/// Get the data for the given s-expression as an `u32`.
@@ -728,7 +728,7 @@ impl Context {
728728
/// `Context`. Failure to do so is safe, but may trigger a panic or return
729729
/// invalid data.
730730
pub fn get_u32(&self, expr: SExpr) -> Option<u32> {
731-
self.arena.get_t(expr)
731+
self.arena.get_i(expr)
732732
}
733733

734734
/// Get the data for the given s-expression as an `u64`.
@@ -740,7 +740,7 @@ impl Context {
740740
/// `Context`. Failure to do so is safe, but may trigger a panic or return
741741
/// invalid data.
742742
pub fn get_u64(&self, expr: SExpr) -> Option<u64> {
743-
self.arena.get_t(expr)
743+
self.arena.get_i(expr)
744744
}
745745

746746
/// Get the data for the given s-expression as an `u128`.
@@ -752,7 +752,7 @@ impl Context {
752752
/// `Context`. Failure to do so is safe, but may trigger a panic or return
753753
/// invalid data.
754754
pub fn get_u128(&self, expr: SExpr) -> Option<u128> {
755-
self.arena.get_t(expr)
755+
self.arena.get_i(expr)
756756
}
757757

758758
/// Get the data for the given s-expression as an `usize`.
@@ -764,7 +764,7 @@ impl Context {
764764
/// `Context`. Failure to do so is safe, but may trigger a panic or return
765765
/// invalid data.
766766
pub fn get_usize(&self, expr: SExpr) -> Option<usize> {
767-
self.arena.get_t(expr)
767+
self.arena.get_i(expr)
768768
}
769769

770770
/// Get the data for the given s-expression as an `i8`.
@@ -776,7 +776,7 @@ impl Context {
776776
/// `Context`. Failure to do so is safe, but may trigger a panic or return
777777
/// invalid data.
778778
pub fn get_i8(&self, expr: SExpr) -> Option<i8> {
779-
self.arena.get_t(expr)
779+
self.arena.get_i(expr)
780780
}
781781

782782
/// Get the data for the given s-expression as an `i16`.
@@ -788,7 +788,7 @@ impl Context {
788788
/// `Context`. Failure to do so is safe, but may trigger a panic or return
789789
/// invalid data.
790790
pub fn get_i16(&self, expr: SExpr) -> Option<i16> {
791-
self.arena.get_t(expr)
791+
self.arena.get_i(expr)
792792
}
793793

794794
/// Get the data for the given s-expression as an `i32`.
@@ -800,7 +800,7 @@ impl Context {
800800
/// `Context`. Failure to do so is safe, but may trigger a panic or return
801801
/// invalid data.
802802
pub fn get_i32(&self, expr: SExpr) -> Option<i32> {
803-
self.arena.get_t(expr)
803+
self.arena.get_i(expr)
804804
}
805805

806806
/// Get the data for the given s-expression as an `i64`.
@@ -812,7 +812,7 @@ impl Context {
812812
/// `Context`. Failure to do so is safe, but may trigger a panic or return
813813
/// invalid data.
814814
pub fn get_i64(&self, expr: SExpr) -> Option<i64> {
815-
self.arena.get_t(expr)
815+
self.arena.get_i(expr)
816816
}
817817

818818
/// Get the data for the given s-expression as an `i128`.
@@ -824,7 +824,7 @@ impl Context {
824824
/// `Context`. Failure to do so is safe, but may trigger a panic or return
825825
/// invalid data.
826826
pub fn get_i128(&self, expr: SExpr) -> Option<i128> {
827-
self.arena.get_t(expr)
827+
self.arena.get_i(expr)
828828
}
829829

830830
/// Get the data for the given s-expression as an `isize`.
@@ -836,7 +836,30 @@ impl Context {
836836
/// `Context`. Failure to do so is safe, but may trigger a panic or return
837837
/// invalid data.
838838
pub fn get_isize(&self, expr: SExpr) -> Option<isize> {
839-
self.arena.get_t(expr)
839+
self.arena.get_i(expr)
840+
}
841+
842+
/// Get the data for the given s-expression as a `f32`.
843+
///
844+
/// This allows you to inspect s-expressions. If the s-expression is not an
845+
/// cannot be parsed into an `f32` this function returns `None`.
846+
///
847+
/// You may only pass in `SExpr`s that were created by this
848+
/// `Context`. Failure to do so is safe, but may trigger a panic or return
849+
/// invalid data.
850+
pub fn get_f32(&self, expr: SExpr) -> Option<f32> {
851+
self.arena.get_f(expr)
852+
}
853+
854+
/// Get the data for the given s-expression as a `f64`.
855+
///
856+
/// This allows you to inspect s-expressions. If the s-expression is not an
857+
/// cannot be parsed into an `f64` this function returns `None`.
858+
/// You may only pass in `SExpr`s that were created by this
859+
/// `Context`. Failure to do so is safe, but may trigger a panic or return
860+
/// invalid data.
861+
pub fn get_f64(&self, expr: SExpr) -> Option<f64> {
862+
self.arena.get_f(expr)
840863
}
841864

842865
/// Access "known" atoms.
@@ -997,6 +1020,19 @@ impl Context {
9971020
chainable!(gte, gte_many, gte);
9981021
}
9991022

1023+
/// # Real Helpers
1024+
///
1025+
/// These methods help you construct s-expressions for various real operations.
1026+
impl Context {
1027+
/// The `Real` sort.
1028+
pub fn real_sort(&self) -> SExpr {
1029+
self.atoms.real
1030+
}
1031+
1032+
left_assoc!(rdiv, rdiv_many, slash);
1033+
unary!(conv_to_real, to_real);
1034+
}
1035+
10001036
/// # Array Helpers
10011037
///
10021038
/// These methods help you construct s-expressions for various array operations.

src/known_atoms.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ macro_rules! for_each_known_atom {
4848
lt: "<";
4949
gte: ">=";
5050
gt: ">";
51+
real: "Real";
52+
slash: "/";
53+
to_real: "to_real";
5154
array: "Array";
5255
select: "select";
5356
store: "store";

src/sexpr.rs

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::{cell::RefCell, collections::HashMap};
1+
use std::{cell::RefCell, collections::HashMap, ops::Div};
22

33
#[cfg(debug_assertions)]
44
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
@@ -250,7 +250,7 @@ impl Arena {
250250
return Some(data);
251251
}
252252

253-
pub(crate) fn get_t<T: TryParseInt>(&self, expr: SExpr) -> Option<T> {
253+
pub(crate) fn get_i<T: TryParseInt>(&self, expr: SExpr) -> Option<T> {
254254
let inner = self.0.borrow();
255255

256256
if expr.is_atom() {
@@ -278,6 +278,58 @@ impl Arena {
278278

279279
None
280280
}
281+
282+
pub(crate) fn get_f<T: TryParseFloat + Div<Output = T>>(&self, expr: SExpr) -> Option<T> {
283+
let inner = self.0.borrow();
284+
285+
if expr.is_atom() {
286+
let data = inner.strings[expr.index()].as_str();
287+
return T::try_parse_t(data, false);
288+
}
289+
290+
if expr.is_list() {
291+
let mut data = inner.lists[expr.index()].as_slice();
292+
293+
if !([1, 2, 3].contains(&data.len())) || !data[0].is_atom() {
294+
return None;
295+
}
296+
297+
let mut index = 0;
298+
let is_negated = match inner.strings[data[0].index()].as_str() {
299+
"-" => {
300+
index += 1;
301+
true
302+
}
303+
"+" => {
304+
index += 1;
305+
false
306+
}
307+
_ => false,
308+
};
309+
310+
// Solution could be of the form `(- (/ 1.0 2.0))`
311+
if data.len() == 2 && !data[1].is_atom() {
312+
data = inner.lists[data[1].index()].as_slice();
313+
index = 0;
314+
}
315+
316+
let data = &data[index..];
317+
318+
if data.len() == 1 {
319+
return T::try_parse_t(inner.strings[data[0].index()].as_str(), is_negated);
320+
}
321+
322+
// Solution returned is a fraction of the form `(/ 1.0 2.0)`
323+
if data.len() == 3 && inner.strings[data[0].index()].as_str() == "/" {
324+
let numerator =
325+
T::try_parse_t(inner.strings[data[1].index()].as_str(), is_negated)?;
326+
let denominator = T::try_parse_t(inner.strings[data[2].index()].as_str(), false)?;
327+
return Some(numerator / denominator);
328+
}
329+
}
330+
331+
None
332+
}
281333
}
282334

283335
pub(crate) trait TryParseInt: Sized {
@@ -310,6 +362,30 @@ macro_rules! impl_get_int {
310362

311363
impl_get_int!(u8 u16 u32 u64 u128 usize i8 i16 i32 i64 i128 isize);
312364

365+
pub(crate) trait TryParseFloat: Sized {
366+
fn try_parse_t(a: &str, negate: bool) -> Option<Self>;
367+
}
368+
369+
macro_rules! impl_get_float {
370+
( $( $ty:ty )* ) => {
371+
$(
372+
impl TryParseFloat for $ty {
373+
fn try_parse_t(a: &str, negate: bool) -> Option<Self> {
374+
let mut x = a.parse::<$ty>().ok()?;
375+
376+
if negate {
377+
x = -x;
378+
}
379+
380+
Some(x)
381+
}
382+
}
383+
)*
384+
};
385+
}
386+
387+
impl_get_float!(f32 f64);
388+
313389
/// The data contents of an [`SExpr`][crate::SExpr].
314390
///
315391
/// ## Converting `SExprData` to an Integer

tests/real_numbers.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
use easy_smt::{ContextBuilder, Response};
2+
3+
#[test]
4+
fn test_real_numbers() {
5+
let mut ctx = ContextBuilder::new()
6+
.solver("z3")
7+
.solver_args(["-smt2", "-in"])
8+
.build()
9+
.unwrap();
10+
11+
let x = ctx.declare_const("x", ctx.real_sort()).unwrap();
12+
// x == 2.0
13+
ctx.assert(ctx.eq(x, ctx.decimal(2.0))).unwrap();
14+
assert_eq!(ctx.check().unwrap(), Response::Sat);
15+
let solution = ctx.get_value(vec![x]).unwrap();
16+
let sol = ctx.get_f64(solution[0].1).unwrap();
17+
// Z3 returns `2.0`
18+
assert_eq!(sol, 2.0);
19+
20+
let y = ctx.declare_const("y", ctx.real_sort()).unwrap();
21+
// y == -2.0
22+
ctx.assert(ctx.eq(y, ctx.decimal(-2.0))).unwrap();
23+
assert_eq!(ctx.check().unwrap(), Response::Sat);
24+
let solution = ctx.get_value(vec![y]).unwrap();
25+
let sol = ctx.get_f64(solution[0].1).unwrap();
26+
// Z3 returns `- 2.0`
27+
assert_eq!(sol, -2.0);
28+
29+
let z = ctx.declare_const("z", ctx.real_sort()).unwrap();
30+
// z == 2.5 / 1.0
31+
ctx.assert(ctx.eq(z, ctx.rdiv(ctx.decimal(2.5), ctx.decimal(1.0))))
32+
.unwrap();
33+
assert_eq!(ctx.check().unwrap(), Response::Sat);
34+
let solution = ctx.get_value(vec![z]).unwrap();
35+
let sol = ctx.get_f64(solution[0].1).unwrap();
36+
// Z3 returns `(/ 5.0 2.0)`
37+
assert_eq!(sol, 2.5);
38+
39+
let a = ctx.declare_const("a", ctx.real_sort()).unwrap();
40+
// a == 2.5 / -1.0
41+
ctx.assert(ctx.eq(a, ctx.rdiv(ctx.decimal(2.5), ctx.decimal(-1.0))))
42+
.unwrap();
43+
assert_eq!(ctx.check().unwrap(), Response::Sat);
44+
let solution = ctx.get_value(vec![a]).unwrap();
45+
let sol = ctx.get_f64(solution[0].1).unwrap();
46+
// Z3 returns `(- (/ 5.0 2.0))`
47+
assert_eq!(sol, -2.5);
48+
}

0 commit comments

Comments
 (0)