Skip to content

Commit 7622f40

Browse files
committed
fix: refinement type assert cast bug
1 parent 1762588 commit 7622f40

File tree

4 files changed

+22
-1
lines changed

4 files changed

+22
-1
lines changed

crates/erg_common/triple.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,14 @@ impl<T> Triple<T, T> {
141141
Triple::Ok(a) | Triple::Err(a) => Some(a),
142142
}
143143
}
144+
145+
pub fn merge_or(self, default: T) -> T {
146+
match self {
147+
Triple::None => default,
148+
Triple::Ok(ok) => ok,
149+
Triple::Err(err) => err,
150+
}
151+
}
144152
}
145153

146154
impl<T, E: std::error::Error> Triple<T, E> {

crates/erg_compiler/context/inquire.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3657,6 +3657,7 @@ impl Context {
36573657
/// ```erg
36583658
/// recover_typarams(Int, Nat) == Nat
36593659
/// recover_typarams(Array!(Int, _), Array(Nat, 2)) == Array!(Nat, 2)
3660+
/// recover_typarams(Str or NoneType, {"a", "b"}) == {"a", "b"}
36603661
/// ```
36613662
/// ```erg
36623663
/// # REVIEW: should be?
@@ -3667,7 +3668,8 @@ impl Context {
36673668
let is_never =
36683669
self.subtype_of(&intersec, &Type::Never) && guard.to.as_ref() != &Type::Never;
36693670
if !is_never {
3670-
return Ok(intersec);
3671+
let min = self.min(&intersec, &guard.to).merge_or(&intersec);
3672+
return Ok(min.clone());
36713673
}
36723674
if guard.to.is_monomorphic() {
36733675
if self.related(base, &guard.to) {

crates/erg_compiler/tests/infer.er

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,11 @@ c_new x, y = C.new x, y
3030
C = Class Int
3131
C.
3232
new x, y = Self x + y
33+
34+
val!() =
35+
for! [{ "a": "b" }], (pkg as {Str: Str}) =>
36+
x = pkg.get("a", "c")
37+
assert x in {"b"}
38+
val!::return x
39+
"d"
40+
val = val!()

crates/erg_compiler/tests/test.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ fn _test_infer_types() -> Result<(), ()> {
8787
let c_new_t = func2(add_r, r, c.clone()).quantify();
8888
module.context.assert_var_type("c_new", &c_new_t)?;
8989
module.context.assert_attr_type(&c, "new", &c_new_t)?;
90+
module
91+
.context
92+
.assert_var_type("val", &v_enum(set! { "b".into(), "d".into() }))?;
9093
Ok(())
9194
}
9295

0 commit comments

Comments
 (0)