all: match multi aggregate for union sum types (#6868)
parent
df4165c7ee
commit
e06756ef58
|
@ -3,6 +3,7 @@
|
||||||
module checker
|
module checker
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import strings
|
||||||
import v.ast
|
import v.ast
|
||||||
import v.vmod
|
import v.vmod
|
||||||
import v.table
|
import v.table
|
||||||
|
@ -3405,6 +3406,7 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, type_sym table.TypeSymbol
|
||||||
mut branch_exprs := map[string]int{}
|
mut branch_exprs := map[string]int{}
|
||||||
cond_type_sym := c.table.get_type_symbol(node.cond_type)
|
cond_type_sym := c.table.get_type_symbol(node.cond_type)
|
||||||
for branch in node.branches {
|
for branch in node.branches {
|
||||||
|
mut expr_types := []ast.Type{}
|
||||||
for expr in branch.exprs {
|
for expr in branch.exprs {
|
||||||
mut key := ''
|
mut key := ''
|
||||||
if expr is ast.RangeExpr {
|
if expr is ast.RangeExpr {
|
||||||
|
@ -3444,28 +3446,7 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, type_sym table.TypeSymbol
|
||||||
match expr {
|
match expr {
|
||||||
ast.Type {
|
ast.Type {
|
||||||
key = c.table.type_to_str(expr.typ)
|
key = c.table.type_to_str(expr.typ)
|
||||||
// smart cast only if one type is given (currently) // TODO make this work if types have same fields
|
expr_types << expr
|
||||||
if branch.exprs.len == 1 && cond_type_sym.kind == .union_sum_type {
|
|
||||||
mut scope := c.file.scope.innermost(branch.pos.pos)
|
|
||||||
match node.cond as node_cond {
|
|
||||||
ast.SelectorExpr { scope.register_struct_field(ast.ScopeStructField{
|
|
||||||
struct_type: node_cond.expr_type
|
|
||||||
name: node_cond.field_name
|
|
||||||
typ: node.cond_type
|
|
||||||
sum_type_cast: expr.typ
|
|
||||||
pos: node_cond.pos
|
|
||||||
}) }
|
|
||||||
ast.Ident { scope.register(node.var_name, ast.Var{
|
|
||||||
name: node.var_name
|
|
||||||
typ: node.cond_type
|
|
||||||
pos: node_cond.pos
|
|
||||||
is_used: true
|
|
||||||
is_mut: node.is_mut
|
|
||||||
sum_type_cast: expr.typ
|
|
||||||
}) }
|
|
||||||
else {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
ast.EnumVal {
|
ast.EnumVal {
|
||||||
key = expr.val
|
key = expr.val
|
||||||
|
@ -3501,6 +3482,59 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, type_sym table.TypeSymbol
|
||||||
}
|
}
|
||||||
branch_exprs[key] = val + 1
|
branch_exprs[key] = val + 1
|
||||||
}
|
}
|
||||||
|
// when match is sum type matching, then register smart cast for every branch
|
||||||
|
if expr_types.len > 0 {
|
||||||
|
if cond_type_sym.kind == .union_sum_type {
|
||||||
|
mut expr_type := table.Type(0)
|
||||||
|
if expr_types.len > 1 {
|
||||||
|
mut agg_name := strings.new_builder(20)
|
||||||
|
agg_name.write('(')
|
||||||
|
for i, expr in expr_types {
|
||||||
|
if i > 0 {
|
||||||
|
agg_name.write(' | ')
|
||||||
|
}
|
||||||
|
type_str := c.table.type_to_str(expr.typ)
|
||||||
|
agg_name.write(if c.is_builtin_mod {
|
||||||
|
type_str
|
||||||
|
} else {
|
||||||
|
'${c.mod}.$type_str'
|
||||||
|
})
|
||||||
|
}
|
||||||
|
agg_name.write(')')
|
||||||
|
name := agg_name.str()
|
||||||
|
expr_type = c.table.register_type_symbol(table.TypeSymbol{
|
||||||
|
name: name
|
||||||
|
source_name: name
|
||||||
|
kind: .aggregate
|
||||||
|
mod: c.mod
|
||||||
|
info: table.Aggregate{
|
||||||
|
types: expr_types.map(it.typ)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
expr_type = expr_types[0].typ
|
||||||
|
}
|
||||||
|
mut scope := c.file.scope.innermost(branch.pos.pos)
|
||||||
|
match node.cond as node_cond {
|
||||||
|
ast.SelectorExpr { scope.register_struct_field(ast.ScopeStructField{
|
||||||
|
struct_type: node_cond.expr_type
|
||||||
|
name: node_cond.field_name
|
||||||
|
typ: node.cond_type
|
||||||
|
sum_type_cast: expr_type
|
||||||
|
pos: node_cond.pos
|
||||||
|
}) }
|
||||||
|
ast.Ident { scope.register(node.var_name, ast.Var{
|
||||||
|
name: node.var_name
|
||||||
|
typ: node.cond_type
|
||||||
|
pos: node_cond.pos
|
||||||
|
is_used: true
|
||||||
|
is_mut: node.is_mut
|
||||||
|
sum_type_cast: expr_type
|
||||||
|
}) }
|
||||||
|
else {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// check that expressions are exhaustive
|
// check that expressions are exhaustive
|
||||||
// this is achieved either by putting an else
|
// this is achieved either by putting an else
|
||||||
|
|
|
@ -122,6 +122,11 @@ mut:
|
||||||
// doing_autofree_tmp bool
|
// doing_autofree_tmp bool
|
||||||
inside_lambda bool
|
inside_lambda bool
|
||||||
prevent_sum_type_unwrapping_once bool // needed for assign new values to sum type
|
prevent_sum_type_unwrapping_once bool // needed for assign new values to sum type
|
||||||
|
// used in match multi branch
|
||||||
|
// TypeOne, TypeTwo {}
|
||||||
|
// where an aggregate (at least two types) is generated
|
||||||
|
// sum type deref needs to know which index to deref because unions take care of the correct field
|
||||||
|
aggregate_type_idx int
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -2513,7 +2518,12 @@ fn (mut g Gen) expr(node ast.Expr) {
|
||||||
if field := scope.find_struct_field(node.expr_type, node.field_name) {
|
if field := scope.find_struct_field(node.expr_type, node.field_name) {
|
||||||
// union sum type deref
|
// union sum type deref
|
||||||
g.write('(*')
|
g.write('(*')
|
||||||
sum_type_deref_field = '_$field.sum_type_cast'
|
cast_sym := g.table.get_type_symbol(field.sum_type_cast)
|
||||||
|
if cast_sym.info is table.Aggregate as sym_info {
|
||||||
|
sum_type_deref_field = '_${sym_info.types[g.aggregate_type_idx]}'
|
||||||
|
} else {
|
||||||
|
sum_type_deref_field = '_$field.sum_type_cast'
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3003,6 +3013,7 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str
|
||||||
mut sumtype_index := 0
|
mut sumtype_index := 0
|
||||||
// iterates through all types in sumtype branches
|
// iterates through all types in sumtype branches
|
||||||
for {
|
for {
|
||||||
|
g.aggregate_type_idx = sumtype_index
|
||||||
is_last := j == node.branches.len - 1
|
is_last := j == node.branches.len - 1
|
||||||
sym := g.table.get_type_symbol(node.cond_type)
|
sym := g.table.get_type_symbol(node.cond_type)
|
||||||
if branch.is_else || (node.is_expr && is_last) {
|
if branch.is_else || (node.is_expr && is_last) {
|
||||||
|
@ -3070,6 +3081,8 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// reset global field for next use
|
||||||
|
g.aggregate_type_idx = 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3336,7 +3349,12 @@ fn (mut g Gen) ident(node ast.Ident) {
|
||||||
if v := scope.find_var(node.name) {
|
if v := scope.find_var(node.name) {
|
||||||
if v.sum_type_cast != 0 {
|
if v.sum_type_cast != 0 {
|
||||||
if !prevent_sum_type_unwrapping_once {
|
if !prevent_sum_type_unwrapping_once {
|
||||||
g.write('(*${name}._$v.sum_type_cast)')
|
sym := g.table.get_type_symbol(v.sum_type_cast)
|
||||||
|
if sym.info is table.Aggregate as sym_info {
|
||||||
|
g.write('(*${name}._${sym_info.types[g.aggregate_type_idx]})')
|
||||||
|
} else {
|
||||||
|
g.write('(*${name}._$v.sum_type_cast)')
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -244,33 +244,33 @@ fn (mut p Parser) match_expr() ast.MatchExpr {
|
||||||
}
|
}
|
||||||
p.check(.comma)
|
p.check(.comma)
|
||||||
}
|
}
|
||||||
mut it_typ := table.void_type
|
|
||||||
if types.len == 1 {
|
|
||||||
it_typ = types[0]
|
|
||||||
} else {
|
|
||||||
// there is more than one types, so we must create a type aggregate
|
|
||||||
mut agg_name := strings.new_builder(20)
|
|
||||||
agg_name.write('(')
|
|
||||||
for i, typ in types {
|
|
||||||
if i > 0 {
|
|
||||||
agg_name.write(' | ')
|
|
||||||
}
|
|
||||||
type_str := p.table.type_to_str(typ)
|
|
||||||
agg_name.write(p.prepend_mod(type_str))
|
|
||||||
}
|
|
||||||
agg_name.write(')')
|
|
||||||
name := agg_name.str()
|
|
||||||
it_typ = p.table.register_type_symbol(table.TypeSymbol{
|
|
||||||
name: name
|
|
||||||
source_name: name
|
|
||||||
kind: .aggregate
|
|
||||||
mod: p.mod
|
|
||||||
info: table.Aggregate{
|
|
||||||
types: types
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
if !is_union_match {
|
if !is_union_match {
|
||||||
|
mut it_typ := table.void_type
|
||||||
|
if types.len == 1 {
|
||||||
|
it_typ = types[0]
|
||||||
|
} else {
|
||||||
|
// there is more than one types, so we must create a type aggregate
|
||||||
|
mut agg_name := strings.new_builder(20)
|
||||||
|
agg_name.write('(')
|
||||||
|
for i, typ in types {
|
||||||
|
if i > 0 {
|
||||||
|
agg_name.write(' | ')
|
||||||
|
}
|
||||||
|
type_str := p.table.type_to_str(typ)
|
||||||
|
agg_name.write(p.prepend_mod(type_str))
|
||||||
|
}
|
||||||
|
agg_name.write(')')
|
||||||
|
name := agg_name.str()
|
||||||
|
it_typ = p.table.register_type_symbol(table.TypeSymbol{
|
||||||
|
name: name
|
||||||
|
source_name: name
|
||||||
|
kind: .aggregate
|
||||||
|
mod: p.mod
|
||||||
|
info: table.Aggregate{
|
||||||
|
types: types
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
p.scope.register('it', ast.Var{
|
p.scope.register('it', ast.Var{
|
||||||
name: 'it'
|
name: 'it'
|
||||||
typ: it_typ.to_ptr()
|
typ: it_typ.to_ptr()
|
||||||
|
|
|
@ -331,16 +331,6 @@ fn test_reassign_from_function_with_parameter_selector() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn test_match_multi_branch() {
|
|
||||||
f := Expr3(CTempVarExpr{'ctemp'})
|
|
||||||
match union f {
|
|
||||||
CallExpr, CTempVarExpr {
|
|
||||||
// this check works only if f is not castet
|
|
||||||
assert f is CTempVarExpr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn test_typeof() {
|
fn test_typeof() {
|
||||||
x := Expr3(CTempVarExpr{})
|
x := Expr3(CTempVarExpr{})
|
||||||
assert typeof(x) == 'CTempVarExpr'
|
assert typeof(x) == 'CTempVarExpr'
|
||||||
|
@ -355,6 +345,25 @@ fn test_zero_value_init() {
|
||||||
_ := Outer2{}
|
_ := Outer2{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Milk {
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Eggs {
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
__type Food = Milk | Eggs
|
||||||
|
|
||||||
|
fn test_match_aggregate() {
|
||||||
|
f := Food(Milk{'test'})
|
||||||
|
match union f {
|
||||||
|
Milk, Eggs {
|
||||||
|
assert f.name == 'test'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn test_sum_type_match() {
|
fn test_sum_type_match() {
|
||||||
// TODO: Remove these casts
|
// TODO: Remove these casts
|
||||||
assert is_gt_simple('3', int(2))
|
assert is_gt_simple('3', int(2))
|
||||||
|
|
Loading…
Reference in New Issue