From 2d3d8ccbd1c3ff1fe53903fba255d6234ce92ea3 Mon Sep 17 00:00:00 2001 From: monoid Date: Mon, 17 Feb 2025 00:10:55 +0900 Subject: [PATCH] feat: implement type checking for expressions and add type scope management --- lib/eval.ml | 2 +- lib/parser.ml | 62 +++++++------- lib/typecheck.ml | 215 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 245 insertions(+), 34 deletions(-) create mode 100644 lib/typecheck.ml diff --git a/lib/eval.ml b/lib/eval.ml index a4dd1f3..8ed2751 100644 --- a/lib/eval.ml +++ b/lib/eval.ml @@ -63,7 +63,7 @@ and eval_fun_expr scope (ftree: Parser.fun_expr_tree) = Fun { argname = ftree.name; body = ftree.body_expr; scope = scope } and eval_bin_op_expr scope op left_expr right_expr = let left = eval_expr scope left_expr in - let right = eval_expr scope right_expr in + let right = eval_expr scope right_expr in (match op with | Add -> ( match (left, right) with diff --git a/lib/parser.ml b/lib/parser.ml index fd35a75..e1486e3 100644 --- a/lib/parser.ml +++ b/lib/parser.ml @@ -371,50 +371,46 @@ let get_expr_tree_from_tokens (tokens: (Token.t * Lexer.lexer_context) Seq.t): e | Some (e, _) -> Some e | None -> None +let normalize_calc_string (s: string): string = + Lexer.lex_tokens_seq s |> get_expr_tree_from_tokens |> Option.map expr2str |> Option.value ~default:"" + let%test "test get_expr_tree_from_tokens 1" = - let tokens = Lexer.lex_tokens_seq "let x = 1 in\n x" in - match get_expr_tree_from_tokens tokens with - | Some e -> expr2str e = "let x = 1 in\nx" - | None -> false + let actual = normalize_calc_string "let x = 1 in\n x" in + let expected = "let x = 1 in\nx" in + actual = expected let%test "test get_expr_tree_from_tokens 2" = - let tokens = Lexer.lex_tokens_seq "fun x -> x" in - match get_expr_tree_from_tokens tokens with - | Some e -> expr2str e = "fun x ->\nx" - | None -> false + let actual = normalize_calc_string "fun x -> x" in + let expected = "fun x ->\nx" in + actual = expected let%test "test get_expr_tree_from_tokens 3" = - let tokens = Lexer.lex_tokens_seq "if 1 then 2 else 3" in - match get_expr_tree_from_tokens tokens with - | Some e -> expr2str e = "if 1 then 2 else 3" - | None -> false + let actual = normalize_calc_string "if 1 then 2 else 3" in + let expected = "if 1 then 2 else 3" in + actual = expected let%test "test get_expr_tree_from_tokens 4" = - let tokens = Lexer.lex_tokens_seq "1 + 2 * 3" in - match get_expr_tree_from_tokens tokens with - | Some e -> expr2str e = "1 + 2 * 3" - | None -> false + let actual = normalize_calc_string "1 + 2 * 3" in + let expected = "1 + 2 * 3" in + actual = expected let%test "test get_expr_tree_from_tokens 5" = - let tokens = Lexer.lex_tokens_seq "x 1 2" in - match get_expr_tree_from_tokens tokens with - | Some e -> expr2str e = "x(1)(2)" - | None -> false - + let actual = normalize_calc_string "x 1 2" in + let expected = "x(1)(2)" in + actual = expected + let%test "test get_expr_tree_from_tokens 6 with type" = - let tokens = Lexer.lex_tokens_seq "let x: int = 1 in\n x" in - match get_expr_tree_from_tokens tokens with - | Some e -> expr2str e = "let x: int = 1 in\nx" - | None -> false + let actual = normalize_calc_string "let x: int = 1 in\n x" in + let expected = "let x: int = 1 in\nx" in + actual = expected let%test "test get_expr_tree_from_tokens 7 with type" = - let tokens = Lexer.lex_tokens_seq "fun (x: int) -> x" in - match get_expr_tree_from_tokens tokens with - | Some e -> expr2str e = "fun (x: int) ->\nx" - | None -> false + let actual = normalize_calc_string "fun (x: int) -> x" in + let expected = "fun (x: int) ->\nx" in + actual = expected let%test "test get_expr_tree_from_tokens 8" = - let tokens = Lexer.lex_tokens_seq "fun (x) -> x" in - match get_expr_tree_from_tokens tokens with - | Some e -> expr2str e = "fun x ->\nx" - | None -> false \ No newline at end of file + let actual = normalize_calc_string "fun (x) -> x" in + let expected = "fun x ->\nx" in + actual = expected + diff --git a/lib/typecheck.ml b/lib/typecheck.ml new file mode 100644 index 0000000..f45333b --- /dev/null +++ b/lib/typecheck.ml @@ -0,0 +1,215 @@ +type type_v = + | Int + | Fun of { + arg: type_v; + ret: type_v; + } + | Generic of string + | Universal + | Nothing + + +type type_scope = { + parent: type_scope option; + bindings: (string, type_v) Hashtbl.t; + generics_count: int ref; +} + +let make_type_scope (parent: type_scope): type_scope = { + parent = Some (parent); + bindings = Hashtbl.create 10; + generics_count = parent.generics_count; +} + +let make_top_type_scope (): type_scope = { + parent = None; + bindings = Hashtbl.create 10; + generics_count = ref 0; +} + +let rec typetree2type_v (t: Parser.type_tree): type_v = + match t with + | Parser.TypeIdentifier (x) -> if x = "int" then Int else + failwith "not implemented (type alias is not supported yet)" + | Parser.TypeArrow (arg, ret) -> Fun { arg = typetree2type_v arg; ret = typetree2type_v ret } + +let rec type_v2str (t: type_v): string = + match t with + | Int -> "int" + | Fun { arg = arg; ret = ret } -> Printf.sprintf "(%s -> %s)" (type_v2str arg) (type_v2str ret) + | Generic s -> "'" ^ s + | Universal -> "universal" + | Nothing -> "nothing" + +(* meet *) +let rec intersect_type_v (a: type_v) (b: type_v): type_v = + match a, b with + | Universal, _ -> b + | _, Universal -> a + | Int, Int -> Int + | Fun { arg = arg1; ret = ret1 }, Fun { arg = arg2; ret = ret2 } -> + (* contravariance *) + let arg = intersect_type_v arg1 arg2 in + let ret = intersect_type_v ret1 ret2 in + Fun { arg = arg; ret = ret } + (* // TODO: fix it *) + | Generic s1, Generic s2 when s1 = s2 -> Generic s1 + | Generic _, _ -> b + | _ -> Nothing +(* join *) +(* and union_type_v (a: type_v) (b: type_v): type_v = + match a, b with + | Universal, _ -> Universal + | _, Universal -> Universal + | Int, Int -> Int + | Fun { arg = arg1; ret = ret1 }, Fun { arg = arg2; ret = ret2 } -> + + | Generic s1, Generic s2 when s1 = s2 -> Generic s1 + | Generic _, _ -> b + | _ -> Nothing *) + +(* it assumes that there is already variable binding. *) + +let find_type_v_opt (scope: type_scope) (name: string): type_v option = + let rec find_binding scope = + match scope with + | None -> None + | Some s -> + match Hashtbl.find_opt s.bindings name with + | Some v -> Some v + | None -> find_binding s.parent in + find_binding (Some scope) + +(* it assumes that there is already variable binding. *) +let assert_and_get_type_v (scope: type_scope) (name: string) (expected: type_v) = + let rec assert_binding scope = + match scope with + | None -> failwith "Unbound variable" + | Some s -> + match Hashtbl.find_opt s.bindings name with + | Some v -> + let subtype = intersect_type_v v expected in + if subtype = Nothing then failwith "Type error" + else Hashtbl.replace s.bindings name subtype; + subtype + | None -> assert_binding s.parent in + assert_binding (Some scope) + +let gen_generic_free_name (scope: type_scope): string = + let generics_count = !(scope.generics_count) in + let name = Printf.sprintf "%d" generics_count in + scope.generics_count := generics_count + 1; + name + +let replace_generic_with (t: type_v) (from: string) (to_: type_v): type_v = + let rec replace t = + match t with + | Int -> Int + | Fun { arg = arg; ret = ret } -> Fun { arg = replace arg; ret = replace ret } + | Generic s when s = from -> to_ + | Generic s -> Generic s + | Universal -> Universal + | Nothing -> Nothing in + replace t + +let rec typecheck_expr (scope: type_scope) (expr: Parser.expr_tree) (required_type: type_v): type_v = + let actual_type = match expr with + | Parser.LetExpr (l) -> + typecheck_let_expr scope required_type l + | Parser.FunExpr (ftree) -> + typecheck_fun_expr scope required_type ftree + | Parser.IfExpr (Parser.If (cond_expr, then_expr, else_expr)) -> + typecheck_if_expr scope required_type cond_expr then_expr else_expr + | Parser.BinOpExpr (op, left_expr, right_expr) -> + typecheck_bin_op_expr scope required_type op left_expr right_expr + | Parser.MonoOpExpr (_op, _expr) -> + failwith "Not implemented" + | Parser.CallExpr (Parser.Call (func_expr, arg_expr)) -> + typecheck_call_expr scope required_type func_expr arg_expr + | Parser.Identifier(name) -> assert_and_get_type_v scope name required_type + | Parser.Number(_n) -> Int + in + let subtype = intersect_type_v required_type actual_type in + if subtype = Nothing then + failwith (Printf.sprintf "Type error: expect %s but actual %s" + (type_v2str required_type) (type_v2str actual_type) + ) + else subtype +and typecheck_let_expr (scope: type_scope) (required_type: type_v) ({ + name = name; + value_expr = value_expr; + in_expr = in_expr; + type_declare = type_decl; +}: Parser.let_expr_tree): type_v = + let value_reqired_type = type_decl |> Option.map typetree2type_v |> Option.value ~default: Universal in + let value_type = typecheck_expr scope value_expr value_reqired_type in + let new_scope = make_type_scope scope in + Hashtbl.add new_scope.bindings name value_type; + typecheck_expr new_scope in_expr required_type +and typecheck_fun_expr (scope: type_scope) (_required_type: type_v) ({ + name = argname; + body_expr = body_expr; + type_declare = type_decl; +}: Parser.fun_expr_tree): type_v = + let default_type = Generic (gen_generic_free_name scope) in + let arg_type = type_decl |> Option.map typetree2type_v |> Option.value ~default: default_type in + let new_scope = make_type_scope scope in + Hashtbl.add new_scope.bindings argname arg_type; + (* unreachable because *) + let ret_type = typecheck_expr new_scope body_expr Universal in + let arg_type = Hashtbl.find new_scope.bindings argname in + Printf.printf "arg: %s, ret: %s\n" (type_v2str arg_type) (type_v2str ret_type); + + Fun { arg = arg_type; ret = ret_type } +and typecheck_if_expr (scope: type_scope) (required_type: type_v) + (cond_expr: Parser.expr_tree) (then_expr: Parser.expr_tree) (else_expr: Parser.expr_tree): type_v = + let _ = typecheck_expr scope cond_expr Int in + let then_type = typecheck_expr scope then_expr required_type in + let else_type = typecheck_expr scope else_expr required_type in + intersect_type_v then_type else_type +and typecheck_bin_op_expr (scope: type_scope) (_required_type: type_v) + (_op: Parser.bin_op_type) (left_expr: Parser.expr_tree) (right_expr: Parser.expr_tree): type_v = + (* default int *) + let _ = typecheck_expr scope left_expr Int in + let _ = typecheck_expr scope right_expr Int in + Int +and typecheck_call_expr (scope: type_scope) (_required_type: type_v) + (func_expr: Parser.expr_tree) (arg_expr: Parser.expr_tree): type_v = + let func_type = typecheck_expr scope func_expr Universal in + Printf.printf "func_type: %s\n" (type_v2str func_type); + + match func_type with + | Fun { arg = arg_type; ret = ret_type } -> + let mono_arg_type = typecheck_expr scope arg_expr arg_type in + Printf.printf "arg_type: %s\n" (type_v2str mono_arg_type); + begin match arg_type with + | Generic s -> + (* instance *) + let new_ret_type = replace_generic_with ret_type s mono_arg_type in + Printf.printf "new_ret_type: %s\n" (type_v2str new_ret_type); + new_ret_type + | _ -> ret_type + end + | _ -> failwith "Type error" + +let typecheck (expr: Parser.expr_tree): type_v = + typecheck_expr (make_top_type_scope()) expr Universal + +let typecheck_result (expr: Parser.expr_tree): (type_v, exn) result = + try + let t = typecheck expr in + Result.Ok (t) + with e -> Result.Error e + +let test_typecheck (content:string) = + let tokens = Lexer.lex_tokens_seq content in + let expr = Parser.get_expr_tree_from_tokens tokens in + match expr with + | Some e -> typecheck_result e + | None -> Result.Error (Failure "parse error") + +let%test "typecheck 1" = + let expr = "let x = fun y -> y in x 1" in + match test_typecheck expr with + | Result.Ok (t) -> Printf.printf "%s\n" (type_v2str t); t = Int + | Result.Error _ -> Printf.printf "error\n"; false \ No newline at end of file