diff --git a/vlib/v/checker/checker.v b/vlib/v/checker/checker.v index ec32488af5..9714362eff 100644 --- a/vlib/v/checker/checker.v +++ b/vlib/v/checker/checker.v @@ -1528,6 +1528,9 @@ pub fn (mut c Checker) infix_expr(mut node ast.InfixExpr) ast.Type { .gt, .lt, .ge, .le { if left_sym.kind in [.array, .array_fixed] && right_sym.kind in [.array, .array_fixed] { c.error('only `==` and `!=` are defined on arrays', node.pos) + } else if left_sym.kind == .struct_ + && (left_sym.info as ast.Struct).generic_types.len > 0 { + return ast.bool_type } else if left_sym.kind == .struct_ && right_sym.kind == .struct_ && node.op in [.eq, .lt] { if !(left_sym.has_method(node.op.str()) && right_sym.has_method(node.op.str())) { @@ -4286,6 +4289,9 @@ pub fn (mut c Checker) assign_stmt(mut node ast.AssignStmt) { .mult_assign { '*' } else { 'unknown op' } } + if left_sym.kind == .struct_ && (left_sym.info as ast.Struct).generic_types.len > 0 { + continue + } if method := left_sym.find_method(extracted_op) { if method.return_type != left_type { c.error('operator `$extracted_op` must return `$left_name` to be used as an assignment operator', diff --git a/vlib/v/gen/c/cgen.v b/vlib/v/gen/c/cgen.v index 7c7486e6f4..25ce1c533e 100644 --- a/vlib/v/gen/c/cgen.v +++ b/vlib/v/gen/c/cgen.v @@ -2812,16 +2812,28 @@ fn (mut g Gen) gen_assign_stmt(assign_stmt ast.AssignStmt) { else { 'unknown op' } } g.expr(left) - g.write(' = ${styp}_${util.replace_op(extracted_op)}(') - method := g.table.type_find_method(left_sym, extracted_op) or { - // the checker will most likely have found this, already... - g.error('assignemnt operator `$extracted_op=` used but no `$extracted_op` method defined', - assign_stmt.pos) - ast.Fn{} + if left_sym.kind == .struct_ && (left_sym.info as ast.Struct).generic_types.len > 0 { + concrete_types := (left_sym.info as ast.Struct).concrete_types + mut method_name := left_sym.cname + '_' + util.replace_op(extracted_op) + method_name = g.generic_fn_name(concrete_types, method_name, true) + g.write(' = ${method_name}(') + g.expr(left) + g.write(', ') + g.expr(val) + g.writeln(');') + return + } else { + g.write(' = ${styp}_${util.replace_op(extracted_op)}(') + method := g.table.type_find_method(left_sym, extracted_op) or { + // the checker will most likely have found this, already... + g.error('assignemnt operator `$extracted_op=` used but no `$extracted_op` method defined', + assign_stmt.pos) + ast.Fn{} + } + op_expected_left = method.params[0].typ + op_expected_right = method.params[1].typ + op_overloaded = true } - op_expected_left = method.params[0].typ - op_expected_right = method.params[1].typ - op_overloaded = true } if right_sym.kind == .function && is_decl { if is_inside_ternary && is_decl { diff --git a/vlib/v/gen/c/infix_expr.v b/vlib/v/gen/c/infix_expr.v index de9bff7c27..f57d4b2f9f 100644 --- a/vlib/v/gen/c/infix_expr.v +++ b/vlib/v/gen/c/infix_expr.v @@ -266,7 +266,32 @@ fn (mut g Gen) infix_expr_cmp_op(node ast.InfixExpr) { left := g.unwrap(node.left_type) right := g.unwrap(node.right_type) has_operator_overloading := g.table.type_has_method(left.sym, '<') - if left.sym.kind == right.sym.kind && has_operator_overloading { + if left.sym.kind == .struct_ && (left.sym.info as ast.Struct).generic_types.len > 0 { + if node.op in [.le, .ge] { + g.write('!') + } + concrete_types := (left.sym.info as ast.Struct).concrete_types + mut method_name := left.sym.cname + '__lt' + method_name = g.generic_fn_name(concrete_types, method_name, true) + g.write(method_name) + if node.op in [.lt, .ge] { + g.write('(') + g.write('*'.repeat(left.typ.nr_muls())) + g.expr(node.left) + g.write(', ') + g.write('*'.repeat(right.typ.nr_muls())) + g.expr(node.right) + g.write(')') + } else { + g.write('(') + g.write('*'.repeat(right.typ.nr_muls())) + g.expr(node.right) + g.write(', ') + g.write('*'.repeat(left.typ.nr_muls())) + g.expr(node.left) + g.write(')') + } + } else if left.sym.kind == right.sym.kind && has_operator_overloading { if node.op in [.le, .ge] { g.write('!') } diff --git a/vlib/v/tests/generic_operator_overload_test.v b/vlib/v/tests/generic_operator_overload_test.v index cbc227321c..8395c0184e 100644 --- a/vlib/v/tests/generic_operator_overload_test.v +++ b/vlib/v/tests/generic_operator_overload_test.v @@ -26,10 +26,51 @@ fn (m1 Matrix) + (m2 Matrix) Matrix { return res } -fn test_generic_operator_overload() { - result := from_array([[1, 2, 3], [4, 5, 6]]) + from_array([[7, 8, 9], [10, 11, 12]]) - println(result) - assert result.row == 2 - assert result.col == 3 - assert result.data == [[8, 10, 12], [14, 16, 18]] +fn (m1 Matrix) == (m2 Matrix) bool { + return m1.row == m2.row && m1.col == m2.col && m1.data == m2.data +} + +fn (m1 Matrix) < (m2 Matrix) bool { + return m1.row < m2.row && m1.col < m2.col +} + +fn test_generic_operator_overload() { + mut a1 := from_array([[1, 2, 3], [4, 5, 6]]) + a2 := from_array([[7, 8, 9], [10, 11, 12]]) + + plus_ret := a1 + a2 + println(plus_ret) + assert plus_ret.row == 2 + assert plus_ret.col == 3 + assert plus_ret.data == [[8, 10, 12], [14, 16, 18]] + + a1 += a2 + println(a1) + assert a1.row == 2 + assert a1.col == 3 + assert a1.data == [[15, 18, 21], [24, 27, 30]] + + eq_ret := a1 == a2 + println(eq_ret) + assert !eq_ret + + ne_ret := a1 != a2 + println(ne_ret) + assert ne_ret + + lt_ret := a1 < a2 + println(lt_ret) + assert !lt_ret + + le_ret := a1 <= a2 + println(le_ret) + assert le_ret + + gt_ret := a1 > a2 + println(gt_ret) + assert !gt_ret + + ge_ret := a1 >= a2 + println(ge_ret) + assert ge_ret }