Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions butane_codegen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,13 @@ mod test_field_type;
/// * `#[pk]` on a field to specify that it is the primary key.
/// * `#[unique]` on a field indicates that the field's value must be unique
/// (perhaps implemented as the SQL UNIQUE constraint by some backends).
/// * `#[default]` should be used on fields added by later migrations to avoid errors on existing objects.
/// Unnecessary if the new field is an `Option<>`
/// * `#[default]` or `#[default = value]` should be used on fields added by later migrations to avoid errors on existing objects.
/// - `#[default = value]`: Explicitly specify a default value (e.g., `#[default = false]`, `#[default = "draft"]`)
/// - `#[serde(default)]`: For fields with primitive types or custom types backed by primitive SQL types,
/// Butane will automatically infer the appropriate default value for migrations. This works for:
/// bool, integers, floats, String, Vec<u8>, and custom types that wrap these primitives.
/// - For custom types not backed by primitives, use an explicit `#[default = value]` instead.
/// - Unnecessary if the new field is an `Option<>`
///
/// For example
/// ```ignore
Expand Down
148 changes: 138 additions & 10 deletions butane_core/src/codegen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ use desynt::{create_static_resolver, PathResolver, StripRaw};
use phf::{phf_map, Map};
use proc_macro2::TokenStream as TokenStream2;
use proc_macro2::{Ident, Span, TokenTree};
use quote::{quote, ToTokens};
use quote::{quote, quote_spanned, ToTokens};
use regex::Regex;
use syn::parse_quote;
use syn::spanned::Spanned;
use syn::{
punctuated::Punctuated, Attribute, Field, ItemEnum, ItemStruct, ItemType, Lit, LitStr, Meta,
MetaNameValue,
Expand Down Expand Up @@ -449,26 +450,153 @@ pub fn get_deferred_sql_type(path: &syn::Path) -> DeferredSqlType {
})
}

/// Check if a field has `#[serde(default)]` attribute.
fn has_serde_default_attribute(field: &Field) -> bool {
for attr in &field.attrs {
if attr.path().is_ident("serde") {
// Try to parse as a meta list (for attributes like #[serde(default)])
if let Ok(meta) = attr.meta.require_list() {
for nested in meta.tokens.clone() {
// Check if any token is the identifier "default"
if let TokenTree::Ident(ident) = nested {
if ident == "default" {
return true;
}
}
}
}
}
}
false
}

/// Infer the default SqlVal for primitive types and simple custom types.
/// For custom types, attempts to look up their underlying SQL type and use that type's default.
/// Returns None for types where we cannot infer the default.
fn infer_primitive_default(path: &syn::Path) -> Option<SqlVal> {
let path = path.strip_raw();

// Check built-in primitives first
if path == parse_quote!(bool) {
Some(SqlVal::Bool(false))
} else if path == parse_quote!(i8)
|| path == parse_quote!(u8)
|| path == parse_quote!(i16)
|| path == parse_quote!(u16)
|| path == parse_quote!(i32)
{
Some(SqlVal::Int(0))
} else if path == parse_quote!(u32) || path == parse_quote!(i64) || path == parse_quote!(u64) {
Some(SqlVal::Int(0))
} else if path == parse_quote!(f32) || path == parse_quote!(f64) {
Some(SqlVal::Real(0.0))
} else if path == parse_quote!(String)
|| path == parse_quote!(std::string::String)
|| path == parse_quote!(::std::string::String)
{
Some(SqlVal::Text(String::new()))
} else if path == parse_quote!(Vec<u8>)
|| path == parse_quote!(std::vec::Vec<u8>)
|| path == parse_quote!(::std::vec::Vec<u8>)
{
Some(SqlVal::Blob(Vec::new()))
} else {
// Check if it's a custom type by looking up its SQL type
infer_default_from_custom_type(&path)
}
}

/// Try to infer default for a custom type by looking up its underlying SQL type.
fn infer_default_from_custom_type(path: &syn::Path) -> Option<SqlVal> {
let deferred_sql_type = get_deferred_sql_type(path);

match deferred_sql_type {
DeferredSqlType::KnownId(TypeIdentifier::Ty(sqltype)) => {
// We know the SQL type, infer default based on it
infer_default_from_sqltype(sqltype)
}
DeferredSqlType::KnownId(TypeIdentifier::Name(_)) => {
// Named type (e.g., Postgres custom type), can't infer
None
}
DeferredSqlType::Known(sqltype) => infer_default_from_sqltype(sqltype),
DeferredSqlType::Deferred(_) => {
// Type not yet resolved, can't infer
None
}
}
}

/// Infer default SqlVal for a given SqlType.
fn infer_default_from_sqltype(sqltype: SqlType) -> Option<SqlVal> {
match sqltype {
SqlType::Bool => Some(SqlVal::Bool(false)),
SqlType::Int => Some(SqlVal::Int(0)),
SqlType::BigInt => Some(SqlVal::Int(0)),
SqlType::Real => Some(SqlVal::Real(0.0)),
SqlType::Text => Some(SqlVal::Text(String::new())),
SqlType::Blob => Some(SqlVal::Blob(Vec::new())),
#[cfg(feature = "json")]
SqlType::Json => Some(SqlVal::Json(serde_json::Value::Null)),
#[cfg(feature = "datetime")]
SqlType::Timestamp => Some(SqlVal::Timestamp(
chrono::DateTime::from_timestamp(0, 0).unwrap().naive_utc(),
)),
#[cfg(feature = "datetime")]
SqlType::Date => Some(SqlVal::Date(
chrono::naive::NaiveDate::from_ymd_opt(1, 1, 1).unwrap(),
)),
#[cfg(feature = "pg")]
SqlType::Custom(_) => None, // Can't infer custom Postgres types
}
}

/// Defaults are used for fields added by later migrations.
///
/// Example:
/// `#[default = 42]`
/// Examples:
/// * `#[default = 42]` - Explicit default value
/// * `#[serde(default)]` - Inferred from type (primitives only)
fn get_default(field: &Field) -> std::result::Result<Option<SqlVal>, CompilerErrorMsg> {
let attr: Option<&Attribute> = field
// First, check for explicit #[default = value] attribute (highest priority)
let explicit_attr: Option<&Attribute> = field
.attrs
.iter()
.find(|attr| attr.path().is_ident("default"));
let lit: Lit = match attr {
None => return Ok(None),
Some(attr) => match &attr.meta {

if let Some(attr) = explicit_attr {
let lit: Lit = match &attr.meta {
Meta::NameValue(MetaNameValue {
value: syn::Expr::Lit(expr_lit),
..
}) => expr_lit.lit.clone(),
_ => return Err(make_compile_error!("malformed default value").into()),
},
};
Ok(Some(sqlval_from_lit(lit)?))
};
return Ok(Some(sqlval_from_lit(lit)?));
}

// Second, check for #[serde(default)] and try to infer the default
if has_serde_default_attribute(field) {
let field_type = extract_path_from_type(&field.ty);
if let Some(default_val) = infer_primitive_default(field_type) {
return Ok(Some(default_val));
} else {
let field_name = field
.ident
.as_ref()
.map(|i| i.to_string())
.unwrap_or_else(|| "unnamed".to_string());
return Err(make_compile_error!(
field.span() =>
"Field '{}' has #[serde(default)] but Butane cannot infer the default value for this type. \
Please add an explicit #[default = value] attribute for migration support.",
field_name
)
.into());
}
}

// No default specified
Ok(None)
}

fn some_id(ty: SqlType) -> Option<TypeIdentifier> {
Expand Down
125 changes: 125 additions & 0 deletions butane_core/tests/migration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,131 @@ fn current_migration_default_attribute() {
assert_eq!(*barcol.default(), Some(SqlVal::Text("turtle".to_string())));
}

#[test]
fn current_migration_serde_default_bool() {
let tokens = quote! {
#[derive(PartialEq, Eq, Debug, Clone)]
struct Foo {
id: i64,
#[serde(default)]
published: bool,
}
};

let mut ms = MemMigrations::new();
model_with_migrations(tokens, &mut ms);
let m = ms.current();
let db = m.db().unwrap();
let table = db.get_table("Foo").expect("No Foo table");
let col = table.column("published").unwrap();
assert_eq!(*col.default(), Some(SqlVal::Bool(false)));
}

#[test]
fn current_migration_serde_default_int() {
let tokens = quote! {
#[derive(PartialEq, Eq, Debug, Clone)]
struct Foo {
id: i64,
#[serde(default)]
count: i32,
}
};

let mut ms = MemMigrations::new();
model_with_migrations(tokens, &mut ms);
let m = ms.current();
let db = m.db().unwrap();
let table = db.get_table("Foo").expect("No Foo table");
let col = table.column("count").unwrap();
assert_eq!(*col.default(), Some(SqlVal::Int(0)));
}

#[test]
fn current_migration_serde_default_string() {
let tokens = quote! {
#[derive(PartialEq, Eq, Debug, Clone)]
struct Foo {
id: i64,
#[serde(default)]
name: String,
}
};

let mut ms = MemMigrations::new();
model_with_migrations(tokens, &mut ms);
let m = ms.current();
let db = m.db().unwrap();
let table = db.get_table("Foo").expect("No Foo table");
let col = table.column("name").unwrap();
assert_eq!(*col.default(), Some(SqlVal::Text(String::new())));
}

#[test]
fn current_migration_explicit_default_overrides_serde() {
// Explicit #[default = value] should take precedence over #[serde(default)]
let tokens = quote! {
#[derive(PartialEq, Eq, Debug, Clone)]
struct Foo {
id: i64,
#[serde(default)]
#[default = true]
published: bool,
}
};

let mut ms = MemMigrations::new();
model_with_migrations(tokens, &mut ms);
let m = ms.current();
let db = m.db().unwrap();
let table = db.get_table("Foo").expect("No Foo table");
let col = table.column("published").unwrap();
// Should use explicit value (true), not inferred (false)
assert_eq!(*col.default(), Some(SqlVal::Bool(true)));
}

#[test]
fn current_migration_serde_default_json_custom_type() {
// Test that custom JSON types can have their defaults inferred
#[cfg(feature = "json")]
{
let tokens = quote! {
#[derive(PartialEq, Eq, Debug, Clone, serde::Serialize, serde::Deserialize)]
struct CustomData {
value: String,
}

#[derive(PartialEq, Eq, Debug, Clone)]
struct Foo {
id: i64,
#[serde(default)]
data: CustomData,
}
};

let mut ms = MemMigrations::new();
// Register CustomData as a JSON type
use butane_core::migrations::adb::{DeferredSqlType, TypeIdentifier, TypeKey};
use butane_core::SqlType;
ms.current()
.add_type(
TypeKey::CustomType("CustomData".to_string()),
DeferredSqlType::KnownId(TypeIdentifier::Ty(SqlType::Json)),
)
.unwrap();

// Now create the model
model_with_migrations(tokens, &mut ms);

let m = ms.current();
let db = m.db().unwrap();
let table = db.get_table("Foo").expect("No Foo table");
let col = table.column("data").unwrap();
// Should infer default from the underlying Json type
assert_eq!(*col.default(), Some(SqlVal::Json(serde_json::Value::Null)));
}
}

#[test]
fn current_migration_auto_attribute() {
let tokens = quote! {
Expand Down
Loading