Skip to content

Commit

Permalink
Extract recursive rendering into maybe_render_recursive_class function
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniosarosi committed Oct 29, 2024
1 parent ec75b0e commit e75f6d4
Showing 1 changed file with 102 additions and 114 deletions.
216 changes: 102 additions & 114 deletions engine/baml-lib/jinja-runtime/src/output_format/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,77 @@ impl OutputFormatContent {
.to_string(options)
}

/// Recursive classes are rendered using their name instead of schema.
///
/// The schema must be hoisted and named, otherwise there's no way to refer
/// to a recursive class.
///
/// This function returns [`Some`] if the given `field_type` is a recursive
/// class otherwise it returns [`None`] which means the given type is not
/// a recursive class so rendering must be handled normally.
fn maybe_render_recursive_class(
&self,
field_type: &FieldType,
render_state: &mut RenderState,
options: &RenderOptions,
) -> Option<String> {
// Hoist recursive classes.
//
// TODO: Some cloning in these functions again, check
// baml-lib/jsonish/src/tests/mod.rs
// there's room for optimization.
//
// TODO: Maybe we can put this somewhere else, it can probably run only
// once before the recursive rendering happens.
if render_state.hoisted_classes.len() < self.recursive_classes.len() {
for recursive_class in self.recursive_classes.iter() {
render_state.hoisted_classes.insert(recursive_class.clone());
}
}

let mut maybe_nested_recursive_class = None;
let mut is_optional = false;
let mut is_list = false;

match field_type {
// Non-optional class, part of a cycle.
FieldType::Class(nested_class) if self.recursive_classes.contains(nested_class) => {
maybe_nested_recursive_class = Some(nested_class);
}

// Optional class, part of a cycle.
FieldType::Optional(boxed_field_type) => {
if let FieldType::Class(nested_class) = boxed_field_type.as_ref() {
if self.recursive_classes.contains(nested_class) {
maybe_nested_recursive_class = Some(nested_class);
is_optional = true;
}
}
}

// List class, part of a cycle.
FieldType::List(boxed_field_type) => {
if let FieldType::Class(nested_class) = boxed_field_type.as_ref() {
if self.recursive_classes.contains(nested_class) {
maybe_nested_recursive_class = Some(nested_class);
is_list = true;
}
}
}
_ => {}
}

maybe_nested_recursive_class.map(|nested_class| {
if is_optional {
format!("{nested_class}{}null", options.or_splitter)
} else if is_list {
format!("{nested_class}[]")
} else {
nested_class.to_string()
}
})
}

fn inner_type_render(
&self,
options: &RenderOptions,
Expand Down Expand Up @@ -404,81 +475,34 @@ impl OutputFormatContent {
));
};

// Hoist recursive classes.
//
// TODO: Some cloning in this function again, check
// baml-lib/jsonish/src/tests/mod.rs
// there's room for optimization.
if render_state.hoisted_classes.len() < self.recursive_classes.len() {
for recursive_class in self.recursive_classes.iter() {
render_state.hoisted_classes.insert(recursive_class.clone());
}
}

