cgen: refactor array sort generation (#11067)
							parent
							
								
									8d2567740b
								
							
						
					
					
						commit
						7346aeca5f
					
				| 
						 | 
					@ -808,25 +808,17 @@ fn test_fixed_array_literal_eq() {
 | 
				
			||||||
fn test_sort() {
 | 
					fn test_sort() {
 | 
				
			||||||
	mut a := ['hi', '1', '5', '3']
 | 
						mut a := ['hi', '1', '5', '3']
 | 
				
			||||||
	a.sort()
 | 
						a.sort()
 | 
				
			||||||
	assert a[0] == '1'
 | 
						assert a == ['1', '3', '5', 'hi']
 | 
				
			||||||
	assert a[1] == '3'
 | 
					 | 
				
			||||||
	assert a[2] == '5'
 | 
					 | 
				
			||||||
	assert a[3] == 'hi'
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	mut nums := [67, -3, 108, 42, 7]
 | 
						mut nums := [67, -3, 108, 42, 7]
 | 
				
			||||||
	nums.sort()
 | 
						nums.sort()
 | 
				
			||||||
	assert nums[0] == -3
 | 
						assert nums == [-3, 7, 42, 67, 108]
 | 
				
			||||||
	assert nums[1] == 7
 | 
					 | 
				
			||||||
	assert nums[2] == 42
 | 
					 | 
				
			||||||
	assert nums[3] == 67
 | 
					 | 
				
			||||||
	assert nums[4] == 108
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	nums.sort(a < b)
 | 
						nums.sort(a < b)
 | 
				
			||||||
	assert nums[0] == -3
 | 
						assert nums == [-3, 7, 42, 67, 108]
 | 
				
			||||||
	assert nums[1] == 7
 | 
					
 | 
				
			||||||
	assert nums[2] == 42
 | 
						nums.sort(b < a)
 | 
				
			||||||
	assert nums[3] == 67
 | 
						assert nums == [108, 67, 42, 7, -3]
 | 
				
			||||||
	assert nums[4] == 108
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	mut users := [User{22, 'Peter'}, User{20, 'Bob'}, User{25, 'Alice'}]
 | 
						mut users := [User{22, 'Peter'}, User{20, 'Bob'}, User{25, 'Alice'}]
 | 
				
			||||||
	users.sort(a.age < b.age)
 | 
						users.sort(a.age < b.age)
 | 
				
			||||||
| 
						 | 
					@ -842,23 +834,30 @@ fn test_sort() {
 | 
				
			||||||
	assert users[1].age == 22
 | 
						assert users[1].age == 22
 | 
				
			||||||
	assert users[2].age == 20
 | 
						assert users[2].age == 20
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	users.sort(a.name < b.name) // Test sorting by string fields
 | 
						users.sort(b.age > a.age)
 | 
				
			||||||
	// assert users.map(it.name).join(' ') == 'Alice Bob Peter'
 | 
						assert users[0].age == 20
 | 
				
			||||||
 | 
						assert users[1].age == 22
 | 
				
			||||||
 | 
						assert users[2].age == 25
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						users.sort(a.name < b.name)
 | 
				
			||||||
 | 
						assert users[0].name == 'Alice'
 | 
				
			||||||
 | 
						assert users[1].name == 'Bob'
 | 
				
			||||||
 | 
						assert users[2].name == 'Peter'
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
fn test_rune_sort() {
 | 
					fn test_rune_sort() {
 | 
				
			||||||
	mut bs := [`f`, `e`, `d`, `b`, `c`, `a`]
 | 
						mut bs := [`f`, `e`, `d`, `b`, `c`, `a`]
 | 
				
			||||||
	bs.sort()
 | 
						bs.sort()
 | 
				
			||||||
	println(bs)
 | 
						println(bs)
 | 
				
			||||||
	assert '$bs' == '[`a`, `b`, `c`, `d`, `e`, `f`]'
 | 
						assert bs == [`a`, `b`, `c`, `d`, `e`, `f`]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	bs.sort(a > b)
 | 
						bs.sort(a > b)
 | 
				
			||||||
	println(bs)
 | 
						println(bs)
 | 
				
			||||||
	assert '$bs' == '[`f`, `e`, `d`, `c`, `b`, `a`]'
 | 
						assert bs == [`f`, `e`, `d`, `c`, `b`, `a`]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	bs.sort(a < b)
 | 
						bs.sort(a < b)
 | 
				
			||||||
	println(bs)
 | 
						println(bs)
 | 
				
			||||||
	assert '$bs' == '[`a`, `b`, `c`, `d`, `e`, `f`]'
 | 
						assert bs == [`a`, `b`, `c`, `d`, `e`, `f`]
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
fn test_sort_by_different_order_of_a_b() {
 | 
					fn test_sort_by_different_order_of_a_b() {
 | 
				
			||||||
| 
						 | 
					@ -876,9 +875,19 @@ fn test_sort_by_different_order_of_a_b() {
 | 
				
			||||||
fn test_f32_sort() {
 | 
					fn test_f32_sort() {
 | 
				
			||||||
	mut f := [f32(50.0), 15, 1, 79, 38, 0, 27]
 | 
						mut f := [f32(50.0), 15, 1, 79, 38, 0, 27]
 | 
				
			||||||
	f.sort()
 | 
						f.sort()
 | 
				
			||||||
	assert f[0] == 0.0
 | 
						assert f == [f32(0.0), 1, 15, 27, 38, 50, 79]
 | 
				
			||||||
	assert f[1] == 1.0
 | 
					
 | 
				
			||||||
	assert f[6] == 79.0
 | 
						f.sort(a < b)
 | 
				
			||||||
 | 
						assert f == [f32(0.0), 1, 15, 27, 38, 50, 79]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						f.sort(b > a)
 | 
				
			||||||
 | 
						assert f == [f32(0.0), 1, 15, 27, 38, 50, 79]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						f.sort(b < a)
 | 
				
			||||||
 | 
						assert f == [f32(79.0), 50, 38, 27, 15, 1, 0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						f.sort(a > b)
 | 
				
			||||||
 | 
						assert f == [f32(79.0), 50, 38, 27, 15, 1, 0]
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
fn test_f64_sort() {
 | 
					fn test_f64_sort() {
 | 
				
			||||||
| 
						 | 
					@ -897,6 +906,17 @@ fn test_i64_sort() {
 | 
				
			||||||
	assert f[6] == 79
 | 
						assert f[6] == 79
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					fn test_sort_index_expr() {
 | 
				
			||||||
 | 
						mut f := [[i64(50), 48], [i64(15)], [i64(1)], [i64(79)], [i64(38)],
 | 
				
			||||||
 | 
							[i64(0)], [i64(27)]]
 | 
				
			||||||
 | 
						// TODO This currently gives "indexing pointer" error without unsafe
 | 
				
			||||||
 | 
						unsafe {
 | 
				
			||||||
 | 
							f.sort(a[0] < b[0])
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						assert f == [[i64(0)], [i64(1)], [i64(15)], [i64(27)], [i64(38)],
 | 
				
			||||||
 | 
							[i64(50), 48], [i64(79)]]
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
fn test_a_b_paras_sort() {
 | 
					fn test_a_b_paras_sort() {
 | 
				
			||||||
	mut arr_i := [1, 3, 2]
 | 
						mut arr_i := [1, 3, 2]
 | 
				
			||||||
	arr_i.sort(a < b)
 | 
						arr_i.sort(a < b)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -222,118 +222,92 @@ fn (mut g Gen) gen_array_sort(node ast.CallExpr) {
 | 
				
			||||||
		verror('.sort() is an array method')
 | 
							verror('.sort() is an array method')
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	info := rec_sym.info as ast.Array
 | 
						info := rec_sym.info as ast.Array
 | 
				
			||||||
	// No arguments means we are sorting an array of builtins (e.g. `numbers.sort()`)
 | 
						// `users.sort(a.age > b.age)`
 | 
				
			||||||
	// The type for the comparison fns is the type of the element itself.
 | 
						// Generate a comparison function for a custom type
 | 
				
			||||||
	mut typ := info.elem_type
 | 
						elem_stype := g.typ(info.elem_type)
 | 
				
			||||||
	mut is_default := false
 | 
						mut compare_fn := 'compare_${elem_stype.replace('*', '_ptr')}'
 | 
				
			||||||
	mut is_reverse := false
 | 
						mut comparison_type := g.unwrap(ast.void_type)
 | 
				
			||||||
	mut compare_fn := ''
 | 
						mut left_expr, mut right_expr := '', ''
 | 
				
			||||||
 | 
						// the only argument can only be an infix expression like `a < b` or `b.field > a.field`
 | 
				
			||||||
	if node.args.len == 0 {
 | 
						if node.args.len == 0 {
 | 
				
			||||||
		is_default = true
 | 
							comparison_type = g.unwrap(info.elem_type.set_nr_muls(0))
 | 
				
			||||||
 | 
							if compare_fn in g.array_sort_fn {
 | 
				
			||||||
 | 
								g.gen_array_sort_call(node, compare_fn)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							left_expr = '*a'
 | 
				
			||||||
 | 
							right_expr = '*b'
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		infix_expr := node.args[0].expr as ast.InfixExpr
 | 
							infix_expr := node.args[0].expr as ast.InfixExpr
 | 
				
			||||||
		left_name := '$infix_expr.left'
 | 
							comparison_type = g.unwrap(infix_expr.left_type.set_nr_muls(0))
 | 
				
			||||||
		is_default = left_name in ['a', 'b'] && '$infix_expr.right' in ['a', 'b']
 | 
							left_name := infix_expr.left.str()
 | 
				
			||||||
		is_reverse = (left_name.starts_with('a') && infix_expr.op == .gt)
 | 
							if left_name.len > 1 {
 | 
				
			||||||
 | 
								compare_fn += '_by' + left_name[1..].replace_each(['.', '_', '[', '_', ']', '_'])
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							// is_reverse is `true` for `.sort(a > b)` and `.sort(b < a)`
 | 
				
			||||||
 | 
							is_reverse := (left_name.starts_with('a') && infix_expr.op == .gt)
 | 
				
			||||||
			|| (left_name.starts_with('b') && infix_expr.op == .lt)
 | 
								|| (left_name.starts_with('b') && infix_expr.op == .lt)
 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if is_default {
 | 
					 | 
				
			||||||
		// users.sort() or users.sort(a > b)
 | 
					 | 
				
			||||||
		compare_fn = match typ {
 | 
					 | 
				
			||||||
			ast.int_type, ast.int_type.to_ptr() { 'compare_ints' }
 | 
					 | 
				
			||||||
			ast.string_type, ast.string_type.to_ptr() { 'compare_strings' }
 | 
					 | 
				
			||||||
			else { '' }
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		if compare_fn != '' && is_reverse {
 | 
					 | 
				
			||||||
			compare_fn += '_reverse'
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if compare_fn == '' {
 | 
					 | 
				
			||||||
		// `users.sort(a.age > b.age)`
 | 
					 | 
				
			||||||
		// Generate a comparison function for a custom type
 | 
					 | 
				
			||||||
		tmp_name := g.new_global_tmp_var()
 | 
					 | 
				
			||||||
		styp := g.typ(typ).trim('*')
 | 
					 | 
				
			||||||
		compare_fn = 'compare_${tmp_name}_$styp'
 | 
					 | 
				
			||||||
		if is_reverse {
 | 
							if is_reverse {
 | 
				
			||||||
			compare_fn += '_reverse'
 | 
								compare_fn += '_reverse'
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		// Register a new custom `compare_xxx` function for qsort()
 | 
							if compare_fn in g.array_sort_fn {
 | 
				
			||||||
		// TODO: move to checker
 | 
								g.gen_array_sort_call(node, compare_fn)
 | 
				
			||||||
		g.table.register_fn(name: compare_fn, return_type: ast.int_type)
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
		if node.args.len == 0 {
 | 
							if left_name.starts_with('a') != is_reverse {
 | 
				
			||||||
			styp_arg := g.typ(typ)
 | 
								left_expr = g.expr_string(infix_expr.left)
 | 
				
			||||||
			g.definitions.writeln('int $compare_fn ($styp_arg* a, $styp_arg* b) {')
 | 
								right_expr = g.expr_string(infix_expr.right)
 | 
				
			||||||
			sym := g.table.get_type_symbol(typ)
 | 
								if infix_expr.left is ast.Ident {
 | 
				
			||||||
			if !is_reverse && sym.has_method('<') {
 | 
									left_expr = '*' + left_expr
 | 
				
			||||||
				g.definitions.writeln('\tif (${styp}__lt(*a, *b)) { return -1; } else { return 1; }}')
 | 
								}
 | 
				
			||||||
			} else if is_reverse && sym.has_method('<') {
 | 
								if infix_expr.right is ast.Ident {
 | 
				
			||||||
				g.definitions.writeln('\tif (${styp}__lt(*b, *a)) { return -1; } else { return 1; }}')
 | 
									right_expr = '*' + right_expr
 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				g.definitions.writeln('if (*a < *b) return -1;')
 | 
					 | 
				
			||||||
				g.definitions.writeln('if (*a > *b) return 1; else return 0; }\n')
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			infix_expr := node.args[0].expr as ast.InfixExpr
 | 
								left_expr = g.expr_string(infix_expr.right)
 | 
				
			||||||
			// Variables `a` and `b` are used in the `.sort(a < b)` syntax, so we can reuse them
 | 
								right_expr = g.expr_string(infix_expr.left)
 | 
				
			||||||
			// when generating the function as long as the args are named the same.
 | 
								if infix_expr.left is ast.Ident {
 | 
				
			||||||
			styp_arg := g.typ(typ)
 | 
									right_expr = '*' + right_expr
 | 
				
			||||||
			g.definitions.writeln('int $compare_fn ($styp_arg* a, $styp_arg* b) {')
 | 
								}
 | 
				
			||||||
			sym := g.table.get_type_symbol(typ)
 | 
								if infix_expr.right is ast.Ident {
 | 
				
			||||||
			if !is_reverse && sym.has_method('<') && infix_expr.left.str().len == 1 {
 | 
									left_expr = '*' + left_expr
 | 
				
			||||||
				g.definitions.writeln('\tif (${styp}__lt(*a, *b)) { return -1; } else { return 1; }}')
 | 
					 | 
				
			||||||
			} else if is_reverse && sym.has_method('<') && infix_expr.left.str().len == 1 {
 | 
					 | 
				
			||||||
				g.definitions.writeln('\tif (${styp}__lt(*b, *a)) { return -1; } else { return 1; }}')
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				field_type := g.typ(infix_expr.left_type)
 | 
					 | 
				
			||||||
				left_name := '$infix_expr.left'
 | 
					 | 
				
			||||||
				left_expr_str := g.expr_string(infix_expr.left)
 | 
					 | 
				
			||||||
				right_expr_str := g.expr_string(infix_expr.right)
 | 
					 | 
				
			||||||
				if left_name.starts_with('a') {
 | 
					 | 
				
			||||||
					g.definitions.writeln('\t$field_type a_ = $left_expr_str;')
 | 
					 | 
				
			||||||
					g.definitions.writeln('\t$field_type b_ = $right_expr_str;')
 | 
					 | 
				
			||||||
				} else {
 | 
					 | 
				
			||||||
					g.definitions.writeln('\t$field_type a_ = $right_expr_str;')
 | 
					 | 
				
			||||||
					g.definitions.writeln('\t$field_type b_ = $left_expr_str;')
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				mut op1, mut op2 := '', ''
 | 
					 | 
				
			||||||
				if infix_expr.left_type == ast.string_type {
 | 
					 | 
				
			||||||
					if is_reverse {
 | 
					 | 
				
			||||||
						op1 = 'string__lt(b_, a_)'
 | 
					 | 
				
			||||||
						op2 = 'string__lt(a_, b_)'
 | 
					 | 
				
			||||||
					} else {
 | 
					 | 
				
			||||||
						op1 = 'string__lt(a_, b_)'
 | 
					 | 
				
			||||||
						op2 = 'string__lt(b_, a_)'
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
				} else {
 | 
					 | 
				
			||||||
					deref_str := if infix_expr.left_type.is_ptr() { '*' } else { '' }
 | 
					 | 
				
			||||||
					if is_reverse {
 | 
					 | 
				
			||||||
						op1 = '${deref_str}a_ > ${deref_str}b_'
 | 
					 | 
				
			||||||
						op2 = '${deref_str}a_ < ${deref_str}b_'
 | 
					 | 
				
			||||||
					} else {
 | 
					 | 
				
			||||||
						op1 = '${deref_str}a_ < ${deref_str}b_'
 | 
					 | 
				
			||||||
						op2 = '${deref_str}a_ > ${deref_str}b_'
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				g.definitions.writeln('\tif ($op1) return -1;')
 | 
					 | 
				
			||||||
				g.definitions.writeln('\tif ($op2) return 1; \n\telse return 0; \n}\n')
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if is_reverse && !compare_fn.ends_with('_reverse') {
 | 
					
 | 
				
			||||||
		compare_fn += '_reverse'
 | 
						// Register a new custom `compare_xxx` function for qsort()
 | 
				
			||||||
 | 
						// TODO: move to checker
 | 
				
			||||||
 | 
						g.table.register_fn(name: compare_fn, return_type: ast.int_type)
 | 
				
			||||||
 | 
						g.array_sort_fn[compare_fn] = true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						stype_arg := g.typ(info.elem_type)
 | 
				
			||||||
 | 
						g.definitions.writeln('int ${compare_fn}($stype_arg* a, $stype_arg* b) {')
 | 
				
			||||||
 | 
						c_condition := if comparison_type.sym.has_method('<') {
 | 
				
			||||||
 | 
							'${g.typ(comparison_type.typ)}__lt($left_expr, $right_expr)'
 | 
				
			||||||
 | 
						} else if comparison_type.unaliased_sym.has_method('<') {
 | 
				
			||||||
 | 
							'${g.typ(comparison_type.unaliased)}__lt($left_expr, $right_expr)'
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							'$left_expr < $right_expr'
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	//
 | 
						g.definitions.writeln('\tif ($c_condition) return -1;')
 | 
				
			||||||
	deref := if node.left_type.is_ptr() || node.left_type.is_pointer() { '->' } else { '.' }
 | 
						g.definitions.writeln('\telse return 1;')
 | 
				
			||||||
 | 
						g.definitions.writeln('}\n')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// write call to the generated function
 | 
				
			||||||
 | 
						g.gen_array_sort_call(node, compare_fn)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					fn (mut g Gen) gen_array_sort_call(node ast.CallExpr, compare_fn string) {
 | 
				
			||||||
 | 
						deref_field := if node.left_type.is_ptr() || node.left_type.is_pointer() { '->' } else { '.' }
 | 
				
			||||||
	// eprintln('> qsort: pointer $node.left_type | deref: `$deref`')
 | 
						// eprintln('> qsort: pointer $node.left_type | deref: `$deref`')
 | 
				
			||||||
	g.empty_line = true
 | 
						g.empty_line = true
 | 
				
			||||||
	g.write('qsort(')
 | 
						g.write('qsort(')
 | 
				
			||||||
	g.expr(node.left)
 | 
						g.expr(node.left)
 | 
				
			||||||
	g.write('${deref}data, ')
 | 
						g.write('${deref_field}data, ')
 | 
				
			||||||
	g.expr(node.left)
 | 
						g.expr(node.left)
 | 
				
			||||||
	g.write('${deref}len, ')
 | 
						g.write('${deref_field}len, ')
 | 
				
			||||||
	g.expr(node.left)
 | 
						g.expr(node.left)
 | 
				
			||||||
	g.write('${deref}element_size, (int (*)(const void *, const void *))&$compare_fn)')
 | 
						g.write('${deref_field}element_size, (int (*)(const void *, const void *))&$compare_fn)')
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// `nums.filter(it % 2 == 0)`
 | 
					// `nums.filter(it % 2 == 0)`
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -179,6 +179,7 @@ mut:
 | 
				
			||||||
	expected_cast_type ast.Type // for match expr of sumtypes
 | 
						expected_cast_type ast.Type // for match expr of sumtypes
 | 
				
			||||||
	defer_vars         []string
 | 
						defer_vars         []string
 | 
				
			||||||
	anon_fn            bool
 | 
						anon_fn            bool
 | 
				
			||||||
 | 
						array_sort_fn      map[string]bool
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
pub fn gen(files []&ast.File, table &ast.Table, pref &pref.Preferences) string {
 | 
					pub fn gen(files []&ast.File, table &ast.Table, pref &pref.Preferences) string {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue