From c981d38179c569135c65fce5b23d2b41f856678a Mon Sep 17 00:00:00 2001 From: Afsal Thaj Date: Wed, 2 Oct 2024 12:37:07 +1000 Subject: [PATCH] Simplify type registry --- golem-rib/src/expr.rs | 8 +-- golem-rib/src/inferred_type.rs | 30 +++++++---- golem-rib/src/interpreter/mod.rs | 2 +- golem-rib/src/interpreter/rib_interpreter.rs | 2 + ...ference.rs => call_arguments_inference.rs} | 46 ++++++++++++---- .../src/type_inference/enum_resolution.rs | 8 +-- golem-rib/src/type_inference/mod.rs | 4 +- .../src/type_inference/variant_resolution.rs | 23 ++++---- golem-rib/src/type_registry.rs | 52 ++++++++++++++----- 9 files changed, 121 insertions(+), 54 deletions(-) rename golem-rib/src/type_inference/{function_type_inference.rs => call_arguments_inference.rs} (92%) diff --git a/golem-rib/src/expr.rs b/golem-rib/src/expr.rs index f33ed518e..cbdefb7ab 100644 --- a/golem-rib/src/expr.rs +++ b/golem-rib/src/expr.rs @@ -419,10 +419,10 @@ impl Expr { self.bind_types(); self.name_binding_pattern_match_variables(); self.name_binding_local_variables(); - self.infer_function_types(function_type_registry) - .map_err(|x| vec![x])?; self.infer_variants(function_type_registry); self.infer_enums(function_type_registry); + self.infer_call_arguments_type(function_type_registry) + .map_err(|x| vec![x])?; type_inference::type_inference_fix_point(Self::inference_scan, self) .map_err(|x| vec![x])?; self.unify_types()?; @@ -452,11 +452,11 @@ impl Expr { } // At this point we simply update the types to the parameter type expressions and the call expression itself. - pub fn infer_function_types( + pub fn infer_call_arguments_type( &mut self, function_type_registry: &FunctionTypeRegistry, ) -> Result<(), String> { - type_inference::infer_function_types(self, function_type_registry) + type_inference::infer_call_arguments_type(self, function_type_registry) } pub fn push_types_down(&mut self) -> Result<(), String> { diff --git a/golem-rib/src/inferred_type.rs b/golem-rib/src/inferred_type.rs index 6577def69..bee7f8269 100644 --- a/golem-rib/src/inferred_type.rs +++ b/golem-rib/src/inferred_type.rs @@ -1117,6 +1117,25 @@ impl InferredType { } } } + + pub fn from_variant_cases(type_variant: &TypeVariant) -> InferredType { + let cases = type_variant + .cases + .iter() + .map(|name_type_pair| { + ( + name_type_pair.name.clone(), + name_type_pair.typ.clone().map(|t| t.into()), + ) + }) + .collect(); + + InferredType::Variant(cases) + } + + pub fn from_enum_cases(type_enum: &TypeEnum) -> InferredType { + InferredType::Enum(type_enum.cases.clone()) + } } impl From for InferredType { @@ -1146,7 +1165,7 @@ impl From for InferredType { .collect(), ), AnalysedType::Flags(vs) => InferredType::Flags(vs.names), - AnalysedType::Enum(vs) => InferredType::Enum(vs.cases), + AnalysedType::Enum(vs) => InferredType::from_enum_cases(&vs), AnalysedType::Option(t) => InferredType::Option(Box::new((*t.inner).into())), AnalysedType::Result(golem_wasm_ast::analysis::TypeResult { ok, err, .. }) => { InferredType::Result { @@ -1154,14 +1173,7 @@ impl From for InferredType { error: err.map(|t| Box::new((*t).into())), } } - AnalysedType::Variant(vs) => InferredType::Variant( - vs.cases - .into_iter() - .map(|name_type_pair| { - (name_type_pair.name, name_type_pair.typ.map(|t| t.into())) - }) - .collect(), - ), + AnalysedType::Variant(vs) => InferredType::from_variant_cases(&vs), AnalysedType::Handle(golem_wasm_ast::analysis::TypeHandle { resource_id, mode }) => { InferredType::Resource { resource_id: resource_id.0, diff --git a/golem-rib/src/interpreter/mod.rs b/golem-rib/src/interpreter/mod.rs index 270f33c35..cf78699c1 100644 --- a/golem-rib/src/interpreter/mod.rs +++ b/golem-rib/src/interpreter/mod.rs @@ -42,4 +42,4 @@ pub async fn interpret_pure( ) -> Result { let mut interpreter = Interpreter::pure(rib_input.clone()); interpreter.run(rib.clone()).await -} \ No newline at end of file +} diff --git a/golem-rib/src/interpreter/rib_interpreter.rs b/golem-rib/src/interpreter/rib_interpreter.rs index 3588504e5..9aac657ff 100644 --- a/golem-rib/src/interpreter/rib_interpreter.rs +++ b/golem-rib/src/interpreter/rib_interpreter.rs @@ -535,6 +535,8 @@ mod internal { analysed_type: AnalysedType, interpreter: &mut Interpreter, ) -> Result<(), String> { + dbg!(variant_name.clone()); + dbg!(analysed_type.clone()); match analysed_type { AnalysedType::Variant(variants) => { let variant = variants diff --git a/golem-rib/src/type_inference/function_type_inference.rs b/golem-rib/src/type_inference/call_arguments_inference.rs similarity index 92% rename from golem-rib/src/type_inference/function_type_inference.rs rename to golem-rib/src/type_inference/call_arguments_inference.rs index 484d195ec..42f2154cd 100644 --- a/golem-rib/src/type_inference/function_type_inference.rs +++ b/golem-rib/src/type_inference/call_arguments_inference.rs @@ -16,7 +16,7 @@ use crate::type_registry::FunctionTypeRegistry; use crate::Expr; use std::collections::VecDeque; -pub fn infer_function_types( +pub fn infer_call_arguments_type( expr: &mut Expr, function_type_registry: &FunctionTypeRegistry, ) -> Result<(), String> { @@ -25,7 +25,7 @@ pub fn infer_function_types( while let Some(expr) = queue.pop_back() { match expr { Expr::Call(parsed_fn_name, args, inferred_type) => { - internal::resolve_call_expressions( + internal::resolve_call_argument_types( parsed_fn_name, function_type_registry, args, @@ -47,7 +47,7 @@ mod internal { use golem_wasm_ast::analysis::AnalysedType; use std::fmt::Display; - pub(crate) fn resolve_call_expressions( + pub(crate) fn resolve_call_argument_types( call_type: &mut CallType, function_type_registry: &FunctionTypeRegistry, args: &mut [Expr], @@ -112,7 +112,24 @@ mod internal { } } - _ => Ok(()), + CallType::EnumConstructor(_) => { + if args.is_empty() { + Ok(()) + } else { + Err("Enum constructor does not take any arguments".to_string()) + } + } + + CallType::VariantConstructor(variant_name) => { + let registry_key = RegistryKey::FunctionName(variant_name.clone()); + infer_types( + &FunctionNameInternal::VariantName(variant_name.clone()), + function_type_registry, + registry_key, + args, + inferred_type, + ) + } } } @@ -125,7 +142,7 @@ mod internal { ) -> Result<(), String> { if let Some(value) = function_type_registry.types.get(&key) { match value { - RegistryValue::Value(_) => {} + RegistryValue::Value(_) => Ok(()), RegistryValue::Function { parameter_types, return_types, @@ -153,26 +170,29 @@ mod internal { return_types.iter().map(|t| t.clone().into()).collect(), ) } - } + }; + + Ok(()) } else { - return Err(format!( + Err(format!( "Function {} expects {} arguments, but {} were provided", function_name, parameter_types.len(), args.len() - )); + )) } } } + } else { + Err(format!("Unknown function/variant call {}", function_name)) } - - Ok(()) } enum FunctionNameInternal { ResourceConstructorName(String), ResourceMethodName(String), Fqn(ParsedFunctionName), + VariantName(String), } impl Display for FunctionNameInternal { @@ -187,6 +207,9 @@ mod internal { FunctionNameInternal::Fqn(name) => { write!(f, "{}", name) } + FunctionNameInternal::VariantName(name) => { + write!(f, "{}", name) + } } } } @@ -408,7 +431,8 @@ mod function_parameters_inference_tests { let function_type_registry = get_function_type_registry(); let mut expr = Expr::from_text(rib_expr).unwrap(); - expr.infer_function_types(&function_type_registry).unwrap(); + expr.infer_call_arguments_type(&function_type_registry) + .unwrap(); let let_binding = Expr::let_binding("x", Expr::number(1f64)); diff --git a/golem-rib/src/type_inference/enum_resolution.rs b/golem-rib/src/type_inference/enum_resolution.rs index 11226529f..58937995e 100644 --- a/golem-rib/src/type_inference/enum_resolution.rs +++ b/golem-rib/src/type_inference/enum_resolution.rs @@ -23,6 +23,7 @@ pub fn infer_enums(expr: &mut Expr, function_type_registry: &FunctionTypeRegistr mod internal { use crate::call_type::CallType; use crate::{Expr, FunctionTypeRegistry, RegistryKey, RegistryValue}; + use golem_wasm_ast::analysis::AnalysedType; use std::collections::VecDeque; pub(crate) fn convert_identifiers_to_enum_function_calls( @@ -62,12 +63,13 @@ mod internal { match expr { Expr::Identifier(variable_id, inferred_type) => { // Retrieve the possible no-arg variant from the registry - let key = RegistryKey::EnumName(variable_id.name().clone()); - if let Some(RegistryValue::Value(analysed_type)) = + let key = RegistryKey::FunctionName(variable_id.name().clone()); + if let Some(RegistryValue::Value(AnalysedType::Enum(typed_enum))) = function_type_registry.types.get(&key) { enum_cases.push(variable_id.name()); - *inferred_type = inferred_type.merge(analysed_type.clone().into()); + *inferred_type = inferred_type + .merge(AnalysedType::Enum(typed_enum.clone()).clone().into()); } } diff --git a/golem-rib/src/type_inference/mod.rs b/golem-rib/src/type_inference/mod.rs index 9f2b96182..d0d2aa7a2 100644 --- a/golem-rib/src/type_inference/mod.rs +++ b/golem-rib/src/type_inference/mod.rs @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +pub use call_arguments_inference::*; pub use enum_resolution::*; pub use expr_visitor::*; -pub use function_type_inference::*; pub use global_input_inference::*; pub use identifier_inference::*; pub use inference_fix_point::*; @@ -29,8 +29,8 @@ pub use type_reset::*; pub use type_unification::*; pub use variant_resolution::*; +mod call_arguments_inference; mod expr_visitor; -mod function_type_inference; mod identifier_inference; mod name_binding; mod pattern_match_binding; diff --git a/golem-rib/src/type_inference/variant_resolution.rs b/golem-rib/src/type_inference/variant_resolution.rs index 298b214d3..dc109c337 100644 --- a/golem-rib/src/type_inference/variant_resolution.rs +++ b/golem-rib/src/type_inference/variant_resolution.rs @@ -24,7 +24,8 @@ pub fn infer_variants(expr: &mut Expr, function_type_registry: &FunctionTypeRegi mod internal { use crate::call_type::CallType; - use crate::{Expr, FunctionTypeRegistry, RegistryKey, RegistryValue}; + use crate::{Expr, FunctionTypeRegistry, InferredType, RegistryKey, RegistryValue}; + use golem_wasm_ast::analysis::AnalysedType; use std::collections::VecDeque; pub(crate) fn convert_function_calls_to_variant_calls( @@ -88,26 +89,26 @@ mod internal { while let Some(expr) = queue.pop_back() { match expr { Expr::Identifier(variable_id, inferred_type) => { - let key = RegistryKey::VariantName(variable_id.name().clone()); - if let Some(RegistryValue::Value(analysed_type)) = + let key = RegistryKey::FunctionName(variable_id.name().clone()); + if let Some(RegistryValue::Value(AnalysedType::Variant(type_variant))) = function_type_registry.types.get(&key) { no_arg_variants.push(variable_id.name()); - *inferred_type = inferred_type.merge(analysed_type.clone().into()); + *inferred_type = + inferred_type.merge(InferredType::from_variant_cases(type_variant)); } } Expr::Call(CallType::Function(parsed_function_name), exprs, inferred_type) => { - let key = RegistryKey::VariantName(parsed_function_name.to_string()); - if let Some(RegistryValue::Function { return_types, .. }) = + let key = RegistryKey::FunctionName(parsed_function_name.to_string()); + if let Some(RegistryValue::Value(AnalysedType::Variant(type_variant))) = function_type_registry.types.get(&key) { - variant_with_args.push(parsed_function_name.to_string()); + let variant_inferred_type = + InferredType::from_variant_cases(type_variant); + *inferred_type = inferred_type.merge(variant_inferred_type); - // TODO; return type is only 1 in reality for variants - we can make this typed - if let Some(variant_type) = return_types.first() { - *inferred_type = inferred_type.merge(variant_type.clone().into()); - } + variant_with_args.push(parsed_function_name.to_string()); } for expr in exprs { diff --git a/golem-rib/src/type_registry.rs b/golem-rib/src/type_registry.rs index dad0f6a04..55d7ecf9c 100644 --- a/golem-rib/src/type_registry.rs +++ b/golem-rib/src/type_registry.rs @@ -29,8 +29,6 @@ use std::collections::{HashMap, HashSet}; // then the RegistryValue is simply an AnalysedType representing the variant type itself. #[derive(Hash, Eq, PartialEq, Clone, Debug)] pub enum RegistryKey { - VariantName(String), - EnumName(String), FunctionName(String), FunctionNameWithInterface { interface_name: String, @@ -51,9 +49,9 @@ impl RegistryKey { pub fn from_invocation_name(invocation_name: &CallType) -> RegistryKey { match invocation_name { CallType::VariantConstructor(variant_name) => { - RegistryKey::VariantName(variant_name.clone()) + RegistryKey::FunctionName(variant_name.clone()) } - CallType::EnumConstructor(enum_name) => RegistryKey::EnumName(enum_name.clone()), + CallType::EnumConstructor(enum_name) => RegistryKey::FunctionName(enum_name.clone()), CallType::Function(function_name) => match function_name.site.interface_name() { None => RegistryKey::FunctionName(function_name.function_name()), Some(interface_name) => RegistryKey::FunctionNameWithInterface { @@ -117,16 +115,16 @@ impl FunctionTypeRegistry { }) .collect::>(); - let registry_value = RegistryValue::Function { - parameter_types, - return_types, - }; - let registry_key = RegistryKey::FunctionNameWithInterface { interface_name: interface_name.clone(), function_name: function_name.clone(), }; + let registry_value = RegistryValue::Function { + parameter_types, + return_types, + }; + map.insert(registry_key, registry_value); } } @@ -179,7 +177,7 @@ impl FunctionTypeRegistry { mod internal { use crate::{RegistryKey, RegistryValue}; - use golem_wasm_ast::analysis::AnalysedType; + use golem_wasm_ast::analysis::{AnalysedType, TypeResult}; use std::collections::HashMap; pub(crate) fn update_registry( @@ -189,7 +187,7 @@ mod internal { match ty.clone() { AnalysedType::Variant(variant) => { for name_type_pair in variant.cases { - registry.insert(RegistryKey::VariantName(name_type_pair.name.clone()), { + registry.insert(RegistryKey::FunctionName(name_type_pair.name.clone()), { name_type_pair.typ.map_or( RegistryValue::Value(ty.clone()), |variant_parameter_typ| RegistryValue::Function { @@ -204,7 +202,7 @@ mod internal { AnalysedType::Enum(type_enum) => { for name_type_pair in type_enum.cases { registry.insert( - RegistryKey::EnumName(name_type_pair.clone()), + RegistryKey::FunctionName(name_type_pair.clone()), RegistryValue::Value(ty.clone()), ); } @@ -226,7 +224,35 @@ mod internal { } } - _ => {} + AnalysedType::Result(TypeResult {ok: Some(ok_type), err: Some(err_type)}) => { + update_registry(ok_type.as_ref(), registry); + update_registry(err_type.as_ref(), registry); + } + AnalysedType::Result(TypeResult {ok: None, err: Some(err_type)}) => { + update_registry(err_type.as_ref(), registry); + } + AnalysedType::Result(TypeResult {ok: Some(ok_type), err: None}) => { + update_registry(ok_type.as_ref(), registry); + } + AnalysedType::Option(type_option) => { + update_registry(type_option.inner.as_ref(), registry); + } + AnalysedType::Result(_) => {} + AnalysedType::Flags(_) => {} + AnalysedType::Str(_) => {} + AnalysedType::Chr(_) => {} + AnalysedType::F64(_) => {} + AnalysedType::F32(_) => {} + AnalysedType::U64(_) => {} + AnalysedType::S64(_) => {} + AnalysedType::U32(_) => {} + AnalysedType::S32(_) => {} + AnalysedType::U16(_) => {} + AnalysedType::S16(_) => {} + AnalysedType::U8(_) => {} + AnalysedType::S8(_) => {} + AnalysedType::Bool(_) => {} + AnalysedType::Handle(_) => {} } } }