diff --git a/src/semantic/mod.rs b/src/semantic/mod.rs index ed460f6..055283d 100644 --- a/src/semantic/mod.rs +++ b/src/semantic/mod.rs @@ -1,3 +1,6 @@ +use crate::ast_types::Expression; +use crate::symbol_table::{_STRING, _BOOLEAN}; + use super::symbol_table::{SymbolTable, _NUMBER}; use super::ast_types::{ModuleAST, Binding}; @@ -7,44 +10,49 @@ pub fn check_ast<'a>(ast: &'a mut ModuleAST, symbol_table: &'a mut SymbolTable) match binding { Binding::Val(val_binding) => { // TODO: create a function to get the datatype, instead of a hardcoded value - symbol_table.add(val_binding.identifier, _NUMBER); + symbol_table.add( + val_binding.identifier, + get_expression_type(&val_binding.expression).as_str() + ); } Binding::Var(var_binding) => { // TODO: create a function to get the datatype, instead of a hardcoded value - symbol_table.add(var_binding.identifier, _NUMBER); + symbol_table.add( + var_binding.identifier, + get_expression_type(&var_binding.expression).as_str(), + ); } } } } +fn get_expression_type(exp: &Expression) -> String { + match exp { + Expression::Number(_) => String::from(_NUMBER), + Expression::String(_) => String::from(_STRING), + Expression::Boolean(_) => String::from(_BOOLEAN), + } +} + #[cfg(test)] mod tests { + use crate::symbol_table::_BOOLEAN; + use crate::symbol_table::_STRING; use crate::syntax; use crate::lexic; use super::*; - /* - val identifier = 20 + fn test_type(input: String, datatype: &str) -> bool { + let tokens = lexic::get_tokens(&input).unwrap(); + let mut table = SymbolTable::new(); + let mut ast = syntax::construct_ast(&tokens).unwrap(); - [Binding] - | identifier - | [Expression] - | [Number] - | 20 + check_ast(&mut ast, &mut table); - - Check [Expression] is valid - - Check type of [Expression] - - Check if `identifier` already exists in the symbol table - - Create entry in symbol table - - -> - - SymbolTable { - identifier: Num + table.check_type("a", datatype) } - */ - + #[test] fn should_update_symbol_table() { let tokens = lexic::get_tokens(&String::from("val identifier = 20")).unwrap(); @@ -56,4 +64,16 @@ mod tests { let result = table.test("identifier"); assert_eq!(true, result); } + + #[test] + fn should_get_correct_type() { + assert!(test_type(String::from("val a = 322"), _NUMBER)); + assert!(test_type(String::from("var a = 322"), _NUMBER)); + + assert!(test_type(String::from("val a = \"str\" "), _STRING)); + assert!(test_type(String::from("var a = \"str\" "), _STRING)); + + assert!(test_type(String::from("val a = false"), _BOOLEAN)); + assert!(test_type(String::from("var a = true"), _BOOLEAN)); + } }