Skip to content

Commit

Permalink
Simplify type registry
Browse files Browse the repository at this point in the history
  • Loading branch information
afsalthaj committed Oct 2, 2024
1 parent ae1d669 commit c981d38
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 54 deletions.
8 changes: 4 additions & 4 deletions golem-rib/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -419,10 +419,10 @@ impl Expr {
self.bind_types();
self.name_binding_pattern_match_variables();
self.name_binding_local_variables();
self.infer_function_types(function_type_registry)
.map_err(|x| vec![x])?;
self.infer_variants(function_type_registry);
self.infer_enums(function_type_registry);
self.infer_call_arguments_type(function_type_registry)
.map_err(|x| vec![x])?;
type_inference::type_inference_fix_point(Self::inference_scan, self)
.map_err(|x| vec![x])?;
self.unify_types()?;
Expand Down Expand Up @@ -452,11 +452,11 @@ impl Expr {
}

// At this point we simply update the types to the parameter type expressions and the call expression itself.
pub fn infer_function_types(
pub fn infer_call_arguments_type(
&mut self,
function_type_registry: &FunctionTypeRegistry,
) -> Result<(), String> {
type_inference::infer_function_types(self, function_type_registry)
type_inference::infer_call_arguments_type(self, function_type_registry)
}

pub fn push_types_down(&mut self) -> Result<(), String> {
Expand Down
30 changes: 21 additions & 9 deletions golem-rib/src/inferred_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,25 @@ impl InferredType {
}
}
}

pub fn from_variant_cases(type_variant: &TypeVariant) -> InferredType {
let cases = type_variant
.cases
.iter()
.map(|name_type_pair| {
(
name_type_pair.name.clone(),
name_type_pair.typ.clone().map(|t| t.into()),
)
})
.collect();

InferredType::Variant(cases)
}

pub fn from_enum_cases(type_enum: &TypeEnum) -> InferredType {
InferredType::Enum(type_enum.cases.clone())
}
}

impl From<AnalysedType> for InferredType {
Expand Down Expand Up @@ -1146,22 +1165,15 @@ impl From<AnalysedType> for InferredType {
.collect(),
),
AnalysedType::Flags(vs) => InferredType::Flags(vs.names),
AnalysedType::Enum(vs) => InferredType::Enum(vs.cases),
AnalysedType::Enum(vs) => InferredType::from_enum_cases(&vs),
AnalysedType::Option(t) => InferredType::Option(Box::new((*t.inner).into())),
AnalysedType::Result(golem_wasm_ast::analysis::TypeResult { ok, err, .. }) => {
InferredType::Result {
ok: ok.map(|t| Box::new((*t).into())),
error: err.map(|t| Box::new((*t).into())),
}
}
AnalysedType::Variant(vs) => InferredType::Variant(
vs.cases
.into_iter()
.map(|name_type_pair| {
(name_type_pair.name, name_type_pair.typ.map(|t| t.into()))
})
.collect(),
),
AnalysedType::Variant(vs) => InferredType::from_variant_cases(&vs),
AnalysedType::Handle(golem_wasm_ast::analysis::TypeHandle { resource_id, mode }) => {
InferredType::Resource {
resource_id: resource_id.0,
Expand Down
2 changes: 1 addition & 1 deletion golem-rib/src/interpreter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ pub async fn interpret_pure(
) -> Result<RibInterpreterResult, String> {
let mut interpreter = Interpreter::pure(rib_input.clone());
interpreter.run(rib.clone()).await
}
}
2 changes: 2 additions & 0 deletions golem-rib/src/interpreter/rib_interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,8 @@ mod internal {
analysed_type: AnalysedType,
interpreter: &mut Interpreter,
) -> Result<(), String> {
dbg!(variant_name.clone());
dbg!(analysed_type.clone());
match analysed_type {
AnalysedType::Variant(variants) => {
let variant = variants
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::type_registry::FunctionTypeRegistry;
use crate::Expr;
use std::collections::VecDeque;

pub fn infer_function_types(
pub fn infer_call_arguments_type(
expr: &mut Expr,
function_type_registry: &FunctionTypeRegistry,
) -> Result<(), String> {
Expand All @@ -25,7 +25,7 @@ pub fn infer_function_types(
while let Some(expr) = queue.pop_back() {
match expr {
Expr::Call(parsed_fn_name, args, inferred_type) => {
internal::resolve_call_expressions(
internal::resolve_call_argument_types(
parsed_fn_name,
function_type_registry,
args,
Expand All @@ -47,7 +47,7 @@ mod internal {
use golem_wasm_ast::analysis::AnalysedType;
use std::fmt::Display;

pub(crate) fn resolve_call_expressions(
pub(crate) fn resolve_call_argument_types(
call_type: &mut CallType,
function_type_registry: &FunctionTypeRegistry,
args: &mut [Expr],
Expand Down Expand Up @@ -112,7 +112,24 @@ mod internal {
}
}

_ => Ok(()),
CallType::EnumConstructor(_) => {
if args.is_empty() {
Ok(())
} else {
Err("Enum constructor does not take any arguments".to_string())
}
}

CallType::VariantConstructor(variant_name) => {
let registry_key = RegistryKey::FunctionName(variant_name.clone());
infer_types(
&FunctionNameInternal::VariantName(variant_name.clone()),
function_type_registry,
registry_key,
args,
inferred_type,
)
}
}
}

Expand All @@ -125,7 +142,7 @@ mod internal {
) -> Result<(), String> {
if let Some(value) = function_type_registry.types.get(&key) {
match value {
RegistryValue::Value(_) => {}
RegistryValue::Value(_) => Ok(()),
RegistryValue::Function {
parameter_types,
return_types,
Expand Down Expand Up @@ -153,26 +170,29 @@ mod internal {
return_types.iter().map(|t| t.clone().into()).collect(),
)
}
}
};

Ok(())
} else {
return Err(format!(
Err(format!(
"Function {} expects {} arguments, but {} were provided",
function_name,
parameter_types.len(),
args.len()
));
))
}
}
}
} else {
Err(format!("Unknown function/variant call {}", function_name))
}

Ok(())
}

enum FunctionNameInternal {
ResourceConstructorName(String),
ResourceMethodName(String),
Fqn(ParsedFunctionName),
VariantName(String),
}

impl Display for FunctionNameInternal {
Expand All @@ -187,6 +207,9 @@ mod internal {
FunctionNameInternal::Fqn(name) => {
write!(f, "{}", name)
}
FunctionNameInternal::VariantName(name) => {
write!(f, "{}", name)
}
}
}
}
Expand Down Expand Up @@ -408,7 +431,8 @@ mod function_parameters_inference_tests {
let function_type_registry = get_function_type_registry();

let mut expr = Expr::from_text(rib_expr).unwrap();
expr.infer_function_types(&function_type_registry).unwrap();
expr.infer_call_arguments_type(&function_type_registry)
.unwrap();

let let_binding = Expr::let_binding("x", Expr::number(1f64));

Expand Down
8 changes: 5 additions & 3 deletions golem-rib/src/type_inference/enum_resolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub fn infer_enums(expr: &mut Expr, function_type_registry: &FunctionTypeRegistr
mod internal {
use crate::call_type::CallType;
use crate::{Expr, FunctionTypeRegistry, RegistryKey, RegistryValue};
use golem_wasm_ast::analysis::AnalysedType;
use std::collections::VecDeque;

pub(crate) fn convert_identifiers_to_enum_function_calls(
Expand Down Expand Up @@ -62,12 +63,13 @@ mod internal {
match expr {
Expr::Identifier(variable_id, inferred_type) => {
// Retrieve the possible no-arg variant from the registry
let key = RegistryKey::EnumName(variable_id.name().clone());
if let Some(RegistryValue::Value(analysed_type)) =
let key = RegistryKey::FunctionName(variable_id.name().clone());
if let Some(RegistryValue::Value(AnalysedType::Enum(typed_enum))) =
function_type_registry.types.get(&key)
{
enum_cases.push(variable_id.name());
*inferred_type = inferred_type.merge(analysed_type.clone().into());
*inferred_type = inferred_type
.merge(AnalysedType::Enum(typed_enum.clone()).clone().into());
}
}

Expand Down
4 changes: 2 additions & 2 deletions golem-rib/src/type_inference/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

pub use call_arguments_inference::*;
pub use enum_resolution::*;
pub use expr_visitor::*;
pub use function_type_inference::*;
pub use global_input_inference::*;
pub use identifier_inference::*;
pub use inference_fix_point::*;
Expand All @@ -29,8 +29,8 @@ pub use type_reset::*;
pub use type_unification::*;
pub use variant_resolution::*;

mod call_arguments_inference;
mod expr_visitor;
mod function_type_inference;
mod identifier_inference;
mod name_binding;
mod pattern_match_binding;
Expand Down
23 changes: 12 additions & 11 deletions golem-rib/src/type_inference/variant_resolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ pub fn infer_variants(expr: &mut Expr, function_type_registry: &FunctionTypeRegi

mod internal {
use crate::call_type::CallType;
use crate::{Expr, FunctionTypeRegistry, RegistryKey, RegistryValue};
use crate::{Expr, FunctionTypeRegistry, InferredType, RegistryKey, RegistryValue};
use golem_wasm_ast::analysis::AnalysedType;
use std::collections::VecDeque;

pub(crate) fn convert_function_calls_to_variant_calls(
Expand Down Expand Up @@ -88,26 +89,26 @@ mod internal {
while let Some(expr) = queue.pop_back() {
match expr {
Expr::Identifier(variable_id, inferred_type) => {
let key = RegistryKey::VariantName(variable_id.name().clone());
if let Some(RegistryValue::Value(analysed_type)) =
let key = RegistryKey::FunctionName(variable_id.name().clone());
if let Some(RegistryValue::Value(AnalysedType::Variant(type_variant))) =
function_type_registry.types.get(&key)
{
no_arg_variants.push(variable_id.name());
*inferred_type = inferred_type.merge(analysed_type.clone().into());
*inferred_type =
inferred_type.merge(InferredType::from_variant_cases(type_variant));
}
}

Expr::Call(CallType::Function(parsed_function_name), exprs, inferred_type) => {
let key = RegistryKey::VariantName(parsed_function_name.to_string());
if let Some(RegistryValue::Function { return_types, .. }) =
let key = RegistryKey::FunctionName(parsed_function_name.to_string());
if let Some(RegistryValue::Value(AnalysedType::Variant(type_variant))) =
function_type_registry.types.get(&key)
{
variant_with_args.push(parsed_function_name.to_string());
let variant_inferred_type =
InferredType::from_variant_cases(type_variant);
*inferred_type = inferred_type.merge(variant_inferred_type);

// TODO; return type is only 1 in reality for variants - we can make this typed
if let Some(variant_type) = return_types.first() {
*inferred_type = inferred_type.merge(variant_type.clone().into());
}
variant_with_args.push(parsed_function_name.to_string());
}

for expr in exprs {
Expand Down
Loading

0 comments on commit c981d38

Please sign in to comment.