transformer: eliminate unreachable branches & redundant branch expressions in MatchExpr (#12174)
							parent
							
								
									5b9553d5c4
								
							
						
					
					
						commit
						6f629d1a6a
					
				| 
						 | 
				
			
			@ -866,12 +866,12 @@ pub mut:
 | 
			
		|||
pub struct MatchBranch {
 | 
			
		||||
pub:
 | 
			
		||||
	ecmnts        [][]Comment // inline comments for each left side expr
 | 
			
		||||
	stmts         []Stmt      // right side
 | 
			
		||||
	pos           token.Position
 | 
			
		||||
	is_else       bool
 | 
			
		||||
	post_comments []Comment      // comments below ´... }´
 | 
			
		||||
	branch_pos    token.Position // for checker errors about invalid branches
 | 
			
		||||
pub mut:
 | 
			
		||||
	stmts []Stmt // right side
 | 
			
		||||
	exprs []Expr // left side
 | 
			
		||||
	scope &Scope
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -56,18 +56,19 @@ pub fn (t Transformer) stmt(mut node ast.Stmt) {
 | 
			
		|||
		ast.DeferStmt {}
 | 
			
		||||
		ast.EnumDecl {}
 | 
			
		||||
		ast.ExprStmt {
 | 
			
		||||
			if node.expr is ast.IfExpr {
 | 
			
		||||
				mut untrans_expr := node.expr as ast.IfExpr
 | 
			
		||||
				expr := t.if_expr(mut untrans_expr)
 | 
			
		||||
				node = &ast.ExprStmt{
 | 
			
		||||
					...node
 | 
			
		||||
					expr: expr
 | 
			
		||||
				}
 | 
			
		||||
			} else {
 | 
			
		||||
				expr := t.expr(node.expr)
 | 
			
		||||
				node = &ast.ExprStmt{
 | 
			
		||||
					...node
 | 
			
		||||
					expr: expr
 | 
			
		||||
			expr := node.expr
 | 
			
		||||
			node = &ast.ExprStmt{
 | 
			
		||||
				...node
 | 
			
		||||
				expr: match mut expr {
 | 
			
		||||
					ast.IfExpr {
 | 
			
		||||
						t.if_expr(mut expr)
 | 
			
		||||
					}
 | 
			
		||||
					ast.MatchExpr {
 | 
			
		||||
						t.match_expr(mut expr)
 | 
			
		||||
					}
 | 
			
		||||
					else {
 | 
			
		||||
						t.expr(expr)
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
| 
						 | 
				
			
			@ -116,10 +117,77 @@ pub fn (t Transformer) expr(node ast.Expr) ast.Expr {
 | 
			
		|||
				index: t.expr(node.index)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		ast.MatchExpr {
 | 
			
		||||
		ast.IfExpr {
 | 
			
		||||
			for mut branch in node.branches {
 | 
			
		||||
				for mut stmt in branch.stmts {
 | 
			
		||||
				branch = ast.IfBranch{
 | 
			
		||||
					...(*branch)
 | 
			
		||||
					cond: t.expr(branch.cond)
 | 
			
		||||
				}
 | 
			
		||||
				for i, mut stmt in branch.stmts {
 | 
			
		||||
					t.stmt(mut stmt)
 | 
			
		||||
 | 
			
		||||
					if i == branch.stmts.len - 1 {
 | 
			
		||||
						if stmt is ast.ExprStmt {
 | 
			
		||||
							expr := (stmt as ast.ExprStmt).expr
 | 
			
		||||
 | 
			
		||||
							match expr {
 | 
			
		||||
								ast.IfExpr {
 | 
			
		||||
									if expr.branches.len == 1 {
 | 
			
		||||
										branch.stmts.pop()
 | 
			
		||||
										branch.stmts << expr.branches[0].stmts
 | 
			
		||||
										break
 | 
			
		||||
									}
 | 
			
		||||
								}
 | 
			
		||||
								ast.MatchExpr {
 | 
			
		||||
									if expr.branches.len == 1 {
 | 
			
		||||
										branch.stmts.pop()
 | 
			
		||||
										branch.stmts << expr.branches[0].stmts
 | 
			
		||||
										break
 | 
			
		||||
									}
 | 
			
		||||
								}
 | 
			
		||||
								else {}
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			return node
 | 
			
		||||
		}
 | 
			
		||||
		ast.MatchExpr {
 | 
			
		||||
			node = ast.MatchExpr{
 | 
			
		||||
				...node
 | 
			
		||||
				cond: t.expr(node.cond)
 | 
			
		||||
			}
 | 
			
		||||
			for mut branch in node.branches {
 | 
			
		||||
				for mut expr in branch.exprs {
 | 
			
		||||
					expr = t.expr(expr)
 | 
			
		||||
				}
 | 
			
		||||
				for i, mut stmt in branch.stmts {
 | 
			
		||||
					t.stmt(mut stmt)
 | 
			
		||||
 | 
			
		||||
					if i == branch.stmts.len - 1 {
 | 
			
		||||
						if stmt is ast.ExprStmt {
 | 
			
		||||
							expr := (stmt as ast.ExprStmt).expr
 | 
			
		||||
 | 
			
		||||
							match expr {
 | 
			
		||||
								ast.IfExpr {
 | 
			
		||||
									if expr.branches.len == 1 {
 | 
			
		||||
										branch.stmts.pop()
 | 
			
		||||
										branch.stmts << expr.branches[0].stmts
 | 
			
		||||
										break
 | 
			
		||||
									}
 | 
			
		||||
								}
 | 
			
		||||
								ast.MatchExpr {
 | 
			
		||||
									if expr.branches.len == 1 {
 | 
			
		||||
										branch.stmts.pop()
 | 
			
		||||
										branch.stmts << expr.branches[0].stmts
 | 
			
		||||
										break
 | 
			
		||||
									}
 | 
			
		||||
								}
 | 
			
		||||
								else {}
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			return node
 | 
			
		||||
| 
						 | 
				
			
			@ -132,6 +200,9 @@ pub fn (t Transformer) expr(node ast.Expr) ast.Expr {
 | 
			
		|||
 | 
			
		||||
pub fn (t Transformer) if_expr(mut original ast.IfExpr) ast.Expr {
 | 
			
		||||
	mut stop_index, mut unreachable_branches := -1, []int{cap: original.branches.len}
 | 
			
		||||
	if original.is_comptime {
 | 
			
		||||
		return *original
 | 
			
		||||
	}
 | 
			
		||||
	for i, mut branch in original.branches {
 | 
			
		||||
		for mut stmt in branch.stmts {
 | 
			
		||||
			t.stmt(mut stmt)
 | 
			
		||||
| 
						 | 
				
			
			@ -171,6 +242,84 @@ pub fn (t Transformer) if_expr(mut original ast.IfExpr) ast.Expr {
 | 
			
		|||
	return *original
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub fn (t Transformer) match_expr(mut original ast.MatchExpr) ast.Expr {
 | 
			
		||||
	cond, mut terminate := t.expr(original.cond), false
 | 
			
		||||
	original = ast.MatchExpr{
 | 
			
		||||
		...(*original)
 | 
			
		||||
		cond: cond
 | 
			
		||||
	}
 | 
			
		||||
	for mut branch in original.branches {
 | 
			
		||||
		if branch.is_else {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for mut stmt in branch.stmts {
 | 
			
		||||
			t.stmt(mut stmt)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for mut expr in branch.exprs {
 | 
			
		||||
			expr = t.expr(expr)
 | 
			
		||||
 | 
			
		||||
			match cond {
 | 
			
		||||
				ast.BoolLiteral {
 | 
			
		||||
					if expr is ast.BoolLiteral {
 | 
			
		||||
						if cond.val == (expr as ast.BoolLiteral).val {
 | 
			
		||||
							branch.exprs = [expr]
 | 
			
		||||
							original = ast.MatchExpr{
 | 
			
		||||
								...(*original)
 | 
			
		||||
								branches: [branch]
 | 
			
		||||
							}
 | 
			
		||||
							terminate = true
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
				ast.IntegerLiteral {
 | 
			
		||||
					if expr is ast.IntegerLiteral {
 | 
			
		||||
						if cond.val.int() == (expr as ast.IntegerLiteral).val.int() {
 | 
			
		||||
							branch.exprs = [expr]
 | 
			
		||||
							original = ast.MatchExpr{
 | 
			
		||||
								...(*original)
 | 
			
		||||
								branches: [branch]
 | 
			
		||||
							}
 | 
			
		||||
							terminate = true
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
				ast.FloatLiteral {
 | 
			
		||||
					if expr is ast.FloatLiteral {
 | 
			
		||||
						if cond.val.f32() == (expr as ast.FloatLiteral).val.f32() {
 | 
			
		||||
							branch.exprs = [expr]
 | 
			
		||||
							original = ast.MatchExpr{
 | 
			
		||||
								...(*original)
 | 
			
		||||
								branches: [branch]
 | 
			
		||||
							}
 | 
			
		||||
							terminate = true
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
				ast.StringLiteral {
 | 
			
		||||
					if expr is ast.StringLiteral {
 | 
			
		||||
						if cond.val == (expr as ast.StringLiteral).val {
 | 
			
		||||
							branch.exprs = [expr]
 | 
			
		||||
							original = ast.MatchExpr{
 | 
			
		||||
								...(*original)
 | 
			
		||||
								branches: [branch]
 | 
			
		||||
							}
 | 
			
		||||
							terminate = true
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
				else {}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if terminate {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return *original
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub fn (t Transformer) infix_expr(original ast.InfixExpr) ast.Expr {
 | 
			
		||||
	mut node := original
 | 
			
		||||
	node.left = t.expr(node.left)
 | 
			
		||||
| 
						 | 
				
			
			@ -357,6 +506,76 @@ pub fn (t Transformer) infix_expr(original ast.InfixExpr) ast.Expr {
 | 
			
		|||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		ast.FloatLiteral {
 | 
			
		||||
			match right_node {
 | 
			
		||||
				ast.FloatLiteral {
 | 
			
		||||
					left_val := left_node.val.f32()
 | 
			
		||||
					right_val := right_node.val.f32()
 | 
			
		||||
					match node.op {
 | 
			
		||||
						.eq {
 | 
			
		||||
							return ast.BoolLiteral{
 | 
			
		||||
								val: left_node.val == right_node.val
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
						.ne {
 | 
			
		||||
							return ast.BoolLiteral{
 | 
			
		||||
								val: left_node.val != right_node.val
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
						.gt {
 | 
			
		||||
							return ast.BoolLiteral{
 | 
			
		||||
								val: left_node.val > right_node.val
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
						.ge {
 | 
			
		||||
							return ast.BoolLiteral{
 | 
			
		||||
								val: left_node.val >= right_node.val
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
						.lt {
 | 
			
		||||
							return ast.BoolLiteral{
 | 
			
		||||
								val: left_node.val < right_node.val
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
						.le {
 | 
			
		||||
							return ast.BoolLiteral{
 | 
			
		||||
								val: left_node.val <= right_node.val
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
						.plus {
 | 
			
		||||
							return ast.FloatLiteral{
 | 
			
		||||
								val: (left_val + right_val).str()
 | 
			
		||||
								pos: pos
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
						.mul {
 | 
			
		||||
							return ast.FloatLiteral{
 | 
			
		||||
								val: (left_val * right_val).str()
 | 
			
		||||
								pos: pos
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
						.minus {
 | 
			
		||||
							return ast.FloatLiteral{
 | 
			
		||||
								val: (left_val - right_val).str()
 | 
			
		||||
								pos: pos
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
						.div {
 | 
			
		||||
							return ast.FloatLiteral{
 | 
			
		||||
								val: (left_val / right_val).str()
 | 
			
		||||
								pos: pos
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
						else {
 | 
			
		||||
							return node
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
				else {
 | 
			
		||||
					return node
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		else {
 | 
			
		||||
			return node
 | 
			
		||||
		}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue