diff --git a/codegenerator/common/rust.py b/codegenerator/common/rust.py index 86a1f2a..dffc54c 100644 --- a/codegenerator/common/rust.py +++ b/codegenerator/common/rust.py @@ -252,9 +252,19 @@ def type_hint(self): @property def imports(self): - imports: set[str] = set(["serde::Deserialize"]) - for field in self.fields.values(): - imports.update(field.data_type.imports) + imports: set[str] = set([]) + field_types = [x.data_type for x in self.fields.values()] + if len(field_types) > 1 or ( + len(field_types) == 1 + and not isinstance(field_types[0], Null) + and not isinstance(field_types[0], Dictionary) + and not isinstance(field_types[0], Array) + ): + # We use structure only if it is not consisting from only Null + imports.add("serde::Deserialize") + imports.add("serde::Serialize") + for field_type in field_types: + imports.update(field_type.imports) if self.additional_fields_type: imports.add("std::collections::BTreeMap") imports.update(self.additional_fields_type.imports) @@ -305,6 +315,8 @@ def type_hint(self): @property def imports(self): imports: set[str] = set() + imports.add("serde::Deserialize") + imports.add("serde::Serialize") for kind in self.kinds.values(): imports.update(kind.data_type.imports) return imports @@ -325,7 +337,7 @@ def clap_macros(self) -> set[str]: class StringEnum(BaseCompoundType): base_type: str = "enum" variants: dict[str, set[str]] = {} - imports: set[str] = set([]) + imports: set[str] = set(["serde::Deserialize", "serde::Serialize"]) lifetimes: set[str] = set() derive_container_macros: str = ( "#[derive(Debug, Deserialize, Clone, Serialize)]" @@ -765,6 +777,9 @@ def _simplify_oneof_combinations(self, type_model, kinds): elif string_klass in kinds_classes and dict_klass in kinds_classes: # oneOf [string, dummy object] => JsonValue # Simple string can be easily represented by JsonValue + for c in kinds: + # Discard dict + self.ignored_models.append(c["model"]) kinds.clear() jsonval_klass = self.primitive_type_mapping[model.PrimitiveAny] kinds.append({"local": jsonval_klass(), "class": jsonval_klass}) @@ -880,8 +895,11 @@ def get_root_data_type(self): def get_imports(self): """Get complete set of additional imports required by all models in scope""" imports: set[str] = set() - for item in self.refs.values(): - imports.update(item.imports) + imports.update(self.get_root_data_type().imports) + for subt in self.get_subtypes(): + imports.update(subt.imports) + # for item in self.refs.values(): + # imports.update(item.imports) for param in self.parameters.values(): imports.update(param.data_type.imports) return imports @@ -904,6 +922,8 @@ def subtype_requires_private_builders(self, subtype) -> bool: for field in subtype.fields.values(): if "private" in field.builder_macros: return True + if isinstance(subtype, Struct) and subtype.additional_fields_type: + return True return False def set_parameters(self, parameters: list[model.RequestParameter]) -> None: diff --git a/codegenerator/rust_cli.py b/codegenerator/rust_cli.py index 3b83c10..6b49446 100644 --- a/codegenerator/rust_cli.py +++ b/codegenerator/rust_cli.py @@ -1152,6 +1152,7 @@ def generate( ) if args.operation_type == "download": additional_imports.add("crate::common::download_file") + if args.operation_type == "upload": additional_imports.add( "crate::common::build_upload_asyncread" @@ -1212,11 +1213,15 @@ def generate( ] ) ) + # Discard unnecessry imports + additional_imports.discard("http::Response") + additional_imports.discard("bytes::Bytes") additional_imports.update(type_manager.get_imports()) additional_imports.update(response_type_manager.get_imports()) # Deserialize is already in template since it is uncoditionally required additional_imports.discard("serde::Deserialize") + additional_imports.discard("serde::Serialize") command_description: str = spec.get("description") command_summary: str = spec.get("summary") diff --git a/codegenerator/templates/rust_sdk/find.rs.j2 b/codegenerator/templates/rust_sdk/find.rs.j2 index cb739d7..4c47859 100644 --- a/codegenerator/templates/rust_sdk/find.rs.j2 +++ b/codegenerator/templates/rust_sdk/find.rs.j2 @@ -17,15 +17,13 @@ {% import 'rust_macros.j2' as macros with context -%} use derive_builder::Builder; use http::{HeaderMap, HeaderName, HeaderValue}; -use serde::de::DeserializeOwned; -use tracing::trace; -use crate::api::common::CommaSeparatedList; use crate::api::find::Findable; use crate::api::rest_endpoint_prelude::*; -use crate::api::ParamValue; - -use crate::api::{ApiError, Client, Pageable, Query, RestClient}; +{%- if not name_filter_supported %} +use crate::api::{ApiError, RestClient}; +use tracing::trace; +{%- endif %} use crate::api::{{ mod_path | join("::") }}::{ get as Get, diff --git a/codegenerator/templates/rust_sdk/impl.rs.j2 b/codegenerator/templates/rust_sdk/impl.rs.j2 index 1dd0dd6..dd91c67 100644 --- a/codegenerator/templates/rust_sdk/impl.rs.j2 +++ b/codegenerator/templates/rust_sdk/impl.rs.j2 @@ -20,7 +20,6 @@ use derive_builder::Builder; use http::{HeaderMap, HeaderName, HeaderValue}; use crate::api::rest_endpoint_prelude::*; -use serde::Serialize; {% for mod in type_manager.get_imports() | sort %} use {{ mod }}; @@ -30,7 +29,7 @@ use {{ mod }}; use json_patch::Patch; {%- endif %} -{%- if operation_type == "list" %} +{%- if operation_type == "list" and "limit" in type_manager.parameters.keys() or "marker" in type_manager.parameters.keys() %} use crate::api::Pageable; {%- endif %} @@ -266,14 +265,17 @@ impl{{ type_manager.get_request_static_lifetimes(request) }} Pageable for Reques #[cfg(test)] mod tests { + #![allow(unused_imports)] use super::*; - use crate::api::{self, Query, RawQuery}; +{%- if method.upper() == "HEAD" %} + use crate::api::RawQuery; +{%- else %} + use crate::api::Query; + use serde_json::json; +{%- endif %} use crate::types::ServiceType; use crate::test::client::MockServerClient; use http::{HeaderName, HeaderValue}; - use serde::Deserialize; - use serde::Serialize; - use serde_json::json; {%- if is_json_patch %} use serde_json::from_value; use json_patch::Patch; @@ -370,7 +372,7 @@ mod tests { .header("not_foo", "not_bar") .build() .unwrap(); - {%- if method.upper() != "HEAD" %} +{%- if method.upper() != "HEAD" %} let _: serde_json::Value = endpoint.query(&client).unwrap(); {%- else %} let _ = endpoint.raw_query(&client).unwrap(); diff --git a/codegenerator/templates/rust_sdk/subtypes.j2 b/codegenerator/templates/rust_sdk/subtypes.j2 index af780ce..65940fe 100644 --- a/codegenerator/templates/rust_sdk/subtypes.j2 +++ b/codegenerator/templates/rust_sdk/subtypes.j2 @@ -32,6 +32,12 @@ pub {{ subtype.base_type }} {{ subtype.name }}{{ ("<" + ",".join(subtype.lifetim {{ k }}, {%- endfor %} {%- endif %} + + {%- if subtype.base_type == "struct" and subtype.additional_fields_type %} + + #[builder(setter(name = "_properties"), default, private)] + _properties: BTreeMap, {{ subtype.additional_fields_type.type_hint }}>, + {%- endif %} } {% if type_manager.subtype_requires_private_builders(subtype) %} @@ -42,6 +48,22 @@ impl{{ ("<" + ",".join(subtype.lifetimes) + ">") if subtype.lifetimes else ""}} {{ macros.sdk_builder_setter(field)}} {%- endif %} {%- endfor %} + + {% if subtype.additional_fields_type is defined and subtype.additional_fields_type %} + pub fn properties(&mut self, iter: I) -> &mut Self + where + I: Iterator, + K: Into>, + V: Into<{{ subtype.additional_fields_type.type_hint }}>, + { + self._properties + .get_or_insert_with(BTreeMap::new) + .extend(iter.map(|(k, v)| (k.into(), v.into()))); + self + } + + {%- endif %} + } -{%- endif %} +{% endif %} {%- endfor %}