checker: allow accessing fields common to all sumtype members (#9201)

pull/9221/head
spaceface 2021-03-09 18:16:18 +01:00 committed by GitHub
parent eed6f7dbff
commit f1469a8761
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 330 additions and 69 deletions

View File

@ -879,6 +879,7 @@ pub:
is_pub bool
pos token.Position
comments []Comment
typ table.Type
pub mut:
variants []SumTypeVariant
}

View File

@ -2293,7 +2293,7 @@ pub fn (mut c Checker) selector_expr(mut selector_expr ast.SelectorExpr) table.T
c.error('ambiguous field `$field_name`', selector_expr.pos)
}
}
if sym.kind == .aggregate {
if sym.kind in [.aggregate, .sum_type] {
unknown_field_msg = err
}
}
@ -2321,7 +2321,7 @@ pub fn (mut c Checker) selector_expr(mut selector_expr ast.SelectorExpr) table.T
selector_expr.typ = field.typ
return field.typ
}
if sym.kind !in [.struct_, .aggregate, .interface_] {
if sym.kind !in [.struct_, .aggregate, .interface_, .sum_type] {
if sym.kind != .placeholder {
c.error('`$sym.name` is not a struct', selector_expr.pos)
}

View File

@ -0,0 +1,20 @@
vlib/v/checker/tests/sum_type_common_fields_alias_error.vv:35:14: error: field `name` does not exist or have the same type in all sumtype variants
33 | }
34 | println(m)
35 | assert m[0].name == 'abc'
| ~~~~
36 | assert m[1].name == 'def'
37 | assert m[2].name == 'xyz'
vlib/v/checker/tests/sum_type_common_fields_alias_error.vv:36:14: error: field `name` does not exist or have the same type in all sumtype variants
34 | println(m)
35 | assert m[0].name == 'abc'
36 | assert m[1].name == 'def'
| ~~~~
37 | assert m[2].name == 'xyz'
38 | }
vlib/v/checker/tests/sum_type_common_fields_alias_error.vv:37:14: error: field `name` does not exist or have the same type in all sumtype variants
35 | assert m[0].name == 'abc'
36 | assert m[1].name == 'def'
37 | assert m[2].name == 'xyz'
| ~~~~
38 | }

View File

@ -0,0 +1,38 @@
type Main = Sub1 | Sub2 | Sub3
// NB: the subtypes will have a common `name` field, of the same `string`
// type, except Sub3, which has `name` of type AliasedString.
type AliasedString = string
struct Sub1 {
mut:
name string
}
struct Sub2 {
mut:
name string
}
struct Sub3 {
mut:
name AliasedString
}
fn main() {
mut m := []Main{}
m << Sub1{
name: 'abc'
}
m << Sub2{
name: 'def'
}
m << Sub3{
name: 'xyz'
}
println(m)
assert m[0].name == 'abc'
assert m[1].name == 'def'
assert m[2].name == 'xyz'
}

View File

@ -0,0 +1,6 @@
vlib/v/checker/tests/sum_type_common_fields_error.vv:53:14: error: field `val` does not exist or have the same type in all sumtype variants
51 | assert m[2].name == '64bit integer'
52 | assert m[3].name == 'string'
53 | assert m[0].val == 123
| ~~~
54 | }

View File

@ -0,0 +1,54 @@
type Main = Sub1 | Sub2 | Sub3 | Sub4
// NB: all subtypes have a common name field, of the same `string` type
// but they also have a field `val` that is of a different type in the
// different subtypes => accessing `m[0].name` is fine, but *not* `m[0].val`
struct Sub1 {
mut:
val int
name string
}
struct Sub2 {
mut:
val f32
name string
}
struct Sub3 {
mut:
val i64
name string
}
struct Sub4 {
mut:
val string
name string
}
fn main() {
mut m := []Main{}
m << Sub1{
val: 123
name: 'integer'
}
m << Sub2{
val: 3.14
name: 'float'
}
m << Sub3{
val: 9_876_543_210
name: '64bit integer'
}
m << Sub4{
val: 'abcd'
name: 'string'
}
println(m)
assert m[0].name == 'integer'
assert m[1].name == 'float'
assert m[2].name == '64bit integer'
assert m[3].name == 'string'
assert m[0].val == 123
}

View File

@ -677,7 +677,7 @@ fn (mut g Gen) gen_str_for_union_sum_type(info table.SumType, styp string, str_f
clean_sum_type_v_type_name = '&' + clean_sum_type_v_type_name.replace('*', '')
}
clean_sum_type_v_type_name = util.strip_main_name(clean_sum_type_v_type_name)
g.auto_str_funcs.writeln('\tswitch(x.typ) {')
g.auto_str_funcs.writeln('\tswitch(x._typ) {')
for typ in info.variants {
mut value_fmt := '%.*s\\000'
if typ == table.string_type {

View File

@ -94,15 +94,16 @@ mut:
defer_stmts []ast.DeferStmt
defer_ifdef string
defer_profile_code string
str_types []string // types that need automatic str() generation
threaded_fns []string // for generating unique wrapper types and fns for `go xxx()`
waiter_fns []string // functions that wait for `go xxx()` to finish
array_fn_definitions []string // array equality functions that have been defined
map_fn_definitions []string // map equality functions that have been defined
struct_fn_definitions []string // struct equality functions that have been defined
alias_fn_definitions []string // alias equality functions that have been defined
auto_fn_definitions []string // auto generated functions defination list
anon_fn_definitions []string // anon generated functions defination list
str_types []string // types that need automatic str() generation
threaded_fns []string // for generating unique wrapper types and fns for `go xxx()`
waiter_fns []string // functions that wait for `go xxx()` to finish
array_fn_definitions []string // array equality functions that have been defined
map_fn_definitions []string // map equality functions that have been defined
struct_fn_definitions []string // struct equality functions that have been defined
alias_fn_definitions []string // alias equality functions that have been defined
auto_fn_definitions []string // auto generated functions defination list
anon_fn_definitions []string // anon generated functions defination list
sumtype_definitions map[int]bool // `_TypeA_to_sumtype_TypeB()` fns that have been generated
is_json_fn bool // inside json.encode()
json_types []string // to avoid json gen duplicates
pcs []ProfileCounterMeta // -prof profile counter fn_names => fn counter name
@ -1553,6 +1554,33 @@ fn (mut g Gen) for_in_stmt(node ast.ForInStmt) {
}
}
fn (mut g Gen) write_sumtype_casting_fn(got_ table.Type, exp_ table.Type) {
got, exp := got_.idx(), exp_.idx()
i := got | (exp << 16)
if got == exp || g.sumtype_definitions[i] {
return
}
g.sumtype_definitions[i] = true
got_sym := g.table.get_type_symbol(got)
exp_sym := g.table.get_type_symbol(exp)
mut sb := strings.new_builder(128)
got_cname, exp_cname := got_sym.cname, exp_sym.cname
sb.writeln('static inline $exp_cname ${got_cname}_to_sumtype_${exp_cname}($got_cname* x) {')
sb.writeln('\t$got_cname* ptr = memdup(x, sizeof($got_cname));')
sb.write_string('\treturn ($exp_cname){ ._$got_cname = ptr, ._typ = ${g.type_sidx(got)}')
for field in (exp_sym.info as table.SumType).fields {
field_styp := g.typ(field.typ)
if got_sym.kind in [.sum_type, .interface_] {
// the field is already a wrapped pointer; we shouldn't wrap it once again
sb.write_string(', .$field.name = ptr->$field.name')
} else {
sb.write_string(', .$field.name = ($field_styp*)((char*)ptr + __offsetof($got_cname, $field.name))')
}
}
sb.writeln('};\n}')
g.auto_fn_definitions << sb.str()
}
// use instead of expr() when you need to cast to a different type
fn (mut g Gen) expr_with_cast(expr ast.Expr, got_type_raw table.Type, expected_type table.Type) {
got_type := g.table.mktyp(got_type_raw)
@ -1590,45 +1618,39 @@ fn (mut g Gen) expr_with_cast(expr ast.Expr, got_type_raw table.Type, expected_t
expected_deref_type := if expected_is_ptr { expected_type.deref() } else { expected_type }
got_deref_type := if got_is_ptr { got_type.deref() } else { got_type }
if g.table.sumtype_has_variant(expected_deref_type, got_deref_type) {
// got_idx := got_type.idx()
got_sidx := g.type_sidx(got_type)
// TODO: do we need 1-3?
if expected_is_ptr && got_is_ptr {
exp_der_styp := g.typ(expected_deref_type)
g.write('/* sum type cast 1 */ ($exp_styp) memdup(&($exp_der_styp){._$got_sym.cname = ')
g.expr(expr)
g.write(', .typ = $got_sidx /* $got_sym.name */}, sizeof($exp_der_styp))')
} else if expected_is_ptr {
exp_der_styp := g.typ(expected_deref_type)
g.write('/* sum type cast 2 */ ($exp_styp) memdup(&($exp_der_styp){._$got_sym.cname = memdup(&($got_styp[]){')
g.expr(expr)
g.write('}, sizeof($got_styp)), .typ = $got_sidx /* $got_sym.name */}, sizeof($exp_der_styp))')
} else if got_is_ptr {
g.write('/* sum type cast 3 */ ($exp_styp){._$got_sym.cname = ')
g.expr(expr)
g.write(', .typ = $got_sidx /* $got_sym.name */}')
} else {
mut is_already_sum_type := false
scope := g.file.scope.innermost(expr.position().pos)
if expr is ast.Ident {
if v := scope.find_var(expr.name) {
if v.sum_type_casts.len > 0 {
is_already_sum_type = true
}
}
} else if expr is ast.SelectorExpr {
if _ := scope.find_struct_field(expr.expr_type, expr.field_name) {
mut is_already_sum_type := false
scope := g.file.scope.innermost(expr.position().pos)
if expr is ast.Ident {
if v := scope.find_var(expr.name) {
if v.sum_type_casts.len > 0 {
is_already_sum_type = true
}
}
if is_already_sum_type {
// Don't create a new sum type wrapper if there is already one
g.prevent_sum_type_unwrapping_once = true
} else if expr is ast.SelectorExpr {
if _ := scope.find_struct_field(expr.expr_type, expr.field_name) {
is_already_sum_type = true
}
}
if is_already_sum_type {
// Don't create a new sum type wrapper if there is already one
g.prevent_sum_type_unwrapping_once = true
g.expr(expr)
} else {
g.write_sumtype_casting_fn(got_type, expected_type)
if expected_is_ptr {
g.write('memdup(&($exp_sym.cname[]){')
}
g.write('${got_sym.cname}_to_sumtype_${exp_sym.cname}(')
if !got_is_ptr {
g.write('(&(($got_styp[]){')
g.expr(expr)
g.write('}[0])))')
} else {
g.write('/* sum type cast 4 */ ($exp_styp){._$got_sym.cname = memdup(&($got_styp[]){')
g.expr(expr)
g.write('}, sizeof($got_styp)), .typ = $got_sidx /* $got_sym.name */}')
g.write(')')
}
if expected_is_ptr {
g.write('}, sizeof($exp_sym.cname))')
}
}
return
@ -2707,7 +2729,7 @@ fn (mut g Gen) expr(node ast.Expr) {
g.concat_expr(node)
}
ast.CTempVar {
// g.write('/*ctmp .orig: $node.orig.str() , .typ: $node.typ, .is_ptr: $node.is_ptr */ ')
// g.write('/*ctmp .orig: $node.orig.str() , ._typ: $node.typ, .is_ptr: $node.is_ptr */ ')
g.write(node.name)
}
ast.EnumVal {
@ -2923,7 +2945,7 @@ fn (mut g Gen) typeof_expr(node ast.TypeOf) {
// because the subtype of the expression may change:
g.write('tos3( /* $sym.name */ v_typeof_sumtype_${sym.cname}( (')
g.expr(node.expr)
g.write(').typ ))')
g.write(')._typ ))')
} else if sym.kind == .array_fixed {
fixed_info := sym.info as table.ArrayFixed
typ_name := g.table.get_type_name(fixed_info.elem_type)
@ -2956,7 +2978,7 @@ fn (mut g Gen) selector_expr(node ast.SelectorExpr) {
opt_base_typ := g.base_type(node.expr_type)
g.writeln('(*($opt_base_typ*)')
}
if sym.kind == .interface_ {
if sym.kind in [.interface_, .sum_type] {
g.write('(*(')
}
if sym.kind == .array_fixed {
@ -3037,7 +3059,7 @@ fn (mut g Gen) selector_expr(node ast.SelectorExpr) {
if sum_type_deref_field != '' {
g.write('$sum_type_dot$sum_type_deref_field)')
}
if sym.kind == .interface_ {
if sym.kind in [.interface_, .sum_type] {
g.write('))')
}
}
@ -3678,7 +3700,7 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str
g.write(cond_var)
dot_or_ptr := if node.cond_type.is_ptr() { '->' } else { '.' }
if sym.kind == .sum_type {
g.write('${dot_or_ptr}typ == ')
g.write('${dot_or_ptr}_typ == ')
g.expr(branch.exprs[sumtype_index])
} else if sym.kind == .interface_ {
typ := branch.exprs[sumtype_index] as ast.Type
@ -5074,13 +5096,19 @@ fn (mut g Gen) write_types(types []table.TypeSymbol) {
g.type_definitions.writeln('// | ${variant:4d} = ${g.typ(variant.idx()):-20s}')
}
g.type_definitions.writeln('struct $name {')
g.type_definitions.writeln(' union {')
g.type_definitions.writeln('\tunion {')
for variant in typ.info.variants {
variant_sym := g.table.get_type_symbol(variant)
g.type_definitions.writeln(' ${g.typ(variant.to_ptr())} _$variant_sym.cname;')
g.type_definitions.writeln('\t\t${g.typ(variant.to_ptr())} _$variant_sym.cname;')
}
g.type_definitions.writeln('\t};')
g.type_definitions.writeln('\tint _typ;')
if typ.info.fields.len > 0 {
g.writeln('\t// pointers to common sumtype fields')
for field in typ.info.fields {
g.type_definitions.writeln('\t${g.typ(field.typ.to_ptr())} $field.name;')
}
}
g.type_definitions.writeln(' };')
g.type_definitions.writeln(' int typ;')
g.type_definitions.writeln('};')
g.type_definitions.writeln('')
}
@ -5691,7 +5719,7 @@ fn (mut g Gen) as_cast(node ast.AsCast) {
// g.write('typ, /*expected:*/$node.typ)')
sidx := g.type_sidx(node.typ)
expected_sym := g.table.get_type_symbol(node.typ)
g.write('typ, $sidx) /*expected idx: $sidx, name: $expected_sym.name */ ')
g.write('_typ, $sidx) /*expected idx: $sidx, name: $expected_sym.name */ ')
// fill as cast name table
for variant in expr_type_sym.info.variants {
@ -5739,7 +5767,7 @@ fn (mut g Gen) is_expr(node ast.InfixExpr) {
g.write('_${c_name(sym.name)}_${c_name(sub_sym.name)}_index')
return
} else if sym.kind == .sum_type {
g.write('typ $eq ')
g.write('_typ $eq ')
}
g.expr(node.right)
}

View File

@ -542,7 +542,7 @@ fn (mut g Gen) method_call(node ast.CallExpr) {
g.write('tos3( /* $left_sym.name */ v_typeof_sumtype_${typ_sym.cname}( (')
g.expr(node.left)
dot := if node.left_type.is_ptr() { '->' } else { '.' }
g.write(')${dot}typ ))')
g.write(')${dot}_typ ))')
return
}
if left_sym.kind == .interface_ && node.name == 'type_name' {

View File

@ -2318,7 +2318,7 @@ fn (mut p Parser) type_decl() ast.TypeDecl {
}
variant_types := sum_variants.map(it.typ)
prepend_mod_name := p.prepend_mod(name)
p.table.register_type_symbol(table.TypeSymbol{
typ := p.table.register_type_symbol(table.TypeSymbol{
kind: .sum_type
name: prepend_mod_name
cname: util.no_dots(prepend_mod_name)
@ -2331,6 +2331,7 @@ fn (mut p Parser) type_decl() ast.TypeDecl {
comments = p.eat_comments(same_line: true)
return ast.SumTypeDecl{
name: name
typ: typ
is_pub: is_pub
variants: sum_variants
pos: decl_pos

View File

@ -277,20 +277,32 @@ pub fn (t &Table) find_field(s &TypeSymbol, name string) ?Field {
// println('find_field($s.name, $name) types.len=$t.types.len s.parent_idx=$s.parent_idx')
mut ts := s
for {
if mut ts.info is Struct {
if field := ts.info.find_field(name) {
match mut ts.info {
Struct {
if field := ts.info.find_field(name) {
return field
}
}
Aggregate {
if field := ts.info.find_field(name) {
return field
}
field := t.register_aggregate_field(mut ts, name) or { return err }
return field
}
} else if mut ts.info is Aggregate {
if field := ts.info.find_field(name) {
return field
Interface {
if field := ts.info.find_field(name) {
return field
}
}
field := t.register_aggregate_field(mut ts, name) or { return err }
return field
} else if mut ts.info is Interface {
if field := ts.info.find_field(name) {
return field
SumType {
t.resolve_common_sumtype_fields(s)
if field := ts.info.find_field(name) {
return field
}
return error('field `$name` does not exist or have the same type in all sumtype variants')
}
else {}
}
if ts.parent_idx == 0 {
break
@ -326,6 +338,46 @@ pub fn (t &Table) find_field_with_embeds(sym &TypeSymbol, field_name string) ?Fi
}
}
pub fn (t &Table) resolve_common_sumtype_fields(sym_ &TypeSymbol) {
mut sym := sym_
mut info := sym.info as SumType
if info.found_fields {
return
}
mut field_map := map[string]Field{}
mut field_usages := map[string]int{}
for variant in info.variants {
mut v_sym := t.get_type_symbol(variant)
fields := match mut v_sym.info {
Struct {
v_sym.info.fields
}
SumType {
t.resolve_common_sumtype_fields(v_sym)
v_sym.info.fields
}
else {
[]Field{}
}
}
for field in fields {
if field.name !in field_map {
field_map[field.name] = field
field_usages[field.name]++
} else if field.equals(field_map[field.name]) {
field_usages[field.name]++
}
}
}
for field, nr_definitions in field_usages {
if nr_definitions == info.variants.len {
info.fields << field_map[field]
}
}
info.found_fields = true
sym.info = info
}
[inline]
pub fn (t &Table) find_type_idx(name string) int {
return t.type_idxs[name]

View File

@ -713,7 +713,7 @@ pub mut:
is_global bool
}
fn (f &Field) equals(o &Field) bool {
pub fn (f &Field) equals(o &Field) bool {
// TODO: f.is_mut == o.is_mut was removed here to allow read only access
// to (mut/not mut), but otherwise equal fields; some other new checks are needed:
// - if node is declared mut, and we mutate node.stmts, all stmts fields must be mutable
@ -755,6 +755,9 @@ pub mut:
pub struct SumType {
pub:
variants []Type
pub mut:
fields []Field
found_fields bool
}
// human readable type name
@ -971,6 +974,7 @@ pub fn (t &TypeSymbol) find_field(name string) ?Field {
Aggregate { return t.info.find_field(name) }
Struct { return t.info.find_field(name) }
Interface { return t.info.find_field(name) }
SumType { return t.info.find_field(name) }
else { return none }
}
}
@ -1025,6 +1029,15 @@ pub fn (s Struct) get_field(name string) Field {
panic('unknown field `$name`')
}
pub fn (s &SumType) find_field(name string) ?Field {
for field in s.fields {
if field.name == name {
return field
}
}
return none
}
pub fn (i Interface) defines_method(name string) bool {
for method in i.methods {
if method.name == name {

View File

@ -0,0 +1,48 @@
type Master = Sub1 | Sub2
struct Sub1 {
mut:
val int
name string
}
struct Sub2 {
name string
val int
}
struct Sub3 {
name string
val int
}
type Master2 = Master | Sub3
fn test_common_sumtype_field_access() {
mut out := []Master{}
out << Sub1{
val: 1
name: 'one'
}
out << Sub2{
val: 2
name: 'two'
}
out << Sub2{
val: 3
name: 'three'
}
assert out[0].val == 1
assert out[0].name == 'one'
assert out[1].val == 2
assert out[1].name == 'two'
assert out[2].val == 3
assert out[2].name == 'three'
mut out0 := Master2(out[0]) // common fields on a doubly-wrapped sumtype
assert out0.val == 1
assert out0.name == 'one'
}