cgen: support embedded struct fields on sumtype common fields (#11084)
							parent
							
								
									25d49bc615
								
							
						
					
					
						commit
						3b116d2455
					
				|  | @ -344,6 +344,28 @@ pub fn (t &Table) type_find_method(s &TypeSymbol, name string) ?Fn { | |||
| 	return none | ||||
| } | ||||
| 
 | ||||
| pub struct GetEmbedsOptions { | ||||
| 	preceding []Type | ||||
| } | ||||
| 
 | ||||
| // get_embeds returns all nested embedded structs
 | ||||
| // the hierarchy of embeds is returned as a list
 | ||||
| pub fn (t &Table) get_embeds(sym &TypeSymbol, options GetEmbedsOptions) [][]Type { | ||||
| 	mut embeds := [][]Type{} | ||||
| 	if sym.info is Struct { | ||||
| 		for embed in sym.info.embeds { | ||||
| 			embed_sym := t.get_type_symbol(embed) | ||||
| 			mut preceding := options.preceding | ||||
| 			preceding << embed | ||||
| 			embeds << t.get_embeds(embed_sym, preceding: preceding) | ||||
| 		} | ||||
| 		if sym.info.embeds.len == 0 && options.preceding.len > 0 { | ||||
| 			embeds << options.preceding | ||||
| 		} | ||||
| 	} | ||||
| 	return embeds | ||||
| } | ||||
| 
 | ||||
| pub fn (t &Table) type_find_method_from_embeds(sym &TypeSymbol, method_name string) ?(Fn, Type) { | ||||
| 	if sym.info is Struct { | ||||
| 		mut found_methods := []Fn{} | ||||
|  | @ -407,6 +429,20 @@ pub fn (t &Table) struct_has_field(struct_ &TypeSymbol, name string) bool { | |||
| 	return false | ||||
| } | ||||
| 
 | ||||
| // struct_fields returns all fields including fields from embeds
 | ||||
| // use this instead symbol.info.fields to get all fields
 | ||||
| pub fn (t &Table) struct_fields(sym &TypeSymbol) []StructField { | ||||
| 	mut fields := []StructField{} | ||||
| 	if sym.info is Struct { | ||||
| 		fields << sym.info.fields | ||||
| 		for embed in sym.info.embeds { | ||||
| 			embed_sym := t.get_type_symbol(embed) | ||||
| 			fields << t.struct_fields(embed_sym) | ||||
| 		} | ||||
| 	} | ||||
| 	return fields | ||||
| } | ||||
| 
 | ||||
| // search from current type up through each parent looking for field
 | ||||
| pub fn (t &Table) find_field(s &TypeSymbol, name string) ?StructField { | ||||
| 	// println('find_field($s.name, $name) types.len=$t.types.len s.parent_idx=$s.parent_idx')
 | ||||
|  | @ -447,6 +483,43 @@ pub fn (t &Table) find_field(s &TypeSymbol, name string) ?StructField { | |||
| 	return none | ||||
| } | ||||
| 
 | ||||
| // find_field_from_embeds is the same as find_field_from_embeds but also looks into nested embeds
 | ||||
| pub fn (t &Table) find_field_from_embeds_recursive(sym &TypeSymbol, field_name string) ?(StructField, []Type) { | ||||
| 	if sym.info is Struct { | ||||
| 		mut found_fields := []StructField{} | ||||
| 		mut embeds_of_found_fields := [][]Type{} | ||||
| 		for embed in sym.info.embeds { | ||||
| 			embed_sym := t.get_type_symbol(embed) | ||||
| 			if field := t.find_field(embed_sym, field_name) { | ||||
| 				found_fields << field | ||||
| 				embeds_of_found_fields << [embed] | ||||
| 			} else { | ||||
| 				field, types := t.find_field_from_embeds_recursive(embed_sym, field_name) or { | ||||
| 					StructField{}, []Type{} | ||||
| 				} | ||||
| 				found_fields << field | ||||
| 				embeds_of_found_fields << types | ||||
| 			} | ||||
| 		} | ||||
| 		if found_fields.len == 1 { | ||||
| 			return found_fields[0], embeds_of_found_fields[0] | ||||
| 		} else if found_fields.len > 1 { | ||||
| 			return error('ambiguous field `$field_name`') | ||||
| 		} | ||||
| 	} else if sym.info is Aggregate { | ||||
| 		for typ in sym.info.types { | ||||
| 			agg_sym := t.get_type_symbol(typ) | ||||
| 			field, embed_types := t.find_field_from_embeds_recursive(agg_sym, field_name) or { | ||||
| 				return err | ||||
| 			} | ||||
| 			if embed_types.len > 0 { | ||||
| 				return field, embed_types | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return none | ||||
| } | ||||
| 
 | ||||
| // find_field_from_embeds finds and returns a field in the embeddings of a struct and the embedding type
 | ||||
| pub fn (t &Table) find_field_from_embeds(sym &TypeSymbol, field_name string) ?(StructField, Type) { | ||||
| 	if sym.info is Struct { | ||||
|  | @ -500,7 +573,7 @@ pub fn (t &Table) resolve_common_sumtype_fields(sym_ &TypeSymbol) { | |||
| 		mut v_sym := t.get_type_symbol(variant) | ||||
| 		fields := match mut v_sym.info { | ||||
| 			Struct { | ||||
| 				v_sym.info.fields | ||||
| 				t.struct_fields(v_sym) | ||||
| 			} | ||||
| 			SumType { | ||||
| 				t.resolve_common_sumtype_fields(v_sym) | ||||
|  |  | |||
|  | @ -1875,14 +1875,41 @@ fn (mut g Gen) write_sumtype_casting_fn(got_ ast.Type, exp_ ast.Type) { | |||
| 	got_cname, exp_cname := got_sym.cname, exp_sym.cname | ||||
| 	sb.writeln('static inline $exp_cname ${got_cname}_to_sumtype_${exp_cname}($got_cname* x) {') | ||||
| 	sb.writeln('\t$got_cname* ptr = memdup(x, sizeof($got_cname));') | ||||
| 	for embed_hierarchy in g.table.get_embeds(got_sym) { | ||||
| 		// last embed in the hierarchy
 | ||||
| 		mut embed_cname := '' | ||||
| 		mut embed_name := '' | ||||
| 		mut accessor := '&x->' | ||||
| 		for j, embed in embed_hierarchy { | ||||
| 			embed_sym := g.table.get_type_symbol(embed) | ||||
| 			embed_cname = embed_sym.cname | ||||
| 			embed_name = embed_sym.embed_name() | ||||
| 			if j > 0 { | ||||
| 				accessor += '.' | ||||
| 			} | ||||
| 			accessor += embed_name | ||||
| 		} | ||||
| 		// if the variable is not used, the C compiler will optimize it away
 | ||||
| 		sb.writeln('\t$embed_cname* ${embed_name}_ptr = memdup($accessor, sizeof($embed_cname));') | ||||
| 	} | ||||
| 	sb.write_string('\treturn ($exp_cname){ ._$got_cname = ptr, ._typ = ${g.type_sidx(got)}') | ||||
| 	for field in (exp_sym.info as ast.SumType).fields { | ||||
| 		mut ptr := 'ptr' | ||||
| 		mut type_cname := got_cname | ||||
| 		_, embed_types := g.table.find_field_from_embeds_recursive(got_sym, field.name) or { | ||||
| 			ast.StructField{}, []ast.Type{} | ||||
| 		} | ||||
| 		if embed_types.len > 0 { | ||||
| 			embed_sym := g.table.get_type_symbol(embed_types.last()) | ||||
| 			ptr = '${embed_sym.embed_name()}_ptr' | ||||
| 			type_cname = embed_sym.cname | ||||
| 		} | ||||
| 		field_styp := g.typ(field.typ) | ||||
| 		if got_sym.kind in [.sum_type, .interface_] { | ||||
| 			// the field is already a wrapped pointer; we shouldn't wrap it once again
 | ||||
| 			sb.write_string(', .$field.name = ptr->$field.name') | ||||
| 		} else { | ||||
| 			sb.write_string(', .$field.name = ($field_styp*)((char*)ptr + __offsetof_ptr(ptr, $got_cname, $field.name))') | ||||
| 			sb.write_string(', .$field.name = ($field_styp*)((char*)$ptr + __offsetof_ptr($ptr, $type_cname, $field.name))') | ||||
| 		} | ||||
| 	} | ||||
| 	sb.writeln('};\n}') | ||||
|  |  | |||
|  | @ -107,7 +107,7 @@ fn test_converting_down() { | |||
| } | ||||
| 
 | ||||
| fn test_assignment_and_push() { | ||||
| 	mut expr1 := Expr{} | ||||
| 	mut expr1 := Expr(IfExpr{}) | ||||
| 	mut arr1 := []Expr{} | ||||
| 	expr := IntegerLiteral{ | ||||
| 		val: '111' | ||||
|  | @ -712,3 +712,48 @@ fn test_binary_search_tree() { | |||
| 	deleted.sort() | ||||
| 	assert deleted == [0.0, 0.3, 0.6, 1.0] | ||||
| } | ||||
| 
 | ||||
| struct Common { | ||||
| 	a int | ||||
| 	b int | ||||
| } | ||||
| 
 | ||||
| struct Common2 { | ||||
| 	Common | ||||
| } | ||||
| 
 | ||||
| struct Aa { | ||||
| 	Common | ||||
| 	x int | ||||
| } | ||||
| 
 | ||||
| struct Bb { | ||||
| 	Common | ||||
| 	x int | ||||
| } | ||||
| 
 | ||||
| struct Cc { | ||||
| 	a int | ||||
| } | ||||
| 
 | ||||
| struct Dd { | ||||
| 	Common2 | ||||
| } | ||||
| 
 | ||||
| type MySum = Aa | Bb | Cc | Dd | ||||
| 
 | ||||
| fn test_sumtype_access_embed_fields() { | ||||
| 	a := MySum(Aa{ | ||||
| 		a: 1 | ||||
| 	}) | ||||
| 	assert a.a == 1 | ||||
| } | ||||
| 
 | ||||
| fn test_sumtype_access_nested_embed_fields() { | ||||
| 	a := MySum(Dd{ | ||||
| 		Common2: Common2{ | ||||
| 			a: 2 | ||||
| 		} | ||||
| 	}) | ||||
| 	assert a.a == 2 | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue