all: match multi aggregate for union sum types (#6868)
							parent
							
								
									df4165c7ee
								
							
						
					
					
						commit
						e06756ef58
					
				|  | @ -3,6 +3,7 @@ | |||
| module checker | ||||
| 
 | ||||
| import os | ||||
| import strings | ||||
| import v.ast | ||||
| import v.vmod | ||||
| import v.table | ||||
|  | @ -3405,6 +3406,7 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, type_sym table.TypeSymbol | |||
| 	mut branch_exprs := map[string]int{} | ||||
| 	cond_type_sym := c.table.get_type_symbol(node.cond_type) | ||||
| 	for branch in node.branches { | ||||
| 		mut expr_types := []ast.Type{} | ||||
| 		for expr in branch.exprs { | ||||
| 			mut key := '' | ||||
| 			if expr is ast.RangeExpr { | ||||
|  | @ -3444,28 +3446,7 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, type_sym table.TypeSymbol | |||
| 			match expr { | ||||
| 				ast.Type { | ||||
| 					key = c.table.type_to_str(expr.typ) | ||||
| 					// smart cast only if one type is given (currently) // TODO make this work if types have same fields
 | ||||
| 					if branch.exprs.len == 1 && cond_type_sym.kind == .union_sum_type { | ||||
| 						mut scope := c.file.scope.innermost(branch.pos.pos) | ||||
| 						match node.cond as node_cond { | ||||
| 							ast.SelectorExpr { scope.register_struct_field(ast.ScopeStructField{ | ||||
| 									struct_type: node_cond.expr_type | ||||
| 									name: node_cond.field_name | ||||
| 									typ: node.cond_type | ||||
| 									sum_type_cast: expr.typ | ||||
| 									pos: node_cond.pos | ||||
| 								}) } | ||||
| 							ast.Ident { scope.register(node.var_name, ast.Var{ | ||||
| 									name: node.var_name | ||||
| 									typ: node.cond_type | ||||
| 									pos: node_cond.pos | ||||
| 									is_used: true | ||||
| 									is_mut: node.is_mut | ||||
| 									sum_type_cast: expr.typ | ||||
| 								}) } | ||||
| 							else {} | ||||
| 						} | ||||
| 					} | ||||
| 					expr_types << expr | ||||
| 				} | ||||
| 				ast.EnumVal { | ||||
| 					key = expr.val | ||||
|  | @ -3501,6 +3482,59 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, type_sym table.TypeSymbol | |||
| 			} | ||||
| 			branch_exprs[key] = val + 1 | ||||
| 		} | ||||
| 		// when match is sum type matching, then register smart cast for every branch
 | ||||
| 		if expr_types.len > 0 { | ||||
| 			if cond_type_sym.kind == .union_sum_type { | ||||
| 				mut expr_type := table.Type(0) | ||||
| 				if expr_types.len > 1 { | ||||
| 					mut agg_name := strings.new_builder(20) | ||||
| 					agg_name.write('(') | ||||
| 					for i, expr in expr_types { | ||||
| 						if i > 0 { | ||||
| 							agg_name.write(' | ') | ||||
| 						} | ||||
| 						type_str := c.table.type_to_str(expr.typ) | ||||
| 						agg_name.write(if c.is_builtin_mod { | ||||
| 							type_str | ||||
| 						} else { | ||||
| 							'${c.mod}.$type_str' | ||||
| 						}) | ||||
| 					} | ||||
| 					agg_name.write(')') | ||||
| 					name := agg_name.str() | ||||
| 					expr_type = c.table.register_type_symbol(table.TypeSymbol{ | ||||
| 						name: name | ||||
| 						source_name: name | ||||
| 						kind: .aggregate | ||||
| 						mod: c.mod | ||||
| 						info: table.Aggregate{ | ||||
| 							types: expr_types.map(it.typ) | ||||
| 						} | ||||
| 					}) | ||||
| 				} else { | ||||
| 					expr_type = expr_types[0].typ | ||||
| 				} | ||||
| 				mut scope := c.file.scope.innermost(branch.pos.pos) | ||||
| 				match node.cond as node_cond { | ||||
| 					ast.SelectorExpr { scope.register_struct_field(ast.ScopeStructField{ | ||||
| 							struct_type: node_cond.expr_type | ||||
| 							name: node_cond.field_name | ||||
| 							typ: node.cond_type | ||||
| 							sum_type_cast: expr_type | ||||
| 							pos: node_cond.pos | ||||
| 						}) } | ||||
| 					ast.Ident { scope.register(node.var_name, ast.Var{ | ||||
| 							name: node.var_name | ||||
| 							typ: node.cond_type | ||||
| 							pos: node_cond.pos | ||||
| 							is_used: true | ||||
| 							is_mut: node.is_mut | ||||
| 							sum_type_cast: expr_type | ||||
| 						}) } | ||||
| 					else {} | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	// check that expressions are exhaustive
 | ||||
| 	// this is achieved either by putting an else
 | ||||
|  |  | |||
|  | @ -122,6 +122,11 @@ mut: | |||
| 	// doing_autofree_tmp    bool
 | ||||
| 	inside_lambda                    bool | ||||
| 	prevent_sum_type_unwrapping_once bool // needed for assign new values to sum type
 | ||||
| 	// used in match multi branch
 | ||||
| 	// TypeOne, TypeTwo {}
 | ||||
| 	// where an aggregate (at least two types) is generated
 | ||||
| 	// sum type deref needs to know which index to deref because unions take care of the correct field
 | ||||
| 	aggregate_type_idx               int | ||||
| } | ||||
| 
 | ||||
| const ( | ||||
|  | @ -2513,7 +2518,12 @@ fn (mut g Gen) expr(node ast.Expr) { | |||
| 						if field := scope.find_struct_field(node.expr_type, node.field_name) { | ||||
| 							// union sum type deref
 | ||||
| 							g.write('(*') | ||||
| 							sum_type_deref_field = '_$field.sum_type_cast' | ||||
| 							cast_sym := g.table.get_type_symbol(field.sum_type_cast) | ||||
| 							if cast_sym.info is table.Aggregate as sym_info { | ||||
| 								sum_type_deref_field = '_${sym_info.types[g.aggregate_type_idx]}' | ||||
| 							} else { | ||||
| 								sum_type_deref_field = '_$field.sum_type_cast' | ||||
| 							} | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
|  | @ -3003,6 +3013,7 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str | |||
| 		mut sumtype_index := 0 | ||||
| 		// iterates through all types in sumtype branches
 | ||||
| 		for { | ||||
| 			g.aggregate_type_idx = sumtype_index | ||||
| 			is_last := j == node.branches.len - 1 | ||||
| 			sym := g.table.get_type_symbol(node.cond_type) | ||||
| 			if branch.is_else || (node.is_expr && is_last) { | ||||
|  | @ -3070,6 +3081,8 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str | |||
| 				break | ||||
| 			} | ||||
| 		} | ||||
| 		// reset global field for next use
 | ||||
| 		g.aggregate_type_idx = 0 | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | @ -3336,7 +3349,12 @@ fn (mut g Gen) ident(node ast.Ident) { | |||
| 		if v := scope.find_var(node.name) { | ||||
| 			if v.sum_type_cast != 0 { | ||||
| 				if !prevent_sum_type_unwrapping_once { | ||||
| 					g.write('(*${name}._$v.sum_type_cast)') | ||||
| 					sym := g.table.get_type_symbol(v.sum_type_cast) | ||||
| 					if sym.info is table.Aggregate as sym_info { | ||||
| 						g.write('(*${name}._${sym_info.types[g.aggregate_type_idx]})') | ||||
| 					} else { | ||||
| 						g.write('(*${name}._$v.sum_type_cast)') | ||||
| 					} | ||||
| 					return | ||||
| 				} | ||||
| 			} | ||||
|  |  | |||
|  | @ -244,33 +244,33 @@ fn (mut p Parser) match_expr() ast.MatchExpr { | |||
| 				} | ||||
| 				p.check(.comma) | ||||
| 			} | ||||
| 			mut it_typ := table.void_type | ||||
| 			if types.len == 1 { | ||||
| 				it_typ = types[0] | ||||
| 			} else { | ||||
| 				// there is more than one types, so we must create a type aggregate
 | ||||
| 				mut agg_name := strings.new_builder(20) | ||||
| 				agg_name.write('(') | ||||
| 				for i, typ in types { | ||||
| 					if i > 0 { | ||||
| 						agg_name.write(' | ') | ||||
| 					} | ||||
| 					type_str := p.table.type_to_str(typ) | ||||
| 					agg_name.write(p.prepend_mod(type_str)) | ||||
| 				} | ||||
| 				agg_name.write(')') | ||||
| 				name := agg_name.str() | ||||
| 				it_typ = p.table.register_type_symbol(table.TypeSymbol{ | ||||
| 					name: name | ||||
| 					source_name: name | ||||
| 					kind: .aggregate | ||||
| 					mod: p.mod | ||||
| 					info: table.Aggregate{ | ||||
| 						types: types | ||||
| 					} | ||||
| 				}) | ||||
| 			} | ||||
| 			if !is_union_match { | ||||
| 				mut it_typ := table.void_type | ||||
| 				if types.len == 1 { | ||||
| 					it_typ = types[0] | ||||
| 				} else { | ||||
| 					// there is more than one types, so we must create a type aggregate
 | ||||
| 					mut agg_name := strings.new_builder(20) | ||||
| 					agg_name.write('(') | ||||
| 					for i, typ in types { | ||||
| 						if i > 0 { | ||||
| 							agg_name.write(' | ') | ||||
| 						} | ||||
| 						type_str := p.table.type_to_str(typ) | ||||
| 						agg_name.write(p.prepend_mod(type_str)) | ||||
| 					} | ||||
| 					agg_name.write(')') | ||||
| 					name := agg_name.str() | ||||
| 					it_typ = p.table.register_type_symbol(table.TypeSymbol{ | ||||
| 						name: name | ||||
| 						source_name: name | ||||
| 						kind: .aggregate | ||||
| 						mod: p.mod | ||||
| 						info: table.Aggregate{ | ||||
| 							types: types | ||||
| 						} | ||||
| 					}) | ||||
| 				} | ||||
| 				p.scope.register('it', ast.Var{ | ||||
| 					name: 'it' | ||||
| 					typ: it_typ.to_ptr() | ||||
|  |  | |||
|  | @ -331,16 +331,6 @@ fn test_reassign_from_function_with_parameter_selector() { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| fn test_match_multi_branch() { | ||||
| 	f := Expr3(CTempVarExpr{'ctemp'}) | ||||
| 	match union f { | ||||
| 		CallExpr, CTempVarExpr { | ||||
| 			// this check works only if f is not castet
 | ||||
| 			assert f is CTempVarExpr | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| fn test_typeof() { | ||||
|     x := Expr3(CTempVarExpr{}) | ||||
| 	assert typeof(x) == 'CTempVarExpr' | ||||
|  | @ -355,6 +345,25 @@ fn test_zero_value_init() { | |||
| 	_ := Outer2{} | ||||
| } | ||||
| 
 | ||||
| struct Milk { | ||||
| 	name string | ||||
| } | ||||
| 
 | ||||
| struct Eggs { | ||||
| 	name string | ||||
| } | ||||
| 
 | ||||
| __type Food = Milk | Eggs | ||||
| 
 | ||||
| fn test_match_aggregate() { | ||||
| 	f := Food(Milk{'test'}) | ||||
| 	match union f { | ||||
| 		Milk, Eggs { | ||||
| 			assert f.name == 'test' | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| fn test_sum_type_match() { | ||||
| 	// TODO: Remove these casts
 | ||||
| 	assert is_gt_simple('3', int(2)) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue