cgen: refactor array sort generation (#11067)
							parent
							
								
									8d2567740b
								
							
						
					
					
						commit
						7346aeca5f
					
				|  | @ -808,25 +808,17 @@ fn test_fixed_array_literal_eq() { | |||
| fn test_sort() { | ||||
| 	mut a := ['hi', '1', '5', '3'] | ||||
| 	a.sort() | ||||
| 	assert a[0] == '1' | ||||
| 	assert a[1] == '3' | ||||
| 	assert a[2] == '5' | ||||
| 	assert a[3] == 'hi' | ||||
| 	assert a == ['1', '3', '5', 'hi'] | ||||
| 
 | ||||
| 	mut nums := [67, -3, 108, 42, 7] | ||||
| 	nums.sort() | ||||
| 	assert nums[0] == -3 | ||||
| 	assert nums[1] == 7 | ||||
| 	assert nums[2] == 42 | ||||
| 	assert nums[3] == 67 | ||||
| 	assert nums[4] == 108 | ||||
| 	assert nums == [-3, 7, 42, 67, 108] | ||||
| 
 | ||||
| 	nums.sort(a < b) | ||||
| 	assert nums[0] == -3 | ||||
| 	assert nums[1] == 7 | ||||
| 	assert nums[2] == 42 | ||||
| 	assert nums[3] == 67 | ||||
| 	assert nums[4] == 108 | ||||
| 	assert nums == [-3, 7, 42, 67, 108] | ||||
| 
 | ||||
| 	nums.sort(b < a) | ||||
| 	assert nums == [108, 67, 42, 7, -3] | ||||
| 
 | ||||
| 	mut users := [User{22, 'Peter'}, User{20, 'Bob'}, User{25, 'Alice'}] | ||||
| 	users.sort(a.age < b.age) | ||||
|  | @ -842,23 +834,30 @@ fn test_sort() { | |||
| 	assert users[1].age == 22 | ||||
| 	assert users[2].age == 20 | ||||
| 
 | ||||
| 	users.sort(a.name < b.name) // Test sorting by string fields
 | ||||
| 	// assert users.map(it.name).join(' ') == 'Alice Bob Peter'
 | ||||
| 	users.sort(b.age > a.age) | ||||
| 	assert users[0].age == 20 | ||||
| 	assert users[1].age == 22 | ||||
| 	assert users[2].age == 25 | ||||
| 
 | ||||
| 	users.sort(a.name < b.name) | ||||
| 	assert users[0].name == 'Alice' | ||||
| 	assert users[1].name == 'Bob' | ||||
| 	assert users[2].name == 'Peter' | ||||
| } | ||||
| 
 | ||||
| fn test_rune_sort() { | ||||
| 	mut bs := [`f`, `e`, `d`, `b`, `c`, `a`] | ||||
| 	bs.sort() | ||||
| 	println(bs) | ||||
| 	assert '$bs' == '[`a`, `b`, `c`, `d`, `e`, `f`]' | ||||
| 	assert bs == [`a`, `b`, `c`, `d`, `e`, `f`] | ||||
| 
 | ||||
| 	bs.sort(a > b) | ||||
| 	println(bs) | ||||
| 	assert '$bs' == '[`f`, `e`, `d`, `c`, `b`, `a`]' | ||||
| 	assert bs == [`f`, `e`, `d`, `c`, `b`, `a`] | ||||
| 
 | ||||
| 	bs.sort(a < b) | ||||
| 	println(bs) | ||||
| 	assert '$bs' == '[`a`, `b`, `c`, `d`, `e`, `f`]' | ||||
| 	assert bs == [`a`, `b`, `c`, `d`, `e`, `f`] | ||||
| } | ||||
| 
 | ||||
| fn test_sort_by_different_order_of_a_b() { | ||||
|  | @ -876,9 +875,19 @@ fn test_sort_by_different_order_of_a_b() { | |||
| fn test_f32_sort() { | ||||
| 	mut f := [f32(50.0), 15, 1, 79, 38, 0, 27] | ||||
| 	f.sort() | ||||
| 	assert f[0] == 0.0 | ||||
| 	assert f[1] == 1.0 | ||||
| 	assert f[6] == 79.0 | ||||
| 	assert f == [f32(0.0), 1, 15, 27, 38, 50, 79] | ||||
| 
 | ||||
| 	f.sort(a < b) | ||||
| 	assert f == [f32(0.0), 1, 15, 27, 38, 50, 79] | ||||
| 
 | ||||
| 	f.sort(b > a) | ||||
| 	assert f == [f32(0.0), 1, 15, 27, 38, 50, 79] | ||||
| 
 | ||||
| 	f.sort(b < a) | ||||
| 	assert f == [f32(79.0), 50, 38, 27, 15, 1, 0] | ||||
| 
 | ||||
| 	f.sort(a > b) | ||||
| 	assert f == [f32(79.0), 50, 38, 27, 15, 1, 0] | ||||
| } | ||||
| 
 | ||||
| fn test_f64_sort() { | ||||
|  | @ -897,6 +906,17 @@ fn test_i64_sort() { | |||
| 	assert f[6] == 79 | ||||
| } | ||||
| 
 | ||||
| fn test_sort_index_expr() { | ||||
| 	mut f := [[i64(50), 48], [i64(15)], [i64(1)], [i64(79)], [i64(38)], | ||||
| 		[i64(0)], [i64(27)]] | ||||
| 	// TODO This currently gives "indexing pointer" error without unsafe
 | ||||
| 	unsafe { | ||||
| 		f.sort(a[0] < b[0]) | ||||
| 	} | ||||
| 	assert f == [[i64(0)], [i64(1)], [i64(15)], [i64(27)], [i64(38)], | ||||
| 		[i64(50), 48], [i64(79)]] | ||||
| } | ||||
| 
 | ||||
| fn test_a_b_paras_sort() { | ||||
| 	mut arr_i := [1, 3, 2] | ||||
| 	arr_i.sort(a < b) | ||||
|  |  | |||
|  | @ -222,118 +222,92 @@ fn (mut g Gen) gen_array_sort(node ast.CallExpr) { | |||
| 		verror('.sort() is an array method') | ||||
| 	} | ||||
| 	info := rec_sym.info as ast.Array | ||||
| 	// No arguments means we are sorting an array of builtins (e.g. `numbers.sort()`)
 | ||||
| 	// The type for the comparison fns is the type of the element itself.
 | ||||
| 	mut typ := info.elem_type | ||||
| 	mut is_default := false | ||||
| 	mut is_reverse := false | ||||
| 	mut compare_fn := '' | ||||
| 	// `users.sort(a.age > b.age)`
 | ||||
| 	// Generate a comparison function for a custom type
 | ||||
| 	elem_stype := g.typ(info.elem_type) | ||||
| 	mut compare_fn := 'compare_${elem_stype.replace('*', '_ptr')}' | ||||
| 	mut comparison_type := g.unwrap(ast.void_type) | ||||
| 	mut left_expr, mut right_expr := '', '' | ||||
| 	// the only argument can only be an infix expression like `a < b` or `b.field > a.field`
 | ||||
| 	if node.args.len == 0 { | ||||
| 		is_default = true | ||||
| 		comparison_type = g.unwrap(info.elem_type.set_nr_muls(0)) | ||||
| 		if compare_fn in g.array_sort_fn { | ||||
| 			g.gen_array_sort_call(node, compare_fn) | ||||
| 			return | ||||
| 		} | ||||
| 		left_expr = '*a' | ||||
| 		right_expr = '*b' | ||||
| 	} else { | ||||
| 		infix_expr := node.args[0].expr as ast.InfixExpr | ||||
| 		left_name := '$infix_expr.left' | ||||
| 		is_default = left_name in ['a', 'b'] && '$infix_expr.right' in ['a', 'b'] | ||||
| 		is_reverse = (left_name.starts_with('a') && infix_expr.op == .gt) | ||||
| 		comparison_type = g.unwrap(infix_expr.left_type.set_nr_muls(0)) | ||||
| 		left_name := infix_expr.left.str() | ||||
| 		if left_name.len > 1 { | ||||
| 			compare_fn += '_by' + left_name[1..].replace_each(['.', '_', '[', '_', ']', '_']) | ||||
| 		} | ||||
| 		// is_reverse is `true` for `.sort(a > b)` and `.sort(b < a)`
 | ||||
| 		is_reverse := (left_name.starts_with('a') && infix_expr.op == .gt) | ||||
| 			|| (left_name.starts_with('b') && infix_expr.op == .lt) | ||||
| 	} | ||||
| 	if is_default { | ||||
| 		// users.sort() or users.sort(a > b)
 | ||||
| 		compare_fn = match typ { | ||||
| 			ast.int_type, ast.int_type.to_ptr() { 'compare_ints' } | ||||
| 			ast.string_type, ast.string_type.to_ptr() { 'compare_strings' } | ||||
| 			else { '' } | ||||
| 		} | ||||
| 		if compare_fn != '' && is_reverse { | ||||
| 			compare_fn += '_reverse' | ||||
| 		} | ||||
| 	} | ||||
| 	if compare_fn == '' { | ||||
| 		// `users.sort(a.age > b.age)`
 | ||||
| 		// Generate a comparison function for a custom type
 | ||||
| 		tmp_name := g.new_global_tmp_var() | ||||
| 		styp := g.typ(typ).trim('*') | ||||
| 		compare_fn = 'compare_${tmp_name}_$styp' | ||||
| 		if is_reverse { | ||||
| 			compare_fn += '_reverse' | ||||
| 		} | ||||
| 		// Register a new custom `compare_xxx` function for qsort()
 | ||||
| 		// TODO: move to checker
 | ||||
| 		g.table.register_fn(name: compare_fn, return_type: ast.int_type) | ||||
| 
 | ||||
| 		if node.args.len == 0 { | ||||
| 			styp_arg := g.typ(typ) | ||||
| 			g.definitions.writeln('int $compare_fn ($styp_arg* a, $styp_arg* b) {') | ||||
| 			sym := g.table.get_type_symbol(typ) | ||||
| 			if !is_reverse && sym.has_method('<') { | ||||
| 				g.definitions.writeln('\tif (${styp}__lt(*a, *b)) { return -1; } else { return 1; }}') | ||||
| 			} else if is_reverse && sym.has_method('<') { | ||||
| 				g.definitions.writeln('\tif (${styp}__lt(*b, *a)) { return -1; } else { return 1; }}') | ||||
| 			} else { | ||||
| 				g.definitions.writeln('if (*a < *b) return -1;') | ||||
| 				g.definitions.writeln('if (*a > *b) return 1; else return 0; }\n') | ||||
| 		if compare_fn in g.array_sort_fn { | ||||
| 			g.gen_array_sort_call(node, compare_fn) | ||||
| 			return | ||||
| 		} | ||||
| 		if left_name.starts_with('a') != is_reverse { | ||||
| 			left_expr = g.expr_string(infix_expr.left) | ||||
| 			right_expr = g.expr_string(infix_expr.right) | ||||
| 			if infix_expr.left is ast.Ident { | ||||
| 				left_expr = '*' + left_expr | ||||
| 			} | ||||
| 			if infix_expr.right is ast.Ident { | ||||
| 				right_expr = '*' + right_expr | ||||
| 			} | ||||
| 		} else { | ||||
| 			infix_expr := node.args[0].expr as ast.InfixExpr | ||||
| 			// Variables `a` and `b` are used in the `.sort(a < b)` syntax, so we can reuse them
 | ||||
| 			// when generating the function as long as the args are named the same.
 | ||||
| 			styp_arg := g.typ(typ) | ||||
| 			g.definitions.writeln('int $compare_fn ($styp_arg* a, $styp_arg* b) {') | ||||
| 			sym := g.table.get_type_symbol(typ) | ||||
| 			if !is_reverse && sym.has_method('<') && infix_expr.left.str().len == 1 { | ||||
| 				g.definitions.writeln('\tif (${styp}__lt(*a, *b)) { return -1; } else { return 1; }}') | ||||
| 			} else if is_reverse && sym.has_method('<') && infix_expr.left.str().len == 1 { | ||||
| 				g.definitions.writeln('\tif (${styp}__lt(*b, *a)) { return -1; } else { return 1; }}') | ||||
| 			} else { | ||||
| 				field_type := g.typ(infix_expr.left_type) | ||||
| 				left_name := '$infix_expr.left' | ||||
| 				left_expr_str := g.expr_string(infix_expr.left) | ||||
| 				right_expr_str := g.expr_string(infix_expr.right) | ||||
| 				if left_name.starts_with('a') { | ||||
| 					g.definitions.writeln('\t$field_type a_ = $left_expr_str;') | ||||
| 					g.definitions.writeln('\t$field_type b_ = $right_expr_str;') | ||||
| 				} else { | ||||
| 					g.definitions.writeln('\t$field_type a_ = $right_expr_str;') | ||||
| 					g.definitions.writeln('\t$field_type b_ = $left_expr_str;') | ||||
| 				} | ||||
| 				mut op1, mut op2 := '', '' | ||||
| 				if infix_expr.left_type == ast.string_type { | ||||
| 					if is_reverse { | ||||
| 						op1 = 'string__lt(b_, a_)' | ||||
| 						op2 = 'string__lt(a_, b_)' | ||||
| 					} else { | ||||
| 						op1 = 'string__lt(a_, b_)' | ||||
| 						op2 = 'string__lt(b_, a_)' | ||||
| 					} | ||||
| 				} else { | ||||
| 					deref_str := if infix_expr.left_type.is_ptr() { '*' } else { '' } | ||||
| 					if is_reverse { | ||||
| 						op1 = '${deref_str}a_ > ${deref_str}b_' | ||||
| 						op2 = '${deref_str}a_ < ${deref_str}b_' | ||||
| 					} else { | ||||
| 						op1 = '${deref_str}a_ < ${deref_str}b_' | ||||
| 						op2 = '${deref_str}a_ > ${deref_str}b_' | ||||
| 					} | ||||
| 				} | ||||
| 				g.definitions.writeln('\tif ($op1) return -1;') | ||||
| 				g.definitions.writeln('\tif ($op2) return 1; \n\telse return 0; \n}\n') | ||||
| 			left_expr = g.expr_string(infix_expr.right) | ||||
| 			right_expr = g.expr_string(infix_expr.left) | ||||
| 			if infix_expr.left is ast.Ident { | ||||
| 				right_expr = '*' + right_expr | ||||
| 			} | ||||
| 			if infix_expr.right is ast.Ident { | ||||
| 				left_expr = '*' + left_expr | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	if is_reverse && !compare_fn.ends_with('_reverse') { | ||||
| 		compare_fn += '_reverse' | ||||
| 
 | ||||
| 	// Register a new custom `compare_xxx` function for qsort()
 | ||||
| 	// TODO: move to checker
 | ||||
| 	g.table.register_fn(name: compare_fn, return_type: ast.int_type) | ||||
| 	g.array_sort_fn[compare_fn] = true | ||||
| 
 | ||||
| 	stype_arg := g.typ(info.elem_type) | ||||
| 	g.definitions.writeln('int ${compare_fn}($stype_arg* a, $stype_arg* b) {') | ||||
| 	c_condition := if comparison_type.sym.has_method('<') { | ||||
| 		'${g.typ(comparison_type.typ)}__lt($left_expr, $right_expr)' | ||||
| 	} else if comparison_type.unaliased_sym.has_method('<') { | ||||
| 		'${g.typ(comparison_type.unaliased)}__lt($left_expr, $right_expr)' | ||||
| 	} else { | ||||
| 		'$left_expr < $right_expr' | ||||
| 	} | ||||
| 	//
 | ||||
| 	deref := if node.left_type.is_ptr() || node.left_type.is_pointer() { '->' } else { '.' } | ||||
| 	g.definitions.writeln('\tif ($c_condition) return -1;') | ||||
| 	g.definitions.writeln('\telse return 1;') | ||||
| 	g.definitions.writeln('}\n') | ||||
| 
 | ||||
| 	// write call to the generated function
 | ||||
| 	g.gen_array_sort_call(node, compare_fn) | ||||
| } | ||||
| 
 | ||||
| fn (mut g Gen) gen_array_sort_call(node ast.CallExpr, compare_fn string) { | ||||
| 	deref_field := if node.left_type.is_ptr() || node.left_type.is_pointer() { '->' } else { '.' } | ||||
| 	// eprintln('> qsort: pointer $node.left_type | deref: `$deref`')
 | ||||
| 	g.empty_line = true | ||||
| 	g.write('qsort(') | ||||
| 	g.expr(node.left) | ||||
| 	g.write('${deref}data, ') | ||||
| 	g.write('${deref_field}data, ') | ||||
| 	g.expr(node.left) | ||||
| 	g.write('${deref}len, ') | ||||
| 	g.write('${deref_field}len, ') | ||||
| 	g.expr(node.left) | ||||
| 	g.write('${deref}element_size, (int (*)(const void *, const void *))&$compare_fn)') | ||||
| 	g.write('${deref_field}element_size, (int (*)(const void *, const void *))&$compare_fn)') | ||||
| } | ||||
| 
 | ||||
| // `nums.filter(it % 2 == 0)`
 | ||||
|  |  | |||
|  | @ -179,6 +179,7 @@ mut: | |||
| 	expected_cast_type ast.Type // for match expr of sumtypes
 | ||||
| 	defer_vars         []string | ||||
| 	anon_fn            bool | ||||
| 	array_sort_fn      map[string]bool | ||||
| } | ||||
| 
 | ||||
| pub fn gen(files []&ast.File, table &ast.Table, pref &pref.Preferences) string { | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue