diff --git a/cli/src/main.rs b/cli/src/main.rs index 6a34b98c..c3eaef28 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -257,12 +257,13 @@ fn main() { .filter_map(|dir_entry| dir_entry.path().to_str().map(String::from)) .collect(); + let supports_flatten = lang.supports_flatten(); let mut generated_contents = vec![]; let parsed_data = glob_paths .par_iter() .map(|filepath| { let data = std::fs::read_to_string(filepath).unwrap(); - let parsed_data = typeshare_core::parser::parse(&data); + let parsed_data = typeshare_core::parser::parse(&data, supports_flatten); if parsed_data.is_err() { panic!("{}", parsed_data.err().unwrap()); } diff --git a/core/src/language/go.rs b/core/src/language/go.rs index 9cd3a9d2..305c6414 100644 --- a/core/src/language/go.rs +++ b/core/src/language/go.rs @@ -154,6 +154,10 @@ impl Language for Go { writeln!(w, "}}") } + + fn supports_flatten(&self) -> bool { + false + } } impl Go { diff --git a/core/src/language/kotlin.rs b/core/src/language/kotlin.rs index fe9703d6..6d7c01f5 100644 --- a/core/src/language/kotlin.rs +++ b/core/src/language/kotlin.rs @@ -181,6 +181,10 @@ impl Language for Kotlin { writeln!(w, "}}\n") } + + fn supports_flatten(&self) -> bool { + false + } } impl Kotlin { diff --git a/core/src/language/mod.rs b/core/src/language/mod.rs index ebf5361a..34514f00 100644 --- a/core/src/language/mod.rs +++ b/core/src/language/mod.rs @@ -294,4 +294,6 @@ pub trait Language { Ok(()) } + /// whether `#[serde(flatten)]` macro attribute is supported or not + fn supports_flatten(&self) -> bool; } diff --git a/core/src/language/scala.rs b/core/src/language/scala.rs index be7ce99d..fb018901 100644 --- a/core/src/language/scala.rs +++ b/core/src/language/scala.rs @@ -209,6 +209,10 @@ impl Language for Scala { self.write_enum_variants(w, e)?; writeln!(w, "}}\n") } + + fn supports_flatten(&self) -> bool { + false + } } impl Scala { diff --git a/core/src/language/swift.rs b/core/src/language/swift.rs index 72c96d47..3dc9efe9 100644 --- a/core/src/language/swift.rs +++ b/core/src/language/swift.rs @@ -506,6 +506,10 @@ impl Language for Swift { writeln!(w, "}}") } + + fn supports_flatten(&self) -> bool { + false + } } impl Swift { diff --git a/core/src/language/typescript.rs b/core/src/language/typescript.rs index 32bf7ae6..5a8f47dc 100644 --- a/core/src/language/typescript.rs +++ b/core/src/language/typescript.rs @@ -110,13 +110,32 @@ impl Language for TypeScript { fn write_struct(&mut self, w: &mut dyn Write, rs: &RustStruct) -> std::io::Result<()> { self.write_comments(w, 0, &rs.comments)?; + let mut inheritance = "".to_string(); + let mut count = 0; + for field in rs.fields.iter() { + if field.flattened { + let ts_ty = self + .format_type(&field.ty, &rs.generic_types) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?; + if count >= 1 { + inheritance.push_str(", "); + } + inheritance.push_str(ts_ty.as_str()); + count += 1; + } + } writeln!( w, - "export interface {}{} {{", + "export interface {}{}{} {{", rs.id.renamed, (!rs.generic_types.is_empty()) .then(|| format!("<{}>", rs.generic_types.join(", "))) - .unwrap_or_default() + .unwrap_or_default(), + if !inheritance.is_empty() { + format!(" extends {inheritance}") + } else { + "".to_string() + } )?; rs.fields @@ -160,6 +179,10 @@ impl Language for TypeScript { } } } + + fn supports_flatten(&self) -> bool { + true + } } impl TypeScript { @@ -230,6 +253,9 @@ impl TypeScript { field: &RustField, generic_types: &[String], ) -> std::io::Result<()> { + if field.flattened { + return Ok(()); + } self.write_comments(w, 1, &field.comments)?; let ts_ty = self .format_type(&field.ty, generic_types) diff --git a/core/src/lib.rs b/core/src/lib.rs index 65a8c8c3..4a57d661 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -30,7 +30,7 @@ pub fn process_input( language: &mut dyn Language, out: &mut dyn Write, ) -> Result<(), ProcessInputError> { - let parsed_data = parser::parse(input)?; + let parsed_data = parser::parse(input, language.supports_flatten())?; language.generate_types(out, &parsed_data)?; Ok(()) } diff --git a/core/src/parser.rs b/core/src/parser.rs index 36e1bc7f..09e283ad 100644 --- a/core/src/parser.rs +++ b/core/src/parser.rs @@ -78,7 +78,7 @@ pub enum ParseError { } /// Parse the given Rust source string into `ParsedData`. -pub fn parse(input: &str) -> Result { +pub fn parse(input: &str, supports_flatten: bool) -> Result { let mut parsed_data = ParsedData::default(); // We will only produce output for files that contain the `#[typeshare]` @@ -94,7 +94,7 @@ pub fn parse(input: &str) -> Result { for item in flatten_items(source.items.iter()) { match item { syn::Item::Struct(s) if has_typeshare_annotation(&s.attrs) => { - parsed_data.push_rust_thing(parse_struct(s)?); + parsed_data.push_rust_thing(parse_struct(s, supports_flatten)?); } syn::Item::Enum(e) if has_typeshare_annotation(&e.attrs) => { parsed_data.push_rust_thing(parse_enum(e)?); @@ -131,7 +131,7 @@ fn flatten_items<'a>( /// /// This function can currently return something other than a struct, which is a /// hack. -fn parse_struct(s: &ItemStruct) -> Result { +fn parse_struct(s: &ItemStruct, supports_flatten: bool) -> Result { let serde_rename_all = serde_rename_all(&s.attrs); let generic_types = s @@ -156,6 +156,7 @@ fn parse_struct(s: &ItemStruct) -> Result { })); } + let mut flattened: bool = false; Ok(match &s.fields { // Structs Fields::Named(f) => { @@ -170,9 +171,14 @@ fn parse_struct(s: &ItemStruct) -> Result { RustType::try_from(&f.ty)? }; - if serde_flatten(&f.attrs) { - return Err(ParseError::SerdeFlattenNotAllowed); - } + flattened = if serde_flatten(&f.attrs) { + if !supports_flatten { + return Err(ParseError::SerdeFlattenNotAllowed); + } + true + } else { + false + }; let has_default = serde_default(&f.attrs); let decorators = get_field_decorators(&f.attrs); @@ -183,6 +189,7 @@ fn parse_struct(s: &ItemStruct) -> Result { comments: parse_comment_attrs(&f.attrs), has_default, decorators, + flattened, }) }) .collect::>()?; @@ -380,6 +387,7 @@ fn parse_enum_variant( comments: parse_comment_attrs(&f.attrs), has_default, decorators, + flattened: false, }) }) .collect::, ParseError>>()?, diff --git a/core/src/rust_types.rs b/core/src/rust_types.rs index 49c55393..a80bf00e 100644 --- a/core/src/rust_types.rs +++ b/core/src/rust_types.rs @@ -77,6 +77,9 @@ pub struct RustField { /// Language-specific decorators assigned to a given field. /// The keys are language names (e.g. SupportedLanguage::TypeScript), the values are decorators (e.g. readonly) pub decorators: HashMap>, + /// Whether the field should be flattened or not, + /// as per `#[serde(flatten)]` definition. + pub flattened: bool, } /// A Rust type. diff --git a/core/tests/snapshot_tests.rs b/core/tests/snapshot_tests.rs index beb80d62..69e41f14 100644 --- a/core/tests/snapshot_tests.rs +++ b/core/tests/snapshot_tests.rs @@ -67,7 +67,7 @@ fn check( )?; let mut typeshare_output: Vec = Vec::new(); - let parsed_data = typeshare_core::parser::parse(&rust_input)?; + let parsed_data = typeshare_core::parser::parse(&rust_input, lang.supports_flatten())?; lang.generate_types(&mut typeshare_output, &parsed_data)?; let typeshare_output = String::from_utf8(typeshare_output)?;