diff --git a/vlib/v/gen/c/cgen.v b/vlib/v/gen/c/cgen.v index e52312e326..64e3a5b1e2 100644 --- a/vlib/v/gen/c/cgen.v +++ b/vlib/v/gen/c/cgen.v @@ -3488,6 +3488,30 @@ fn (mut g Gen) lock_expr(node ast.LockExpr) { } } +fn (mut g Gen) need_tmp_var_in_match(node ast.MatchExpr) bool { + if node.is_expr && node.return_type != table.void_type && node.return_type != 0 { + sym := g.table.get_type_symbol(node.return_type) + if sym.kind == .multi_return { + return false + } + for branch in node.branches { + if branch.stmts.len > 1 { + return true + } + if branch.stmts.len == 1 { + if branch.stmts[0] is ast.ExprStmt { + stmt := branch.stmts[0] as ast.ExprStmt + if stmt.expr is ast.CallExpr || stmt.expr is ast.IfExpr + || stmt.expr is ast.MatchExpr { + return true + } + } + } + } + } + return false +} + fn (mut g Gen) match_expr(node ast.MatchExpr) { // println('match expr typ=$it.expr_type') // TODO @@ -3495,11 +3519,13 @@ fn (mut g Gen) match_expr(node ast.MatchExpr) { g.writeln('// match 0') return } + need_tmp_var := g.need_tmp_var_in_match(node) is_expr := (node.is_expr && node.return_type != table.void_type) || g.inside_ternary > 0 mut cond_var := '' - if is_expr { + mut tmp_var := '' + mut cur_line := '' + if is_expr && !need_tmp_var { g.inside_ternary++ - // g.write('/* EM ret type=${g.typ(node.return_type)} expected_type=${g.typ(node.expected_type)} */') } if node.cond is ast.Ident || node.cond is ast.SelectorExpr || node.cond is ast.IntegerLiteral || node.cond is ast.StringLiteral || node.cond is ast.FloatLiteral { @@ -3509,7 +3535,7 @@ fn (mut g Gen) match_expr(node ast.MatchExpr) { g.out.go_back(cond_var.len) cond_var = cond_var.trim_space() } else { - cur_line := if is_expr { + line := if is_expr { g.empty_line = true g.go_before_stmt(0) } else { @@ -3519,25 +3545,35 @@ fn (mut g Gen) match_expr(node ast.MatchExpr) { g.write('${g.typ(node.cond_type)} $cond_var = ') g.expr(node.cond) g.writeln('; ') - g.write(cur_line) + g.write(line) + } + if need_tmp_var { + g.empty_line = true + cur_line = g.go_before_stmt(0) + tmp_var = g.new_tmp_var() + g.writeln('\t${g.typ(node.return_type)} $tmp_var;') } - if is_expr { + if is_expr && !need_tmp_var { // brackets needed otherwise '?' will apply to everything on the left g.write('(') } if node.is_sum_type { - g.match_expr_sumtype(node, is_expr, cond_var) + g.match_expr_sumtype(node, is_expr, cond_var, tmp_var) } else { - g.match_expr_classic(node, is_expr, cond_var) + g.match_expr_classic(node, is_expr, cond_var, tmp_var) } - if is_expr { + g.write(cur_line) + if need_tmp_var { + g.write('$tmp_var') + } + if is_expr && !need_tmp_var { g.write(')') g.decrement_inside_ternary() } } -fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var string) { +fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var string, tmp_var string) { for j, branch in node.branches { mut sumtype_index := 0 // iterates through all types in sumtype branches @@ -3545,8 +3581,8 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str g.aggregate_type_idx = sumtype_index is_last := j == node.branches.len - 1 sym := g.table.get_type_symbol(node.cond_type) - if branch.is_else || (node.is_expr && is_last) { - if is_expr { + if branch.is_else || (node.is_expr && is_last && tmp_var.len == 0) { + if is_expr && tmp_var.len == 0 { // TODO too many branches. maybe separate ?: matches g.write(' : ') } else { @@ -3556,7 +3592,7 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str } } else { if j > 0 || sumtype_index > 0 { - if is_expr { + if is_expr && tmp_var.len == 0 { g.write(' : ') } else { g.writeln('') @@ -3564,7 +3600,7 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str g.write('else ') } } - if is_expr { + if is_expr && tmp_var.len == 0 { g.write('(') } else { if j == 0 && sumtype_index == 0 { @@ -3584,13 +3620,13 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str g.write('._interface_idx == ') } g.expr(branch.exprs[sumtype_index]) - if is_expr { + if is_expr && tmp_var.len == 0 { g.write(') ? ') } else { g.writeln(') {') } } - g.stmts(branch.stmts) + g.stmts_with_tmp_var(branch.stmts, tmp_var) if g.inside_ternary == 0 { g.write('}') } @@ -3604,13 +3640,13 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str } } -fn (mut g Gen) match_expr_classic(node ast.MatchExpr, is_expr bool, cond_var string) { +fn (mut g Gen) match_expr_classic(node ast.MatchExpr, is_expr bool, cond_var string, tmp_var string) { type_sym := g.table.get_type_symbol(node.cond_type) for j, branch in node.branches { is_last := j == node.branches.len - 1 - if branch.is_else || (node.is_expr && is_last) { + if branch.is_else || (node.is_expr && is_last && tmp_var.len == 0) { if node.branches.len > 1 { - if is_expr { + if is_expr && tmp_var.len == 0 { // TODO too many branches. maybe separate ?: matches g.write(' : ') } else { @@ -3621,7 +3657,7 @@ fn (mut g Gen) match_expr_classic(node ast.MatchExpr, is_expr bool, cond_var str } } else { if j > 0 { - if is_expr { + if is_expr && tmp_var.len == 0 { g.write(' : ') } else { g.writeln('') @@ -3629,7 +3665,7 @@ fn (mut g Gen) match_expr_classic(node ast.MatchExpr, is_expr bool, cond_var str g.write('else ') } } - if is_expr { + if is_expr && tmp_var.len == 0 { g.write('(') } else { if j == 0 { @@ -3678,13 +3714,13 @@ fn (mut g Gen) match_expr_classic(node ast.MatchExpr, is_expr bool, cond_var str g.expr(expr) } } - if is_expr { + if is_expr && tmp_var.len == 0 { g.write(') ? ') } else { g.writeln(') {') } } - g.stmts(branch.stmts) + g.stmts_with_tmp_var(branch.stmts, tmp_var) if g.inside_ternary == 0 && node.branches.len > 1 { g.write('}') } diff --git a/vlib/v/tests/match_with_complex_exprs_in_branches_test.v b/vlib/v/tests/match_with_complex_exprs_in_branches_test.v new file mode 100644 index 0000000000..037dd31529 --- /dev/null +++ b/vlib/v/tests/match_with_complex_exprs_in_branches_test.v @@ -0,0 +1,31 @@ +type Arr = []int | []string + +fn test_match_with_array_map_in_branches() { + arr := Arr([0, 1]) + ret := match arr { + []int { + arr.map(fn(s int) string { return s.str() }).str() + } + else { + '' + } + } + println(ret) + assert ret == "['0', '1']" +} + +fn test_match_expr_of_multi_expr_stmts() { + a := 1 + ret := match a { + 1 { + mut m := map[string]int{} + m['two'] = 2 + m['two'] + } + else { + int(0) + } + } + println(ret) + assert ret == 2 +}