Skip to content

Commit 0761ea4

Browse files
authored
[ty] Eagerly evaluate types.UnionType elements as type expressions (#21531)
## Summary Eagerly evaluate the elements of a PEP 604 union in value position (e.g. `IntOrStr = int | str`) as type expressions and store the result (the corresponding `Type::Union` if all elements are valid type expressions, or the first encountered `InvalidTypeExpressionError`) on the `UnionTypeInstance`, such that the `Type::Union(…)` does not need to be recomputed every time the implicit type alias is used in a type annotation. This might lead to performance improvements for large unions, but is also necessary for correctness, because the elements of the union might refer to type variables that need to be looked up in the scope of the type alias, not at the usage site. ## Test Plan New Markdown tests
1 parent 416e226 commit 0761ea4

File tree

7 files changed

+200
-100
lines changed

7 files changed

+200
-100
lines changed

crates/ty_python_semantic/resources/mdtest/implicit_type_aliases.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,13 +191,13 @@ def _(
191191
reveal_type(int_or_callable) # revealed: int | ((str, /) -> bytes)
192192
reveal_type(callable_or_int) # revealed: ((str, /) -> bytes) | int
193193
# TODO should be Unknown | int
194-
reveal_type(type_var_or_int) # revealed: T@_ | int
194+
reveal_type(type_var_or_int) # revealed: typing.TypeVar | int
195195
# TODO should be int | Unknown
196-
reveal_type(int_or_type_var) # revealed: int | T@_
196+
reveal_type(int_or_type_var) # revealed: int | typing.TypeVar
197197
# TODO should be Unknown | None
198-
reveal_type(type_var_or_none) # revealed: T@_ | None
198+
reveal_type(type_var_or_none) # revealed: typing.TypeVar | None
199199
# TODO should be None | Unknown
200-
reveal_type(none_or_type_var) # revealed: None | T@_
200+
reveal_type(none_or_type_var) # revealed: None | typing.TypeVar
201201
```
202202

203203
If a type is unioned with itself in a value expression, the result is just that type. No

crates/ty_python_semantic/resources/mdtest/narrow/isinstance.md

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -159,19 +159,43 @@ IntOrStr = Union[int, str]
159159
reveal_type(IntOrStr) # revealed: types.UnionType
160160

161161
def _(x: int | str | bytes | memoryview | range):
162-
# TODO: no error
163-
# error: [invalid-argument-type]
164162
if isinstance(x, IntOrStr):
165-
# TODO: Should be `int | str`
166-
reveal_type(x) # revealed: int | str | bytes | memoryview[int] | range
167-
# TODO: no error
168-
# error: [invalid-argument-type]
163+
reveal_type(x) # revealed: int | str
169164
elif isinstance(x, Union[bytes, memoryview]):
170-
# TODO: Should be `bytes | memoryview[int]`
171-
reveal_type(x) # revealed: int | str | bytes | memoryview[int] | range
165+
reveal_type(x) # revealed: bytes | memoryview[int]
172166
else:
173-
# TODO: Should be `range`
174-
reveal_type(x) # revealed: int | str | bytes | memoryview[int] | range
167+
reveal_type(x) # revealed: range
168+
169+
def _(x: int | str | None):
170+
if isinstance(x, Union[int, None]):
171+
reveal_type(x) # revealed: int | None
172+
else:
173+
reveal_type(x) # revealed: str
174+
175+
ListStrOrInt = Union[list[str], int]
176+
177+
def _(x: dict[int, str] | ListStrOrInt):
178+
# TODO: this should ideally be an error
179+
if isinstance(x, ListStrOrInt):
180+
# TODO: this should not be narrowed
181+
reveal_type(x) # revealed: list[str] | int
182+
183+
# TODO: this should ideally be an error
184+
if isinstance(x, Union[list[str], int]):
185+
# TODO: this should not be narrowed
186+
reveal_type(x) # revealed: list[str] | int
187+
```
188+
189+
## `Optional` as `classinfo`
190+
191+
```py
192+
from typing import Optional
193+
194+
def _(x: int | str | None):
195+
if isinstance(x, Optional[int]):
196+
reveal_type(x) # revealed: int | None
197+
else:
198+
reveal_type(x) # revealed: str
175199
```
176200

177201
## `classinfo` is a `typing.py` special form
@@ -289,6 +313,23 @@ def _(flag: bool):
289313
reveal_type(x) # revealed: Literal[1, "a"]
290314
```
291315

316+
## Generic aliases are not supported as second argument
317+
318+
The `classinfo` argument cannot be a generic alias:
319+
320+
```py
321+
def _(x: list[str] | list[int] | list[bytes]):
322+
# TODO: Ideally, this would be an error (requires https://github.com/astral-sh/ty/issues/116)
323+
if isinstance(x, list[int]):
324+
# No narrowing here:
325+
reveal_type(x) # revealed: list[str] | list[int] | list[bytes]
326+
327+
# error: [invalid-argument-type] "Invalid second argument to `isinstance`"
328+
if isinstance(x, list[int] | list[str]):
329+
# No narrowing here:
330+
reveal_type(x) # revealed: list[str] | list[int] | list[bytes]
331+
```
332+
292333
## `type[]` types are narrowed as well as class-literal types
293334

294335
```py

crates/ty_python_semantic/resources/mdtest/narrow/issubclass.md

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -212,19 +212,12 @@ IntOrStr = Union[int, str]
212212
reveal_type(IntOrStr) # revealed: types.UnionType
213213

214214
def f(x: type[int | str | bytes | range]):
215-
# TODO: No error
216-
# error: [invalid-argument-type]
217215
if issubclass(x, IntOrStr):
218-
# TODO: Should be `type[int] | type[str]`
219-
reveal_type(x) # revealed: type[int] | type[str] | type[bytes] | <class 'range'>
220-
# TODO: No error
221-
# error: [invalid-argument-type]
216+
reveal_type(x) # revealed: type[int] | type[str]
222217
elif issubclass(x, Union[bytes, memoryview]):
223-
# TODO: Should be `type[bytes]`
224-
reveal_type(x) # revealed: type[int] | type[str] | type[bytes] | <class 'range'>
218+
reveal_type(x) # revealed: type[bytes]
225219
else:
226-
# TODO: Should be `<class 'range'>`
227-
reveal_type(x) # revealed: type[int] | type[str] | type[bytes] | <class 'range'>
220+
reveal_type(x) # revealed: <class 'range'>
228221
```
229222

230223
## Special cases

crates/ty_python_semantic/src/types.rs

Lines changed: 98 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6738,17 +6738,10 @@ impl<'db> Type<'db> {
67386738
invalid_expressions: smallvec::smallvec_inline![InvalidTypeExpression::Generic],
67396739
fallback_type: Type::unknown(),
67406740
}),
6741-
KnownInstanceType::UnionType(list) => {
6742-
let mut builder = UnionBuilder::new(db);
6743-
let inferred_as = list.inferred_as(db);
6744-
for element in list.elements(db) {
6745-
builder = builder.add(if inferred_as.type_expression() {
6746-
*element
6747-
} else {
6748-
element.in_type_expression(db, scope_id, typevar_binding_context)?
6749-
});
6750-
}
6751-
Ok(builder.build())
6741+
KnownInstanceType::UnionType(instance) => {
6742+
// Cloning here is cheap if the result is a `Type` (which is `Copy`). It's more
6743+
// expensive if there are errors.
6744+
instance.union_type(db).clone()
67526745
}
67536746
KnownInstanceType::Literal(ty) => Ok(ty.inner(db)),
67546747
KnownInstanceType::Annotated(ty) => Ok(ty.inner(db)),
@@ -8004,9 +7997,9 @@ pub enum KnownInstanceType<'db> {
80047997
/// `ty_extensions.Specialization`.
80057998
Specialization(Specialization<'db>),
80067999

8007-
/// A single instance of `types.UnionType`, which stores the left- and
8008-
/// right-hand sides of a PEP 604 union.
8009-
UnionType(InternedTypes<'db>),
8000+
/// A single instance of `types.UnionType`, which stores the elements of
8001+
/// a PEP 604 union, or a `typing.Union`.
8002+
UnionType(UnionTypeInstance<'db>),
80108003

80118004
/// A single instance of `typing.Literal`
80128005
Literal(InternedType<'db>),
@@ -8052,9 +8045,9 @@ fn walk_known_instance_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
80528045
visitor.visit_type(db, default_ty);
80538046
}
80548047
}
8055-
KnownInstanceType::UnionType(list) => {
8056-
for element in list.elements(db) {
8057-
visitor.visit_type(db, *element);
8048+
KnownInstanceType::UnionType(instance) => {
8049+
if let Ok(union_type) = instance.union_type(db) {
8050+
visitor.visit_type(db, *union_type);
80588051
}
80598052
}
80608053
KnownInstanceType::Literal(ty)
@@ -8098,7 +8091,7 @@ impl<'db> KnownInstanceType<'db> {
80988091
Self::TypeAliasType(type_alias.normalized_impl(db, visitor))
80998092
}
81008093
Self::Field(field) => Self::Field(field.normalized_impl(db, visitor)),
8101-
Self::UnionType(list) => Self::UnionType(list.normalized_impl(db, visitor)),
8094+
Self::UnionType(instance) => Self::UnionType(instance.normalized_impl(db, visitor)),
81028095
Self::Literal(ty) => Self::Literal(ty.normalized_impl(db, visitor)),
81038096
Self::Annotated(ty) => Self::Annotated(ty.normalized_impl(db, visitor)),
81048097
Self::TypeGenericAlias(ty) => Self::TypeGenericAlias(ty.normalized_impl(db, visitor)),
@@ -8430,7 +8423,7 @@ impl<'db> TypeAndQualifiers<'db> {
84308423
/// Error struct providing information on type(s) that were deemed to be invalid
84318424
/// in a type expression context, and the type we should therefore fallback to
84328425
/// for the problematic type expression.
8433-
#[derive(Debug, PartialEq, Eq)]
8426+
#[derive(Clone, Debug, PartialEq, Eq, Hash, get_size2::GetSize)]
84348427
pub struct InvalidTypeExpressionError<'db> {
84358428
fallback_type: Type<'db>,
84368429
invalid_expressions: smallvec::SmallVec<[InvalidTypeExpression<'db>; 1]>,
@@ -8461,7 +8454,7 @@ impl<'db> InvalidTypeExpressionError<'db> {
84618454
}
84628455

84638456
/// Enumeration of various types that are invalid in type-expression contexts
8464-
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
8457+
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, get_size2::GetSize)]
84658458
enum InvalidTypeExpression<'db> {
84668459
/// Some types always require exactly one argument when used in a type expression
84678460
RequiresOneArgument(Type<'db>),
@@ -9399,39 +9392,106 @@ impl InferredAs {
93999392
}
94009393
}
94019394

9402-
/// A salsa-interned list of types.
9395+
/// Contains information about a `types.UnionType` instance built from a PEP 604
9396+
/// union or a legacy `typing.Union[…]` annotation in a value expression context,
9397+
/// e.g. `IntOrStr = int | str` or `IntOrStr = Union[int, str]`.
94039398
///
94049399
/// # Ordering
94059400
/// Ordering is based on the context's salsa-assigned id and not on its values.
94069401
/// The id may change between runs, or when the context was garbage collected and recreated.
94079402
#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)]
94089403
#[derive(PartialOrd, Ord)]
9409-
pub struct InternedTypes<'db> {
9410-
#[returns(deref)]
9411-
elements: Box<[Type<'db>]>,
9412-
inferred_as: InferredAs,
9404+
pub struct UnionTypeInstance<'db> {
9405+
/// The types of the elements of this union, as they were inferred in a value
9406+
/// expression context. For `int | str`, this would contain `<class 'int'>` and
9407+
/// `<class 'str'>`. For `Union[int, str]`, this field is `None`, as we infer
9408+
/// the elements as type expressions. Use `value_expression_types` to get the
9409+
/// corresponding value expression types.
9410+
#[expect(clippy::ref_option)]
9411+
#[returns(ref)]
9412+
_value_expr_types: Option<Box<[Type<'db>]>>,
9413+
9414+
/// The type of the full union, which can be used when this `UnionType` instance
9415+
/// is used in a type expression context. For `int | str`, this would contain
9416+
/// `Ok(int | str)`. If any of the element types could not be converted, this
9417+
/// contains the first encountered error.
9418+
#[returns(ref)]
9419+
union_type: Result<Type<'db>, InvalidTypeExpressionError<'db>>,
94139420
}
94149421

9415-
impl get_size2::GetSize for InternedTypes<'_> {}
9422+
impl get_size2::GetSize for UnionTypeInstance<'_> {}
9423+
9424+
impl<'db> UnionTypeInstance<'db> {
9425+
pub(crate) fn from_value_expression_types(
9426+
db: &'db dyn Db,
9427+
value_expr_types: impl IntoIterator<Item = Type<'db>>,
9428+
scope_id: ScopeId<'db>,
9429+
typevar_binding_context: Option<Definition<'db>>,
9430+
) -> Type<'db> {
9431+
let value_expr_types = value_expr_types.into_iter().collect::<Box<_>>();
9432+
9433+
let mut builder = UnionBuilder::new(db);
9434+
for ty in &value_expr_types {
9435+
match ty.in_type_expression(db, scope_id, typevar_binding_context) {
9436+
Ok(ty) => builder.add_in_place(ty),
9437+
Err(error) => {
9438+
return Type::KnownInstance(KnownInstanceType::UnionType(
9439+
UnionTypeInstance::new(db, Some(value_expr_types), Err(error)),
9440+
));
9441+
}
9442+
}
9443+
}
9444+
9445+
Type::KnownInstance(KnownInstanceType::UnionType(UnionTypeInstance::new(
9446+
db,
9447+
Some(value_expr_types),
9448+
Ok(builder.build()),
9449+
)))
9450+
}
94169451

9417-
impl<'db> InternedTypes<'db> {
9418-
pub(crate) fn from_elements(
9452+
/// Get the types of the elements of this union as they would appear in a value
9453+
/// expression context. For a PEP 604 union, we return the actual types that were
9454+
/// inferred when we encountered the union in a value expression context. For a
9455+
/// legacy `typing.Union[…]` annotation, we turn the type-expression types into
9456+
/// their corresponding value-expression types, i.e. we turn instances like `int`
9457+
/// into class literals like `<class 'int'>`. This operation is potentially lossy.
9458+
pub(crate) fn value_expression_types(
9459+
self,
94199460
db: &'db dyn Db,
9420-
elements: impl IntoIterator<Item = Type<'db>>,
9421-
inferred_as: InferredAs,
9422-
) -> InternedTypes<'db> {
9423-
InternedTypes::new(db, elements.into_iter().collect::<Box<[_]>>(), inferred_as)
9461+
) -> Result<impl Iterator<Item = Type<'db>> + 'db, InvalidTypeExpressionError<'db>> {
9462+
let to_class_literal = |ty: Type<'db>| {
9463+
ty.as_nominal_instance()
9464+
.map(|instance| Type::ClassLiteral(instance.class(db).class_literal(db).0))
9465+
.unwrap_or_else(Type::unknown)
9466+
};
9467+
9468+
if let Some(value_expr_types) = self._value_expr_types(db) {
9469+
Ok(Either::Left(value_expr_types.iter().copied()))
9470+
} else {
9471+
match self.union_type(db).clone()? {
9472+
Type::Union(union) => Ok(Either::Right(Either::Left(
9473+
union.elements(db).iter().copied().map(to_class_literal),
9474+
))),
9475+
ty => Ok(Either::Right(Either::Right(std::iter::once(
9476+
to_class_literal(ty),
9477+
)))),
9478+
}
9479+
}
94249480
}
94259481

94269482
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
9427-
InternedTypes::new(
9428-
db,
9429-
self.elements(db)
9483+
let value_expr_types = self._value_expr_types(db).as_ref().map(|types| {
9484+
types
94309485
.iter()
94319486
.map(|ty| ty.normalized_impl(db, visitor))
9432-
.collect::<Box<[_]>>(),
9433-
self.inferred_as(db),
9434-
)
9487+
.collect::<Box<_>>()
9488+
});
9489+
let union_type = self
9490+
.union_type(db)
9491+
.clone()
9492+
.map(|ty| ty.normalized_impl(db, visitor));
9493+
9494+
Self::new(db, value_expr_types, union_type)
94359495
}
94369496
}
94379497

crates/ty_python_semantic/src/types/function.rs

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1790,14 +1790,21 @@ impl KnownFunction {
17901790
// `Any` can be used in `issubclass()` calls but not `isinstance()` calls
17911791
Type::SpecialForm(SpecialFormType::Any)
17921792
if function == KnownFunction::IsSubclass => {}
1793-
Type::KnownInstance(KnownInstanceType::UnionType(union)) => {
1794-
for element in union.elements(db) {
1795-
find_invalid_elements(
1796-
db,
1797-
function,
1798-
*element,
1799-
invalid_elements,
1800-
);
1793+
Type::KnownInstance(KnownInstanceType::UnionType(instance)) => {
1794+
match instance.value_expression_types(db) {
1795+
Ok(value_expression_types) => {
1796+
for element in value_expression_types {
1797+
find_invalid_elements(
1798+
db,
1799+
function,
1800+
element,
1801+
invalid_elements,
1802+
);
1803+
}
1804+
}
1805+
Err(_) => {
1806+
invalid_elements.push(ty);
1807+
}
18011808
}
18021809
}
18031810
_ => invalid_elements.push(ty),

0 commit comments

Comments
 (0)