@@ -6738,17 +6738,10 @@ impl<'db> Type<'db> {
67386738 invalid_expressions: smallvec::smallvec_inline![InvalidTypeExpression::Generic],
67396739 fallback_type: Type::unknown(),
67406740 }),
6741- KnownInstanceType::UnionType(list) => {
6742- let mut builder = UnionBuilder::new(db);
6743- let inferred_as = list.inferred_as(db);
6744- for element in list.elements(db) {
6745- builder = builder.add(if inferred_as.type_expression() {
6746- *element
6747- } else {
6748- element.in_type_expression(db, scope_id, typevar_binding_context)?
6749- });
6750- }
6751- Ok(builder.build())
6741+ KnownInstanceType::UnionType(instance) => {
6742+ // Cloning here is cheap if the result is a `Type` (which is `Copy`). It's more
6743+ // expensive if there are errors.
6744+ instance.union_type(db).clone()
67526745 }
67536746 KnownInstanceType::Literal(ty) => Ok(ty.inner(db)),
67546747 KnownInstanceType::Annotated(ty) => Ok(ty.inner(db)),
@@ -8004,9 +7997,9 @@ pub enum KnownInstanceType<'db> {
80047997 /// `ty_extensions.Specialization`.
80057998 Specialization(Specialization<'db>),
80067999
8007- /// A single instance of `types.UnionType`, which stores the left- and
8008- /// right-hand sides of a PEP 604 union.
8009- UnionType(InternedTypes <'db>),
8000+ /// A single instance of `types.UnionType`, which stores the elements of
8001+ /// a PEP 604 union, or a `typing.Union` .
8002+ UnionType(UnionTypeInstance <'db>),
80108003
80118004 /// A single instance of `typing.Literal`
80128005 Literal(InternedType<'db>),
@@ -8052,9 +8045,9 @@ fn walk_known_instance_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
80528045 visitor.visit_type(db, default_ty);
80538046 }
80548047 }
8055- KnownInstanceType::UnionType(list ) => {
8056- for element in list.elements (db) {
8057- visitor.visit_type(db, *element );
8048+ KnownInstanceType::UnionType(instance ) => {
8049+ if let Ok(union_type) = instance.union_type (db) {
8050+ visitor.visit_type(db, *union_type );
80588051 }
80598052 }
80608053 KnownInstanceType::Literal(ty)
@@ -8098,7 +8091,7 @@ impl<'db> KnownInstanceType<'db> {
80988091 Self::TypeAliasType(type_alias.normalized_impl(db, visitor))
80998092 }
81008093 Self::Field(field) => Self::Field(field.normalized_impl(db, visitor)),
8101- Self::UnionType(list ) => Self::UnionType(list .normalized_impl(db, visitor)),
8094+ Self::UnionType(instance ) => Self::UnionType(instance .normalized_impl(db, visitor)),
81028095 Self::Literal(ty) => Self::Literal(ty.normalized_impl(db, visitor)),
81038096 Self::Annotated(ty) => Self::Annotated(ty.normalized_impl(db, visitor)),
81048097 Self::TypeGenericAlias(ty) => Self::TypeGenericAlias(ty.normalized_impl(db, visitor)),
@@ -8430,7 +8423,7 @@ impl<'db> TypeAndQualifiers<'db> {
84308423/// Error struct providing information on type(s) that were deemed to be invalid
84318424/// in a type expression context, and the type we should therefore fallback to
84328425/// for the problematic type expression.
8433- #[derive(Debug, PartialEq, Eq)]
8426+ #[derive(Clone, Debug, PartialEq, Eq, Hash, get_size2::GetSize )]
84348427pub struct InvalidTypeExpressionError<'db> {
84358428 fallback_type: Type<'db>,
84368429 invalid_expressions: smallvec::SmallVec<[InvalidTypeExpression<'db>; 1]>,
@@ -8461,7 +8454,7 @@ impl<'db> InvalidTypeExpressionError<'db> {
84618454}
84628455
84638456/// Enumeration of various types that are invalid in type-expression contexts
8464- #[derive(Debug, Copy, Clone, PartialEq, Eq)]
8457+ #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, get_size2::GetSize )]
84658458enum InvalidTypeExpression<'db> {
84668459 /// Some types always require exactly one argument when used in a type expression
84678460 RequiresOneArgument(Type<'db>),
@@ -9399,39 +9392,106 @@ impl InferredAs {
93999392 }
94009393}
94019394
9402- /// A salsa-interned list of types.
9395+ /// Contains information about a `types.UnionType` instance built from a PEP 604
9396+ /// union or a legacy `typing.Union[…]` annotation in a value expression context,
9397+ /// e.g. `IntOrStr = int | str` or `IntOrStr = Union[int, str]`.
94039398///
94049399/// # Ordering
94059400/// Ordering is based on the context's salsa-assigned id and not on its values.
94069401/// The id may change between runs, or when the context was garbage collected and recreated.
94079402#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)]
94089403#[derive(PartialOrd, Ord)]
9409- pub struct InternedTypes<'db> {
9410- #[returns(deref)]
9411- elements: Box<[Type<'db>]>,
9412- inferred_as: InferredAs,
9404+ pub struct UnionTypeInstance<'db> {
9405+ /// The types of the elements of this union, as they were inferred in a value
9406+ /// expression context. For `int | str`, this would contain `<class 'int'>` and
9407+ /// `<class 'str'>`. For `Union[int, str]`, this field is `None`, as we infer
9408+ /// the elements as type expressions. Use `value_expression_types` to get the
9409+ /// corresponding value expression types.
9410+ #[expect(clippy::ref_option)]
9411+ #[returns(ref)]
9412+ _value_expr_types: Option<Box<[Type<'db>]>>,
9413+
9414+ /// The type of the full union, which can be used when this `UnionType` instance
9415+ /// is used in a type expression context. For `int | str`, this would contain
9416+ /// `Ok(int | str)`. If any of the element types could not be converted, this
9417+ /// contains the first encountered error.
9418+ #[returns(ref)]
9419+ union_type: Result<Type<'db>, InvalidTypeExpressionError<'db>>,
94139420}
94149421
9415- impl get_size2::GetSize for InternedTypes<'_> {}
9422+ impl get_size2::GetSize for UnionTypeInstance<'_> {}
9423+
9424+ impl<'db> UnionTypeInstance<'db> {
9425+ pub(crate) fn from_value_expression_types(
9426+ db: &'db dyn Db,
9427+ value_expr_types: impl IntoIterator<Item = Type<'db>>,
9428+ scope_id: ScopeId<'db>,
9429+ typevar_binding_context: Option<Definition<'db>>,
9430+ ) -> Type<'db> {
9431+ let value_expr_types = value_expr_types.into_iter().collect::<Box<_>>();
9432+
9433+ let mut builder = UnionBuilder::new(db);
9434+ for ty in &value_expr_types {
9435+ match ty.in_type_expression(db, scope_id, typevar_binding_context) {
9436+ Ok(ty) => builder.add_in_place(ty),
9437+ Err(error) => {
9438+ return Type::KnownInstance(KnownInstanceType::UnionType(
9439+ UnionTypeInstance::new(db, Some(value_expr_types), Err(error)),
9440+ ));
9441+ }
9442+ }
9443+ }
9444+
9445+ Type::KnownInstance(KnownInstanceType::UnionType(UnionTypeInstance::new(
9446+ db,
9447+ Some(value_expr_types),
9448+ Ok(builder.build()),
9449+ )))
9450+ }
94169451
9417- impl<'db> InternedTypes<'db> {
9418- pub(crate) fn from_elements(
9452+ /// Get the types of the elements of this union as they would appear in a value
9453+ /// expression context. For a PEP 604 union, we return the actual types that were
9454+ /// inferred when we encountered the union in a value expression context. For a
9455+ /// legacy `typing.Union[…]` annotation, we turn the type-expression types into
9456+ /// their corresponding value-expression types, i.e. we turn instances like `int`
9457+ /// into class literals like `<class 'int'>`. This operation is potentially lossy.
9458+ pub(crate) fn value_expression_types(
9459+ self,
94199460 db: &'db dyn Db,
9420- elements: impl IntoIterator<Item = Type<'db>>,
9421- inferred_as: InferredAs,
9422- ) -> InternedTypes<'db> {
9423- InternedTypes::new(db, elements.into_iter().collect::<Box<[_]>>(), inferred_as)
9461+ ) -> Result<impl Iterator<Item = Type<'db>> + 'db, InvalidTypeExpressionError<'db>> {
9462+ let to_class_literal = |ty: Type<'db>| {
9463+ ty.as_nominal_instance()
9464+ .map(|instance| Type::ClassLiteral(instance.class(db).class_literal(db).0))
9465+ .unwrap_or_else(Type::unknown)
9466+ };
9467+
9468+ if let Some(value_expr_types) = self._value_expr_types(db) {
9469+ Ok(Either::Left(value_expr_types.iter().copied()))
9470+ } else {
9471+ match self.union_type(db).clone()? {
9472+ Type::Union(union) => Ok(Either::Right(Either::Left(
9473+ union.elements(db).iter().copied().map(to_class_literal),
9474+ ))),
9475+ ty => Ok(Either::Right(Either::Right(std::iter::once(
9476+ to_class_literal(ty),
9477+ )))),
9478+ }
9479+ }
94249480 }
94259481
94269482 pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
9427- InternedTypes::new(
9428- db,
9429- self.elements(db)
9483+ let value_expr_types = self._value_expr_types(db).as_ref().map(|types| {
9484+ types
94309485 .iter()
94319486 .map(|ty| ty.normalized_impl(db, visitor))
9432- .collect::<Box<[_]>>(),
9433- self.inferred_as(db),
9434- )
9487+ .collect::<Box<_>>()
9488+ });
9489+ let union_type = self
9490+ .union_type(db)
9491+ .clone()
9492+ .map(|ty| ty.normalized_impl(db, visitor));
9493+
9494+ Self::new(db, value_expr_types, union_type)
94359495 }
94369496}
94379497
0 commit comments