cgen: fix sum type match (#5978)
							parent
							
								
									3df0ef249c
								
							
						
					
					
						commit
						1d59d35129
					
				|  | @ -99,6 +99,8 @@ mut: | |||
| 	inside_const          bool | ||||
| 	comp_for_method       string // $for method in T {
 | ||||
| 	comptime_var_type_map map[string]table.Type | ||||
| 	match_sumtype_exprs   []ast.Expr | ||||
| 	match_sumtype_syms    []table.TypeSymbol | ||||
| } | ||||
| 
 | ||||
| const ( | ||||
|  | @ -1760,6 +1762,9 @@ fn (mut g Gen) expr(node ast.Expr) { | |||
| 			g.write(node.val) | ||||
| 		} | ||||
| 		ast.Ident { | ||||
| 			if g.should_write_asterisk_due_to_match_sumtype(node) { | ||||
| 				g.write('*') | ||||
| 			} | ||||
| 			g.ident(node) | ||||
| 		} | ||||
| 		ast.IfExpr { | ||||
|  | @ -2269,6 +2274,16 @@ fn (mut g Gen) match_expr(node ast.MatchExpr) { | |||
| 		// g.write('/* EM ret type=${g.typ(node.return_type)}		expected_type=${g.typ(node.expected_type)}  */')
 | ||||
| 	} | ||||
| 	type_sym := g.table.get_type_symbol(node.cond_type) | ||||
| 	if node.is_sum_type { | ||||
| 		g.match_sumtype_exprs << node.cond | ||||
| 		g.match_sumtype_syms << type_sym | ||||
| 	} | ||||
| 	defer { | ||||
| 		if node.is_sum_type { | ||||
| 			g.match_sumtype_exprs.pop() | ||||
| 			g.match_sumtype_syms.pop() | ||||
| 		} | ||||
| 	} | ||||
| 	mut tmp := '' | ||||
| 	if type_sym.kind != .void { | ||||
| 		tmp = g.new_tmp_var() | ||||
|  | @ -2417,6 +2432,36 @@ fn (mut g Gen) ident(node ast.Ident) { | |||
| 	g.write(g.get_ternary_name(name)) | ||||
| } | ||||
| 
 | ||||
| [unlikely] | ||||
| fn (mut g Gen) should_write_asterisk_due_to_match_sumtype(expr ast.Expr) bool { | ||||
| 	if expr is ast.Ident { | ||||
| 		typ := if expr.info is ast.IdentVar { (expr.info as ast.IdentVar).typ } else { (expr.info as ast.IdentFn).typ } | ||||
| 		return if typ.is_ptr() && g.match_sumtype_has_no_struct_and_contains(expr) { | ||||
| 			true | ||||
| 		} else { | ||||
| 			false | ||||
| 		} | ||||
| 	} else { | ||||
| 		return false | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| [unlikely] | ||||
| fn (mut g Gen) match_sumtype_has_no_struct_and_contains(node ast.Ident) bool { | ||||
| 	for i, expr in g.match_sumtype_exprs { | ||||
| 		if expr is ast.Ident && node.name == (expr as ast.Ident).name { | ||||
| 			sumtype := g.match_sumtype_syms[i].info as table.SumType | ||||
| 			for typ in sumtype.variants { | ||||
| 				if g.table.get_type_symbol(typ).kind == .struct_ { | ||||
| 					return false | ||||
| 				} | ||||
| 			} | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| fn (mut g Gen) concat_expr(node ast.ConcatExpr) { | ||||
| 	styp := g.typ(node.return_type) | ||||
| 	sym := g.table.get_type_symbol(node.return_type) | ||||
|  |  | |||
|  | @ -391,7 +391,7 @@ fn (mut g Gen) method_call(node ast.CallExpr) { | |||
| 	// g.write('/*${g.typ(node.receiver_type)}*/')
 | ||||
| 	// g.write('/*expr_type=${g.typ(node.left_type)} rec type=${g.typ(node.receiver_type)}*/')
 | ||||
| 	// }
 | ||||
| 	if !node.receiver_type.is_ptr() && node.left_type.is_ptr() && node.name == 'str' { | ||||
| 	if !node.receiver_type.is_ptr() && node.left_type.is_ptr() && node.name == 'str' && !g.should_write_asterisk_due_to_match_sumtype(node.left) { | ||||
| 		g.write('ptr_str(') | ||||
| 	} else { | ||||
| 		g.write('${name}(') | ||||
|  |  | |||
|  | @ -177,4 +177,95 @@ fn test_int_cast_to_sumtype() { | |||
| 			assert false | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| } | ||||
| 
 | ||||
| // TODO: change definition once types other than any_int and any_float (int, f64, etc) are supported in sumtype
 | ||||
| type Number = any_int | any_float | ||||
| 
 | ||||
| fn is_gt_simple(val string, dst Number) bool { | ||||
| 	match dst { | ||||
| 		any_int { | ||||
| 			return val.int() > dst | ||||
| 		} | ||||
| 		any_float { | ||||
| 			return dst < val.f64() | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| fn is_gt_nested(val string, dst Number) bool { | ||||
| 	dst2 := dst | ||||
| 	match dst { | ||||
| 		any_int { | ||||
| 			match dst2 { | ||||
| 				any_int { | ||||
| 					return val.int() > dst | ||||
| 				} | ||||
| 				// this branch should never been hit
 | ||||
| 				else {  | ||||
| 					return val.int() < dst  | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		any_float { | ||||
| 			match dst2 { | ||||
| 				any_float { | ||||
| 					return dst < val.f64() | ||||
| 				} | ||||
| 				// this branch should never been hit
 | ||||
| 				else {  | ||||
| 					return dst > val.f64()  | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| fn concat(val string, dst Number) string { | ||||
| 	match dst { | ||||
| 		any_int { | ||||
| 			mut res := val + '(int)' | ||||
| 			res += dst.str() | ||||
| 			return res | ||||
| 		} | ||||
| 		any_float { | ||||
| 			mut res := val + '(float)' | ||||
| 			res += dst.str() | ||||
| 			return res | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| fn get_sum(val string, dst Number) f64 { | ||||
| 	match dst { | ||||
| 		any_int { | ||||
| 			mut res := val.int() | ||||
| 			res += dst | ||||
| 			return res | ||||
| 		} | ||||
| 		any_float { | ||||
| 			mut res := val.f64() | ||||
| 			res += dst | ||||
| 			return res | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| fn test_sum_type_match() { | ||||
| 	assert is_gt_simple('3', 2) | ||||
| 	assert !is_gt_simple('3', 5) | ||||
| 	assert is_gt_simple('3', 1.2) | ||||
| 	assert !is_gt_simple('3', 3.5)	 | ||||
| 	assert is_gt_nested('3', 2) | ||||
| 	assert !is_gt_nested('3', 5) | ||||
| 	assert is_gt_nested('3', 1.2) | ||||
| 	assert !is_gt_nested('3', 3.5) | ||||
| 	assert concat('3', 2) == '3(int)2' | ||||
| 	assert concat('3', 5) == '3(int)5' | ||||
| 	assert concat('3', 1.2) == '3(float)1.2' | ||||
| 	assert concat('3', 3.5) == '3(float)3.5' | ||||
| 	assert get_sum('3', 2) == 5.0 | ||||
| 	assert get_sum('3', 5) == 8.0 | ||||
| 	assert get_sum('3', 1.2) == 4.2 | ||||
| 	assert get_sum('3', 3.5) == 6.5 | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue