cgen: support embedded struct fields on sumtype common fields (#11084)
							parent
							
								
									25d49bc615
								
							
						
					
					
						commit
						3b116d2455
					
				| 
						 | 
					@ -344,6 +344,28 @@ pub fn (t &Table) type_find_method(s &TypeSymbol, name string) ?Fn {
 | 
				
			||||||
	return none
 | 
						return none
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					pub struct GetEmbedsOptions {
 | 
				
			||||||
 | 
						preceding []Type
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// get_embeds returns all nested embedded structs
 | 
				
			||||||
 | 
					// the hierarchy of embeds is returned as a list
 | 
				
			||||||
 | 
					pub fn (t &Table) get_embeds(sym &TypeSymbol, options GetEmbedsOptions) [][]Type {
 | 
				
			||||||
 | 
						mut embeds := [][]Type{}
 | 
				
			||||||
 | 
						if sym.info is Struct {
 | 
				
			||||||
 | 
							for embed in sym.info.embeds {
 | 
				
			||||||
 | 
								embed_sym := t.get_type_symbol(embed)
 | 
				
			||||||
 | 
								mut preceding := options.preceding
 | 
				
			||||||
 | 
								preceding << embed
 | 
				
			||||||
 | 
								embeds << t.get_embeds(embed_sym, preceding: preceding)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if sym.info.embeds.len == 0 && options.preceding.len > 0 {
 | 
				
			||||||
 | 
								embeds << options.preceding
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return embeds
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
pub fn (t &Table) type_find_method_from_embeds(sym &TypeSymbol, method_name string) ?(Fn, Type) {
 | 
					pub fn (t &Table) type_find_method_from_embeds(sym &TypeSymbol, method_name string) ?(Fn, Type) {
 | 
				
			||||||
	if sym.info is Struct {
 | 
						if sym.info is Struct {
 | 
				
			||||||
		mut found_methods := []Fn{}
 | 
							mut found_methods := []Fn{}
 | 
				
			||||||
| 
						 | 
					@ -407,6 +429,20 @@ pub fn (t &Table) struct_has_field(struct_ &TypeSymbol, name string) bool {
 | 
				
			||||||
	return false
 | 
						return false
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// struct_fields returns all fields including fields from embeds
 | 
				
			||||||
 | 
					// use this instead symbol.info.fields to get all fields
 | 
				
			||||||
 | 
					pub fn (t &Table) struct_fields(sym &TypeSymbol) []StructField {
 | 
				
			||||||
 | 
						mut fields := []StructField{}
 | 
				
			||||||
 | 
						if sym.info is Struct {
 | 
				
			||||||
 | 
							fields << sym.info.fields
 | 
				
			||||||
 | 
							for embed in sym.info.embeds {
 | 
				
			||||||
 | 
								embed_sym := t.get_type_symbol(embed)
 | 
				
			||||||
 | 
								fields << t.struct_fields(embed_sym)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return fields
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// search from current type up through each parent looking for field
 | 
					// search from current type up through each parent looking for field
 | 
				
			||||||
pub fn (t &Table) find_field(s &TypeSymbol, name string) ?StructField {
 | 
					pub fn (t &Table) find_field(s &TypeSymbol, name string) ?StructField {
 | 
				
			||||||
	// println('find_field($s.name, $name) types.len=$t.types.len s.parent_idx=$s.parent_idx')
 | 
						// println('find_field($s.name, $name) types.len=$t.types.len s.parent_idx=$s.parent_idx')
 | 
				
			||||||
| 
						 | 
					@ -447,6 +483,43 @@ pub fn (t &Table) find_field(s &TypeSymbol, name string) ?StructField {
 | 
				
			||||||
	return none
 | 
						return none
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// find_field_from_embeds is the same as find_field_from_embeds but also looks into nested embeds
 | 
				
			||||||
 | 
					pub fn (t &Table) find_field_from_embeds_recursive(sym &TypeSymbol, field_name string) ?(StructField, []Type) {
 | 
				
			||||||
 | 
						if sym.info is Struct {
 | 
				
			||||||
 | 
							mut found_fields := []StructField{}
 | 
				
			||||||
 | 
							mut embeds_of_found_fields := [][]Type{}
 | 
				
			||||||
 | 
							for embed in sym.info.embeds {
 | 
				
			||||||
 | 
								embed_sym := t.get_type_symbol(embed)
 | 
				
			||||||
 | 
								if field := t.find_field(embed_sym, field_name) {
 | 
				
			||||||
 | 
									found_fields << field
 | 
				
			||||||
 | 
									embeds_of_found_fields << [embed]
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									field, types := t.find_field_from_embeds_recursive(embed_sym, field_name) or {
 | 
				
			||||||
 | 
										StructField{}, []Type{}
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									found_fields << field
 | 
				
			||||||
 | 
									embeds_of_found_fields << types
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if found_fields.len == 1 {
 | 
				
			||||||
 | 
								return found_fields[0], embeds_of_found_fields[0]
 | 
				
			||||||
 | 
							} else if found_fields.len > 1 {
 | 
				
			||||||
 | 
								return error('ambiguous field `$field_name`')
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						} else if sym.info is Aggregate {
 | 
				
			||||||
 | 
							for typ in sym.info.types {
 | 
				
			||||||
 | 
								agg_sym := t.get_type_symbol(typ)
 | 
				
			||||||
 | 
								field, embed_types := t.find_field_from_embeds_recursive(agg_sym, field_name) or {
 | 
				
			||||||
 | 
									return err
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								if embed_types.len > 0 {
 | 
				
			||||||
 | 
									return field, embed_types
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return none
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// find_field_from_embeds finds and returns a field in the embeddings of a struct and the embedding type
 | 
					// find_field_from_embeds finds and returns a field in the embeddings of a struct and the embedding type
 | 
				
			||||||
pub fn (t &Table) find_field_from_embeds(sym &TypeSymbol, field_name string) ?(StructField, Type) {
 | 
					pub fn (t &Table) find_field_from_embeds(sym &TypeSymbol, field_name string) ?(StructField, Type) {
 | 
				
			||||||
	if sym.info is Struct {
 | 
						if sym.info is Struct {
 | 
				
			||||||
| 
						 | 
					@ -500,7 +573,7 @@ pub fn (t &Table) resolve_common_sumtype_fields(sym_ &TypeSymbol) {
 | 
				
			||||||
		mut v_sym := t.get_type_symbol(variant)
 | 
							mut v_sym := t.get_type_symbol(variant)
 | 
				
			||||||
		fields := match mut v_sym.info {
 | 
							fields := match mut v_sym.info {
 | 
				
			||||||
			Struct {
 | 
								Struct {
 | 
				
			||||||
				v_sym.info.fields
 | 
									t.struct_fields(v_sym)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			SumType {
 | 
								SumType {
 | 
				
			||||||
				t.resolve_common_sumtype_fields(v_sym)
 | 
									t.resolve_common_sumtype_fields(v_sym)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1875,14 +1875,41 @@ fn (mut g Gen) write_sumtype_casting_fn(got_ ast.Type, exp_ ast.Type) {
 | 
				
			||||||
	got_cname, exp_cname := got_sym.cname, exp_sym.cname
 | 
						got_cname, exp_cname := got_sym.cname, exp_sym.cname
 | 
				
			||||||
	sb.writeln('static inline $exp_cname ${got_cname}_to_sumtype_${exp_cname}($got_cname* x) {')
 | 
						sb.writeln('static inline $exp_cname ${got_cname}_to_sumtype_${exp_cname}($got_cname* x) {')
 | 
				
			||||||
	sb.writeln('\t$got_cname* ptr = memdup(x, sizeof($got_cname));')
 | 
						sb.writeln('\t$got_cname* ptr = memdup(x, sizeof($got_cname));')
 | 
				
			||||||
 | 
						for embed_hierarchy in g.table.get_embeds(got_sym) {
 | 
				
			||||||
 | 
							// last embed in the hierarchy
 | 
				
			||||||
 | 
							mut embed_cname := ''
 | 
				
			||||||
 | 
							mut embed_name := ''
 | 
				
			||||||
 | 
							mut accessor := '&x->'
 | 
				
			||||||
 | 
							for j, embed in embed_hierarchy {
 | 
				
			||||||
 | 
								embed_sym := g.table.get_type_symbol(embed)
 | 
				
			||||||
 | 
								embed_cname = embed_sym.cname
 | 
				
			||||||
 | 
								embed_name = embed_sym.embed_name()
 | 
				
			||||||
 | 
								if j > 0 {
 | 
				
			||||||
 | 
									accessor += '.'
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								accessor += embed_name
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							// if the variable is not used, the C compiler will optimize it away
 | 
				
			||||||
 | 
							sb.writeln('\t$embed_cname* ${embed_name}_ptr = memdup($accessor, sizeof($embed_cname));')
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	sb.write_string('\treturn ($exp_cname){ ._$got_cname = ptr, ._typ = ${g.type_sidx(got)}')
 | 
						sb.write_string('\treturn ($exp_cname){ ._$got_cname = ptr, ._typ = ${g.type_sidx(got)}')
 | 
				
			||||||
	for field in (exp_sym.info as ast.SumType).fields {
 | 
						for field in (exp_sym.info as ast.SumType).fields {
 | 
				
			||||||
 | 
							mut ptr := 'ptr'
 | 
				
			||||||
 | 
							mut type_cname := got_cname
 | 
				
			||||||
 | 
							_, embed_types := g.table.find_field_from_embeds_recursive(got_sym, field.name) or {
 | 
				
			||||||
 | 
								ast.StructField{}, []ast.Type{}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if embed_types.len > 0 {
 | 
				
			||||||
 | 
								embed_sym := g.table.get_type_symbol(embed_types.last())
 | 
				
			||||||
 | 
								ptr = '${embed_sym.embed_name()}_ptr'
 | 
				
			||||||
 | 
								type_cname = embed_sym.cname
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
		field_styp := g.typ(field.typ)
 | 
							field_styp := g.typ(field.typ)
 | 
				
			||||||
		if got_sym.kind in [.sum_type, .interface_] {
 | 
							if got_sym.kind in [.sum_type, .interface_] {
 | 
				
			||||||
			// the field is already a wrapped pointer; we shouldn't wrap it once again
 | 
								// the field is already a wrapped pointer; we shouldn't wrap it once again
 | 
				
			||||||
			sb.write_string(', .$field.name = ptr->$field.name')
 | 
								sb.write_string(', .$field.name = ptr->$field.name')
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			sb.write_string(', .$field.name = ($field_styp*)((char*)ptr + __offsetof_ptr(ptr, $got_cname, $field.name))')
 | 
								sb.write_string(', .$field.name = ($field_styp*)((char*)$ptr + __offsetof_ptr($ptr, $type_cname, $field.name))')
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	sb.writeln('};\n}')
 | 
						sb.writeln('};\n}')
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -107,7 +107,7 @@ fn test_converting_down() {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
fn test_assignment_and_push() {
 | 
					fn test_assignment_and_push() {
 | 
				
			||||||
	mut expr1 := Expr{}
 | 
						mut expr1 := Expr(IfExpr{})
 | 
				
			||||||
	mut arr1 := []Expr{}
 | 
						mut arr1 := []Expr{}
 | 
				
			||||||
	expr := IntegerLiteral{
 | 
						expr := IntegerLiteral{
 | 
				
			||||||
		val: '111'
 | 
							val: '111'
 | 
				
			||||||
| 
						 | 
					@ -712,3 +712,48 @@ fn test_binary_search_tree() {
 | 
				
			||||||
	deleted.sort()
 | 
						deleted.sort()
 | 
				
			||||||
	assert deleted == [0.0, 0.3, 0.6, 1.0]
 | 
						assert deleted == [0.0, 0.3, 0.6, 1.0]
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct Common {
 | 
				
			||||||
 | 
						a int
 | 
				
			||||||
 | 
						b int
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct Common2 {
 | 
				
			||||||
 | 
						Common
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct Aa {
 | 
				
			||||||
 | 
						Common
 | 
				
			||||||
 | 
						x int
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct Bb {
 | 
				
			||||||
 | 
						Common
 | 
				
			||||||
 | 
						x int
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct Cc {
 | 
				
			||||||
 | 
						a int
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct Dd {
 | 
				
			||||||
 | 
						Common2
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type MySum = Aa | Bb | Cc | Dd
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					fn test_sumtype_access_embed_fields() {
 | 
				
			||||||
 | 
						a := MySum(Aa{
 | 
				
			||||||
 | 
							a: 1
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						assert a.a == 1
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					fn test_sumtype_access_nested_embed_fields() {
 | 
				
			||||||
 | 
						a := MySum(Dd{
 | 
				
			||||||
 | 
							Common2: Common2{
 | 
				
			||||||
 | 
								a: 2
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						assert a.a == 2
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue