Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic groups inside arrays support. #120

Merged
merged 1 commit into from
Dec 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 71 additions & 39 deletions src/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ struct DeserializeConfig<'a> {
final_exprs: Vec<String>,
/// Overload for the deserializer's name. Defaults to "raw"
deserializer_name_overload: Option<&'a str>,
/// Overload for read_len. This would be a local e.g. for arrays
read_len_overload: Option<String>,
}

impl<'a> DeserializeConfig<'a> {
Expand All @@ -173,6 +175,7 @@ impl<'a> DeserializeConfig<'a> {
optional_field: false,
final_exprs: Vec::new(),
deserializer_name_overload: None,
read_len_overload: None,
}
}

Expand All @@ -194,6 +197,22 @@ impl<'a> DeserializeConfig<'a> {
fn deserializer_name(&self) -> &'a str {
self.deserializer_name_overload.unwrap_or("raw")
}

fn overload_read_len(mut self, overload: String) -> Self {
self.read_len_overload = Some(overload);
self
}

fn pass_read_len(&self) -> String {
if let Some(overload) = &self.read_len_overload {
// the ONLY way to have a name overload is if we have a local variable (e.g. arrays)
format!("&mut {}", overload)
} else if self.in_embedded {
"read_len".to_owned()
} else {
"&mut read_len".to_owned()
}
}
}

