all: match multi aggregate for union sum types (#6868)

pull/6865/head^2
Daniel Däschle 2020-11-18 20:52:00 +01:00 committed by GitHub
parent df4165c7ee
commit e06756ef58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 121 additions and 60 deletions

View File

@ -3,6 +3,7 @@
module checker
import os
import strings
import v.ast
import v.vmod
import v.table
@ -3405,6 +3406,7 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, type_sym table.TypeSymbol
mut branch_exprs := map[string]int{}
cond_type_sym := c.table.get_type_symbol(node.cond_type)
for branch in node.branches {
mut expr_types := []ast.Type{}
for expr in branch.exprs {
mut key := ''
if expr is ast.RangeExpr {
@ -3444,28 +3446,7 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, type_sym table.TypeSymbol
match expr {
ast.Type {
key = c.table.type_to_str(expr.typ)
// smart cast only if one type is given (currently) // TODO make this work if types have same fields
if branch.exprs.len == 1 && cond_type_sym.kind == .union_sum_type {
mut scope := c.file.scope.innermost(branch.pos.pos)
match node.cond as node_cond {
ast.SelectorExpr { scope.register_struct_field(ast.ScopeStructField{
struct_type: node_cond.expr_type
name: node_cond.field_name
typ: node.cond_type
sum_type_cast: expr.typ
pos: node_cond.pos
}) }
ast.Ident { scope.register(node.var_name, ast.Var{
name: node.var_name
typ: node.cond_type
pos: node_cond.pos
is_used: true
is_mut: node.is_mut
sum_type_cast: expr.typ
}) }
else {}
}
}
expr_types << expr
}
ast.EnumVal {
key = expr.val
@ -3501,6 +3482,59 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, type_sym table.TypeSymbol
}
branch_exprs[key] = val + 1
}
// when match is sum type matching, then register smart cast for every branch
if expr_types.len > 0 {
if cond_type_sym.kind == .union_sum_type {
mut expr_type := table.Type(0)
if expr_types.len > 1 {
mut agg_name := strings.new_builder(20)
agg_name.write('(')
for i, expr in expr_types {
if i > 0 {
agg_name.write(' | ')
}
type_str := c.table.type_to_str(expr.typ)
agg_name.write(if c.is_builtin_mod {
type_str
} else {
'${c.mod}.$type_str'
})
}
agg_name.write(')')
name := agg_name.str()
expr_type = c.table.register_type_symbol(table.TypeSymbol{
name: name
source_name: name
kind: .aggregate
mod: c.mod
info: table.Aggregate{
types: expr_types.map(it.typ)
}
})
} else {
expr_type = expr_types[0].typ
}
mut scope := c.file.scope.innermost(branch.pos.pos)
match node.cond as node_cond {
ast.SelectorExpr { scope.register_struct_field(ast.ScopeStructField{
struct_type: node_cond.expr_type
name: node_cond.field_name
typ: node.cond_type
sum_type_cast: expr_type
pos: node_cond.pos
}) }
ast.Ident { scope.register(node.var_name, ast.Var{
name: node.var_name
typ: node.cond_type
pos: node_cond.pos
is_used: true
is_mut: node.is_mut
sum_type_cast: expr_type
}) }
else {}
}
}
}
}
// check that expressions are exhaustive
// this is achieved either by putting an else

View File

@ -122,6 +122,11 @@ mut:
// doing_autofree_tmp bool
inside_lambda bool
prevent_sum_type_unwrapping_once bool // needed for assign new values to sum type
// used in match multi branch
// TypeOne, TypeTwo {}
// where an aggregate (at least two types) is generated
// sum type deref needs to know which index to deref because unions take care of the correct field
aggregate_type_idx int
}
const (
@ -2513,11 +2518,16 @@ fn (mut g Gen) expr(node ast.Expr) {
if field := scope.find_struct_field(node.expr_type, node.field_name) {
// union sum type deref
g.write('(*')
cast_sym := g.table.get_type_symbol(field.sum_type_cast)
if cast_sym.info is table.Aggregate as sym_info {
sum_type_deref_field = '_${sym_info.types[g.aggregate_type_idx]}'
} else {
sum_type_deref_field = '_$field.sum_type_cast'
}
}
}
}
}
g.expr(node.expr)
// struct embedding
if sym.kind == .struct_ {
@ -3003,6 +3013,7 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str
mut sumtype_index := 0
// iterates through all types in sumtype branches
for {
g.aggregate_type_idx = sumtype_index
is_last := j == node.branches.len - 1
sym := g.table.get_type_symbol(node.cond_type)
if branch.is_else || (node.is_expr && is_last) {
@ -3070,6 +3081,8 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str
break
}
}
// reset global field for next use
g.aggregate_type_idx = 0
}
}
@ -3336,7 +3349,12 @@ fn (mut g Gen) ident(node ast.Ident) {
if v := scope.find_var(node.name) {
if v.sum_type_cast != 0 {
if !prevent_sum_type_unwrapping_once {
sym := g.table.get_type_symbol(v.sum_type_cast)
if sym.info is table.Aggregate as sym_info {
g.write('(*${name}._${sym_info.types[g.aggregate_type_idx]})')
} else {
g.write('(*${name}._$v.sum_type_cast)')
}
return
}
}

View File

@ -244,6 +244,7 @@ fn (mut p Parser) match_expr() ast.MatchExpr {
}
p.check(.comma)
}
if !is_union_match {
mut it_typ := table.void_type
if types.len == 1 {
it_typ = types[0]
@ -270,7 +271,6 @@ fn (mut p Parser) match_expr() ast.MatchExpr {
}
})
}
if !is_union_match {
p.scope.register('it', ast.Var{
name: 'it'
typ: it_typ.to_ptr()

View File

@ -331,16 +331,6 @@ fn test_reassign_from_function_with_parameter_selector() {
}
}
fn test_match_multi_branch() {
f := Expr3(CTempVarExpr{'ctemp'})
match union f {
CallExpr, CTempVarExpr {
// this check works only if f is not castet
assert f is CTempVarExpr
}
}
}
fn test_typeof() {
x := Expr3(CTempVarExpr{})
assert typeof(x) == 'CTempVarExpr'
@ -355,6 +345,25 @@ fn test_zero_value_init() {
_ := Outer2{}
}
struct Milk {
name string
}
struct Eggs {
name string
}
__type Food = Milk | Eggs
fn test_match_aggregate() {
f := Food(Milk{'test'})
match union f {
Milk, Eggs {
assert f.name == 'test'
}
}
}
fn test_sum_type_match() {
// TODO: Remove these casts
assert is_gt_simple('3', int(2))