cgen: fix sum type match (#5978)

pull/6008/head^2
Ruofan XU 2020-07-29 04:17:25 +08:00 committed by GitHub
parent 3df0ef249c
commit 1d59d35129
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 138 additions and 2 deletions

View File

@ -99,6 +99,8 @@ mut:
inside_const bool
comp_for_method string // $for method in T {
comptime_var_type_map map[string]table.Type
match_sumtype_exprs []ast.Expr
match_sumtype_syms []table.TypeSymbol
}
const (
@ -1760,6 +1762,9 @@ fn (mut g Gen) expr(node ast.Expr) {
g.write(node.val)
}
ast.Ident {
if g.should_write_asterisk_due_to_match_sumtype(node) {
g.write('*')
}
g.ident(node)
}
ast.IfExpr {
@ -2269,6 +2274,16 @@ fn (mut g Gen) match_expr(node ast.MatchExpr) {
// g.write('/* EM ret type=${g.typ(node.return_type)} expected_type=${g.typ(node.expected_type)} */')
}
type_sym := g.table.get_type_symbol(node.cond_type)
if node.is_sum_type {
g.match_sumtype_exprs << node.cond
g.match_sumtype_syms << type_sym
}
defer {
if node.is_sum_type {
g.match_sumtype_exprs.pop()
g.match_sumtype_syms.pop()
}
}
mut tmp := ''
if type_sym.kind != .void {
tmp = g.new_tmp_var()
@ -2417,6 +2432,36 @@ fn (mut g Gen) ident(node ast.Ident) {
g.write(g.get_ternary_name(name))
}
[unlikely]
fn (mut g Gen) should_write_asterisk_due_to_match_sumtype(expr ast.Expr) bool {
if expr is ast.Ident {
typ := if expr.info is ast.IdentVar { (expr.info as ast.IdentVar).typ } else { (expr.info as ast.IdentFn).typ }
return if typ.is_ptr() && g.match_sumtype_has_no_struct_and_contains(expr) {
true
} else {
false
}
} else {
return false
}
}
[unlikely]
fn (mut g Gen) match_sumtype_has_no_struct_and_contains(node ast.Ident) bool {
for i, expr in g.match_sumtype_exprs {
if expr is ast.Ident && node.name == (expr as ast.Ident).name {
sumtype := g.match_sumtype_syms[i].info as table.SumType
for typ in sumtype.variants {
if g.table.get_type_symbol(typ).kind == .struct_ {
return false
}
}
return true
}
}
return false
}
fn (mut g Gen) concat_expr(node ast.ConcatExpr) {
styp := g.typ(node.return_type)
sym := g.table.get_type_symbol(node.return_type)

View File

@ -391,7 +391,7 @@ fn (mut g Gen) method_call(node ast.CallExpr) {
// g.write('/*${g.typ(node.receiver_type)}*/')
// g.write('/*expr_type=${g.typ(node.left_type)} rec type=${g.typ(node.receiver_type)}*/')
// }
if !node.receiver_type.is_ptr() && node.left_type.is_ptr() && node.name == 'str' {
if !node.receiver_type.is_ptr() && node.left_type.is_ptr() && node.name == 'str' && !g.should_write_asterisk_due_to_match_sumtype(node.left) {
g.write('ptr_str(')
} else {
g.write('${name}(')

View File

@ -177,4 +177,95 @@ fn test_int_cast_to_sumtype() {
assert false
}
}
}
}
// TODO: change definition once types other than any_int and any_float (int, f64, etc) are supported in sumtype
type Number = any_int | any_float
fn is_gt_simple(val string, dst Number) bool {
match dst {
any_int {
return val.int() > dst
}
any_float {
return dst < val.f64()
}
}
}
fn is_gt_nested(val string, dst Number) bool {
dst2 := dst
match dst {
any_int {
match dst2 {
any_int {
return val.int() > dst
}
// this branch should never been hit
else {
return val.int() < dst
}
}
}
any_float {
match dst2 {
any_float {
return dst < val.f64()
}
// this branch should never been hit
else {
return dst > val.f64()
}
}
}
}
}
fn concat(val string, dst Number) string {
match dst {
any_int {
mut res := val + '(int)'
res += dst.str()
return res
}
any_float {
mut res := val + '(float)'
res += dst.str()
return res
}
}
}
fn get_sum(val string, dst Number) f64 {
match dst {
any_int {
mut res := val.int()
res += dst
return res
}
any_float {
mut res := val.f64()
res += dst
return res
}
}
}
fn test_sum_type_match() {
assert is_gt_simple('3', 2)
assert !is_gt_simple('3', 5)
assert is_gt_simple('3', 1.2)
assert !is_gt_simple('3', 3.5)
assert is_gt_nested('3', 2)
assert !is_gt_nested('3', 5)
assert is_gt_nested('3', 1.2)
assert !is_gt_nested('3', 3.5)
assert concat('3', 2) == '3(int)2'
assert concat('3', 5) == '3(int)5'
assert concat('3', 1.2) == '3(float)1.2'
assert concat('3', 3.5) == '3(float)3.5'
assert get_sum('3', 2) == 5.0
assert get_sum('3', 5) == 8.0
assert get_sum('3', 1.2) == 4.2
assert get_sum('3', 3.5) == 6.5
}