diff --git a/vlib/v/checker/if.v b/vlib/v/checker/if.v index 5bbb3f2555..eb4c8e72ec 100644 --- a/vlib/v/checker/if.v +++ b/vlib/v/checker/if.v @@ -173,6 +173,13 @@ pub fn (mut c Checker) if_expr(mut node ast.IfExpr) ast.Type { continue } last_expr.typ = c.expr(last_expr.expr) + if c.table.type_kind(c.expected_type) == .multi_return + && c.table.type_kind(last_expr.typ) == .multi_return { + if node.typ == ast.void_type { + node.is_expr = true + node.typ = c.expected_type + } + } if !c.check_types(last_expr.typ, node.typ) { if node.typ == ast.void_type { // first branch of if expression @@ -210,6 +217,15 @@ pub fn (mut c Checker) if_expr(mut node ast.IfExpr) ast.Type { if is_noreturn_callexpr(last_expr.expr) { continue } + node_sym := c.table.sym(node.typ) + last_sym := c.table.sym(last_expr.typ) + if node_sym.kind == .multi_return && last_sym.kind == .multi_return { + node_types := node_sym.mr_info().types + last_types := last_sym.mr_info().types.map(ast.mktyp(it)) + if node_types == last_types { + continue + } + } c.error('mismatched types `${c.table.type_to_str(node.typ)}` and `${c.table.type_to_str(last_expr.typ)}`', node.pos) diff --git a/vlib/v/checker/match.v b/vlib/v/checker/match.v index 019be624ec..da305a00fc 100644 --- a/vlib/v/checker/match.v +++ b/vlib/v/checker/match.v @@ -72,7 +72,7 @@ pub fn (mut c Checker) match_expr(mut node ast.MatchExpr) ast.Type { expr_type := c.expr(stmt.expr) if first_iteration { if node.is_expr && (node.expected_type.has_flag(.optional) - || c.table.type_kind(node.expected_type) == .sum_type) { + || c.table.type_kind(node.expected_type) in [.sum_type, .multi_return]) { ret_type = node.expected_type } else { ret_type = expr_type @@ -86,6 +86,14 @@ pub fn (mut c Checker) match_expr(mut node ast.MatchExpr) ast.Type { && (ret_type.has_flag(.generic) || c.table.is_sumtype_or_in_variant(ret_type, expr_type))) && !is_noreturn { + expr_sym := c.table.sym(expr_type) + if expr_sym.kind == .multi_return && ret_sym.kind == .multi_return { + ret_types := ret_sym.mr_info().types + expr_types := expr_sym.mr_info().types.map(ast.mktyp(it)) + if expr_types == ret_types { + continue + } + } c.error('return type mismatch, it should be `$ret_sym.name`', stmt.expr.pos()) } diff --git a/vlib/v/tests/fn_multi_return_test.v b/vlib/v/tests/fn_multi_return_test.v new file mode 100644 index 0000000000..d47286f00d --- /dev/null +++ b/vlib/v/tests/fn_multi_return_test.v @@ -0,0 +1,49 @@ +fn multret1(i int, j int) (int, int) { + return if i > j { i, 10 } else { 10, j } +} + +fn multret2(i int, j int) (int, int) { + return match i > j { + true { i, 10 } + false { 10, j } + } +} + +fn multret3(i int, j int) (int, int) { + if i > j { + return i, 10 + } else { + return 10, j + } +} + +fn multret4(i int, j int) (int, int) { + match i > j { + true { return i, 10 } + false { return 10, j } + } +} + +fn test_fn_multi_return() { + mut a, mut b := 0, 0 + + println(multret1(3, 14)) + a, b = multret1(3, 14) + assert a == 10 + assert b == 14 + + println(multret2(3, 14)) + a, b = multret2(3, 14) + assert a == 10 + assert b == 14 + + println(multret3(3, 14)) + a, b = multret3(3, 14) + assert a == 10 + assert b == 14 + + println(multret4(3, 14)) + a, b = multret4(3, 14) + assert a == 10 + assert b == 14 +}