checker, cgen: fix generic operator overload of 'cmp' (#11489)
parent
8862c3af0f
commit
9554470985
|
@ -1528,6 +1528,9 @@ pub fn (mut c Checker) infix_expr(mut node ast.InfixExpr) ast.Type {
|
||||||
.gt, .lt, .ge, .le {
|
.gt, .lt, .ge, .le {
|
||||||
if left_sym.kind in [.array, .array_fixed] && right_sym.kind in [.array, .array_fixed] {
|
if left_sym.kind in [.array, .array_fixed] && right_sym.kind in [.array, .array_fixed] {
|
||||||
c.error('only `==` and `!=` are defined on arrays', node.pos)
|
c.error('only `==` and `!=` are defined on arrays', node.pos)
|
||||||
|
} else if left_sym.kind == .struct_
|
||||||
|
&& (left_sym.info as ast.Struct).generic_types.len > 0 {
|
||||||
|
return ast.bool_type
|
||||||
} else if left_sym.kind == .struct_ && right_sym.kind == .struct_
|
} else if left_sym.kind == .struct_ && right_sym.kind == .struct_
|
||||||
&& node.op in [.eq, .lt] {
|
&& node.op in [.eq, .lt] {
|
||||||
if !(left_sym.has_method(node.op.str()) && right_sym.has_method(node.op.str())) {
|
if !(left_sym.has_method(node.op.str()) && right_sym.has_method(node.op.str())) {
|
||||||
|
@ -4286,6 +4289,9 @@ pub fn (mut c Checker) assign_stmt(mut node ast.AssignStmt) {
|
||||||
.mult_assign { '*' }
|
.mult_assign { '*' }
|
||||||
else { 'unknown op' }
|
else { 'unknown op' }
|
||||||
}
|
}
|
||||||
|
if left_sym.kind == .struct_ && (left_sym.info as ast.Struct).generic_types.len > 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if method := left_sym.find_method(extracted_op) {
|
if method := left_sym.find_method(extracted_op) {
|
||||||
if method.return_type != left_type {
|
if method.return_type != left_type {
|
||||||
c.error('operator `$extracted_op` must return `$left_name` to be used as an assignment operator',
|
c.error('operator `$extracted_op` must return `$left_name` to be used as an assignment operator',
|
||||||
|
|
|
@ -2812,16 +2812,28 @@ fn (mut g Gen) gen_assign_stmt(assign_stmt ast.AssignStmt) {
|
||||||
else { 'unknown op' }
|
else { 'unknown op' }
|
||||||
}
|
}
|
||||||
g.expr(left)
|
g.expr(left)
|
||||||
g.write(' = ${styp}_${util.replace_op(extracted_op)}(')
|
if left_sym.kind == .struct_ && (left_sym.info as ast.Struct).generic_types.len > 0 {
|
||||||
method := g.table.type_find_method(left_sym, extracted_op) or {
|
concrete_types := (left_sym.info as ast.Struct).concrete_types
|
||||||
// the checker will most likely have found this, already...
|
mut method_name := left_sym.cname + '_' + util.replace_op(extracted_op)
|
||||||
g.error('assignemnt operator `$extracted_op=` used but no `$extracted_op` method defined',
|
method_name = g.generic_fn_name(concrete_types, method_name, true)
|
||||||
assign_stmt.pos)
|
g.write(' = ${method_name}(')
|
||||||
ast.Fn{}
|
g.expr(left)
|
||||||
|
g.write(', ')
|
||||||
|
g.expr(val)
|
||||||
|
g.writeln(');')
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
g.write(' = ${styp}_${util.replace_op(extracted_op)}(')
|
||||||
|
method := g.table.type_find_method(left_sym, extracted_op) or {
|
||||||
|
// the checker will most likely have found this, already...
|
||||||
|
g.error('assignemnt operator `$extracted_op=` used but no `$extracted_op` method defined',
|
||||||
|
assign_stmt.pos)
|
||||||
|
ast.Fn{}
|
||||||
|
}
|
||||||
|
op_expected_left = method.params[0].typ
|
||||||
|
op_expected_right = method.params[1].typ
|
||||||
|
op_overloaded = true
|
||||||
}
|
}
|
||||||
op_expected_left = method.params[0].typ
|
|
||||||
op_expected_right = method.params[1].typ
|
|
||||||
op_overloaded = true
|
|
||||||
}
|
}
|
||||||
if right_sym.kind == .function && is_decl {
|
if right_sym.kind == .function && is_decl {
|
||||||
if is_inside_ternary && is_decl {
|
if is_inside_ternary && is_decl {
|
||||||
|
|
|
@ -266,7 +266,32 @@ fn (mut g Gen) infix_expr_cmp_op(node ast.InfixExpr) {
|
||||||
left := g.unwrap(node.left_type)
|
left := g.unwrap(node.left_type)
|
||||||
right := g.unwrap(node.right_type)
|
right := g.unwrap(node.right_type)
|
||||||
has_operator_overloading := g.table.type_has_method(left.sym, '<')
|
has_operator_overloading := g.table.type_has_method(left.sym, '<')
|
||||||
if left.sym.kind == right.sym.kind && has_operator_overloading {
|
if left.sym.kind == .struct_ && (left.sym.info as ast.Struct).generic_types.len > 0 {
|
||||||
|
if node.op in [.le, .ge] {
|
||||||
|
g.write('!')
|
||||||
|
}
|
||||||
|
concrete_types := (left.sym.info as ast.Struct).concrete_types
|
||||||
|
mut method_name := left.sym.cname + '__lt'
|
||||||
|
method_name = g.generic_fn_name(concrete_types, method_name, true)
|
||||||
|
g.write(method_name)
|
||||||
|
if node.op in [.lt, .ge] {
|
||||||
|
g.write('(')
|
||||||
|
g.write('*'.repeat(left.typ.nr_muls()))
|
||||||
|
g.expr(node.left)
|
||||||
|
g.write(', ')
|
||||||
|
g.write('*'.repeat(right.typ.nr_muls()))
|
||||||
|
g.expr(node.right)
|
||||||
|
g.write(')')
|
||||||
|
} else {
|
||||||
|
g.write('(')
|
||||||
|
g.write('*'.repeat(right.typ.nr_muls()))
|
||||||
|
g.expr(node.right)
|
||||||
|
g.write(', ')
|
||||||
|
g.write('*'.repeat(left.typ.nr_muls()))
|
||||||
|
g.expr(node.left)
|
||||||
|
g.write(')')
|
||||||
|
}
|
||||||
|
} else if left.sym.kind == right.sym.kind && has_operator_overloading {
|
||||||
if node.op in [.le, .ge] {
|
if node.op in [.le, .ge] {
|
||||||
g.write('!')
|
g.write('!')
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,10 +26,51 @@ fn (m1 Matrix<T>) + (m2 Matrix<T>) Matrix<T> {
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
fn test_generic_operator_overload() {
|
fn (m1 Matrix<T>) == (m2 Matrix<T>) bool {
|
||||||
result := from_array([[1, 2, 3], [4, 5, 6]]) + from_array([[7, 8, 9], [10, 11, 12]])
|
return m1.row == m2.row && m1.col == m2.col && m1.data == m2.data
|
||||||
println(result)
|
}
|
||||||
assert result.row == 2
|
|
||||||
assert result.col == 3
|
fn (m1 Matrix<T>) < (m2 Matrix<T>) bool {
|
||||||
assert result.data == [[8, 10, 12], [14, 16, 18]]
|
return m1.row < m2.row && m1.col < m2.col
|
||||||
|
}
|
||||||
|
|
||||||
|
fn test_generic_operator_overload() {
|
||||||
|
mut a1 := from_array([[1, 2, 3], [4, 5, 6]])
|
||||||
|
a2 := from_array([[7, 8, 9], [10, 11, 12]])
|
||||||
|
|
||||||
|
plus_ret := a1 + a2
|
||||||
|
println(plus_ret)
|
||||||
|
assert plus_ret.row == 2
|
||||||
|
assert plus_ret.col == 3
|
||||||
|
assert plus_ret.data == [[8, 10, 12], [14, 16, 18]]
|
||||||
|
|
||||||
|
a1 += a2
|
||||||
|
println(a1)
|
||||||
|
assert a1.row == 2
|
||||||
|
assert a1.col == 3
|
||||||
|
assert a1.data == [[15, 18, 21], [24, 27, 30]]
|
||||||
|
|
||||||
|
eq_ret := a1 == a2
|
||||||
|
println(eq_ret)
|
||||||
|
assert !eq_ret
|
||||||
|
|
||||||
|
ne_ret := a1 != a2
|
||||||
|
println(ne_ret)
|
||||||
|
assert ne_ret
|
||||||
|
|
||||||
|
lt_ret := a1 < a2
|
||||||
|
println(lt_ret)
|
||||||
|
assert !lt_ret
|
||||||
|
|
||||||
|
le_ret := a1 <= a2
|
||||||
|
println(le_ret)
|
||||||
|
assert le_ret
|
||||||
|
|
||||||
|
gt_ret := a1 > a2
|
||||||
|
println(gt_ret)
|
||||||
|
assert !gt_ret
|
||||||
|
|
||||||
|
ge_ret := a1 >= a2
|
||||||
|
println(ge_ret)
|
||||||
|
assert ge_ret
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue