Skip to content
This repository has been archived by the owner on Apr 10, 2024. It is now read-only.

Commit

Permalink
chore: improve unused imports handling
Browse files Browse the repository at this point in the history
  • Loading branch information
gtema committed Feb 8, 2024
1 parent a8f0fa3 commit 11ce53b
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 20 deletions.
32 changes: 26 additions & 6 deletions codegenerator/common/rust.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,19 @@ def type_hint(self):

@property
def imports(self):
imports: set[str] = set(["serde::Deserialize"])
for field in self.fields.values():
imports.update(field.data_type.imports)
imports: set[str] = set([])
field_types = [x.data_type for x in self.fields.values()]
if len(field_types) > 1 or (
len(field_types) == 1
and not isinstance(field_types[0], Null)
and not isinstance(field_types[0], Dictionary)
and not isinstance(field_types[0], Array)
):
# We use structure only if it is not consisting from only Null
imports.add("serde::Deserialize")
imports.add("serde::Serialize")
for field_type in field_types:
imports.update(field_type.imports)
if self.additional_fields_type:
imports.add("std::collections::BTreeMap")
imports.update(self.additional_fields_type.imports)
Expand Down Expand Up @@ -305,6 +315,8 @@ def type_hint(self):
@property
def imports(self):
imports: set[str] = set()
imports.add("serde::Deserialize")
imports.add("serde::Serialize")
for kind in self.kinds.values():
imports.update(kind.data_type.imports)
return imports
Expand All @@ -325,7 +337,7 @@ def clap_macros(self) -> set[str]:
class StringEnum(BaseCompoundType):
base_type: str = "enum"
variants: dict[str, set[str]] = {}
imports: set[str] = set([])
imports: set[str] = set(["serde::Deserialize", "serde::Serialize"])
lifetimes: set[str] = set()
derive_container_macros: str = (
"#[derive(Debug, Deserialize, Clone, Serialize)]"
Expand Down Expand Up @@ -765,6 +777,9 @@ def _simplify_oneof_combinations(self, type_model, kinds):
elif string_klass in kinds_classes and dict_klass in kinds_classes:
# oneOf [string, dummy object] => JsonValue
# Simple string can be easily represented by JsonValue
for c in kinds:
# Discard dict
self.ignored_models.append(c["model"])
kinds.clear()
jsonval_klass = self.primitive_type_mapping[model.PrimitiveAny]
kinds.append({"local": jsonval_klass(), "class": jsonval_klass})
Expand Down Expand Up @@ -880,8 +895,11 @@ def get_root_data_type(self):
def get_imports(self):
"""Get complete set of additional imports required by all models in scope"""
imports: set[str] = set()
for item in self.refs.values():
imports.update(item.imports)
imports.update(self.get_root_data_type().imports)
for subt in self.get_subtypes():
imports.update(subt.imports)
# for item in self.refs.values():
# imports.update(item.imports)
for param in self.parameters.values():
imports.update(param.data_type.imports)
return imports
Expand All @@ -904,6 +922,8 @@ def subtype_requires_private_builders(self, subtype) -> bool:
for field in subtype.fields.values():
if "private" in field.builder_macros:
return True
if isinstance(subtype, Struct) and subtype.additional_fields_type:
return True
return False

def set_parameters(self, parameters: list[model.RequestParameter]) -> None:
Expand Down
5 changes: 5 additions & 0 deletions codegenerator/rust_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,7 @@ def generate(
)
if args.operation_type == "download":
additional_imports.add("crate::common::download_file")

if args.operation_type == "upload":
additional_imports.add(
"crate::common::build_upload_asyncread"
Expand Down Expand Up @@ -1212,11 +1213,15 @@ def generate(
]
)
)
# Discard unnecessry imports
additional_imports.discard("http::Response")
additional_imports.discard("bytes::Bytes")

additional_imports.update(type_manager.get_imports())
additional_imports.update(response_type_manager.get_imports())
# Deserialize is already in template since it is uncoditionally required
additional_imports.discard("serde::Deserialize")
additional_imports.discard("serde::Serialize")

command_description: str = spec.get("description")
command_summary: str = spec.get("summary")
Expand Down
10 changes: 4 additions & 6 deletions codegenerator/templates/rust_sdk/find.rs.j2
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@
{% import 'rust_macros.j2' as macros with context -%}
use derive_builder::Builder;
use http::{HeaderMap, HeaderName, HeaderValue};
use serde::de::DeserializeOwned;
use tracing::trace;

use crate::api::common::CommaSeparatedList;
use crate::api::find::Findable;
use crate::api::rest_endpoint_prelude::*;
use crate::api::ParamValue;

use crate::api::{ApiError, Client, Pageable, Query, RestClient};
{%- if not name_filter_supported %}
use crate::api::{ApiError, RestClient};
use tracing::trace;
{%- endif %}

use crate::api::{{ mod_path | join("::") }}::{
get as Get,
Expand Down
16 changes: 9 additions & 7 deletions codegenerator/templates/rust_sdk/impl.rs.j2
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ use derive_builder::Builder;
use http::{HeaderMap, HeaderName, HeaderValue};

use crate::api::rest_endpoint_prelude::*;
use serde::Serialize;

{% for mod in type_manager.get_imports() | sort %}
use {{ mod }};
Expand All @@ -30,7 +29,7 @@ use {{ mod }};
use json_patch::Patch;
{%- endif %}

{%- if operation_type == "list" %}
{%- if operation_type == "list" and "limit" in type_manager.parameters.keys() or "marker" in type_manager.parameters.keys() %}
use crate::api::Pageable;
{%- endif %}

Expand Down Expand Up @@ -266,14 +265,17 @@ impl{{ type_manager.get_request_static_lifetimes(request) }} Pageable for Reques
#[cfg(test)]
mod tests {
#![allow(unused_imports)]
use super::*;
use crate::api::{self, Query, RawQuery};
{%- if method.upper() == "HEAD" %}
use crate::api::RawQuery;
{%- else %}
use crate::api::Query;
use serde_json::json;
{%- endif %}
use crate::types::ServiceType;
use crate::test::client::MockServerClient;
use http::{HeaderName, HeaderValue};
use serde::Deserialize;
use serde::Serialize;
use serde_json::json;
{%- if is_json_patch %}
use serde_json::from_value;
use json_patch::Patch;
Expand Down Expand Up @@ -370,7 +372,7 @@ mod tests {
.header("not_foo", "not_bar")
.build()
.unwrap();
{%- if method.upper() != "HEAD" %}
{%- if method.upper() != "HEAD" %}
let _: serde_json::Value = endpoint.query(&client).unwrap();
{%- else %}
let _ = endpoint.raw_query(&client).unwrap();
Expand Down
24 changes: 23 additions & 1 deletion codegenerator/templates/rust_sdk/subtypes.j2
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ pub {{ subtype.base_type }} {{ subtype.name }}{{ ("<" + ",".join(subtype.lifetim
{{ k }},
{%- endfor %}
{%- endif %}

{%- if subtype.base_type == "struct" and subtype.additional_fields_type %}

#[builder(setter(name = "_properties"), default, private)]
_properties: BTreeMap<Cow<'a, str>, {{ subtype.additional_fields_type.type_hint }}>,
{%- endif %}
}
{% if type_manager.subtype_requires_private_builders(subtype) %}
Expand All @@ -42,6 +48,22 @@ impl{{ ("<" + ",".join(subtype.lifetimes) + ">") if subtype.lifetimes else ""}}
{{ macros.sdk_builder_setter(field)}}
{%- endif %}
{%- endfor %}
{% if subtype.additional_fields_type is defined and subtype.additional_fields_type %}
pub fn properties<I, K, V>(&mut self, iter: I) -> &mut Self
where
I: Iterator<Item = (K, V)>,
K: Into<Cow<'a, str>>,
V: Into<{{ subtype.additional_fields_type.type_hint }}>,
{
self._properties
.get_or_insert_with(BTreeMap::new)
.extend(iter.map(|(k, v)| (k.into(), v.into())));
self
}

{%- endif %}

}
{%- endif %}
{% endif %}
{%- endfor %}

0 comments on commit 11ce53b

Please sign in to comment.