cgen: support RangeExpr while emitting enum switch case (#12226)

pull/12236/head
ChAoS_UnItY 2021-10-19 22:02:22 +08:00 committed by GitHub
parent d8ea9e4969
commit ab350d52ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 211 additions and 39 deletions

View File

@ -4289,27 +4289,6 @@ fn (mut g Gen) need_tmp_var_in_match(node ast.MatchExpr) bool {
return false
}
fn (mut g Gen) branches_all_resolvable_in_runtime(node ast.MatchExpr, typ ast.TypeSymbol) bool {
for branch in node.branches {
for expr in branch.exprs {
if expr is ast.EnumVal {
continue
} else if expr is ast.RangeExpr {
return false
// we must implement constant folding on enum fields and make it accessible
// anywhere to prove that range expr's actual branches are resolvale
// if expr.high !is ast.IntegerLiteral || expr.low !is ast.IntegerLiteral {
// return false
// }
// continue
}
return true
}
}
return true
}
fn (mut g Gen) match_expr(node ast.MatchExpr) {
// println('match expr typ=$it.expr_type')
// TODO
@ -4354,11 +4333,9 @@ fn (mut g Gen) match_expr(node ast.MatchExpr) {
g.write('(')
}
typ := g.table.get_final_type_symbol(node.cond_type)
all_resolvable := g.branches_all_resolvable_in_runtime(node, typ)
if node.is_sum_type {
g.match_expr_sumtype(node, is_expr, cond_var, tmp_var)
} else if typ.kind == .enum_ && g.loop_depth == 0 && node.branches.len > 5 && g.fn_decl != 0
&& all_resolvable { // do not optimize while in top-level
} else if typ.kind == .enum_ && g.loop_depth == 0 && node.branches.len > 5 && g.fn_decl != 0 { // do not optimize while in top-level
g.match_expr_switch(node, is_expr, cond_var, tmp_var, typ)
} else {
g.match_expr_classic(node, is_expr, cond_var, tmp_var)
@ -4450,7 +4427,9 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str
fn (mut g Gen) match_expr_switch(node ast.MatchExpr, is_expr bool, cond_var string, tmp_var string, enum_typ ast.TypeSymbol) {
cname := '${enum_typ.cname}__'
mut covered_enum := []string{cap: (enum_typ.info as ast.Enum).vals.len}
mut covered_enum := []string{cap: (enum_typ.info as ast.Enum).vals.len} // collects missing enum variant branches to avoid cstrict errors
mut range_branches := []ast.MatchBranch{cap: node.branches.len} // branches have RangeExpr cannot emit as switch case branch, we handle it in default branch
mut default_generated := false
g.empty_line = true
g.writeln('switch ($cond_var) {')
g.indent++
@ -4462,19 +4441,56 @@ fn (mut g Gen) match_expr_switch(node ast.MatchExpr, is_expr bool, cond_var stri
}
}
g.writeln('default:')
default_generated = true
if range_branches.len > 0 {
g.indent++
for range_branch in range_branches {
g.write('if (')
for i, expr in range_branch.exprs {
if i > 0 {
g.write(' || ')
}
if expr is ast.RangeExpr {
// if type is unsigned and low is 0, check is unneeded
mut skip_low := false
if expr.low is ast.IntegerLiteral {
if node.cond_type in [ast.u16_type, ast.u32_type, ast.u64_type]
&& expr.low.val == '0' {
skip_low = true
}
}
g.write('(')
if !skip_low {
g.write('$cond_var >= ')
g.expr(expr.low)
g.write(' && ')
}
g.write('$cond_var <= ')
g.expr(expr.high)
g.write(')')
} else {
g.write('$cond_var == (')
g.expr(expr)
g.write(')')
}
}
g.writeln(') {')
g.stmts_with_tmp_var(range_branch.stmts, tmp_var)
g.writeln('}')
}
g.indent--
}
} else {
if branch.exprs.any(it is ast.RangeExpr) {
range_branches << branch
continue
}
for expr in branch.exprs {
if expr is ast.EnumVal {
covered_enum << (expr as ast.EnumVal).val
covered_enum << expr.val
g.write('case ')
g.expr(expr)
g.writeln(': ')
} else if expr is ast.RangeExpr {
// low, high := (expr.low as ast.IntegerLiteral).val.int(), (expr.high as ast.IntegerLiteral).val.int()
// for val in (enum_typ.info as ast.Enum).vals[low..high + 1] {
// covered_enum << val
// g.writeln('case $cname$val:')
// }
}
}
}
@ -4484,6 +4500,45 @@ fn (mut g Gen) match_expr_switch(node ast.MatchExpr, is_expr bool, cond_var stri
g.writeln('} break;')
g.indent--
}
if range_branches.len > 0 && !default_generated {
g.writeln('default:')
g.indent++
for range_branch in range_branches {
g.write('if (')
for i, expr in range_branch.exprs {
if i > 0 {
g.write(' || ')
}
if expr is ast.RangeExpr {
// if type is unsigned and low is 0, check is unneeded
mut skip_low := false
if expr.low is ast.IntegerLiteral {
if node.cond_type in [ast.u16_type, ast.u32_type, ast.u64_type]
&& expr.low.val == '0' {
skip_low = true
}
}
g.write('(')
if !skip_low {
g.write('$cond_var >= ')
g.expr(expr.low)
g.write(' && ')
}
g.write('$cond_var <= ')
g.expr(expr.high)
g.write(')')
} else {
g.write('$cond_var == (')
g.expr(expr)
g.write(')')
}
}
g.writeln(') {')
g.stmts_with_tmp_var(range_branch.stmts, tmp_var)
g.writeln('}')
}
g.indent--
}
g.indent--
g.writeln('}')
}

View File

@ -0,0 +1,5 @@
case main__Enum__e1:
case main__Enum__e4:
case main__Enum__e6:
default:
(e >= 4 && e <= 5)

View File

@ -0,0 +1,20 @@
enum Enum {
e1
e2
e3
e4
e5
e6
}
fn main() {
e := Enum.e1
match e {
.e1 {}
.e2 {}
.e3 {}
.e4 {}
4...5 {}
else {}
}
}

View File

@ -2151,7 +2151,7 @@ struct CondExpr {
fn (mut g JsGen) match_cond(cond MatchCond) {
match cond {
CondString {
g.writeln(cond.s)
g.write(cond.s)
}
CondExpr {
g.expr(cond.expr)
@ -2175,7 +2175,7 @@ fn (mut g JsGen) match_expr(node ast.MatchExpr) {
}
if node.cond in [ast.Ident, ast.SelectorExpr, ast.IntegerLiteral, ast.StringLiteral,
ast.FloatLiteral, ast.CallExpr] {
ast.FloatLiteral, ast.CallExpr, ast.EnumVal] {
cond_var = CondExpr{node.cond}
} else {
s := g.new_tmp_var()
@ -2197,7 +2197,7 @@ fn (mut g JsGen) match_expr(node ast.MatchExpr) {
if node.is_sum_type {
g.match_expr_sumtype(node, is_expr, cond_var, tmp_var)
} else if typ.kind == .enum_ && !g.inside_loop && node.branches.len > 5 && g.fn_decl != 0 { // do not optimize while in top-level
g.match_expr_switch(node, is_expr, cond_var, tmp_var)
g.match_expr_switch(node, is_expr, cond_var, tmp_var, typ)
} else {
g.match_expr_classic(node, is_expr, cond_var, tmp_var)
}
@ -2313,7 +2313,9 @@ fn (mut g JsGen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var M
}
}
fn (mut g JsGen) match_expr_switch(node ast.MatchExpr, is_expr bool, cond_var MatchCond, tmp_var string) {
fn (mut g JsGen) match_expr_switch(node ast.MatchExpr, is_expr bool, cond_var MatchCond, tmp_var string, enum_typ ast.TypeSymbol) {
mut range_branches := []ast.MatchBranch{cap: node.branches.len} // branches have RangeExpr cannot emit as switch case branch, we handle it in default branch
mut default_generated := false
g.empty_line = true
g.write('switch (')
g.match_cond(cond_var)
@ -2322,11 +2324,59 @@ fn (mut g JsGen) match_expr_switch(node ast.MatchExpr, is_expr bool, cond_var Ma
for branch in node.branches {
if branch.is_else {
g.writeln('default:')
default_generated = true
if range_branches.len > 0 {
g.inc_indent()
for range_branch in range_branches {
g.write('if (')
for i, expr in range_branch.exprs {
if i > 0 {
g.write(' || ')
}
if expr is ast.RangeExpr {
// if type is unsigned and low is 0, check is unneeded
mut skip_low := false
if expr.low is ast.IntegerLiteral {
if node.cond_type in [ast.u16_type, ast.u32_type, ast.u64_type]
&& expr.low.val == '0' {
skip_low = true
}
}
g.write('(')
if !skip_low {
g.match_cond(cond_var)
g.write(' >= ')
g.expr(expr.low)
g.write(' && ')
}
g.match_cond(cond_var)
g.write(' <= ')
g.expr(expr.high)
g.write(')')
} else {
g.match_cond(cond_var)
g.write(' == (')
g.expr(expr)
g.write(')')
}
}
g.writeln(') {')
g.stmts_with_tmp_var(range_branch.stmts, tmp_var)
g.writeln('}')
}
g.dec_indent()
}
} else {
if branch.exprs.any(it is ast.RangeExpr) {
range_branches << branch
continue
}
for expr in branch.exprs {
g.write('case ')
g.expr(expr)
g.writeln(': ')
if expr is ast.EnumVal {
g.write('case ')
g.expr(expr)
g.writeln(': ')
}
}
}
g.inc_indent()
@ -2335,6 +2385,48 @@ fn (mut g JsGen) match_expr_switch(node ast.MatchExpr, is_expr bool, cond_var Ma
g.writeln('} break;')
g.dec_indent()
}
if range_branches.len > 0 && !default_generated {
g.writeln('default:')
g.inc_indent()
for range_branch in range_branches {
g.write('if (')
for i, expr in range_branch.exprs {
if i > 0 {
g.write(' || ')
}
if expr is ast.RangeExpr {
// if type is unsigned and low is 0, check is unneeded
mut skip_low := false
if expr.low is ast.IntegerLiteral {
if node.cond_type in [ast.u16_type, ast.u32_type, ast.u64_type]
&& expr.low.val == '0' {
skip_low = true
}
}
g.write('(')
if !skip_low {
g.match_cond(cond_var)
g.write(' >= ')
g.expr(expr.low)
g.write(' && ')
}
g.match_cond(cond_var)
g.write(' <= ')
g.expr(expr.high)
g.write(')')
} else {
g.match_cond(cond_var)
g.write(' == (')
g.expr(expr)
g.write(')')
}
}
g.writeln(') {')
g.stmts_with_tmp_var(range_branch.stmts, tmp_var)
g.writeln('}')
}
g.dec_indent()
}
g.dec_indent()
g.writeln('}')
}