cgen: fix sum type match (#5978)
							parent
							
								
									3df0ef249c
								
							
						
					
					
						commit
						1d59d35129
					
				|  | @ -99,6 +99,8 @@ mut: | ||||||
| 	inside_const          bool | 	inside_const          bool | ||||||
| 	comp_for_method       string // $for method in T {
 | 	comp_for_method       string // $for method in T {
 | ||||||
| 	comptime_var_type_map map[string]table.Type | 	comptime_var_type_map map[string]table.Type | ||||||
|  | 	match_sumtype_exprs   []ast.Expr | ||||||
|  | 	match_sumtype_syms    []table.TypeSymbol | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
|  | @ -1760,6 +1762,9 @@ fn (mut g Gen) expr(node ast.Expr) { | ||||||
| 			g.write(node.val) | 			g.write(node.val) | ||||||
| 		} | 		} | ||||||
| 		ast.Ident { | 		ast.Ident { | ||||||
|  | 			if g.should_write_asterisk_due_to_match_sumtype(node) { | ||||||
|  | 				g.write('*') | ||||||
|  | 			} | ||||||
| 			g.ident(node) | 			g.ident(node) | ||||||
| 		} | 		} | ||||||
| 		ast.IfExpr { | 		ast.IfExpr { | ||||||
|  | @ -2269,6 +2274,16 @@ fn (mut g Gen) match_expr(node ast.MatchExpr) { | ||||||
| 		// g.write('/* EM ret type=${g.typ(node.return_type)}		expected_type=${g.typ(node.expected_type)}  */')
 | 		// g.write('/* EM ret type=${g.typ(node.return_type)}		expected_type=${g.typ(node.expected_type)}  */')
 | ||||||
| 	} | 	} | ||||||
| 	type_sym := g.table.get_type_symbol(node.cond_type) | 	type_sym := g.table.get_type_symbol(node.cond_type) | ||||||
|  | 	if node.is_sum_type { | ||||||
|  | 		g.match_sumtype_exprs << node.cond | ||||||
|  | 		g.match_sumtype_syms << type_sym | ||||||
|  | 	} | ||||||
|  | 	defer { | ||||||
|  | 		if node.is_sum_type { | ||||||
|  | 			g.match_sumtype_exprs.pop() | ||||||
|  | 			g.match_sumtype_syms.pop() | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| 	mut tmp := '' | 	mut tmp := '' | ||||||
| 	if type_sym.kind != .void { | 	if type_sym.kind != .void { | ||||||
| 		tmp = g.new_tmp_var() | 		tmp = g.new_tmp_var() | ||||||
|  | @ -2417,6 +2432,36 @@ fn (mut g Gen) ident(node ast.Ident) { | ||||||
| 	g.write(g.get_ternary_name(name)) | 	g.write(g.get_ternary_name(name)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | [unlikely] | ||||||
|  | fn (mut g Gen) should_write_asterisk_due_to_match_sumtype(expr ast.Expr) bool { | ||||||
|  | 	if expr is ast.Ident { | ||||||
|  | 		typ := if expr.info is ast.IdentVar { (expr.info as ast.IdentVar).typ } else { (expr.info as ast.IdentFn).typ } | ||||||
|  | 		return if typ.is_ptr() && g.match_sumtype_has_no_struct_and_contains(expr) { | ||||||
|  | 			true | ||||||
|  | 		} else { | ||||||
|  | 			false | ||||||
|  | 		} | ||||||
|  | 	} else { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | [unlikely] | ||||||
|  | fn (mut g Gen) match_sumtype_has_no_struct_and_contains(node ast.Ident) bool { | ||||||
|  | 	for i, expr in g.match_sumtype_exprs { | ||||||
|  | 		if expr is ast.Ident && node.name == (expr as ast.Ident).name { | ||||||
|  | 			sumtype := g.match_sumtype_syms[i].info as table.SumType | ||||||
|  | 			for typ in sumtype.variants { | ||||||
|  | 				if g.table.get_type_symbol(typ).kind == .struct_ { | ||||||
|  | 					return false | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  | 
 | ||||||
| fn (mut g Gen) concat_expr(node ast.ConcatExpr) { | fn (mut g Gen) concat_expr(node ast.ConcatExpr) { | ||||||
| 	styp := g.typ(node.return_type) | 	styp := g.typ(node.return_type) | ||||||
| 	sym := g.table.get_type_symbol(node.return_type) | 	sym := g.table.get_type_symbol(node.return_type) | ||||||
|  |  | ||||||
|  | @ -391,7 +391,7 @@ fn (mut g Gen) method_call(node ast.CallExpr) { | ||||||
| 	// g.write('/*${g.typ(node.receiver_type)}*/')
 | 	// g.write('/*${g.typ(node.receiver_type)}*/')
 | ||||||
| 	// g.write('/*expr_type=${g.typ(node.left_type)} rec type=${g.typ(node.receiver_type)}*/')
 | 	// g.write('/*expr_type=${g.typ(node.left_type)} rec type=${g.typ(node.receiver_type)}*/')
 | ||||||
| 	// }
 | 	// }
 | ||||||
| 	if !node.receiver_type.is_ptr() && node.left_type.is_ptr() && node.name == 'str' { | 	if !node.receiver_type.is_ptr() && node.left_type.is_ptr() && node.name == 'str' && !g.should_write_asterisk_due_to_match_sumtype(node.left) { | ||||||
| 		g.write('ptr_str(') | 		g.write('ptr_str(') | ||||||
| 	} else { | 	} else { | ||||||
| 		g.write('${name}(') | 		g.write('${name}(') | ||||||
|  |  | ||||||
|  | @ -177,4 +177,95 @@ fn test_int_cast_to_sumtype() { | ||||||
| 			assert false | 			assert false | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // TODO: change definition once types other than any_int and any_float (int, f64, etc) are supported in sumtype
 | ||||||
|  | type Number = any_int | any_float | ||||||
|  | 
 | ||||||
|  | fn is_gt_simple(val string, dst Number) bool { | ||||||
|  | 	match dst { | ||||||
|  | 		any_int { | ||||||
|  | 			return val.int() > dst | ||||||
|  | 		} | ||||||
|  | 		any_float { | ||||||
|  | 			return dst < val.f64() | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | fn is_gt_nested(val string, dst Number) bool { | ||||||
|  | 	dst2 := dst | ||||||
|  | 	match dst { | ||||||
|  | 		any_int { | ||||||
|  | 			match dst2 { | ||||||
|  | 				any_int { | ||||||
|  | 					return val.int() > dst | ||||||
|  | 				} | ||||||
|  | 				// this branch should never been hit
 | ||||||
|  | 				else {  | ||||||
|  | 					return val.int() < dst  | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		any_float { | ||||||
|  | 			match dst2 { | ||||||
|  | 				any_float { | ||||||
|  | 					return dst < val.f64() | ||||||
|  | 				} | ||||||
|  | 				// this branch should never been hit
 | ||||||
|  | 				else {  | ||||||
|  | 					return dst > val.f64()  | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | fn concat(val string, dst Number) string { | ||||||
|  | 	match dst { | ||||||
|  | 		any_int { | ||||||
|  | 			mut res := val + '(int)' | ||||||
|  | 			res += dst.str() | ||||||
|  | 			return res | ||||||
|  | 		} | ||||||
|  | 		any_float { | ||||||
|  | 			mut res := val + '(float)' | ||||||
|  | 			res += dst.str() | ||||||
|  | 			return res | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | fn get_sum(val string, dst Number) f64 { | ||||||
|  | 	match dst { | ||||||
|  | 		any_int { | ||||||
|  | 			mut res := val.int() | ||||||
|  | 			res += dst | ||||||
|  | 			return res | ||||||
|  | 		} | ||||||
|  | 		any_float { | ||||||
|  | 			mut res := val.f64() | ||||||
|  | 			res += dst | ||||||
|  | 			return res | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | fn test_sum_type_match() { | ||||||
|  | 	assert is_gt_simple('3', 2) | ||||||
|  | 	assert !is_gt_simple('3', 5) | ||||||
|  | 	assert is_gt_simple('3', 1.2) | ||||||
|  | 	assert !is_gt_simple('3', 3.5)	 | ||||||
|  | 	assert is_gt_nested('3', 2) | ||||||
|  | 	assert !is_gt_nested('3', 5) | ||||||
|  | 	assert is_gt_nested('3', 1.2) | ||||||
|  | 	assert !is_gt_nested('3', 3.5) | ||||||
|  | 	assert concat('3', 2) == '3(int)2' | ||||||
|  | 	assert concat('3', 5) == '3(int)5' | ||||||
|  | 	assert concat('3', 1.2) == '3(float)1.2' | ||||||
|  | 	assert concat('3', 3.5) == '3(float)3.5' | ||||||
|  | 	assert get_sum('3', 2) == 5.0 | ||||||
|  | 	assert get_sum('3', 5) == 8.0 | ||||||
|  | 	assert get_sum('3', 1.2) == 4.2 | ||||||
|  | 	assert get_sum('3', 3.5) == 6.5 | ||||||
|  | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue