diff --git a/vlib/v/ast/ast.v b/vlib/v/ast/ast.v index c74a00e116..7975c5a01b 100644 --- a/vlib/v/ast/ast.v +++ b/vlib/v/ast/ast.v @@ -423,10 +423,15 @@ pub: is_arg bool // fn args should not be autofreed pub mut: typ table.Type + orig_type table.Type // original sumtype type; 0 if it's not a sumtype sum_type_casts []table.Type // nested sum types require nested smart casting, for that a list of types is needed - pos token.Position - is_used bool - is_changed bool // to detect mutable vars that are never changed + // TODO: move this to a real docs site later + // 10 <- original type (orig_type) + // [11, 12, 13] <- cast order (sum_type_casts) + // 12 <- the current casted type (typ) + pos token.Position + is_used bool + is_changed bool // to detect mutable vars that are never changed // // (for setting the position after the or block for autofree) is_or bool // `x := foo() or { ... }` @@ -442,6 +447,11 @@ pub: pos token.Position typ table.Type sum_type_casts []table.Type // nested sum types require nested smart casting, for that a list of types is needed + orig_type table.Type // original sumtype type; 0 if it's not a sumtype + // TODO: move this to a real docs site later + // 10 <- original type (orig_type) + // [11, 12, 13] <- cast order (sum_type_casts) + // 12 <- the current casted type (typ) } pub struct GlobalField { diff --git a/vlib/v/checker/checker.v b/vlib/v/checker/checker.v index 056e93368d..7ec56ebcfa 100644 --- a/vlib/v/checker/checker.v +++ b/vlib/v/checker/checker.v @@ -4150,6 +4150,7 @@ fn (c Checker) smartcast_sumtype(expr ast.Expr, cur_type table.Type, to_type tab mut is_mut := false mut sum_type_casts := []table.Type{} expr_sym := c.table.get_type_symbol(expr.expr_type) + mut orig_type := 0 if field := c.table.struct_find_field(expr_sym, expr.field_name) { if field.is_mut { root_ident := expr.root_ident() @@ -4157,6 +4158,9 @@ fn (c Checker) smartcast_sumtype(expr ast.Expr, cur_type table.Type, to_type tab is_mut = v.is_mut } } + if orig_type == 0 { + orig_type = field.typ + } } if field := scope.find_struct_field(expr.expr_type, expr.field_name) { sum_type_casts << field.sum_type_casts @@ -4170,6 +4174,7 @@ fn (c Checker) smartcast_sumtype(expr ast.Expr, cur_type table.Type, to_type tab typ: cur_type sum_type_casts: sum_type_casts pos: expr.pos + orig_type: orig_type }) } } @@ -4177,11 +4182,14 @@ fn (c Checker) smartcast_sumtype(expr ast.Expr, cur_type table.Type, to_type tab mut is_mut := false mut sum_type_casts := []table.Type{} mut is_already_casted := false - if expr.obj is ast.Var { - v := expr.obj as ast.Var - is_mut = v.is_mut - sum_type_casts << v.sum_type_casts - is_already_casted = v.pos.pos == expr.pos.pos + mut orig_type := 0 + if mut expr.obj is ast.Var { + is_mut = expr.obj.is_mut + sum_type_casts << expr.obj.sum_type_casts + is_already_casted = expr.obj.pos.pos == expr.pos.pos + if orig_type == 0 { + orig_type = expr.obj.typ + } } // smartcast either if the value is immutable or if the mut argument is explicitly given if (!is_mut || expr.is_mut) && !is_already_casted { @@ -4193,6 +4201,7 @@ fn (c Checker) smartcast_sumtype(expr ast.Expr, cur_type table.Type, to_type tab is_used: true is_mut: expr.is_mut sum_type_casts: sum_type_casts + orig_type: orig_type }) } } diff --git a/vlib/v/gen/cgen.v b/vlib/v/gen/cgen.v index b26d20b84a..5982b6d374 100644 --- a/vlib/v/gen/cgen.v +++ b/vlib/v/gen/cgen.v @@ -2939,6 +2939,7 @@ fn (mut g Gen) selector_expr(node ast.SelectorExpr) { return } mut sum_type_deref_field := '' + mut sum_type_dot := '.' if f := g.table.struct_find_field(sym, node.field_name) { field_sym := g.table.get_type_symbol(f.typ) if field_sym.kind == .sum_type { @@ -2946,6 +2947,9 @@ fn (mut g Gen) selector_expr(node ast.SelectorExpr) { // check first if field is sum type because scope searching is expensive scope := g.file.scope.innermost(node.pos.pos) if field := scope.find_struct_field(node.expr_type, node.field_name) { + if field.orig_type.is_ptr() { + sum_type_dot = '->' + } // union sum type deref for i, typ in field.sum_type_casts { g.write('(*') @@ -2957,7 +2961,6 @@ fn (mut g Gen) selector_expr(node ast.SelectorExpr) { if mut cast_sym.info is table.Aggregate { agg_sym := g.table.get_type_symbol(cast_sym.info.types[g.aggregate_type_idx]) sum_type_deref_field += '_$agg_sym.cname' - // sum_type_deref_field += '_${cast_sym.info.types[g.aggregate_type_idx]}' } else { sum_type_deref_field += '_$cast_sym.cname' } @@ -2997,7 +3000,7 @@ fn (mut g Gen) selector_expr(node ast.SelectorExpr) { } g.write(c_name(node.field_name)) if sum_type_deref_field != '' { - g.write('.$sum_type_deref_field)') + g.write('$sum_type_dot$sum_type_deref_field)') } } @@ -3782,10 +3785,14 @@ fn (mut g Gen) ident(node ast.Ident) { } for i, typ in v.sum_type_casts { cast_sym := g.table.get_type_symbol(typ) + mut is_ptr := false if i == 0 { g.write(name) + if v.orig_type.is_ptr() { + is_ptr = true + } } - dot := if v.typ.is_ptr() { '->' } else { '.' } + dot := if is_ptr { '->' } else { '.' } if mut cast_sym.info is table.Aggregate { sym := g.table.get_type_symbol(cast_sym.info.types[g.aggregate_type_idx]) g.write('${dot}_$sym.cname') diff --git a/vlib/v/tests/if_smartcast_test.v b/vlib/v/tests/if_smartcast_test.v index d9dd4625f7..14b6306bb8 100644 --- a/vlib/v/tests/if_smartcast_test.v +++ b/vlib/v/tests/if_smartcast_test.v @@ -253,3 +253,51 @@ fn test_nested_sumtype_selector() { assert false } } + +struct Foo1 { + a int +} + +struct Foo2 { + a int +} + +struct Bar1 { + a int +} + +struct Bar2 { + a int +} + +type Sum1 = Foo1 | Foo2 + +type Sum2 = Bar1 | Bar2 + +type SumAll = Sum1 | Sum2 + +struct All_in_one { +pub mut: + ptrs []&SumAll + ptr &SumAll +} + +fn test_nested_pointer_smartcast() { + mut s := All_in_one{ + ptr: &Sum1(Foo1{a: 1}) + ptrs: [&SumAll(Sum2(Bar1{a: 3}))] + } + + if mut s.ptr is Sum1 { + if mut s.ptr is Foo1 { + assert s.ptr.a == 1 + } + } + + a := s.ptrs[0] + if a is Sum1 { + if a is Foo1{ + assert a.a == 3 + } + } +}