ClassRender {
name: class.name.rendered_name().to_string(),
values: class
.fields
.iter()
.map(|(name, field_type, description)| {
let mut maybe_nested_recursive_class = None;
let mut is_optional = false;
let mut is_list = false;

match field_type {
// Non-optional class, part of a cycle.
FieldType::Class(nested_class)
if self.recursive_classes.contains(nested_class) =>
{
maybe_nested_recursive_class = Some(nested_class);
}

// Optional class, part of a cycle.
FieldType::Optional(boxed_field_type) => {
if let FieldType::Class(nested_class) =
boxed_field_type.as_ref()
{
if self.recursive_classes.contains(nested_class) {
maybe_nested_recursive_class = Some(nested_class);
is_optional = true;
}
}
}

// List class, part of a cycle.
FieldType::List(boxed_field_type) => {
if let FieldType::Class(nested_class) =
boxed_field_type.as_ref()
{
if self.recursive_classes.contains(nested_class) {
maybe_nested_recursive_class = Some(nested_class);
is_list = true;
}
}
}
_ => {}
}

// Terminate recursion. There's no other way to
// refer to a recursive class other than by name,
// and all recursive classes are hoisted so they'll
// be handled at a later stage.
let r#type = if let Some(nested_class) = maybe_nested_recursive_class {
if is_optional {
format!("{nested_class}{}null", options.or_splitter)
} else if is_list {
format!("{nested_class}[]")
} else {
nested_class.to_string()
}
} else {
self.inner_type_render(options, field_type, render_state, false)?
};

Ok(ClassFieldRender {
name: name.rendered_name().to_string(),
description: description.clone(),
r#type,
r#type: match self.maybe_render_recursive_class(
field_type,
render_state,
options,
) {
// Terminate recursion. There's no other way
// to refer to a recursive class other than
// by name, and all recursive classes are
// hoisted so they'll be handled at a later
// stage.
Some(recursive_class) => recursive_class,

None => self.inner_type_render(
options,
field_type,
render_state,
false,
)?,
},
})
})
.collect::<Result<_, minijinja::Error>>()?,
Expand All @@ -503,43 +527,14 @@ impl OutputFormatContent {
}
}
}
// TODO: Extract this into function and reuse.
FieldType::Union(items) => items
.iter()
.map(|t| match t {
FieldType::Class(cls) if self.recursive_classes.contains(cls) => {
for recursive_class in self.recursive_classes.iter() {
render_state.hoisted_classes.insert(recursive_class.clone());
}
Ok(cls.to_string())
}
FieldType::Optional(boxed_field_type) => {
if let FieldType::Class(nested_class) = boxed_field_type.as_ref() {
if self.recursive_classes.contains(nested_class) {
for recursive_class in self.recursive_classes.iter() {
render_state.hoisted_classes.insert(recursive_class.clone());
}
return Ok(format!("{nested_class}{}null", options.or_splitter));
}
}

self.inner_type_render(options, t, render_state, false)
}
// List class, part of a cycle.
FieldType::List(boxed_field_type) => {
if let FieldType::Class(nested_class) = boxed_field_type.as_ref() {
if self.recursive_classes.contains(nested_class) {
for recursive_class in self.recursive_classes.iter() {
render_state.hoisted_classes.insert(recursive_class.clone());
}
return Ok(format!("{nested_class}{}null", options.or_splitter));
}
}

self.inner_type_render(options, t, render_state, false)
}
_ => self.inner_type_render(options, t, render_state, false),
})
.map(
|t| match self.maybe_render_recursive_class(t, render_state, options) {
Some(recursive_class) => Ok(recursive_class),
None => self.inner_type_render(options, t, render_state, false),
},
)
.collect::<Result<Vec<_>, minijinja::Error>>()?
.join(&options.or_splitter),
FieldType::Optional(inner) => {
Expand Down Expand Up @@ -599,14 +594,10 @@ impl OutputFormatContent {
}
}

let enum_definitions = render_state
.hoisted_enums
.iter()
.map(|e| {
let enm = self.enums.get(e).expect("Enum not found");
self.enum_to_string(enm, &options)
})
.collect::<Vec<_>>();
let enum_definitions = Vec::from_iter(render_state.hoisted_enums.iter().map(|e| {
let enm = self.enums.get(e).expect("Enum not found"); // TODO: Jinja Err
self.enum_to_string(enm, &options)
}));

// Yeah we love the borrow checker...
let hoisted_classes = std::mem::replace(&mut render_state.hoisted_classes, IndexSet::new());
Expand Down Expand Up @@ -665,14 +656,11 @@ impl OutputFormatContent {
#[cfg(test)]
impl OutputFormatContent {
pub fn new_array() -> Self {
Self::target(FieldType::List(Box::new(FieldType::Primitive(
TypeValue::String,
))))
.build()
Self::target(FieldType::List(Box::new(FieldType::string()))).build()
}

pub fn new_string() -> Self {
Self::target(FieldType::Primitive(TypeValue::String)).build()
Self::target(FieldType::string()).build()
}
}

Expand Down

0 comments on commit e75f6d4

Please sign in to comment.