feat: implement type checking for expressions and add type scope management
This commit is contained in:
parent
b503bd29e8
commit
2d3d8ccbd1
3 changed files with 245 additions and 34 deletions
|
@ -63,7 +63,7 @@ and eval_fun_expr scope (ftree: Parser.fun_expr_tree) =
|
||||||
Fun { argname = ftree.name; body = ftree.body_expr; scope = scope }
|
Fun { argname = ftree.name; body = ftree.body_expr; scope = scope }
|
||||||
and eval_bin_op_expr scope op left_expr right_expr =
|
and eval_bin_op_expr scope op left_expr right_expr =
|
||||||
let left = eval_expr scope left_expr in
|
let left = eval_expr scope left_expr in
|
||||||
let right = eval_expr scope right_expr in
|
let right = eval_expr scope right_expr in
|
||||||
(match op with
|
(match op with
|
||||||
| Add -> (
|
| Add -> (
|
||||||
match (left, right) with
|
match (left, right) with
|
||||||
|
|
|
@ -371,50 +371,46 @@ let get_expr_tree_from_tokens (tokens: (Token.t * Lexer.lexer_context) Seq.t): e
|
||||||
| Some (e, _) -> Some e
|
| Some (e, _) -> Some e
|
||||||
| None -> None
|
| None -> None
|
||||||
|
|
||||||
|
let normalize_calc_string (s: string): string =
|
||||||
|
Lexer.lex_tokens_seq s |> get_expr_tree_from_tokens |> Option.map expr2str |> Option.value ~default:""
|
||||||
|
|
||||||
let%test "test get_expr_tree_from_tokens 1" =
|
let%test "test get_expr_tree_from_tokens 1" =
|
||||||
let tokens = Lexer.lex_tokens_seq "let x = 1 in\n x" in
|
let actual = normalize_calc_string "let x = 1 in\n x" in
|
||||||
match get_expr_tree_from_tokens tokens with
|
let expected = "let x = 1 in\nx" in
|
||||||
| Some e -> expr2str e = "let x = 1 in\nx"
|
actual = expected
|
||||||
| None -> false
|
|
||||||
|
|
||||||
let%test "test get_expr_tree_from_tokens 2" =
|
let%test "test get_expr_tree_from_tokens 2" =
|
||||||
let tokens = Lexer.lex_tokens_seq "fun x -> x" in
|
let actual = normalize_calc_string "fun x -> x" in
|
||||||
match get_expr_tree_from_tokens tokens with
|
let expected = "fun x ->\nx" in
|
||||||
| Some e -> expr2str e = "fun x ->\nx"
|
actual = expected
|
||||||
| None -> false
|
|
||||||
|
|
||||||
let%test "test get_expr_tree_from_tokens 3" =
|
let%test "test get_expr_tree_from_tokens 3" =
|
||||||
let tokens = Lexer.lex_tokens_seq "if 1 then 2 else 3" in
|
let actual = normalize_calc_string "if 1 then 2 else 3" in
|
||||||
match get_expr_tree_from_tokens tokens with
|
let expected = "if 1 then 2 else 3" in
|
||||||
| Some e -> expr2str e = "if 1 then 2 else 3"
|
actual = expected
|
||||||
| None -> false
|
|
||||||
|
|
||||||
let%test "test get_expr_tree_from_tokens 4" =
|
let%test "test get_expr_tree_from_tokens 4" =
|
||||||
let tokens = Lexer.lex_tokens_seq "1 + 2 * 3" in
|
let actual = normalize_calc_string "1 + 2 * 3" in
|
||||||
match get_expr_tree_from_tokens tokens with
|
let expected = "1 + 2 * 3" in
|
||||||
| Some e -> expr2str e = "1 + 2 * 3"
|
actual = expected
|
||||||
| None -> false
|
|
||||||
|
|
||||||
let%test "test get_expr_tree_from_tokens 5" =
|
let%test "test get_expr_tree_from_tokens 5" =
|
||||||
let tokens = Lexer.lex_tokens_seq "x 1 2" in
|
let actual = normalize_calc_string "x 1 2" in
|
||||||
match get_expr_tree_from_tokens tokens with
|
let expected = "x(1)(2)" in
|
||||||
| Some e -> expr2str e = "x(1)(2)"
|
actual = expected
|
||||||
| None -> false
|
|
||||||
|
|
||||||
let%test "test get_expr_tree_from_tokens 6 with type" =
|
let%test "test get_expr_tree_from_tokens 6 with type" =
|
||||||
let tokens = Lexer.lex_tokens_seq "let x: int = 1 in\n x" in
|
let actual = normalize_calc_string "let x: int = 1 in\n x" in
|
||||||
match get_expr_tree_from_tokens tokens with
|
let expected = "let x: int = 1 in\nx" in
|
||||||
| Some e -> expr2str e = "let x: int = 1 in\nx"
|
actual = expected
|
||||||
| None -> false
|
|
||||||
|
|
||||||
let%test "test get_expr_tree_from_tokens 7 with type" =
|
let%test "test get_expr_tree_from_tokens 7 with type" =
|
||||||
let tokens = Lexer.lex_tokens_seq "fun (x: int) -> x" in
|
let actual = normalize_calc_string "fun (x: int) -> x" in
|
||||||
match get_expr_tree_from_tokens tokens with
|
let expected = "fun (x: int) ->\nx" in
|
||||||
| Some e -> expr2str e = "fun (x: int) ->\nx"
|
actual = expected
|
||||||
| None -> false
|
|
||||||
|
|
||||||
let%test "test get_expr_tree_from_tokens 8" =
|
let%test "test get_expr_tree_from_tokens 8" =
|
||||||
let tokens = Lexer.lex_tokens_seq "fun (x) -> x" in
|
let actual = normalize_calc_string "fun (x) -> x" in
|
||||||
match get_expr_tree_from_tokens tokens with
|
let expected = "fun x ->\nx" in
|
||||||
| Some e -> expr2str e = "fun x ->\nx"
|
actual = expected
|
||||||
| None -> false
|
|
||||||
|
|
215
lib/typecheck.ml
Normal file
215
lib/typecheck.ml
Normal file
|
@ -0,0 +1,215 @@
|
||||||
|
type type_v =
|
||||||
|
| Int
|
||||||
|
| Fun of {
|
||||||
|
arg: type_v;
|
||||||
|
ret: type_v;
|
||||||
|
}
|
||||||
|
| Generic of string
|
||||||
|
| Universal
|
||||||
|
| Nothing
|
||||||
|
|
||||||
|
|
||||||
|
type type_scope = {
|
||||||
|
parent: type_scope option;
|
||||||
|
bindings: (string, type_v) Hashtbl.t;
|
||||||
|
generics_count: int ref;
|
||||||
|
}
|
||||||
|
|
||||||
|
let make_type_scope (parent: type_scope): type_scope = {
|
||||||
|
parent = Some (parent);
|
||||||
|
bindings = Hashtbl.create 10;
|
||||||
|
generics_count = parent.generics_count;
|
||||||
|
}
|
||||||
|
|
||||||
|
let make_top_type_scope (): type_scope = {
|
||||||
|
parent = None;
|
||||||
|
bindings = Hashtbl.create 10;
|
||||||
|
generics_count = ref 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let rec typetree2type_v (t: Parser.type_tree): type_v =
|
||||||
|
match t with
|
||||||
|
| Parser.TypeIdentifier (x) -> if x = "int" then Int else
|
||||||
|
failwith "not implemented (type alias is not supported yet)"
|
||||||
|
| Parser.TypeArrow (arg, ret) -> Fun { arg = typetree2type_v arg; ret = typetree2type_v ret }
|
||||||
|
|
||||||
|
let rec type_v2str (t: type_v): string =
|
||||||
|
match t with
|
||||||
|
| Int -> "int"
|
||||||
|
| Fun { arg = arg; ret = ret } -> Printf.sprintf "(%s -> %s)" (type_v2str arg) (type_v2str ret)
|
||||||
|
| Generic s -> "'" ^ s
|
||||||
|
| Universal -> "universal"
|
||||||
|
| Nothing -> "nothing"
|
||||||
|
|
||||||
|
(* meet *)
|
||||||
|
let rec intersect_type_v (a: type_v) (b: type_v): type_v =
|
||||||
|
match a, b with
|
||||||
|
| Universal, _ -> b
|
||||||
|
| _, Universal -> a
|
||||||
|
| Int, Int -> Int
|
||||||
|
| Fun { arg = arg1; ret = ret1 }, Fun { arg = arg2; ret = ret2 } ->
|
||||||
|
(* contravariance *)
|
||||||
|
let arg = intersect_type_v arg1 arg2 in
|
||||||
|
let ret = intersect_type_v ret1 ret2 in
|
||||||
|
Fun { arg = arg; ret = ret }
|
||||||
|
(* // TODO: fix it *)
|
||||||
|
| Generic s1, Generic s2 when s1 = s2 -> Generic s1
|
||||||
|
| Generic _, _ -> b
|
||||||
|
| _ -> Nothing
|
||||||
|
(* join *)
|
||||||
|
(* and union_type_v (a: type_v) (b: type_v): type_v =
|
||||||
|
match a, b with
|
||||||
|
| Universal, _ -> Universal
|
||||||
|
| _, Universal -> Universal
|
||||||
|
| Int, Int -> Int
|
||||||
|
| Fun { arg = arg1; ret = ret1 }, Fun { arg = arg2; ret = ret2 } ->
|
||||||
|
|
||||||
|
| Generic s1, Generic s2 when s1 = s2 -> Generic s1
|
||||||
|
| Generic _, _ -> b
|
||||||
|
| _ -> Nothing *)
|
||||||
|
|
||||||
|
(* it assumes that there is already variable binding. *)
|
||||||
|
|
||||||
|
let find_type_v_opt (scope: type_scope) (name: string): type_v option =
|
||||||
|
let rec find_binding scope =
|
||||||
|
match scope with
|
||||||
|
| None -> None
|
||||||
|
| Some s ->
|
||||||
|
match Hashtbl.find_opt s.bindings name with
|
||||||
|
| Some v -> Some v
|
||||||
|
| None -> find_binding s.parent in
|
||||||
|
find_binding (Some scope)
|
||||||
|
|
||||||
|
(* it assumes that there is already variable binding. *)
|
||||||
|
let assert_and_get_type_v (scope: type_scope) (name: string) (expected: type_v) =
|
||||||
|
let rec assert_binding scope =
|
||||||
|
match scope with
|
||||||
|
| None -> failwith "Unbound variable"
|
||||||
|
| Some s ->
|
||||||
|
match Hashtbl.find_opt s.bindings name with
|
||||||
|
| Some v ->
|
||||||
|
let subtype = intersect_type_v v expected in
|
||||||
|
if subtype = Nothing then failwith "Type error"
|
||||||
|
else Hashtbl.replace s.bindings name subtype;
|
||||||
|
subtype
|
||||||
|
| None -> assert_binding s.parent in
|
||||||
|
assert_binding (Some scope)
|
||||||
|
|
||||||
|
let gen_generic_free_name (scope: type_scope): string =
|
||||||
|
let generics_count = !(scope.generics_count) in
|
||||||
|
let name = Printf.sprintf "%d" generics_count in
|
||||||
|
scope.generics_count := generics_count + 1;
|
||||||
|
name
|
||||||
|
|
||||||
|
let replace_generic_with (t: type_v) (from: string) (to_: type_v): type_v =
|
||||||
|
let rec replace t =
|
||||||
|
match t with
|
||||||
|
| Int -> Int
|
||||||
|
| Fun { arg = arg; ret = ret } -> Fun { arg = replace arg; ret = replace ret }
|
||||||
|
| Generic s when s = from -> to_
|
||||||
|
| Generic s -> Generic s
|
||||||
|
| Universal -> Universal
|
||||||
|
| Nothing -> Nothing in
|
||||||
|
replace t
|
||||||
|
|
||||||
|
let rec typecheck_expr (scope: type_scope) (expr: Parser.expr_tree) (required_type: type_v): type_v =
|
||||||
|
let actual_type = match expr with
|
||||||
|
| Parser.LetExpr (l) ->
|
||||||
|
typecheck_let_expr scope required_type l
|
||||||
|
| Parser.FunExpr (ftree) ->
|
||||||
|
typecheck_fun_expr scope required_type ftree
|
||||||
|
| Parser.IfExpr (Parser.If (cond_expr, then_expr, else_expr)) ->
|
||||||
|
typecheck_if_expr scope required_type cond_expr then_expr else_expr
|
||||||
|
| Parser.BinOpExpr (op, left_expr, right_expr) ->
|
||||||
|
typecheck_bin_op_expr scope required_type op left_expr right_expr
|
||||||
|
| Parser.MonoOpExpr (_op, _expr) ->
|
||||||
|
failwith "Not implemented"
|
||||||
|
| Parser.CallExpr (Parser.Call (func_expr, arg_expr)) ->
|
||||||
|
typecheck_call_expr scope required_type func_expr arg_expr
|
||||||
|
| Parser.Identifier(name) -> assert_and_get_type_v scope name required_type
|
||||||
|
| Parser.Number(_n) -> Int
|
||||||
|
in
|
||||||
|
let subtype = intersect_type_v required_type actual_type in
|
||||||
|
if subtype = Nothing then
|
||||||
|
failwith (Printf.sprintf "Type error: expect %s but actual %s"
|
||||||
|
(type_v2str required_type) (type_v2str actual_type)
|
||||||
|
)
|
||||||
|
else subtype
|
||||||
|
and typecheck_let_expr (scope: type_scope) (required_type: type_v) ({
|
||||||
|
name = name;
|
||||||
|
value_expr = value_expr;
|
||||||
|
in_expr = in_expr;
|
||||||
|
type_declare = type_decl;
|
||||||
|
}: Parser.let_expr_tree): type_v =
|
||||||
|
let value_reqired_type = type_decl |> Option.map typetree2type_v |> Option.value ~default: Universal in
|
||||||
|
let value_type = typecheck_expr scope value_expr value_reqired_type in
|
||||||
|
let new_scope = make_type_scope scope in
|
||||||
|
Hashtbl.add new_scope.bindings name value_type;
|
||||||
|
typecheck_expr new_scope in_expr required_type
|
||||||
|
and typecheck_fun_expr (scope: type_scope) (_required_type: type_v) ({
|
||||||
|
name = argname;
|
||||||
|
body_expr = body_expr;
|
||||||
|
type_declare = type_decl;
|
||||||
|
}: Parser.fun_expr_tree): type_v =
|
||||||
|
let default_type = Generic (gen_generic_free_name scope) in
|
||||||
|
let arg_type = type_decl |> Option.map typetree2type_v |> Option.value ~default: default_type in
|
||||||
|
let new_scope = make_type_scope scope in
|
||||||
|
Hashtbl.add new_scope.bindings argname arg_type;
|
||||||
|
(* unreachable because *)
|
||||||
|
let ret_type = typecheck_expr new_scope body_expr Universal in
|
||||||
|
let arg_type = Hashtbl.find new_scope.bindings argname in
|
||||||
|
Printf.printf "arg: %s, ret: %s\n" (type_v2str arg_type) (type_v2str ret_type);
|
||||||
|
|
||||||
|
Fun { arg = arg_type; ret = ret_type }
|
||||||
|
and typecheck_if_expr (scope: type_scope) (required_type: type_v)
|
||||||
|
(cond_expr: Parser.expr_tree) (then_expr: Parser.expr_tree) (else_expr: Parser.expr_tree): type_v =
|
||||||
|
let _ = typecheck_expr scope cond_expr Int in
|
||||||
|
let then_type = typecheck_expr scope then_expr required_type in
|
||||||
|
let else_type = typecheck_expr scope else_expr required_type in
|
||||||
|
intersect_type_v then_type else_type
|
||||||
|
and typecheck_bin_op_expr (scope: type_scope) (_required_type: type_v)
|
||||||
|
(_op: Parser.bin_op_type) (left_expr: Parser.expr_tree) (right_expr: Parser.expr_tree): type_v =
|
||||||
|
(* default int *)
|
||||||
|
let _ = typecheck_expr scope left_expr Int in
|
||||||
|
let _ = typecheck_expr scope right_expr Int in
|
||||||
|
Int
|
||||||
|
and typecheck_call_expr (scope: type_scope) (_required_type: type_v)
|
||||||
|
(func_expr: Parser.expr_tree) (arg_expr: Parser.expr_tree): type_v =
|
||||||
|
let func_type = typecheck_expr scope func_expr Universal in
|
||||||
|
Printf.printf "func_type: %s\n" (type_v2str func_type);
|
||||||
|
|
||||||
|
match func_type with
|
||||||
|
| Fun { arg = arg_type; ret = ret_type } ->
|
||||||
|
let mono_arg_type = typecheck_expr scope arg_expr arg_type in
|
||||||
|
Printf.printf "arg_type: %s\n" (type_v2str mono_arg_type);
|
||||||
|
begin match arg_type with
|
||||||
|
| Generic s ->
|
||||||
|
(* instance *)
|
||||||
|
let new_ret_type = replace_generic_with ret_type s mono_arg_type in
|
||||||
|
Printf.printf "new_ret_type: %s\n" (type_v2str new_ret_type);
|
||||||
|
new_ret_type
|
||||||
|
| _ -> ret_type
|
||||||
|
end
|
||||||
|
| _ -> failwith "Type error"
|
||||||
|
|
||||||
|
let typecheck (expr: Parser.expr_tree): type_v =
|
||||||
|
typecheck_expr (make_top_type_scope()) expr Universal
|
||||||
|
|
||||||
|
let typecheck_result (expr: Parser.expr_tree): (type_v, exn) result =
|
||||||
|
try
|
||||||
|
let t = typecheck expr in
|
||||||
|
Result.Ok (t)
|
||||||
|
with e -> Result.Error e
|
||||||
|
|
||||||
|
let test_typecheck (content:string) =
|
||||||
|
let tokens = Lexer.lex_tokens_seq content in
|
||||||
|
let expr = Parser.get_expr_tree_from_tokens tokens in
|
||||||
|
match expr with
|
||||||
|
| Some e -> typecheck_result e
|
||||||
|
| None -> Result.Error (Failure "parse error")
|
||||||
|
|
||||||
|
let%test "typecheck 1" =
|
||||||
|
let expr = "let x = fun y -> y in x 1" in
|
||||||
|
match test_typecheck expr with
|
||||||
|
| Result.Ok (t) -> Printf.printf "%s\n" (type_v2str t); t = Int
|
||||||
|
| Result.Error _ -> Printf.printf "error\n"; false
|
Loading…
Add table
Reference in a new issue