fn concat_files(paths: Vec<&str>) -> std::io::Result<String> {
Expand Down Expand Up @@ -944,7 +963,23 @@ impl GenerationScope {
}
},
SerializingRustType::Root(ConceptualRustType::Array(ty)) => {
start_len(body, Representation::Array, serializer_use, &encoding_var, &format!("{}.len() as u64", config.expr));
let len_expr = match &ty.conceptual_type {
ConceptualRustType::Rust(elem_ident) if types.is_plain_group(elem_ident) => {
// you should not be able to indiscriminately encode a plain group like this as it
// could be multiple elements. This would require special handling if it's even permitted in CDDL.
assert!(ty.encodings.is_empty());
if let Some(fixed_elem_size) = ty.conceptual_type.expanded_field_count(types) {
format!("{} * {}.len() as u64", fixed_elem_size, config.expr)
} else {
format!(
"{}.iter().map(|e| {}).sum()",
config.expr,
ty.conceptual_type.definite_info("e", types))
}
},
_ => format!("{}.len() as u64", config.expr)
};
start_len(body, Representation::Array, serializer_use, &encoding_var, &len_expr);
let elem_var_name = format!("{}_elem", config.var_name);
let elem_encs = if CLI_ARGS.preserve_encodings {
encoding_fields(&elem_var_name, &ty.clone().resolve_aliases(), false)
Expand Down Expand Up @@ -1305,17 +1340,12 @@ impl GenerationScope {
// a parameter whether it was an optional field, and if so, read_len.read_elems(embedded mandatory fields)?;
// since otherwise it'd only length check the optional fields within the type.
assert!(!config.optional_field);
let pass_read_len = if config.in_embedded {
"read_len"
} else {
"&mut read_len"
};
deser_code.read_len_used = true;
let final_expr_value = format!(
"{}::deserialize_as_embedded_group({}, {}, len)",
ident,
deserializer_name,
pass_read_len);
config.pass_read_len());

deser_code.content.line(&final_result_expr_complete(&mut deser_code.throws, config.final_exprs, &final_expr_value));
} else {
Expand Down Expand Up @@ -1451,20 +1481,39 @@ impl GenerationScope {
if CLI_ARGS.preserve_encodings {
deser_code.content
.line(&format!("let len = {}.array_sz()?;", deserializer_name))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignore this added mut. If we need to make any other changes to this PR I''ll remove it in that commit, otherwise we can remove it when we fix the other warnings in cddl-codegen.

.line(&format!("let {}_encoding = len.into();", config.var_name));
.line(&format!("let mut {}_encoding = len.into();", config.var_name));
if !elem_encs.is_empty() {
deser_code.content.line(&format!("let mut {}_elem_encodings = Vec::new();", config.var_name));
}
} else {
deser_code.content.line(&format!("let len = {}.array()?;", deserializer_name));
}
let mut deser_loop = make_deser_loop("len", &format!("{}.len()", arr_var_name));
let mut elem_config = DeserializeConfig::new(&elem_var_name);
let (mut deser_loop, plain_len_check) = match &ty.conceptual_type {
ConceptualRustType::Rust(ty_ident) if types.is_plain_group(&*ty_ident) => {
// two things that must be done differently for embedded plain groups:
// 1) We can't directly read the CBOR len's number of items since it could be >1
// 2) We need a different cbor read len var to pass into embedded deserialize
let read_len_overload = format!("{}_read_len", config.var_name);
deser_code.content.line(&format!("let mut {} = CBORReadLen::new(len);", read_len_overload));
// inside of deserialize_as_embedded_group we only modify read_len for things we couldn't
// statically know beforehand. This was done for other areas that use plain groups in order
// to be able to do static length checks for statically sized groups that contain plain groups
// at the start of deserialization instead of many checks for every single field.
let plain_len_check = match ty.conceptual_type.expanded_mandatory_field_count(types) {
0 => None,
n => Some(format!("{}.read_elems({})?;", read_len_overload, n)),
};
elem_config = elem_config.overload_read_len(read_len_overload);
let deser_loop = make_deser_loop("len", &format!("{}_read_len.read", config.var_name));
(deser_loop, plain_len_check)
},
_ => (make_deser_loop("len", &format!("({}.len() as u64)", arr_var_name)), None)
};
deser_loop.push_block(make_deser_loop_break_check());
if let ConceptualRustType::Rust(ty_ident) = &ty.conceptual_type {
// TODO: properly handle which read_len would be checked here.
assert!(!types.is_plain_group(&*ty_ident));
if let Some(plain_len_check) = plain_len_check {
deser_loop.line(plain_len_check);
}
let mut elem_config = DeserializeConfig::new(&elem_var_name);
elem_config.deserializer_name_overload = config.deserializer_name_overload;
if !elem_encs.is_empty() {
let elem_var_names_str = encoding_var_names_str(&elem_var_name, ty);
Expand Down Expand Up @@ -1543,7 +1592,7 @@ impl GenerationScope {
} else {
deser_code.content.line(&format!("let {} = {}.map()?;", len_var, deserializer_name));
}
let mut deser_loop = make_deser_loop(&len_var, &format!("{}.len()", table_var));
let mut deser_loop = make_deser_loop(&len_var, &format!("({}.len() as u64)", table_var));
deser_loop.push_block(make_deser_loop_break_check());
let mut key_config = DeserializeConfig::new(&key_var_name);
key_config.deserializer_name_overload = config.deserializer_name_overload;
Expand Down Expand Up @@ -2145,15 +2194,13 @@ fn create_base_wasm_wrapper<'a>(gen_scope: &GenerationScope, ident: &'a RustIden
// Alway creates directly just Serialize impl. Shortcut for create_serialize_impls when
// we know we won't need the SerializeEmbeddedGroup impl.
// See comments for create_serialize_impls for usage.
fn create_serialize_impl(ident: &RustIdent, rep: Option<Representation>, tag: Option<usize>, definite_len: Option<String>, use_this_encoding: Option<&str>) -> (codegen::Function, codegen::Impl) {
fn create_serialize_impl(ident: &RustIdent, rep: Option<Representation>, tag: Option<usize>, definite_len: &str, use_this_encoding: Option<&str>) -> (codegen::Function, codegen::Impl) {
match create_serialize_impls(ident, rep, tag, definite_len, use_this_encoding, false) {
(ser_func, ser_impl, None) => (ser_func, ser_impl),
(_ser_func, _ser_impl, Some(_embedded_impl)) => unreachable!(),
}
}

// If definite_len is provided, it will use that expression as the definite length.
// Otherwise indefinite will be used and the user should remember to write a Special::Break at the end.
// Returns (serialize, Serialize, Some(SerializeEmbeddedGroup)) impls for structs that require embedded, in which case
// the serialize calls the embedded serialize and you implement the embedded serialize
// Otherwise returns (serialize Serialize, None) impls and you implement the serialize.
Expand All @@ -2164,15 +2211,12 @@ fn create_serialize_impl(ident: &RustIdent, rep: Option<Representation>, tag: Op
// In the second case (no embedded), only the array/map tag + length are written and the user will
// want to write the rest of serialize() after that.
// * `use_this_encoding` - If present, references a variable (must be bool and in this scope) to toggle definite vs indefinite (e.g. for PRESERVE_ENCODING)
fn create_serialize_impls(ident: &RustIdent, rep: Option<Representation>, tag: Option<usize>, definite_len: Option<String>, use_this_encoding: Option<&str>, generate_serialize_embedded: bool) -> (codegen::Function, codegen::Impl, Option<codegen::Impl>) {
fn create_serialize_impls(ident: &RustIdent, rep: Option<Representation>, tag: Option<usize>, definite_len: &str, use_this_encoding: Option<&str>, generate_serialize_embedded: bool) -> (codegen::Function, codegen::Impl, Option<codegen::Impl>) {
if generate_serialize_embedded {
// This is not necessarily a problem but we should investigate this case to ensure we're not calling
// (de)serialize_as_embedded without (de)serializing the tag
assert_eq!(tag, None);
}
if use_this_encoding.is_some() && definite_len.is_none() {
panic!("definite_len is required for use_this_encoding or else we'd only be able to serialize indefinite no matter what");
}
let name = &ident.to_string();
let ser_impl = make_serialization_impl(name);
let mut ser_func = make_serialization_function("serialize");
Expand All @@ -2183,28 +2227,16 @@ fn create_serialize_impls(ident: &RustIdent, rep: Option<Representation>, tag: O
// TODO: do definite length encoding for optional fields too
if let Some (rep) = rep {
if let Some(definite) = use_this_encoding {
start_len(&mut ser_func, rep, "serializer", definite, definite_len.as_ref().unwrap());
start_len(&mut ser_func, rep, "serializer", definite, definite_len);
} else {
let len = match &definite_len {
Some(fixed_field_count) => cbor_event_len_n(fixed_field_count),
None => {
assert!(!CLI_ARGS.canonical_form);
cbor_event_len_indef().to_owned()
},
};
let len = cbor_event_len_n(definite_len);
match rep {
Representation::Array => ser_func.line(format!("serializer.write_array({})?;", len)),
Representation::Map => ser_func.line(format!("serializer.write_map({})?;", len)),
};
}
if generate_serialize_embedded {
match definite_len {
Some(_) => ser_func.line(format!("self.serialize_as_embedded_group(serializer{})", canonical_param())),
None => {
ser_func.line(format!("self.serialize_as_embedded_group(serializer{})?;", canonical_param()));
ser_func.line("serializer.write_special(CBORSpecial::Break)")
},
};
ser_func.line(format!("self.serialize_as_embedded_group(serializer{})", canonical_param()));
}
} else {
// not array or map, generate serialize directly
Expand Down Expand Up @@ -2387,7 +2419,7 @@ fn make_err_annotate_block(annotation: &str, before: &str, after: &str) -> Block
fn make_deser_loop(len_var: &str, len_expr: &str) -> Block {
Block::new(
&format!(
"while match {} {{ {} => {} < n as usize, {} => true, }}",
"while match {} {{ {} => {} < n, {} => true, }}",
len_var,
cbor_event_len_n("n"),
len_expr,
Expand Down Expand Up @@ -2857,7 +2889,7 @@ fn codegen_struct(gen_scope: &mut GenerationScope, types: &IntermediateTypes, na
name,
Some(record.rep),
tag,
record.definite_info(types),
&record.definite_info(types),
len_encoding_var.map(|var| format!("self.encodings.as_ref().map(|encs| encs.{}).unwrap_or_default()", var)).as_deref(),
types.is_plain_group(name));
let mut ser_func = match ser_embedded_impl {
Expand Down Expand Up @@ -3210,7 +3242,7 @@ fn codegen_struct(gen_scope: &mut GenerationScope, types: &IntermediateTypes, na
ser_func.line(format!(
"let deser_order = self.encodings.as_ref().filter(|encs| {}encs.orig_deser_order.len() == {}).map(|encs| encs.orig_deser_order.clone()).unwrap_or_else(|| {});",
check_canonical,
record.definite_info(types).expect("cannot fail for maps"),
record.definite_info(types),
serialization_order));
let mut ser_loop = codegen::Block::new("for field_index in deser_order");
let mut ser_loop_match = codegen::Block::new("match field_index");
Expand Down
Loading