cgen: support embedded struct fields on sumtype common fields (#11084)

pull/11092/head
Daniel Däschle 2021-08-06 20:26:19 +02:00 committed by GitHub
parent 25d49bc615
commit 3b116d2455
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 148 additions and 3 deletions

View File

@ -344,6 +344,28 @@ pub fn (t &Table) type_find_method(s &TypeSymbol, name string) ?Fn {
return none return none
} }
pub struct GetEmbedsOptions {
preceding []Type
}
// get_embeds returns all nested embedded structs
// the hierarchy of embeds is returned as a list
pub fn (t &Table) get_embeds(sym &TypeSymbol, options GetEmbedsOptions) [][]Type {
mut embeds := [][]Type{}
if sym.info is Struct {
for embed in sym.info.embeds {
embed_sym := t.get_type_symbol(embed)
mut preceding := options.preceding
preceding << embed
embeds << t.get_embeds(embed_sym, preceding: preceding)
}
if sym.info.embeds.len == 0 && options.preceding.len > 0 {
embeds << options.preceding
}
}
return embeds
}
pub fn (t &Table) type_find_method_from_embeds(sym &TypeSymbol, method_name string) ?(Fn, Type) { pub fn (t &Table) type_find_method_from_embeds(sym &TypeSymbol, method_name string) ?(Fn, Type) {
if sym.info is Struct { if sym.info is Struct {
mut found_methods := []Fn{} mut found_methods := []Fn{}
@ -407,6 +429,20 @@ pub fn (t &Table) struct_has_field(struct_ &TypeSymbol, name string) bool {
return false return false
} }
// struct_fields returns all fields including fields from embeds
// use this instead symbol.info.fields to get all fields
pub fn (t &Table) struct_fields(sym &TypeSymbol) []StructField {
mut fields := []StructField{}
if sym.info is Struct {
fields << sym.info.fields
for embed in sym.info.embeds {
embed_sym := t.get_type_symbol(embed)
fields << t.struct_fields(embed_sym)
}
}
return fields
}
// search from current type up through each parent looking for field // search from current type up through each parent looking for field
pub fn (t &Table) find_field(s &TypeSymbol, name string) ?StructField { pub fn (t &Table) find_field(s &TypeSymbol, name string) ?StructField {
// println('find_field($s.name, $name) types.len=$t.types.len s.parent_idx=$s.parent_idx') // println('find_field($s.name, $name) types.len=$t.types.len s.parent_idx=$s.parent_idx')
@ -447,6 +483,43 @@ pub fn (t &Table) find_field(s &TypeSymbol, name string) ?StructField {
return none return none
} }
// find_field_from_embeds is the same as find_field_from_embeds but also looks into nested embeds
pub fn (t &Table) find_field_from_embeds_recursive(sym &TypeSymbol, field_name string) ?(StructField, []Type) {
if sym.info is Struct {
mut found_fields := []StructField{}
mut embeds_of_found_fields := [][]Type{}
for embed in sym.info.embeds {
embed_sym := t.get_type_symbol(embed)
if field := t.find_field(embed_sym, field_name) {
found_fields << field
embeds_of_found_fields << [embed]
} else {
field, types := t.find_field_from_embeds_recursive(embed_sym, field_name) or {
StructField{}, []Type{}
}
found_fields << field
embeds_of_found_fields << types
}
}
if found_fields.len == 1 {
return found_fields[0], embeds_of_found_fields[0]
} else if found_fields.len > 1 {
return error('ambiguous field `$field_name`')
}
} else if sym.info is Aggregate {
for typ in sym.info.types {
agg_sym := t.get_type_symbol(typ)
field, embed_types := t.find_field_from_embeds_recursive(agg_sym, field_name) or {
return err
}
if embed_types.len > 0 {
return field, embed_types
}
}
}
return none
}
// find_field_from_embeds finds and returns a field in the embeddings of a struct and the embedding type // find_field_from_embeds finds and returns a field in the embeddings of a struct and the embedding type
pub fn (t &Table) find_field_from_embeds(sym &TypeSymbol, field_name string) ?(StructField, Type) { pub fn (t &Table) find_field_from_embeds(sym &TypeSymbol, field_name string) ?(StructField, Type) {
if sym.info is Struct { if sym.info is Struct {
@ -500,7 +573,7 @@ pub fn (t &Table) resolve_common_sumtype_fields(sym_ &TypeSymbol) {
mut v_sym := t.get_type_symbol(variant) mut v_sym := t.get_type_symbol(variant)
fields := match mut v_sym.info { fields := match mut v_sym.info {
Struct { Struct {
v_sym.info.fields t.struct_fields(v_sym)
} }
SumType { SumType {
t.resolve_common_sumtype_fields(v_sym) t.resolve_common_sumtype_fields(v_sym)

View File

@ -1875,14 +1875,41 @@ fn (mut g Gen) write_sumtype_casting_fn(got_ ast.Type, exp_ ast.Type) {
got_cname, exp_cname := got_sym.cname, exp_sym.cname got_cname, exp_cname := got_sym.cname, exp_sym.cname
sb.writeln('static inline $exp_cname ${got_cname}_to_sumtype_${exp_cname}($got_cname* x) {') sb.writeln('static inline $exp_cname ${got_cname}_to_sumtype_${exp_cname}($got_cname* x) {')
sb.writeln('\t$got_cname* ptr = memdup(x, sizeof($got_cname));') sb.writeln('\t$got_cname* ptr = memdup(x, sizeof($got_cname));')
for embed_hierarchy in g.table.get_embeds(got_sym) {
// last embed in the hierarchy
mut embed_cname := ''
mut embed_name := ''
mut accessor := '&x->'
for j, embed in embed_hierarchy {
embed_sym := g.table.get_type_symbol(embed)
embed_cname = embed_sym.cname
embed_name = embed_sym.embed_name()
if j > 0 {
accessor += '.'
}
accessor += embed_name
}
// if the variable is not used, the C compiler will optimize it away
sb.writeln('\t$embed_cname* ${embed_name}_ptr = memdup($accessor, sizeof($embed_cname));')
}
sb.write_string('\treturn ($exp_cname){ ._$got_cname = ptr, ._typ = ${g.type_sidx(got)}') sb.write_string('\treturn ($exp_cname){ ._$got_cname = ptr, ._typ = ${g.type_sidx(got)}')
for field in (exp_sym.info as ast.SumType).fields { for field in (exp_sym.info as ast.SumType).fields {
mut ptr := 'ptr'
mut type_cname := got_cname
_, embed_types := g.table.find_field_from_embeds_recursive(got_sym, field.name) or {
ast.StructField{}, []ast.Type{}
}
if embed_types.len > 0 {
embed_sym := g.table.get_type_symbol(embed_types.last())
ptr = '${embed_sym.embed_name()}_ptr'
type_cname = embed_sym.cname
}
field_styp := g.typ(field.typ) field_styp := g.typ(field.typ)
if got_sym.kind in [.sum_type, .interface_] { if got_sym.kind in [.sum_type, .interface_] {
// the field is already a wrapped pointer; we shouldn't wrap it once again // the field is already a wrapped pointer; we shouldn't wrap it once again
sb.write_string(', .$field.name = ptr->$field.name') sb.write_string(', .$field.name = ptr->$field.name')
} else { } else {
sb.write_string(', .$field.name = ($field_styp*)((char*)ptr + __offsetof_ptr(ptr, $got_cname, $field.name))') sb.write_string(', .$field.name = ($field_styp*)((char*)$ptr + __offsetof_ptr($ptr, $type_cname, $field.name))')
} }
} }
sb.writeln('};\n}') sb.writeln('};\n}')

View File

@ -107,7 +107,7 @@ fn test_converting_down() {
} }
fn test_assignment_and_push() { fn test_assignment_and_push() {
mut expr1 := Expr{} mut expr1 := Expr(IfExpr{})
mut arr1 := []Expr{} mut arr1 := []Expr{}
expr := IntegerLiteral{ expr := IntegerLiteral{
val: '111' val: '111'
@ -712,3 +712,48 @@ fn test_binary_search_tree() {
deleted.sort() deleted.sort()
assert deleted == [0.0, 0.3, 0.6, 1.0] assert deleted == [0.0, 0.3, 0.6, 1.0]
} }
struct Common {
a int
b int
}
struct Common2 {
Common
}
struct Aa {
Common
x int
}
struct Bb {
Common
x int
}
struct Cc {
a int
}
struct Dd {
Common2
}
type MySum = Aa | Bb | Cc | Dd
fn test_sumtype_access_embed_fields() {
a := MySum(Aa{
a: 1
})
assert a.a == 1
}
fn test_sumtype_access_nested_embed_fields() {
a := MySum(Dd{
Common2: Common2{
a: 2
}
})
assert a.a == 2
}