diff --git a/butane_codegen/src/lib.rs b/butane_codegen/src/lib.rs index 525e1a93..3e81e655 100644 --- a/butane_codegen/src/lib.rs +++ b/butane_codegen/src/lib.rs @@ -34,8 +34,13 @@ mod test_field_type; /// * `#[pk]` on a field to specify that it is the primary key. /// * `#[unique]` on a field indicates that the field's value must be unique /// (perhaps implemented as the SQL UNIQUE constraint by some backends). -/// * `#[default]` should be used on fields added by later migrations to avoid errors on existing objects. -/// Unnecessary if the new field is an `Option<>` +/// * `#[default]` or `#[default = value]` should be used on fields added by later migrations to avoid errors on existing objects. +/// - `#[default = value]`: Explicitly specify a default value (e.g., `#[default = false]`, `#[default = "draft"]`) +/// - `#[serde(default)]`: For fields with primitive types or custom types backed by primitive SQL types, +/// Butane will automatically infer the appropriate default value for migrations. This works for: +/// bool, integers, floats, String, Vec, and custom types that wrap these primitives. +/// - For custom types not backed by primitives, use an explicit `#[default = value]` instead. +/// - Unnecessary if the new field is an `Option<>` /// /// For example /// ```ignore diff --git a/butane_core/src/codegen/mod.rs b/butane_core/src/codegen/mod.rs index e5f76b41..976af4b7 100644 --- a/butane_core/src/codegen/mod.rs +++ b/butane_core/src/codegen/mod.rs @@ -9,9 +9,10 @@ use desynt::{create_static_resolver, PathResolver, StripRaw}; use phf::{phf_map, Map}; use proc_macro2::TokenStream as TokenStream2; use proc_macro2::{Ident, Span, TokenTree}; -use quote::{quote, ToTokens}; +use quote::{quote, quote_spanned, ToTokens}; use regex::Regex; use syn::parse_quote; +use syn::spanned::Spanned; use syn::{ punctuated::Punctuated, Attribute, Field, ItemEnum, ItemStruct, ItemType, Lit, LitStr, Meta, MetaNameValue, @@ -449,26 +450,153 @@ pub fn get_deferred_sql_type(path: &syn::Path) -> DeferredSqlType { }) } +/// Check if a field has `#[serde(default)]` attribute. +fn has_serde_default_attribute(field: &Field) -> bool { + for attr in &field.attrs { + if attr.path().is_ident("serde") { + // Try to parse as a meta list (for attributes like #[serde(default)]) + if let Ok(meta) = attr.meta.require_list() { + for nested in meta.tokens.clone() { + // Check if any token is the identifier "default" + if let TokenTree::Ident(ident) = nested { + if ident == "default" { + return true; + } + } + } + } + } + } + false +} + +/// Infer the default SqlVal for primitive types and simple custom types. +/// For custom types, attempts to look up their underlying SQL type and use that type's default. +/// Returns None for types where we cannot infer the default. +fn infer_primitive_default(path: &syn::Path) -> Option { + let path = path.strip_raw(); + + // Check built-in primitives first + if path == parse_quote!(bool) { + Some(SqlVal::Bool(false)) + } else if path == parse_quote!(i8) + || path == parse_quote!(u8) + || path == parse_quote!(i16) + || path == parse_quote!(u16) + || path == parse_quote!(i32) + { + Some(SqlVal::Int(0)) + } else if path == parse_quote!(u32) || path == parse_quote!(i64) || path == parse_quote!(u64) { + Some(SqlVal::Int(0)) + } else if path == parse_quote!(f32) || path == parse_quote!(f64) { + Some(SqlVal::Real(0.0)) + } else if path == parse_quote!(String) + || path == parse_quote!(std::string::String) + || path == parse_quote!(::std::string::String) + { + Some(SqlVal::Text(String::new())) + } else if path == parse_quote!(Vec) + || path == parse_quote!(std::vec::Vec) + || path == parse_quote!(::std::vec::Vec) + { + Some(SqlVal::Blob(Vec::new())) + } else { + // Check if it's a custom type by looking up its SQL type + infer_default_from_custom_type(&path) + } +} + +/// Try to infer default for a custom type by looking up its underlying SQL type. +fn infer_default_from_custom_type(path: &syn::Path) -> Option { + let deferred_sql_type = get_deferred_sql_type(path); + + match deferred_sql_type { + DeferredSqlType::KnownId(TypeIdentifier::Ty(sqltype)) => { + // We know the SQL type, infer default based on it + infer_default_from_sqltype(sqltype) + } + DeferredSqlType::KnownId(TypeIdentifier::Name(_)) => { + // Named type (e.g., Postgres custom type), can't infer + None + } + DeferredSqlType::Known(sqltype) => infer_default_from_sqltype(sqltype), + DeferredSqlType::Deferred(_) => { + // Type not yet resolved, can't infer + None + } + } +} + +/// Infer default SqlVal for a given SqlType. +fn infer_default_from_sqltype(sqltype: SqlType) -> Option { + match sqltype { + SqlType::Bool => Some(SqlVal::Bool(false)), + SqlType::Int => Some(SqlVal::Int(0)), + SqlType::BigInt => Some(SqlVal::Int(0)), + SqlType::Real => Some(SqlVal::Real(0.0)), + SqlType::Text => Some(SqlVal::Text(String::new())), + SqlType::Blob => Some(SqlVal::Blob(Vec::new())), + #[cfg(feature = "json")] + SqlType::Json => Some(SqlVal::Json(serde_json::Value::Null)), + #[cfg(feature = "datetime")] + SqlType::Timestamp => Some(SqlVal::Timestamp( + chrono::DateTime::from_timestamp(0, 0).unwrap().naive_utc(), + )), + #[cfg(feature = "datetime")] + SqlType::Date => Some(SqlVal::Date( + chrono::naive::NaiveDate::from_ymd_opt(1, 1, 1).unwrap(), + )), + #[cfg(feature = "pg")] + SqlType::Custom(_) => None, // Can't infer custom Postgres types + } +} + /// Defaults are used for fields added by later migrations. /// -/// Example: -/// `#[default = 42]` +/// Examples: +/// * `#[default = 42]` - Explicit default value +/// * `#[serde(default)]` - Inferred from type (primitives only) fn get_default(field: &Field) -> std::result::Result, CompilerErrorMsg> { - let attr: Option<&Attribute> = field + // First, check for explicit #[default = value] attribute (highest priority) + let explicit_attr: Option<&Attribute> = field .attrs .iter() .find(|attr| attr.path().is_ident("default")); - let lit: Lit = match attr { - None => return Ok(None), - Some(attr) => match &attr.meta { + + if let Some(attr) = explicit_attr { + let lit: Lit = match &attr.meta { Meta::NameValue(MetaNameValue { value: syn::Expr::Lit(expr_lit), .. }) => expr_lit.lit.clone(), _ => return Err(make_compile_error!("malformed default value").into()), - }, - }; - Ok(Some(sqlval_from_lit(lit)?)) + }; + return Ok(Some(sqlval_from_lit(lit)?)); + } + + // Second, check for #[serde(default)] and try to infer the default + if has_serde_default_attribute(field) { + let field_type = extract_path_from_type(&field.ty); + if let Some(default_val) = infer_primitive_default(field_type) { + return Ok(Some(default_val)); + } else { + let field_name = field + .ident + .as_ref() + .map(|i| i.to_string()) + .unwrap_or_else(|| "unnamed".to_string()); + return Err(make_compile_error!( + field.span() => + "Field '{}' has #[serde(default)] but Butane cannot infer the default value for this type. \ + Please add an explicit #[default = value] attribute for migration support.", + field_name + ) + .into()); + } + } + + // No default specified + Ok(None) } fn some_id(ty: SqlType) -> Option { diff --git a/butane_core/tests/migration.rs b/butane_core/tests/migration.rs index 04f6879d..8ab6d5fa 100644 --- a/butane_core/tests/migration.rs +++ b/butane_core/tests/migration.rs @@ -96,6 +96,131 @@ fn current_migration_default_attribute() { assert_eq!(*barcol.default(), Some(SqlVal::Text("turtle".to_string()))); } +#[test] +fn current_migration_serde_default_bool() { + let tokens = quote! { + #[derive(PartialEq, Eq, Debug, Clone)] + struct Foo { + id: i64, + #[serde(default)] + published: bool, + } + }; + + let mut ms = MemMigrations::new(); + model_with_migrations(tokens, &mut ms); + let m = ms.current(); + let db = m.db().unwrap(); + let table = db.get_table("Foo").expect("No Foo table"); + let col = table.column("published").unwrap(); + assert_eq!(*col.default(), Some(SqlVal::Bool(false))); +} + +#[test] +fn current_migration_serde_default_int() { + let tokens = quote! { + #[derive(PartialEq, Eq, Debug, Clone)] + struct Foo { + id: i64, + #[serde(default)] + count: i32, + } + }; + + let mut ms = MemMigrations::new(); + model_with_migrations(tokens, &mut ms); + let m = ms.current(); + let db = m.db().unwrap(); + let table = db.get_table("Foo").expect("No Foo table"); + let col = table.column("count").unwrap(); + assert_eq!(*col.default(), Some(SqlVal::Int(0))); +} + +#[test] +fn current_migration_serde_default_string() { + let tokens = quote! { + #[derive(PartialEq, Eq, Debug, Clone)] + struct Foo { + id: i64, + #[serde(default)] + name: String, + } + }; + + let mut ms = MemMigrations::new(); + model_with_migrations(tokens, &mut ms); + let m = ms.current(); + let db = m.db().unwrap(); + let table = db.get_table("Foo").expect("No Foo table"); + let col = table.column("name").unwrap(); + assert_eq!(*col.default(), Some(SqlVal::Text(String::new()))); +} + +#[test] +fn current_migration_explicit_default_overrides_serde() { + // Explicit #[default = value] should take precedence over #[serde(default)] + let tokens = quote! { + #[derive(PartialEq, Eq, Debug, Clone)] + struct Foo { + id: i64, + #[serde(default)] + #[default = true] + published: bool, + } + }; + + let mut ms = MemMigrations::new(); + model_with_migrations(tokens, &mut ms); + let m = ms.current(); + let db = m.db().unwrap(); + let table = db.get_table("Foo").expect("No Foo table"); + let col = table.column("published").unwrap(); + // Should use explicit value (true), not inferred (false) + assert_eq!(*col.default(), Some(SqlVal::Bool(true))); +} + +#[test] +fn current_migration_serde_default_json_custom_type() { + // Test that custom JSON types can have their defaults inferred + #[cfg(feature = "json")] + { + let tokens = quote! { + #[derive(PartialEq, Eq, Debug, Clone, serde::Serialize, serde::Deserialize)] + struct CustomData { + value: String, + } + + #[derive(PartialEq, Eq, Debug, Clone)] + struct Foo { + id: i64, + #[serde(default)] + data: CustomData, + } + }; + + let mut ms = MemMigrations::new(); + // Register CustomData as a JSON type + use butane_core::migrations::adb::{DeferredSqlType, TypeIdentifier, TypeKey}; + use butane_core::SqlType; + ms.current() + .add_type( + TypeKey::CustomType("CustomData".to_string()), + DeferredSqlType::KnownId(TypeIdentifier::Ty(SqlType::Json)), + ) + .unwrap(); + + // Now create the model + model_with_migrations(tokens, &mut ms); + + let m = ms.current(); + let db = m.db().unwrap(); + let table = db.get_table("Foo").expect("No Foo table"); + let col = table.column("data").unwrap(); + // Should infer default from the underlying Json type + assert_eq!(*col.default(), Some(SqlVal::Json(serde_json::Value::Null))); + } +} + #[test] fn current_migration_auto_attribute() { let tokens = quote! {