From 7c0cd2f41d2ec299334444acc3c3b46d632c2cd2 Mon Sep 17 00:00:00 2001 From: yuyi Date: Thu, 27 May 2021 15:13:50 +0800 Subject: [PATCH] checker, cgen: fix match with complex sumtype exprs (#10215) --- vlib/v/checker/checker.v | 19 ++- vlib/v/gen/c/cgen.v | 12 +- .../match_with_complex_sumtype_exprs_test.v | 150 ++++++++++++++++++ 3 files changed, 176 insertions(+), 5 deletions(-) create mode 100644 vlib/v/tests/match_with_complex_sumtype_exprs_test.v diff --git a/vlib/v/checker/checker.v b/vlib/v/checker/checker.v index b80cd9588c..31d27ce694 100644 --- a/vlib/v/checker/checker.v +++ b/vlib/v/checker/checker.v @@ -5358,13 +5358,20 @@ pub fn (mut c Checker) match_expr(mut node ast.MatchExpr) ast.Type { } expr_type := c.expr(stmt.expr) if ret_type == ast.void_type { - ret_type = expr_type - stmt.typ = ret_type + if node.is_expr + && c.table.get_type_symbol(node.expected_type).kind == .sum_type { + ret_type = node.expected_type + } else { + ret_type = expr_type + } + stmt.typ = expr_type } else if node.is_expr && ret_type != expr_type { if !c.check_types(ret_type, expr_type) { ret_sym := c.table.get_type_symbol(ret_type) - c.error('return type mismatch, it should be `$ret_sym.name`', - stmt.expr.position()) + if !(node.is_expr && ret_sym.kind == .sum_type) { + c.error('return type mismatch, it should be `$ret_sym.name`', + stmt.expr.position()) + } } } } @@ -5995,6 +6002,10 @@ pub fn (mut c Checker) if_expr(mut node ast.IfExpr) ast.Type { } } } + if node.is_expr + && c.table.get_type_symbol(former_expected_type).kind == .sum_type { + continue + } c.error('mismatched types `${c.table.type_to_str(node.typ)}` and `${c.table.type_to_str(last_expr.typ)}`', node.pos) } diff --git a/vlib/v/gen/c/cgen.v b/vlib/v/gen/c/cgen.v index 0701994497..6e242064a1 100644 --- a/vlib/v/gen/c/cgen.v +++ b/vlib/v/gen/c/cgen.v @@ -173,6 +173,7 @@ mut: as_cast_type_names map[string]string // table for type name lookup in runtime (for __as_cast) obf_table map[string]string // main_fn_decl_node ast.FnDecl + expected_cast_type ast.Type // for match expr of sumtypes } pub fn gen(files []&ast.File, table &ast.Table, pref &pref.Preferences) string { @@ -1223,7 +1224,11 @@ fn (mut g Gen) stmt(node ast.Stmt) { // } old_is_void_expr_stmt := g.is_void_expr_stmt g.is_void_expr_stmt = !node.is_expr - g.expr(node.expr) + if node.typ != ast.void_type && g.expected_cast_type != 0 { + g.expr_with_cast(node.expr, node.typ, g.expected_cast_type) + } else { + g.expr(node.expr) + } g.is_void_expr_stmt = old_is_void_expr_stmt // if af { // g.autofree_call_postgen() @@ -4131,7 +4136,12 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str g.writeln(') {') } } + if is_expr && tmp_var.len > 0 + && g.table.get_type_symbol(node.return_type).kind == .sum_type { + g.expected_cast_type = node.return_type + } g.stmts_with_tmp_var(branch.stmts, tmp_var) + g.expected_cast_type = 0 if g.inside_ternary == 0 { g.write('}') } diff --git a/vlib/v/tests/match_with_complex_sumtype_exprs_test.v b/vlib/v/tests/match_with_complex_sumtype_exprs_test.v new file mode 100644 index 0000000000..8cbd46f45c --- /dev/null +++ b/vlib/v/tests/match_with_complex_sumtype_exprs_test.v @@ -0,0 +1,150 @@ +struct Empty {} + +struct Node { + value f64 + left Tree + right Tree +} + +type Tree = Empty | Node + +// return size(number of nodes) of BST +fn size(tree Tree) int { + return match tree { + // TODO: remove int() once match gets smarter + Empty { int(0) } + Node { 1 + size(tree.left) + size(tree.right) } + } +} + +// insert a value to BST +fn insert(tree Tree, x f64) Tree { + return match tree { + Empty { + Node{x, tree, tree} + } + Node { + if x == tree.value { + tree + } else if x < tree.value { + Node{ + ...tree + left: insert(tree.left, x) + } + } else { + Node{ + ...tree + right: insert(tree.right, x) + } + } + } + } +} + +// whether able to find a value in BST +fn search(tree Tree, x f64) bool { + return match tree { + Empty { + false + } + Node { + if x == tree.value { + true + } else if x < tree.value { + search(tree.left, x) + } else { + search(tree.right, x) + } + } + } +} + +// find the minimal value of a BST +fn min(tree Tree) f64 { + return match tree { + Empty { + 1e100 + } + Node { + if tree.value < min(tree.left) { + tree.value + } else { + min(tree.left) + } + } + } +} + +// delete a value in BST (if nonexistant do nothing) +fn delete(tree Tree, x f64) Tree { + return match tree { + Empty { + tree + } + Node { + if tree.left is Node && tree.right is Node { + if x < tree.value { + Node{ + ...tree + left: delete(tree.left, x) + } + } else if x > tree.value { + Node{ + ...tree + right: delete(tree.right, x) + } + } else { + Node{ + ...tree + value: min(tree.right) + right: delete(tree.right, min(tree.right)) + } + } + } else if tree.left is Node { + if x == tree.value { + tree.left + } else { + Node{ + ...tree + left: delete(tree.left, x) + } + } + } else { + if x == tree.value { + tree.right + } else { + Node{ + ...tree + right: delete(tree.right, x) + } + } + } + } + } +} + +fn test_match_with_complex_sumtype_exprs() { + mut tree := Tree(Empty{}) + input := [0.3, 0.2, 0.5, 0.0, 0.6, 0.8, 0.9, 1.0, 0.1, 0.4, 0.7] + for i in input { + tree = insert(tree, i) + } + print('[1] after insertion tree size is ') // 11 + println(size(tree)) + del := [-0.3, 0.0, 0.3, 0.6, 1.0, 1.5] + for i in del { + tree = delete(tree, i) + } + print('[2] after deletion tree size is ') // 7 + print(size(tree)) + print(', and these elements were deleted: ') // 0.0 0.3 0.6 1.0 + assert size(tree) == 7 + for i in input { + if !search(tree, i) { + print(i) + print(' ') + } + } + println('') + assert true +}