cgen: refactor array sort generation (#11067)

pull/11072/head
Enzo 2021-08-06 02:55:48 +02:00 committed by GitHub
parent 8d2567740b
commit 7346aeca5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 110 additions and 115 deletions

View File

@ -808,25 +808,17 @@ fn test_fixed_array_literal_eq() {
fn test_sort() {
mut a := ['hi', '1', '5', '3']
a.sort()
assert a[0] == '1'
assert a[1] == '3'
assert a[2] == '5'
assert a[3] == 'hi'
assert a == ['1', '3', '5', 'hi']
mut nums := [67, -3, 108, 42, 7]
nums.sort()
assert nums[0] == -3
assert nums[1] == 7
assert nums[2] == 42
assert nums[3] == 67
assert nums[4] == 108
assert nums == [-3, 7, 42, 67, 108]
nums.sort(a < b)
assert nums[0] == -3
assert nums[1] == 7
assert nums[2] == 42
assert nums[3] == 67
assert nums[4] == 108
assert nums == [-3, 7, 42, 67, 108]
nums.sort(b < a)
assert nums == [108, 67, 42, 7, -3]
mut users := [User{22, 'Peter'}, User{20, 'Bob'}, User{25, 'Alice'}]
users.sort(a.age < b.age)
@ -842,23 +834,30 @@ fn test_sort() {
assert users[1].age == 22
assert users[2].age == 20
users.sort(a.name < b.name) // Test sorting by string fields
// assert users.map(it.name).join(' ') == 'Alice Bob Peter'
users.sort(b.age > a.age)
assert users[0].age == 20
assert users[1].age == 22
assert users[2].age == 25
users.sort(a.name < b.name)
assert users[0].name == 'Alice'
assert users[1].name == 'Bob'
assert users[2].name == 'Peter'
}
fn test_rune_sort() {
mut bs := [`f`, `e`, `d`, `b`, `c`, `a`]
bs.sort()
println(bs)
assert '$bs' == '[`a`, `b`, `c`, `d`, `e`, `f`]'
assert bs == [`a`, `b`, `c`, `d`, `e`, `f`]
bs.sort(a > b)
println(bs)
assert '$bs' == '[`f`, `e`, `d`, `c`, `b`, `a`]'
assert bs == [`f`, `e`, `d`, `c`, `b`, `a`]
bs.sort(a < b)
println(bs)
assert '$bs' == '[`a`, `b`, `c`, `d`, `e`, `f`]'
assert bs == [`a`, `b`, `c`, `d`, `e`, `f`]
}
fn test_sort_by_different_order_of_a_b() {
@ -876,9 +875,19 @@ fn test_sort_by_different_order_of_a_b() {
fn test_f32_sort() {
mut f := [f32(50.0), 15, 1, 79, 38, 0, 27]
f.sort()
assert f[0] == 0.0
assert f[1] == 1.0
assert f[6] == 79.0
assert f == [f32(0.0), 1, 15, 27, 38, 50, 79]
f.sort(a < b)
assert f == [f32(0.0), 1, 15, 27, 38, 50, 79]
f.sort(b > a)
assert f == [f32(0.0), 1, 15, 27, 38, 50, 79]
f.sort(b < a)
assert f == [f32(79.0), 50, 38, 27, 15, 1, 0]
f.sort(a > b)
assert f == [f32(79.0), 50, 38, 27, 15, 1, 0]
}
fn test_f64_sort() {
@ -897,6 +906,17 @@ fn test_i64_sort() {
assert f[6] == 79
}
fn test_sort_index_expr() {
mut f := [[i64(50), 48], [i64(15)], [i64(1)], [i64(79)], [i64(38)],
[i64(0)], [i64(27)]]
// TODO This currently gives "indexing pointer" error without unsafe
unsafe {
f.sort(a[0] < b[0])
}
assert f == [[i64(0)], [i64(1)], [i64(15)], [i64(27)], [i64(38)],
[i64(50), 48], [i64(79)]]
}
fn test_a_b_paras_sort() {
mut arr_i := [1, 3, 2]
arr_i.sort(a < b)

View File

@ -222,118 +222,92 @@ fn (mut g Gen) gen_array_sort(node ast.CallExpr) {
verror('.sort() is an array method')
}
info := rec_sym.info as ast.Array
// No arguments means we are sorting an array of builtins (e.g. `numbers.sort()`)
// The type for the comparison fns is the type of the element itself.
mut typ := info.elem_type
mut is_default := false
mut is_reverse := false
mut compare_fn := ''
if node.args.len == 0 {
is_default = true
} else {
infix_expr := node.args[0].expr as ast.InfixExpr
left_name := '$infix_expr.left'
is_default = left_name in ['a', 'b'] && '$infix_expr.right' in ['a', 'b']
is_reverse = (left_name.starts_with('a') && infix_expr.op == .gt)
|| (left_name.starts_with('b') && infix_expr.op == .lt)
}
if is_default {
// users.sort() or users.sort(a > b)
compare_fn = match typ {
ast.int_type, ast.int_type.to_ptr() { 'compare_ints' }
ast.string_type, ast.string_type.to_ptr() { 'compare_strings' }
else { '' }
}
if compare_fn != '' && is_reverse {
compare_fn += '_reverse'
}
}
if compare_fn == '' {
// `users.sort(a.age > b.age)`
// Generate a comparison function for a custom type
tmp_name := g.new_global_tmp_var()
styp := g.typ(typ).trim('*')
compare_fn = 'compare_${tmp_name}_$styp'
elem_stype := g.typ(info.elem_type)
mut compare_fn := 'compare_${elem_stype.replace('*', '_ptr')}'
mut comparison_type := g.unwrap(ast.void_type)
mut left_expr, mut right_expr := '', ''
// the only argument can only be an infix expression like `a < b` or `b.field > a.field`
if node.args.len == 0 {
comparison_type = g.unwrap(info.elem_type.set_nr_muls(0))
if compare_fn in g.array_sort_fn {
g.gen_array_sort_call(node, compare_fn)
return
}
left_expr = '*a'
right_expr = '*b'
} else {
infix_expr := node.args[0].expr as ast.InfixExpr
comparison_type = g.unwrap(infix_expr.left_type.set_nr_muls(0))
left_name := infix_expr.left.str()
if left_name.len > 1 {
compare_fn += '_by' + left_name[1..].replace_each(['.', '_', '[', '_', ']', '_'])
}
// is_reverse is `true` for `.sort(a > b)` and `.sort(b < a)`
is_reverse := (left_name.starts_with('a') && infix_expr.op == .gt)
|| (left_name.starts_with('b') && infix_expr.op == .lt)
if is_reverse {
compare_fn += '_reverse'
}
if compare_fn in g.array_sort_fn {
g.gen_array_sort_call(node, compare_fn)
return
}
if left_name.starts_with('a') != is_reverse {
left_expr = g.expr_string(infix_expr.left)
right_expr = g.expr_string(infix_expr.right)
if infix_expr.left is ast.Ident {
left_expr = '*' + left_expr
}
if infix_expr.right is ast.Ident {
right_expr = '*' + right_expr
}
} else {
left_expr = g.expr_string(infix_expr.right)
right_expr = g.expr_string(infix_expr.left)
if infix_expr.left is ast.Ident {
right_expr = '*' + right_expr
}
if infix_expr.right is ast.Ident {
left_expr = '*' + left_expr
}
}
}
// Register a new custom `compare_xxx` function for qsort()
// TODO: move to checker
g.table.register_fn(name: compare_fn, return_type: ast.int_type)
g.array_sort_fn[compare_fn] = true
if node.args.len == 0 {
styp_arg := g.typ(typ)
g.definitions.writeln('int $compare_fn ($styp_arg* a, $styp_arg* b) {')
sym := g.table.get_type_symbol(typ)
if !is_reverse && sym.has_method('<') {
g.definitions.writeln('\tif (${styp}__lt(*a, *b)) { return -1; } else { return 1; }}')
} else if is_reverse && sym.has_method('<') {
g.definitions.writeln('\tif (${styp}__lt(*b, *a)) { return -1; } else { return 1; }}')
stype_arg := g.typ(info.elem_type)
g.definitions.writeln('int ${compare_fn}($stype_arg* a, $stype_arg* b) {')
c_condition := if comparison_type.sym.has_method('<') {
'${g.typ(comparison_type.typ)}__lt($left_expr, $right_expr)'
} else if comparison_type.unaliased_sym.has_method('<') {
'${g.typ(comparison_type.unaliased)}__lt($left_expr, $right_expr)'
} else {
g.definitions.writeln('if (*a < *b) return -1;')
g.definitions.writeln('if (*a > *b) return 1; else return 0; }\n')
'$left_expr < $right_expr'
}
} else {
infix_expr := node.args[0].expr as ast.InfixExpr
// Variables `a` and `b` are used in the `.sort(a < b)` syntax, so we can reuse them
// when generating the function as long as the args are named the same.
styp_arg := g.typ(typ)
g.definitions.writeln('int $compare_fn ($styp_arg* a, $styp_arg* b) {')
sym := g.table.get_type_symbol(typ)
if !is_reverse && sym.has_method('<') && infix_expr.left.str().len == 1 {
g.definitions.writeln('\tif (${styp}__lt(*a, *b)) { return -1; } else { return 1; }}')
} else if is_reverse && sym.has_method('<') && infix_expr.left.str().len == 1 {
g.definitions.writeln('\tif (${styp}__lt(*b, *a)) { return -1; } else { return 1; }}')
} else {
field_type := g.typ(infix_expr.left_type)
left_name := '$infix_expr.left'
left_expr_str := g.expr_string(infix_expr.left)
right_expr_str := g.expr_string(infix_expr.right)
if left_name.starts_with('a') {
g.definitions.writeln('\t$field_type a_ = $left_expr_str;')
g.definitions.writeln('\t$field_type b_ = $right_expr_str;')
} else {
g.definitions.writeln('\t$field_type a_ = $right_expr_str;')
g.definitions.writeln('\t$field_type b_ = $left_expr_str;')
g.definitions.writeln('\tif ($c_condition) return -1;')
g.definitions.writeln('\telse return 1;')
g.definitions.writeln('}\n')
// write call to the generated function
g.gen_array_sort_call(node, compare_fn)
}
mut op1, mut op2 := '', ''
if infix_expr.left_type == ast.string_type {
if is_reverse {
op1 = 'string__lt(b_, a_)'
op2 = 'string__lt(a_, b_)'
} else {
op1 = 'string__lt(a_, b_)'
op2 = 'string__lt(b_, a_)'
}
} else {
deref_str := if infix_expr.left_type.is_ptr() { '*' } else { '' }
if is_reverse {
op1 = '${deref_str}a_ > ${deref_str}b_'
op2 = '${deref_str}a_ < ${deref_str}b_'
} else {
op1 = '${deref_str}a_ < ${deref_str}b_'
op2 = '${deref_str}a_ > ${deref_str}b_'
}
}
g.definitions.writeln('\tif ($op1) return -1;')
g.definitions.writeln('\tif ($op2) return 1; \n\telse return 0; \n}\n')
}
}
}
if is_reverse && !compare_fn.ends_with('_reverse') {
compare_fn += '_reverse'
}
//
deref := if node.left_type.is_ptr() || node.left_type.is_pointer() { '->' } else { '.' }
fn (mut g Gen) gen_array_sort_call(node ast.CallExpr, compare_fn string) {
deref_field := if node.left_type.is_ptr() || node.left_type.is_pointer() { '->' } else { '.' }
// eprintln('> qsort: pointer $node.left_type | deref: `$deref`')
g.empty_line = true
g.write('qsort(')
g.expr(node.left)
g.write('${deref}data, ')
g.write('${deref_field}data, ')
g.expr(node.left)
g.write('${deref}len, ')
g.write('${deref_field}len, ')
g.expr(node.left)
g.write('${deref}element_size, (int (*)(const void *, const void *))&$compare_fn)')
g.write('${deref_field}element_size, (int (*)(const void *, const void *))&$compare_fn)')
}
// `nums.filter(it % 2 == 0)`

View File

@ -179,6 +179,7 @@ mut:
expected_cast_type ast.Type // for match expr of sumtypes
defer_vars []string
anon_fn bool
array_sort_fn map[string]bool
}
pub fn gen(files []&ast.File, table &ast.Table, pref &pref.Preferences) string {