cgen: support embedded struct fields on sumtype common fields (#11084)
parent
25d49bc615
commit
3b116d2455
|
@ -344,6 +344,28 @@ pub fn (t &Table) type_find_method(s &TypeSymbol, name string) ?Fn {
|
|||
return none
|
||||
}
|
||||
|
||||
pub struct GetEmbedsOptions {
|
||||
preceding []Type
|
||||
}
|
||||
|
||||
// get_embeds returns all nested embedded structs
|
||||
// the hierarchy of embeds is returned as a list
|
||||
pub fn (t &Table) get_embeds(sym &TypeSymbol, options GetEmbedsOptions) [][]Type {
|
||||
mut embeds := [][]Type{}
|
||||
if sym.info is Struct {
|
||||
for embed in sym.info.embeds {
|
||||
embed_sym := t.get_type_symbol(embed)
|
||||
mut preceding := options.preceding
|
||||
preceding << embed
|
||||
embeds << t.get_embeds(embed_sym, preceding: preceding)
|
||||
}
|
||||
if sym.info.embeds.len == 0 && options.preceding.len > 0 {
|
||||
embeds << options.preceding
|
||||
}
|
||||
}
|
||||
return embeds
|
||||
}
|
||||
|
||||
pub fn (t &Table) type_find_method_from_embeds(sym &TypeSymbol, method_name string) ?(Fn, Type) {
|
||||
if sym.info is Struct {
|
||||
mut found_methods := []Fn{}
|
||||
|
@ -407,6 +429,20 @@ pub fn (t &Table) struct_has_field(struct_ &TypeSymbol, name string) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
// struct_fields returns all fields including fields from embeds
|
||||
// use this instead symbol.info.fields to get all fields
|
||||
pub fn (t &Table) struct_fields(sym &TypeSymbol) []StructField {
|
||||
mut fields := []StructField{}
|
||||
if sym.info is Struct {
|
||||
fields << sym.info.fields
|
||||
for embed in sym.info.embeds {
|
||||
embed_sym := t.get_type_symbol(embed)
|
||||
fields << t.struct_fields(embed_sym)
|
||||
}
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
// search from current type up through each parent looking for field
|
||||
pub fn (t &Table) find_field(s &TypeSymbol, name string) ?StructField {
|
||||
// println('find_field($s.name, $name) types.len=$t.types.len s.parent_idx=$s.parent_idx')
|
||||
|
@ -447,6 +483,43 @@ pub fn (t &Table) find_field(s &TypeSymbol, name string) ?StructField {
|
|||
return none
|
||||
}
|
||||
|
||||
// find_field_from_embeds is the same as find_field_from_embeds but also looks into nested embeds
|
||||
pub fn (t &Table) find_field_from_embeds_recursive(sym &TypeSymbol, field_name string) ?(StructField, []Type) {
|
||||
if sym.info is Struct {
|
||||
mut found_fields := []StructField{}
|
||||
mut embeds_of_found_fields := [][]Type{}
|
||||
for embed in sym.info.embeds {
|
||||
embed_sym := t.get_type_symbol(embed)
|
||||
if field := t.find_field(embed_sym, field_name) {
|
||||
found_fields << field
|
||||
embeds_of_found_fields << [embed]
|
||||
} else {
|
||||
field, types := t.find_field_from_embeds_recursive(embed_sym, field_name) or {
|
||||
StructField{}, []Type{}
|
||||
}
|
||||
found_fields << field
|
||||
embeds_of_found_fields << types
|
||||
}
|
||||
}
|
||||
if found_fields.len == 1 {
|
||||
return found_fields[0], embeds_of_found_fields[0]
|
||||
} else if found_fields.len > 1 {
|
||||
return error('ambiguous field `$field_name`')
|
||||
}
|
||||
} else if sym.info is Aggregate {
|
||||
for typ in sym.info.types {
|
||||
agg_sym := t.get_type_symbol(typ)
|
||||
field, embed_types := t.find_field_from_embeds_recursive(agg_sym, field_name) or {
|
||||
return err
|
||||
}
|
||||
if embed_types.len > 0 {
|
||||
return field, embed_types
|
||||
}
|
||||
}
|
||||
}
|
||||
return none
|
||||
}
|
||||
|
||||
// find_field_from_embeds finds and returns a field in the embeddings of a struct and the embedding type
|
||||
pub fn (t &Table) find_field_from_embeds(sym &TypeSymbol, field_name string) ?(StructField, Type) {
|
||||
if sym.info is Struct {
|
||||
|
@ -500,7 +573,7 @@ pub fn (t &Table) resolve_common_sumtype_fields(sym_ &TypeSymbol) {
|
|||
mut v_sym := t.get_type_symbol(variant)
|
||||
fields := match mut v_sym.info {
|
||||
Struct {
|
||||
v_sym.info.fields
|
||||
t.struct_fields(v_sym)
|
||||
}
|
||||
SumType {
|
||||
t.resolve_common_sumtype_fields(v_sym)
|
||||
|
|
|
@ -1875,14 +1875,41 @@ fn (mut g Gen) write_sumtype_casting_fn(got_ ast.Type, exp_ ast.Type) {
|
|||
got_cname, exp_cname := got_sym.cname, exp_sym.cname
|
||||
sb.writeln('static inline $exp_cname ${got_cname}_to_sumtype_${exp_cname}($got_cname* x) {')
|
||||
sb.writeln('\t$got_cname* ptr = memdup(x, sizeof($got_cname));')
|
||||
for embed_hierarchy in g.table.get_embeds(got_sym) {
|
||||
// last embed in the hierarchy
|
||||
mut embed_cname := ''
|
||||
mut embed_name := ''
|
||||
mut accessor := '&x->'
|
||||
for j, embed in embed_hierarchy {
|
||||
embed_sym := g.table.get_type_symbol(embed)
|
||||
embed_cname = embed_sym.cname
|
||||
embed_name = embed_sym.embed_name()
|
||||
if j > 0 {
|
||||
accessor += '.'
|
||||
}
|
||||
accessor += embed_name
|
||||
}
|
||||
// if the variable is not used, the C compiler will optimize it away
|
||||
sb.writeln('\t$embed_cname* ${embed_name}_ptr = memdup($accessor, sizeof($embed_cname));')
|
||||
}
|
||||
sb.write_string('\treturn ($exp_cname){ ._$got_cname = ptr, ._typ = ${g.type_sidx(got)}')
|
||||
for field in (exp_sym.info as ast.SumType).fields {
|
||||
mut ptr := 'ptr'
|
||||
mut type_cname := got_cname
|
||||
_, embed_types := g.table.find_field_from_embeds_recursive(got_sym, field.name) or {
|
||||
ast.StructField{}, []ast.Type{}
|
||||
}
|
||||
if embed_types.len > 0 {
|
||||
embed_sym := g.table.get_type_symbol(embed_types.last())
|
||||
ptr = '${embed_sym.embed_name()}_ptr'
|
||||
type_cname = embed_sym.cname
|
||||
}
|
||||
field_styp := g.typ(field.typ)
|
||||
if got_sym.kind in [.sum_type, .interface_] {
|
||||
// the field is already a wrapped pointer; we shouldn't wrap it once again
|
||||
sb.write_string(', .$field.name = ptr->$field.name')
|
||||
} else {
|
||||
sb.write_string(', .$field.name = ($field_styp*)((char*)ptr + __offsetof_ptr(ptr, $got_cname, $field.name))')
|
||||
sb.write_string(', .$field.name = ($field_styp*)((char*)$ptr + __offsetof_ptr($ptr, $type_cname, $field.name))')
|
||||
}
|
||||
}
|
||||
sb.writeln('};\n}')
|
||||
|
|
|
@ -107,7 +107,7 @@ fn test_converting_down() {
|
|||
}
|
||||
|
||||
fn test_assignment_and_push() {
|
||||
mut expr1 := Expr{}
|
||||
mut expr1 := Expr(IfExpr{})
|
||||
mut arr1 := []Expr{}
|
||||
expr := IntegerLiteral{
|
||||
val: '111'
|
||||
|
@ -712,3 +712,48 @@ fn test_binary_search_tree() {
|
|||
deleted.sort()
|
||||
assert deleted == [0.0, 0.3, 0.6, 1.0]
|
||||
}
|
||||
|
||||
struct Common {
|
||||
a int
|
||||
b int
|
||||
}
|
||||
|
||||
struct Common2 {
|
||||
Common
|
||||
}
|
||||
|
||||
struct Aa {
|
||||
Common
|
||||
x int
|
||||
}
|
||||
|
||||
struct Bb {
|
||||
Common
|
||||
x int
|
||||
}
|
||||
|
||||
struct Cc {
|
||||
a int
|
||||
}
|
||||
|
||||
struct Dd {
|
||||
Common2
|
||||
}
|
||||
|
||||
type MySum = Aa | Bb | Cc | Dd
|
||||
|
||||
fn test_sumtype_access_embed_fields() {
|
||||
a := MySum(Aa{
|
||||
a: 1
|
||||
})
|
||||
assert a.a == 1
|
||||
}
|
||||
|
||||
fn test_sumtype_access_nested_embed_fields() {
|
||||
a := MySum(Dd{
|
||||
Common2: Common2{
|
||||
a: 2
|
||||
}
|
||||
})
|
||||
assert a.a == 2
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue