all: match multi aggregate for union sum types (#6868)
parent
df4165c7ee
commit
e06756ef58
|
@ -3,6 +3,7 @@
|
|||
module checker
|
||||
|
||||
import os
|
||||
import strings
|
||||
import v.ast
|
||||
import v.vmod
|
||||
import v.table
|
||||
|
@ -3405,6 +3406,7 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, type_sym table.TypeSymbol
|
|||
mut branch_exprs := map[string]int{}
|
||||
cond_type_sym := c.table.get_type_symbol(node.cond_type)
|
||||
for branch in node.branches {
|
||||
mut expr_types := []ast.Type{}
|
||||
for expr in branch.exprs {
|
||||
mut key := ''
|
||||
if expr is ast.RangeExpr {
|
||||
|
@ -3444,28 +3446,7 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, type_sym table.TypeSymbol
|
|||
match expr {
|
||||
ast.Type {
|
||||
key = c.table.type_to_str(expr.typ)
|
||||
// smart cast only if one type is given (currently) // TODO make this work if types have same fields
|
||||
if branch.exprs.len == 1 && cond_type_sym.kind == .union_sum_type {
|
||||
mut scope := c.file.scope.innermost(branch.pos.pos)
|
||||
match node.cond as node_cond {
|
||||
ast.SelectorExpr { scope.register_struct_field(ast.ScopeStructField{
|
||||
struct_type: node_cond.expr_type
|
||||
name: node_cond.field_name
|
||||
typ: node.cond_type
|
||||
sum_type_cast: expr.typ
|
||||
pos: node_cond.pos
|
||||
}) }
|
||||
ast.Ident { scope.register(node.var_name, ast.Var{
|
||||
name: node.var_name
|
||||
typ: node.cond_type
|
||||
pos: node_cond.pos
|
||||
is_used: true
|
||||
is_mut: node.is_mut
|
||||
sum_type_cast: expr.typ
|
||||
}) }
|
||||
else {}
|
||||
}
|
||||
}
|
||||
expr_types << expr
|
||||
}
|
||||
ast.EnumVal {
|
||||
key = expr.val
|
||||
|
@ -3501,6 +3482,59 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, type_sym table.TypeSymbol
|
|||
}
|
||||
branch_exprs[key] = val + 1
|
||||
}
|
||||
// when match is sum type matching, then register smart cast for every branch
|
||||
if expr_types.len > 0 {
|
||||
if cond_type_sym.kind == .union_sum_type {
|
||||
mut expr_type := table.Type(0)
|
||||
if expr_types.len > 1 {
|
||||
mut agg_name := strings.new_builder(20)
|
||||
agg_name.write('(')
|
||||
for i, expr in expr_types {
|
||||
if i > 0 {
|
||||
agg_name.write(' | ')
|
||||
}
|
||||
type_str := c.table.type_to_str(expr.typ)
|
||||
agg_name.write(if c.is_builtin_mod {
|
||||
type_str
|
||||
} else {
|
||||
'${c.mod}.$type_str'
|
||||
})
|
||||
}
|
||||
agg_name.write(')')
|
||||
name := agg_name.str()
|
||||
expr_type = c.table.register_type_symbol(table.TypeSymbol{
|
||||
name: name
|
||||
source_name: name
|
||||
kind: .aggregate
|
||||
mod: c.mod
|
||||
info: table.Aggregate{
|
||||
types: expr_types.map(it.typ)
|
||||
}
|
||||
})
|
||||
} else {
|
||||
expr_type = expr_types[0].typ
|
||||
}
|
||||
mut scope := c.file.scope.innermost(branch.pos.pos)
|
||||
match node.cond as node_cond {
|
||||
ast.SelectorExpr { scope.register_struct_field(ast.ScopeStructField{
|
||||
struct_type: node_cond.expr_type
|
||||
name: node_cond.field_name
|
||||
typ: node.cond_type
|
||||
sum_type_cast: expr_type
|
||||
pos: node_cond.pos
|
||||
}) }
|
||||
ast.Ident { scope.register(node.var_name, ast.Var{
|
||||
name: node.var_name
|
||||
typ: node.cond_type
|
||||
pos: node_cond.pos
|
||||
is_used: true
|
||||
is_mut: node.is_mut
|
||||
sum_type_cast: expr_type
|
||||
}) }
|
||||
else {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// check that expressions are exhaustive
|
||||
// this is achieved either by putting an else
|
||||
|
|
|
@ -122,6 +122,11 @@ mut:
|
|||
// doing_autofree_tmp bool
|
||||
inside_lambda bool
|
||||
prevent_sum_type_unwrapping_once bool // needed for assign new values to sum type
|
||||
// used in match multi branch
|
||||
// TypeOne, TypeTwo {}
|
||||
// where an aggregate (at least two types) is generated
|
||||
// sum type deref needs to know which index to deref because unions take care of the correct field
|
||||
aggregate_type_idx int
|
||||
}
|
||||
|
||||
const (
|
||||
|
@ -2513,11 +2518,16 @@ fn (mut g Gen) expr(node ast.Expr) {
|
|||
if field := scope.find_struct_field(node.expr_type, node.field_name) {
|
||||
// union sum type deref
|
||||
g.write('(*')
|
||||
cast_sym := g.table.get_type_symbol(field.sum_type_cast)
|
||||
if cast_sym.info is table.Aggregate as sym_info {
|
||||
sum_type_deref_field = '_${sym_info.types[g.aggregate_type_idx]}'
|
||||
} else {
|
||||
sum_type_deref_field = '_$field.sum_type_cast'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
g.expr(node.expr)
|
||||
// struct embedding
|
||||
if sym.kind == .struct_ {
|
||||
|
@ -3003,6 +3013,7 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str
|
|||
mut sumtype_index := 0
|
||||
// iterates through all types in sumtype branches
|
||||
for {
|
||||
g.aggregate_type_idx = sumtype_index
|
||||
is_last := j == node.branches.len - 1
|
||||
sym := g.table.get_type_symbol(node.cond_type)
|
||||
if branch.is_else || (node.is_expr && is_last) {
|
||||
|
@ -3070,6 +3081,8 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str
|
|||
break
|
||||
}
|
||||
}
|
||||
// reset global field for next use
|
||||
g.aggregate_type_idx = 0
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3336,7 +3349,12 @@ fn (mut g Gen) ident(node ast.Ident) {
|
|||
if v := scope.find_var(node.name) {
|
||||
if v.sum_type_cast != 0 {
|
||||
if !prevent_sum_type_unwrapping_once {
|
||||
sym := g.table.get_type_symbol(v.sum_type_cast)
|
||||
if sym.info is table.Aggregate as sym_info {
|
||||
g.write('(*${name}._${sym_info.types[g.aggregate_type_idx]})')
|
||||
} else {
|
||||
g.write('(*${name}._$v.sum_type_cast)')
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -244,6 +244,7 @@ fn (mut p Parser) match_expr() ast.MatchExpr {
|
|||
}
|
||||
p.check(.comma)
|
||||
}
|
||||
if !is_union_match {
|
||||
mut it_typ := table.void_type
|
||||
if types.len == 1 {
|
||||
it_typ = types[0]
|
||||
|
@ -270,7 +271,6 @@ fn (mut p Parser) match_expr() ast.MatchExpr {
|
|||
}
|
||||
})
|
||||
}
|
||||
if !is_union_match {
|
||||
p.scope.register('it', ast.Var{
|
||||
name: 'it'
|
||||
typ: it_typ.to_ptr()
|
||||
|
|
|
@ -331,16 +331,6 @@ fn test_reassign_from_function_with_parameter_selector() {
|
|||
}
|
||||
}
|
||||
|
||||
fn test_match_multi_branch() {
|
||||
f := Expr3(CTempVarExpr{'ctemp'})
|
||||
match union f {
|
||||
CallExpr, CTempVarExpr {
|
||||
// this check works only if f is not castet
|
||||
assert f is CTempVarExpr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn test_typeof() {
|
||||
x := Expr3(CTempVarExpr{})
|
||||
assert typeof(x) == 'CTempVarExpr'
|
||||
|
@ -355,6 +345,25 @@ fn test_zero_value_init() {
|
|||
_ := Outer2{}
|
||||
}
|
||||
|
||||
struct Milk {
|
||||
name string
|
||||
}
|
||||
|
||||
struct Eggs {
|
||||
name string
|
||||
}
|
||||
|
||||
__type Food = Milk | Eggs
|
||||
|
||||
fn test_match_aggregate() {
|
||||
f := Food(Milk{'test'})
|
||||
match union f {
|
||||
Milk, Eggs {
|
||||
assert f.name == 'test'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn test_sum_type_match() {
|
||||
// TODO: Remove these casts
|
||||
assert is_gt_simple('3', int(2))
|
||||
|
|
Loading…
Reference in New Issue