diff --git a/vlib/v/checker/checker.v b/vlib/v/checker/checker.v index 14fe732621..48fdb09d31 100644 --- a/vlib/v/checker/checker.v +++ b/vlib/v/checker/checker.v @@ -2635,6 +2635,13 @@ pub fn (mut c Checker) if_expr(mut node ast.IfExpr) table.Type { is_used: true is_mut: left_expr.is_mut }) + scope.register(left_expr.name, ast.Var{ + name: left_expr.name + typ: right_expr.typ.to_ptr() + pos: left_expr.pos + is_used: true + is_mut: left_expr.is_mut + }) node.branches[i].smartcast = true } } diff --git a/vlib/v/gen/cgen.v b/vlib/v/gen/cgen.v index 092591e9e7..551a3b5421 100644 --- a/vlib/v/gen/cgen.v +++ b/vlib/v/gen/cgen.v @@ -963,9 +963,17 @@ fn (mut g Gen) expr_with_cast(expr ast.Expr, got_type, expected_type table.Type) got_styp := g.typ(got_type) exp_styp := g.typ(expected_type) got_idx := got_type.idx() - g.write('/* sum type cast */ ($exp_styp) {.obj = memdup(&($got_styp[]) {') - g.expr(expr) - g.write('}, sizeof($got_styp)), .typ = $got_idx /* $got_sym.name */}') + if got_type.is_ptr() { + g.write('/* sum type cast */ ($exp_styp) {.obj = ') + g.expr(expr) + g.write(', .typ = $got_idx /* $got_sym.name */}') + } + else { + g.write('/* sum type cast */ ($exp_styp) {.obj = memdup(&($got_styp[]) {') + g.expr(expr) + g.write('}, sizeof($got_styp)), .typ = $got_idx /* $got_sym.name */}') + } + return } } @@ -2314,6 +2322,7 @@ fn (mut g Gen) if_expr(node ast.IfExpr) { infix := branch.cond as ast.InfixExpr right_type := infix.right as ast.Type left_type := infix.left_type + left_expr := infix.left as ast.Ident it_type := g.typ(right_type.typ) g.write('\t$it_type* it = ($it_type*)') g.expr(infix.left) @@ -2323,6 +2332,7 @@ fn (mut g Gen) if_expr(node ast.IfExpr) { g.write('.') } g.writeln('obj;') + g.writeln('\t$it_type* $left_expr.name = it;') } g.stmts(branch.stmts) } diff --git a/vlib/v/tests/sum_type_test.v b/vlib/v/tests/sum_type_test.v index be4f21d7a7..45a1baf523 100644 --- a/vlib/v/tests/sum_type_test.v +++ b/vlib/v/tests/sum_type_test.v @@ -110,7 +110,7 @@ fn test_nested_sumtype() { mut a := Node{} mut b := Node{} a = StructDecl{pos: 1} - b = IfExpr {pos: 1} + b = IfExpr{pos: 1} match a { StructDecl { assert true @@ -127,6 +127,18 @@ fn test_nested_sumtype() { else { assert false } + c := Node(Expr(IfExpr{pos:1})) + if c is Expr { + if c is IfExpr { + assert true + } + else { + assert false + } + } + else { + assert false + } } type Abc = int | string