cgen: refactor array sort generation (#11067)
parent
8d2567740b
commit
7346aeca5f
|
@ -808,25 +808,17 @@ fn test_fixed_array_literal_eq() {
|
||||||
fn test_sort() {
|
fn test_sort() {
|
||||||
mut a := ['hi', '1', '5', '3']
|
mut a := ['hi', '1', '5', '3']
|
||||||
a.sort()
|
a.sort()
|
||||||
assert a[0] == '1'
|
assert a == ['1', '3', '5', 'hi']
|
||||||
assert a[1] == '3'
|
|
||||||
assert a[2] == '5'
|
|
||||||
assert a[3] == 'hi'
|
|
||||||
|
|
||||||
mut nums := [67, -3, 108, 42, 7]
|
mut nums := [67, -3, 108, 42, 7]
|
||||||
nums.sort()
|
nums.sort()
|
||||||
assert nums[0] == -3
|
assert nums == [-3, 7, 42, 67, 108]
|
||||||
assert nums[1] == 7
|
|
||||||
assert nums[2] == 42
|
|
||||||
assert nums[3] == 67
|
|
||||||
assert nums[4] == 108
|
|
||||||
|
|
||||||
nums.sort(a < b)
|
nums.sort(a < b)
|
||||||
assert nums[0] == -3
|
assert nums == [-3, 7, 42, 67, 108]
|
||||||
assert nums[1] == 7
|
|
||||||
assert nums[2] == 42
|
nums.sort(b < a)
|
||||||
assert nums[3] == 67
|
assert nums == [108, 67, 42, 7, -3]
|
||||||
assert nums[4] == 108
|
|
||||||
|
|
||||||
mut users := [User{22, 'Peter'}, User{20, 'Bob'}, User{25, 'Alice'}]
|
mut users := [User{22, 'Peter'}, User{20, 'Bob'}, User{25, 'Alice'}]
|
||||||
users.sort(a.age < b.age)
|
users.sort(a.age < b.age)
|
||||||
|
@ -842,23 +834,30 @@ fn test_sort() {
|
||||||
assert users[1].age == 22
|
assert users[1].age == 22
|
||||||
assert users[2].age == 20
|
assert users[2].age == 20
|
||||||
|
|
||||||
users.sort(a.name < b.name) // Test sorting by string fields
|
users.sort(b.age > a.age)
|
||||||
// assert users.map(it.name).join(' ') == 'Alice Bob Peter'
|
assert users[0].age == 20
|
||||||
|
assert users[1].age == 22
|
||||||
|
assert users[2].age == 25
|
||||||
|
|
||||||
|
users.sort(a.name < b.name)
|
||||||
|
assert users[0].name == 'Alice'
|
||||||
|
assert users[1].name == 'Bob'
|
||||||
|
assert users[2].name == 'Peter'
|
||||||
}
|
}
|
||||||
|
|
||||||
fn test_rune_sort() {
|
fn test_rune_sort() {
|
||||||
mut bs := [`f`, `e`, `d`, `b`, `c`, `a`]
|
mut bs := [`f`, `e`, `d`, `b`, `c`, `a`]
|
||||||
bs.sort()
|
bs.sort()
|
||||||
println(bs)
|
println(bs)
|
||||||
assert '$bs' == '[`a`, `b`, `c`, `d`, `e`, `f`]'
|
assert bs == [`a`, `b`, `c`, `d`, `e`, `f`]
|
||||||
|
|
||||||
bs.sort(a > b)
|
bs.sort(a > b)
|
||||||
println(bs)
|
println(bs)
|
||||||
assert '$bs' == '[`f`, `e`, `d`, `c`, `b`, `a`]'
|
assert bs == [`f`, `e`, `d`, `c`, `b`, `a`]
|
||||||
|
|
||||||
bs.sort(a < b)
|
bs.sort(a < b)
|
||||||
println(bs)
|
println(bs)
|
||||||
assert '$bs' == '[`a`, `b`, `c`, `d`, `e`, `f`]'
|
assert bs == [`a`, `b`, `c`, `d`, `e`, `f`]
|
||||||
}
|
}
|
||||||
|
|
||||||
fn test_sort_by_different_order_of_a_b() {
|
fn test_sort_by_different_order_of_a_b() {
|
||||||
|
@ -876,9 +875,19 @@ fn test_sort_by_different_order_of_a_b() {
|
||||||
fn test_f32_sort() {
|
fn test_f32_sort() {
|
||||||
mut f := [f32(50.0), 15, 1, 79, 38, 0, 27]
|
mut f := [f32(50.0), 15, 1, 79, 38, 0, 27]
|
||||||
f.sort()
|
f.sort()
|
||||||
assert f[0] == 0.0
|
assert f == [f32(0.0), 1, 15, 27, 38, 50, 79]
|
||||||
assert f[1] == 1.0
|
|
||||||
assert f[6] == 79.0
|
f.sort(a < b)
|
||||||
|
assert f == [f32(0.0), 1, 15, 27, 38, 50, 79]
|
||||||
|
|
||||||
|
f.sort(b > a)
|
||||||
|
assert f == [f32(0.0), 1, 15, 27, 38, 50, 79]
|
||||||
|
|
||||||
|
f.sort(b < a)
|
||||||
|
assert f == [f32(79.0), 50, 38, 27, 15, 1, 0]
|
||||||
|
|
||||||
|
f.sort(a > b)
|
||||||
|
assert f == [f32(79.0), 50, 38, 27, 15, 1, 0]
|
||||||
}
|
}
|
||||||
|
|
||||||
fn test_f64_sort() {
|
fn test_f64_sort() {
|
||||||
|
@ -897,6 +906,17 @@ fn test_i64_sort() {
|
||||||
assert f[6] == 79
|
assert f[6] == 79
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn test_sort_index_expr() {
|
||||||
|
mut f := [[i64(50), 48], [i64(15)], [i64(1)], [i64(79)], [i64(38)],
|
||||||
|
[i64(0)], [i64(27)]]
|
||||||
|
// TODO This currently gives "indexing pointer" error without unsafe
|
||||||
|
unsafe {
|
||||||
|
f.sort(a[0] < b[0])
|
||||||
|
}
|
||||||
|
assert f == [[i64(0)], [i64(1)], [i64(15)], [i64(27)], [i64(38)],
|
||||||
|
[i64(50), 48], [i64(79)]]
|
||||||
|
}
|
||||||
|
|
||||||
fn test_a_b_paras_sort() {
|
fn test_a_b_paras_sort() {
|
||||||
mut arr_i := [1, 3, 2]
|
mut arr_i := [1, 3, 2]
|
||||||
arr_i.sort(a < b)
|
arr_i.sort(a < b)
|
||||||
|
|
|
@ -222,118 +222,92 @@ fn (mut g Gen) gen_array_sort(node ast.CallExpr) {
|
||||||
verror('.sort() is an array method')
|
verror('.sort() is an array method')
|
||||||
}
|
}
|
||||||
info := rec_sym.info as ast.Array
|
info := rec_sym.info as ast.Array
|
||||||
// No arguments means we are sorting an array of builtins (e.g. `numbers.sort()`)
|
// `users.sort(a.age > b.age)`
|
||||||
// The type for the comparison fns is the type of the element itself.
|
// Generate a comparison function for a custom type
|
||||||
mut typ := info.elem_type
|
elem_stype := g.typ(info.elem_type)
|
||||||
mut is_default := false
|
mut compare_fn := 'compare_${elem_stype.replace('*', '_ptr')}'
|
||||||
mut is_reverse := false
|
mut comparison_type := g.unwrap(ast.void_type)
|
||||||
mut compare_fn := ''
|
mut left_expr, mut right_expr := '', ''
|
||||||
|
// the only argument can only be an infix expression like `a < b` or `b.field > a.field`
|
||||||
if node.args.len == 0 {
|
if node.args.len == 0 {
|
||||||
is_default = true
|
comparison_type = g.unwrap(info.elem_type.set_nr_muls(0))
|
||||||
|
if compare_fn in g.array_sort_fn {
|
||||||
|
g.gen_array_sort_call(node, compare_fn)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
left_expr = '*a'
|
||||||
|
right_expr = '*b'
|
||||||
} else {
|
} else {
|
||||||
infix_expr := node.args[0].expr as ast.InfixExpr
|
infix_expr := node.args[0].expr as ast.InfixExpr
|
||||||
left_name := '$infix_expr.left'
|
comparison_type = g.unwrap(infix_expr.left_type.set_nr_muls(0))
|
||||||
is_default = left_name in ['a', 'b'] && '$infix_expr.right' in ['a', 'b']
|
left_name := infix_expr.left.str()
|
||||||
is_reverse = (left_name.starts_with('a') && infix_expr.op == .gt)
|
if left_name.len > 1 {
|
||||||
|
compare_fn += '_by' + left_name[1..].replace_each(['.', '_', '[', '_', ']', '_'])
|
||||||
|
}
|
||||||
|
// is_reverse is `true` for `.sort(a > b)` and `.sort(b < a)`
|
||||||
|
is_reverse := (left_name.starts_with('a') && infix_expr.op == .gt)
|
||||||
|| (left_name.starts_with('b') && infix_expr.op == .lt)
|
|| (left_name.starts_with('b') && infix_expr.op == .lt)
|
||||||
}
|
|
||||||
if is_default {
|
|
||||||
// users.sort() or users.sort(a > b)
|
|
||||||
compare_fn = match typ {
|
|
||||||
ast.int_type, ast.int_type.to_ptr() { 'compare_ints' }
|
|
||||||
ast.string_type, ast.string_type.to_ptr() { 'compare_strings' }
|
|
||||||
else { '' }
|
|
||||||
}
|
|
||||||
if compare_fn != '' && is_reverse {
|
|
||||||
compare_fn += '_reverse'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if compare_fn == '' {
|
|
||||||
// `users.sort(a.age > b.age)`
|
|
||||||
// Generate a comparison function for a custom type
|
|
||||||
tmp_name := g.new_global_tmp_var()
|
|
||||||
styp := g.typ(typ).trim('*')
|
|
||||||
compare_fn = 'compare_${tmp_name}_$styp'
|
|
||||||
if is_reverse {
|
if is_reverse {
|
||||||
compare_fn += '_reverse'
|
compare_fn += '_reverse'
|
||||||
}
|
}
|
||||||
// Register a new custom `compare_xxx` function for qsort()
|
if compare_fn in g.array_sort_fn {
|
||||||
// TODO: move to checker
|
g.gen_array_sort_call(node, compare_fn)
|
||||||
g.table.register_fn(name: compare_fn, return_type: ast.int_type)
|
return
|
||||||
|
}
|
||||||
if node.args.len == 0 {
|
if left_name.starts_with('a') != is_reverse {
|
||||||
styp_arg := g.typ(typ)
|
left_expr = g.expr_string(infix_expr.left)
|
||||||
g.definitions.writeln('int $compare_fn ($styp_arg* a, $styp_arg* b) {')
|
right_expr = g.expr_string(infix_expr.right)
|
||||||
sym := g.table.get_type_symbol(typ)
|
if infix_expr.left is ast.Ident {
|
||||||
if !is_reverse && sym.has_method('<') {
|
left_expr = '*' + left_expr
|
||||||
g.definitions.writeln('\tif (${styp}__lt(*a, *b)) { return -1; } else { return 1; }}')
|
}
|
||||||
} else if is_reverse && sym.has_method('<') {
|
if infix_expr.right is ast.Ident {
|
||||||
g.definitions.writeln('\tif (${styp}__lt(*b, *a)) { return -1; } else { return 1; }}')
|
right_expr = '*' + right_expr
|
||||||
} else {
|
|
||||||
g.definitions.writeln('if (*a < *b) return -1;')
|
|
||||||
g.definitions.writeln('if (*a > *b) return 1; else return 0; }\n')
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
infix_expr := node.args[0].expr as ast.InfixExpr
|
left_expr = g.expr_string(infix_expr.right)
|
||||||
// Variables `a` and `b` are used in the `.sort(a < b)` syntax, so we can reuse them
|
right_expr = g.expr_string(infix_expr.left)
|
||||||
// when generating the function as long as the args are named the same.
|
if infix_expr.left is ast.Ident {
|
||||||
styp_arg := g.typ(typ)
|
right_expr = '*' + right_expr
|
||||||
g.definitions.writeln('int $compare_fn ($styp_arg* a, $styp_arg* b) {')
|
}
|
||||||
sym := g.table.get_type_symbol(typ)
|
if infix_expr.right is ast.Ident {
|
||||||
if !is_reverse && sym.has_method('<') && infix_expr.left.str().len == 1 {
|
left_expr = '*' + left_expr
|
||||||
g.definitions.writeln('\tif (${styp}__lt(*a, *b)) { return -1; } else { return 1; }}')
|
|
||||||
} else if is_reverse && sym.has_method('<') && infix_expr.left.str().len == 1 {
|
|
||||||
g.definitions.writeln('\tif (${styp}__lt(*b, *a)) { return -1; } else { return 1; }}')
|
|
||||||
} else {
|
|
||||||
field_type := g.typ(infix_expr.left_type)
|
|
||||||
left_name := '$infix_expr.left'
|
|
||||||
left_expr_str := g.expr_string(infix_expr.left)
|
|
||||||
right_expr_str := g.expr_string(infix_expr.right)
|
|
||||||
if left_name.starts_with('a') {
|
|
||||||
g.definitions.writeln('\t$field_type a_ = $left_expr_str;')
|
|
||||||
g.definitions.writeln('\t$field_type b_ = $right_expr_str;')
|
|
||||||
} else {
|
|
||||||
g.definitions.writeln('\t$field_type a_ = $right_expr_str;')
|
|
||||||
g.definitions.writeln('\t$field_type b_ = $left_expr_str;')
|
|
||||||
}
|
|
||||||
mut op1, mut op2 := '', ''
|
|
||||||
if infix_expr.left_type == ast.string_type {
|
|
||||||
if is_reverse {
|
|
||||||
op1 = 'string__lt(b_, a_)'
|
|
||||||
op2 = 'string__lt(a_, b_)'
|
|
||||||
} else {
|
|
||||||
op1 = 'string__lt(a_, b_)'
|
|
||||||
op2 = 'string__lt(b_, a_)'
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
deref_str := if infix_expr.left_type.is_ptr() { '*' } else { '' }
|
|
||||||
if is_reverse {
|
|
||||||
op1 = '${deref_str}a_ > ${deref_str}b_'
|
|
||||||
op2 = '${deref_str}a_ < ${deref_str}b_'
|
|
||||||
} else {
|
|
||||||
op1 = '${deref_str}a_ < ${deref_str}b_'
|
|
||||||
op2 = '${deref_str}a_ > ${deref_str}b_'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
g.definitions.writeln('\tif ($op1) return -1;')
|
|
||||||
g.definitions.writeln('\tif ($op2) return 1; \n\telse return 0; \n}\n')
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if is_reverse && !compare_fn.ends_with('_reverse') {
|
|
||||||
compare_fn += '_reverse'
|
// Register a new custom `compare_xxx` function for qsort()
|
||||||
|
// TODO: move to checker
|
||||||
|
g.table.register_fn(name: compare_fn, return_type: ast.int_type)
|
||||||
|
g.array_sort_fn[compare_fn] = true
|
||||||
|
|
||||||
|
stype_arg := g.typ(info.elem_type)
|
||||||
|
g.definitions.writeln('int ${compare_fn}($stype_arg* a, $stype_arg* b) {')
|
||||||
|
c_condition := if comparison_type.sym.has_method('<') {
|
||||||
|
'${g.typ(comparison_type.typ)}__lt($left_expr, $right_expr)'
|
||||||
|
} else if comparison_type.unaliased_sym.has_method('<') {
|
||||||
|
'${g.typ(comparison_type.unaliased)}__lt($left_expr, $right_expr)'
|
||||||
|
} else {
|
||||||
|
'$left_expr < $right_expr'
|
||||||
}
|
}
|
||||||
//
|
g.definitions.writeln('\tif ($c_condition) return -1;')
|
||||||
deref := if node.left_type.is_ptr() || node.left_type.is_pointer() { '->' } else { '.' }
|
g.definitions.writeln('\telse return 1;')
|
||||||
|
g.definitions.writeln('}\n')
|
||||||
|
|
||||||
|
// write call to the generated function
|
||||||
|
g.gen_array_sort_call(node, compare_fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn (mut g Gen) gen_array_sort_call(node ast.CallExpr, compare_fn string) {
|
||||||
|
deref_field := if node.left_type.is_ptr() || node.left_type.is_pointer() { '->' } else { '.' }
|
||||||
// eprintln('> qsort: pointer $node.left_type | deref: `$deref`')
|
// eprintln('> qsort: pointer $node.left_type | deref: `$deref`')
|
||||||
g.empty_line = true
|
g.empty_line = true
|
||||||
g.write('qsort(')
|
g.write('qsort(')
|
||||||
g.expr(node.left)
|
g.expr(node.left)
|
||||||
g.write('${deref}data, ')
|
g.write('${deref_field}data, ')
|
||||||
g.expr(node.left)
|
g.expr(node.left)
|
||||||
g.write('${deref}len, ')
|
g.write('${deref_field}len, ')
|
||||||
g.expr(node.left)
|
g.expr(node.left)
|
||||||
g.write('${deref}element_size, (int (*)(const void *, const void *))&$compare_fn)')
|
g.write('${deref_field}element_size, (int (*)(const void *, const void *))&$compare_fn)')
|
||||||
}
|
}
|
||||||
|
|
||||||
// `nums.filter(it % 2 == 0)`
|
// `nums.filter(it % 2 == 0)`
|
||||||
|
|
|
@ -179,6 +179,7 @@ mut:
|
||||||
expected_cast_type ast.Type // for match expr of sumtypes
|
expected_cast_type ast.Type // for match expr of sumtypes
|
||||||
defer_vars []string
|
defer_vars []string
|
||||||
anon_fn bool
|
anon_fn bool
|
||||||
|
array_sort_fn map[string]bool
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn gen(files []&ast.File, table &ast.Table, pref &pref.Preferences) string {
|
pub fn gen(files []&ast.File, table &ast.Table, pref &pref.Preferences) string {
|
||||||
|
|
Loading…
Reference in New Issue