diff --git a/Interpreter/Symbols.h b/Interpreter/Symbols.h new file mode 100644 index 0000000..5781e1f --- /dev/null +++ b/Interpreter/Symbols.h @@ -0,0 +1,107 @@ +#pragma once +#include +#include +#include +#include "../Parser/astnodes.h" +class Symbol +{ +private: +public: + Symbol(std::string n = "", Symbol* t = NULL) { + name = n; + type = t; + } + std::string name; + Symbol* type; + std::string print() { + if (type == NULL) return name; + return "<" + name + ":" + type->print() + ">"; + } +}; +class BuiltinTypeSymbol : public Symbol { +public: + BuiltinTypeSymbol(std::string n) : Symbol(n) { + + } +}; +class VarSymbol : public Symbol +{ +public: + VarSymbol(std::string n, BuiltinTypeSymbol* t) : Symbol(n, t) { + + } + +private: + +}; +class SymbolTable { +public: + std::map symbols; + SymbolTable() { + define(new BuiltinTypeSymbol("int")); + define(new BuiltinTypeSymbol("real")); + } + void define(Symbol* symbol) { + std::cout << "Define: " << symbol->print() << std::endl; + symbols[symbol->name] = symbol; + } + Symbol* lookup(std::string name) { + std::cout << "Lookup: " << name << "\n"; + if (symbols.count(name)) { + return symbols[name]; + } + else { + return new Symbol(); + } + } +}; +class SymbolTableBuilder { +public: + SymbolTable symtab; + void visit(AstNode * node) { + if (node->print() == "Block") visit_Block(*static_cast(node)); + else if (node->print() == "UnOp") visit_UnOp(*static_cast(node)); + else if (node->print() == "Var") visit_Var(*static_cast(node)); + else if (node->print() == "BinOp") return visit_BinOp(*static_cast(node)); + else if (node->print() == "Assign") visit_Assign(*static_cast(node)); + } + void visit_Block(Block block) { + for (VarDecl decl : block.declarations) { + visit_VarDecl(decl); + } + for (AstNode* node : block.children) { + visit(node); + } + } + void visit_UnOp(UnOp unOp) { + visit(unOp.expr); + } + void visit_BinOp(BinOp binOp) { + visit(binOp.left); + visit(binOp.right); + } + void visit_Assign(class Assign assign) { + std::string var_name = assign.var.value; + Symbol* var_symbol = symtab.lookup(var_name); + if (var_symbol->name == "") { + std::string error = "Error: variable not defined: " + var_name; + throw error; + } + visit(assign.right); + } + void visit_Var(Var var) { + std::string var_name = var.value; + Symbol* var_symbol = symtab.lookup(var_name); + if (var_symbol->name == "") { + std::string error = "Error: variable not defined: " + var_name; + throw error; + } + } + void visit_VarDecl(VarDecl varDecl) { + std::string type_name = varDecl.type.type; + BuiltinTypeSymbol* type_symbol = static_cast(symtab.lookup(type_name)); + std::string var_name = varDecl.var.value; + VarSymbol* var = new VarSymbol(var_name, type_symbol); + symtab.define(var); + } +}; \ No newline at end of file diff --git a/Interpreter/interpreter.h b/Interpreter/interpreter.h index 2a80745..b2dc300 100644 --- a/Interpreter/interpreter.h +++ b/Interpreter/interpreter.h @@ -2,72 +2,196 @@ #include "../Parser/parser.h" #include #include +#include +#include "Symbols.h" +#include class Interpreter { +private: + Parser parser; + SymbolTable symTab; public: - std::map GLOBAL_SCOPE; + union val + { + int i; + double d; + }; + std::map GLOBAL_SCOPE; Interpreter(std::string input) : parser(input) { + } ~Interpreter() {} - void visit(AstNode* node) { - if (node->print() == "Block") visit_Block(*static_cast(node)); - else if (node->print() == "Assign") visit_Assign(*static_cast(node)); - else if (node->print() == "NoOp") visit_NoOp(); + template + T visit(AstNode* node) { + if (typeid(T) == typeid(void)) { + if (node->print() == "Block") visit_Block(*static_cast(node)); + else if (node->print() == "Assign") visit_Assign(*static_cast(node)); + else if (node->print() == "NoOp") visit_NoOp(); + else { + std::string error = "Error: void operation not recognized"; + throw error; + } + } + else if (typeid(T) == typeid(int) || typeid(T) == typeid(double)) { + if (node->print() == "BinOp") { + BinOp binOp = *static_cast(node); + return visit_BinOp(binOp); + } + else if (node->print() == "Num") { + return visit_Num(*static_cast(node)); + } + else if (node->print() == "UnOp") { + return visit_UnOp(*static_cast(node)); + } + else if (node->print() == "Var") { + return visit_Var(*static_cast(node)); + } + else { + std::string error = "Error: not recognized"; + throw error; + } + } else { - std::string error = "Error: not recognized"; + std::string error("Type not recognized"); throw error; } } - double visit_Real(AstNode* node) { - if (node->print() == "Num") return visit_Num(*static_cast(node)); - else if (node->print() == "UnOp") return visit_UnOp(*static_cast(node)); - else if (node->print() == "Var") return visit_Var(*static_cast(node)); - else if (node->print() == "BinOp") return visit_BinOp(*static_cast(node)); + template + T visit_BinOp(BinOp binOp) { + if (typeid(T) == typeid(int)) + { + if (binOp.op.type == Div) { + std::string error = "Error: Float division and int are incompatible types"; + throw error; + } + else if (binOp.op.type == Plus) { + return (T)(visit(binOp.left) + visit(binOp.right)); + } + else if (binOp.op.type == Minus) { + return (T)(visit(binOp.left) - visit(binOp.right)); + } + else if (binOp.op.type == Times) return (T)(visit(binOp.left) * visit(binOp.right)); + else if (binOp.op.type == IntDiv) return (T)(visit(binOp.left) / visit(binOp.right)); + else { + std::string error("Error: BinOp operation not recognized."); + throw error; + } + } + else if (typeid(T) == typeid(double)) { + if (binOp.op.type == Plus) { + return (T)(visit(binOp.left) + visit(binOp.right)); + } + else if (binOp.op.type == Minus) { + return (T)(visit(binOp.left) - visit(binOp.right)); + } + else if (binOp.op.type == Times) return (T)(visit(binOp.left) * visit(binOp.right)); + else if (binOp.op.type == IntDiv) return (T)(visit(binOp.left) / visit(binOp.right)); + else if (binOp.op.type == Div) { + return (T)(visit(binOp.left) / visit(binOp.right)); + } + else { + std::string error("Error: BinOp operation not recognized."); + throw error; + } + } else { - std::string error = "Error: not recognized"; + std::string error("Type not recognized"); throw error; } } - double visit_Num(Num num) { - return num.value; + template + T visit_Num(Num num) { + if (typeid(T) == typeid(void)) { + return (T)(1); + } + if (typeid(T) == typeid(int)) { + if (ceil(num.value) == num.value) return (T)num.value; + else { + std::string error = "Error: Wanted int but got double"; + throw error; + } + } + else return (T)num.value; + } - double visit_BinOp(BinOp binOp) { - if (binOp.op.type == Plus) return visit_Real(binOp.left) + visit_Real(binOp.right); - else if (binOp.op.type == Minus) return visit_Real(binOp.left) - visit_Real(binOp.right); - else if (binOp.op.type == Times) return visit_Real(binOp.left) * visit_Real(binOp.right); - else if (binOp.op.type == Div) return (double)(visit_Real(binOp.left) / (double)visit_Real(binOp.right)); - else if (binOp.op.type == IntDiv) return (int)(visit_Real(binOp.left) / visit_Real(binOp.right)); - else { - std::string error("Error: BinOp operation not recognized."); + template + T visit_UnOp(UnOp unOp) { + if (typeid(T) == typeid(void)) { + std::string error = "void unOp? srsly"; throw error; } - } - double visit_UnOp(UnOp unOp) { TokenType op = unOp.op.type; - if (op == Plus) return visit_Real(unOp.expr); - else return 0 - visit_Real(unOp.expr); + if (typeid(T) == typeid(int)) { + if (op == Plus) return (T)(visit(unOp.expr)); + else return (T)(0 - visit(unOp.expr)); + } + if (typeid(T) == typeid(double)) { + if (op == Plus) return (T)(visit(unOp.expr)); + else return (T)(0 - visit(unOp.expr)); + } + else { + std::string error("Type not recognized"); + throw error; + } } void visit_Block(Block block) { for (AstNode* child : block.children) { - visit(child); + visit(child); } } void visit_Assign(class Assign assign) { - string var_name = assign.var.value; - GLOBAL_SCOPE[var_name] = visit_Real(assign.right); - } - double visit_Var(Var var) { - if (GLOBAL_SCOPE.find(var.value) != GLOBAL_SCOPE.end()) return GLOBAL_SCOPE[var.value]; + std::string var_name = assign.var.value; + std::string type = symTab.lookup(var_name)->type->name; + cout << type << endl; + if (type == "int") { + GLOBAL_SCOPE[var_name].i = visit(assign.right); + return; + } + else if (type == "real") { + GLOBAL_SCOPE[var_name].d = visit(assign.right); + return; + } else { - std::string error = "Error: variable not found"; + std::string error("Var type not supported"); throw error; } } + template + T visit_Var(Var var) { + if (GLOBAL_SCOPE.find(var.value) != GLOBAL_SCOPE.end()) { + if (typeid(T) == typeid(int)) { + std::string type = symTab.lookup(var.value)->type->name; + if (type == "int") return (T)GLOBAL_SCOPE[var.value].i; + else { + std::string error("Wanted integer, got " + type); + throw error; + } + } + else if (typeid(T) == typeid(double)) { + std::string type = symTab.lookup(var.value)->type->name; + if (type == "real") return (T)GLOBAL_SCOPE[var.value].d; + else { + std::string error("Wanted real, got " + type); + throw error; + } + } + else { + std::string error("Var type invalid"); + throw error; + } + } + else return (T)0; + } + SymbolTable getSymTab() { + return symTab; + } void visit_NoOp() {} void interpret() { Block block = parser.parseProgram(); + SymbolTableBuilder symtabBuilder; + symtabBuilder.visit(&block); + std::cout << "Finished building symtab...\n"; + symTab = symtabBuilder.symtab; visit_Block(block); } -private: - Parser parser; }; \ No newline at end of file diff --git a/Lexer/Lexer.h b/Lexer/Lexer.h index c2dc083..60d3748 100644 --- a/Lexer/Lexer.h +++ b/Lexer/Lexer.h @@ -21,7 +21,8 @@ class Lexer { {"int", INT}, {"real", REAL}, {"vars", VARS}, - {"program", PROGRAM} + {"program", PROGRAM}, + {"func", FUNCTION} }; public: Lexer(string inputPass) { diff --git a/Lexer/tokens.h b/Lexer/tokens.h index 3a9e395..49b154e 100644 --- a/Lexer/tokens.h +++ b/Lexer/tokens.h @@ -39,6 +39,7 @@ enum TokenType { REAL, VARS, PROGRAM, + FUNCTION, // Miscellaneous Colon, Comma diff --git a/main.cpp b/main.cpp index 26cd1ed..2f66945 100644 --- a/main.cpp +++ b/main.cpp @@ -3,6 +3,7 @@ #include "Parser/parser.h" #include "Lexer/Lexer.h" #include "Interpreter/interpreter.h" +#include "Interpreter/Symbols.h" // Tip: Don't use using namespace, see https://bit.ly/aaron_help_CPP_GUIDELINE_1 using namespace std; @@ -22,26 +23,48 @@ int interpret(string input) { block = parser.parseProgram(); } catch (string error) { - cout << error; + cout << error << "\n"; return 1; } cout << "Parser blocks:" << endl; for (int i = 0; i < block.size(); i++) { cout << block[i]->print() << endl; } + cout << "SymbolTable: " << endl; + SymbolTableBuilder symtabBuilder; + try { + symtabBuilder.visit(&block); + } + catch (std::string error) { + cout << error << "\n"; + return 1; + } Interpreter interpreter(input); + cout << "Interpreting...\n"; try { interpreter.interpret(); } catch (std::string error) { - cout << error; + cout << error << "\n"; return 1; } cout << "Variables: " << endl; + SymbolTable symtab = interpreter.getSymTab(); for (auto const& pair : interpreter.GLOBAL_SCOPE) { - cout << pair.first << " " << pair.second << endl; + cout << pair.first << " "; + std::string type = symtab.lookup(pair.first)->type->name; + if (type == "int") { + cout << pair.second.i; + } + else if (type == "real") { + cout << pair.second.d; + } + cout << endl; } + + + return 0; } int main () { diff --git a/test.txt b/test.txt index e1446d1..0825e8e 100644 --- a/test.txt +++ b/test.txt @@ -1,12 +1,14 @@ -main() { + main() { vars: - int: a, b, c; + int: a, c, d; + real: b; program: - a = 1*1+4/-2; + a = 1*1+4//-2; b = -13+4.5; - c = (a - b)//3; + c = (a - d)//3; + d = 40; { program: - c = (a - b)//3; + c = (a - d)//3; } -} \ No newline at end of file +}