From e06756ef58c0b15e7b2a216a43a395f46d6cb4e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20D=C3=A4schle?= Date: Wed, 18 Nov 2020 20:52:00 +0100 Subject: [PATCH] all: match multi aggregate for union sum types (#6868) --- vlib/v/checker/checker.v | 78 +++++++++++++++++++++--------- vlib/v/gen/cgen.v | 22 ++++++++- vlib/v/parser/if_match.v | 52 ++++++++++---------- vlib/v/tests/union_sum_type_test.v | 29 +++++++---- 4 files changed, 121 insertions(+), 60 deletions(-) diff --git a/vlib/v/checker/checker.v b/vlib/v/checker/checker.v index d661e8ba0f..e24cea9e72 100644 --- a/vlib/v/checker/checker.v +++ b/vlib/v/checker/checker.v @@ -3,6 +3,7 @@ module checker import os +import strings import v.ast import v.vmod import v.table @@ -3405,6 +3406,7 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, type_sym table.TypeSymbol mut branch_exprs := map[string]int{} cond_type_sym := c.table.get_type_symbol(node.cond_type) for branch in node.branches { + mut expr_types := []ast.Type{} for expr in branch.exprs { mut key := '' if expr is ast.RangeExpr { @@ -3444,28 +3446,7 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, type_sym table.TypeSymbol match expr { ast.Type { key = c.table.type_to_str(expr.typ) - // smart cast only if one type is given (currently) // TODO make this work if types have same fields - if branch.exprs.len == 1 && cond_type_sym.kind == .union_sum_type { - mut scope := c.file.scope.innermost(branch.pos.pos) - match node.cond as node_cond { - ast.SelectorExpr { scope.register_struct_field(ast.ScopeStructField{ - struct_type: node_cond.expr_type - name: node_cond.field_name - typ: node.cond_type - sum_type_cast: expr.typ - pos: node_cond.pos - }) } - ast.Ident { scope.register(node.var_name, ast.Var{ - name: node.var_name - typ: node.cond_type - pos: node_cond.pos - is_used: true - is_mut: node.is_mut - sum_type_cast: expr.typ - }) } - else {} - } - } + expr_types << expr } ast.EnumVal { key = expr.val @@ -3501,6 +3482,59 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, type_sym table.TypeSymbol } branch_exprs[key] = val + 1 } + // when match is sum type matching, then register smart cast for every branch + if expr_types.len > 0 { + if cond_type_sym.kind == .union_sum_type { + mut expr_type := table.Type(0) + if expr_types.len > 1 { + mut agg_name := strings.new_builder(20) + agg_name.write('(') + for i, expr in expr_types { + if i > 0 { + agg_name.write(' | ') + } + type_str := c.table.type_to_str(expr.typ) + agg_name.write(if c.is_builtin_mod { + type_str + } else { + '${c.mod}.$type_str' + }) + } + agg_name.write(')') + name := agg_name.str() + expr_type = c.table.register_type_symbol(table.TypeSymbol{ + name: name + source_name: name + kind: .aggregate + mod: c.mod + info: table.Aggregate{ + types: expr_types.map(it.typ) + } + }) + } else { + expr_type = expr_types[0].typ + } + mut scope := c.file.scope.innermost(branch.pos.pos) + match node.cond as node_cond { + ast.SelectorExpr { scope.register_struct_field(ast.ScopeStructField{ + struct_type: node_cond.expr_type + name: node_cond.field_name + typ: node.cond_type + sum_type_cast: expr_type + pos: node_cond.pos + }) } + ast.Ident { scope.register(node.var_name, ast.Var{ + name: node.var_name + typ: node.cond_type + pos: node_cond.pos + is_used: true + is_mut: node.is_mut + sum_type_cast: expr_type + }) } + else {} + } + } + } } // check that expressions are exhaustive // this is achieved either by putting an else diff --git a/vlib/v/gen/cgen.v b/vlib/v/gen/cgen.v index 2db29903f5..f00a70aa23 100644 --- a/vlib/v/gen/cgen.v +++ b/vlib/v/gen/cgen.v @@ -122,6 +122,11 @@ mut: // doing_autofree_tmp bool inside_lambda bool prevent_sum_type_unwrapping_once bool // needed for assign new values to sum type + // used in match multi branch + // TypeOne, TypeTwo {} + // where an aggregate (at least two types) is generated + // sum type deref needs to know which index to deref because unions take care of the correct field + aggregate_type_idx int } const ( @@ -2513,7 +2518,12 @@ fn (mut g Gen) expr(node ast.Expr) { if field := scope.find_struct_field(node.expr_type, node.field_name) { // union sum type deref g.write('(*') - sum_type_deref_field = '_$field.sum_type_cast' + cast_sym := g.table.get_type_symbol(field.sum_type_cast) + if cast_sym.info is table.Aggregate as sym_info { + sum_type_deref_field = '_${sym_info.types[g.aggregate_type_idx]}' + } else { + sum_type_deref_field = '_$field.sum_type_cast' + } } } } @@ -3003,6 +3013,7 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str mut sumtype_index := 0 // iterates through all types in sumtype branches for { + g.aggregate_type_idx = sumtype_index is_last := j == node.branches.len - 1 sym := g.table.get_type_symbol(node.cond_type) if branch.is_else || (node.is_expr && is_last) { @@ -3070,6 +3081,8 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str break } } + // reset global field for next use + g.aggregate_type_idx = 0 } } @@ -3336,7 +3349,12 @@ fn (mut g Gen) ident(node ast.Ident) { if v := scope.find_var(node.name) { if v.sum_type_cast != 0 { if !prevent_sum_type_unwrapping_once { - g.write('(*${name}._$v.sum_type_cast)') + sym := g.table.get_type_symbol(v.sum_type_cast) + if sym.info is table.Aggregate as sym_info { + g.write('(*${name}._${sym_info.types[g.aggregate_type_idx]})') + } else { + g.write('(*${name}._$v.sum_type_cast)') + } return } } diff --git a/vlib/v/parser/if_match.v b/vlib/v/parser/if_match.v index 0643beb86f..c4e7cc4db7 100644 --- a/vlib/v/parser/if_match.v +++ b/vlib/v/parser/if_match.v @@ -244,33 +244,33 @@ fn (mut p Parser) match_expr() ast.MatchExpr { } p.check(.comma) } - mut it_typ := table.void_type - if types.len == 1 { - it_typ = types[0] - } else { - // there is more than one types, so we must create a type aggregate - mut agg_name := strings.new_builder(20) - agg_name.write('(') - for i, typ in types { - if i > 0 { - agg_name.write(' | ') - } - type_str := p.table.type_to_str(typ) - agg_name.write(p.prepend_mod(type_str)) - } - agg_name.write(')') - name := agg_name.str() - it_typ = p.table.register_type_symbol(table.TypeSymbol{ - name: name - source_name: name - kind: .aggregate - mod: p.mod - info: table.Aggregate{ - types: types - } - }) - } if !is_union_match { + mut it_typ := table.void_type + if types.len == 1 { + it_typ = types[0] + } else { + // there is more than one types, so we must create a type aggregate + mut agg_name := strings.new_builder(20) + agg_name.write('(') + for i, typ in types { + if i > 0 { + agg_name.write(' | ') + } + type_str := p.table.type_to_str(typ) + agg_name.write(p.prepend_mod(type_str)) + } + agg_name.write(')') + name := agg_name.str() + it_typ = p.table.register_type_symbol(table.TypeSymbol{ + name: name + source_name: name + kind: .aggregate + mod: p.mod + info: table.Aggregate{ + types: types + } + }) + } p.scope.register('it', ast.Var{ name: 'it' typ: it_typ.to_ptr() diff --git a/vlib/v/tests/union_sum_type_test.v b/vlib/v/tests/union_sum_type_test.v index 7335dfafea..73f73f26ec 100644 --- a/vlib/v/tests/union_sum_type_test.v +++ b/vlib/v/tests/union_sum_type_test.v @@ -331,16 +331,6 @@ fn test_reassign_from_function_with_parameter_selector() { } } -fn test_match_multi_branch() { - f := Expr3(CTempVarExpr{'ctemp'}) - match union f { - CallExpr, CTempVarExpr { - // this check works only if f is not castet - assert f is CTempVarExpr - } - } -} - fn test_typeof() { x := Expr3(CTempVarExpr{}) assert typeof(x) == 'CTempVarExpr' @@ -355,6 +345,25 @@ fn test_zero_value_init() { _ := Outer2{} } +struct Milk { + name string +} + +struct Eggs { + name string +} + +__type Food = Milk | Eggs + +fn test_match_aggregate() { + f := Food(Milk{'test'}) + match union f { + Milk, Eggs { + assert f.name == 'test' + } + } +} + fn test_sum_type_match() { // TODO: Remove these casts assert is_gt_simple('3', int(2))