orm: `update` cgen

pull/5450/head
Alexander Medvednikov 2020-06-25 17:12:32 +02:00
parent a8b0dfb38a
commit b280e08ee0
5 changed files with 109 additions and 67 deletions

View File

@ -116,11 +116,15 @@ fn test_orm_sqlite() {
assert customer.is_customer == true assert customer.is_customer == true
assert customer.name == 'Kate' assert customer.name == 'Kate'
// //
/*
sql db { sql db {
update User set age = 31 where name == 'Kate' update User set age = 31 where name == 'Kate'
} }
*/ kate2 := sql db {
select from User where id == 3
}
println(kate2)
assert kate2.age == 31
assert kate2.name == 'Kate'
} }
struct User { struct User {

View File

@ -822,6 +822,9 @@ pub:
object_var_name string // `user` object_var_name string // `user`
table_type table.Type table_type table.Type
pos token.Position pos token.Position
where_expr Expr
updated_columns []string //for `update set x=y`
update_exprs []Expr//for `update`
pub mut: pub mut:
table_name string table_name string
fields []table.Field fields []table.Field

View File

@ -319,8 +319,8 @@ pub fn (mut c Checker) struct_decl(decl ast.StructDecl) {
if !c.check_types(field_expr_type, field.typ) { if !c.check_types(field_expr_type, field.typ) {
field_expr_type_sym := c.table.get_type_symbol(field_expr_type) field_expr_type_sym := c.table.get_type_symbol(field_expr_type)
field_type_sym := c.table.get_type_symbol(field.typ) field_type_sym := c.table.get_type_symbol(field.typ)
c.error('default expression for field `$field.name` ' + c.error('default expression for field `$field.name` ' + 'has type `$field_expr_type_sym.name`, but should be `$field_type_sym.name`',
'has type `$field_expr_type_sym.name`, but should be `$field_type_sym.name`', field.default_expr.position()) field.default_expr.position())
} }
} }
} }
@ -413,9 +413,8 @@ pub fn (mut c Checker) struct_init(mut struct_init ast.StructInit) table.Type {
expr_type := c.expr(field.expr) expr_type := c.expr(field.expr)
expr_type_sym := c.table.get_type_symbol(expr_type) expr_type_sym := c.table.get_type_symbol(expr_type)
field_type_sym := c.table.get_type_symbol(info_field.typ) field_type_sym := c.table.get_type_symbol(info_field.typ)
if !c.check_types(expr_type, info_field.typ) && expr_type != table.void_type && if !c.check_types(expr_type, info_field.typ) && expr_type != table.void_type &&
expr_type_sym.kind != .placeholder { expr_type_sym.kind != .placeholder {
c.error('!cannot assign $expr_type_sym.kind `$expr_type_sym.name` as `$field_type_sym.name` for field `$info_field.name`', c.error('!cannot assign $expr_type_sym.kind `$expr_type_sym.name` as `$field_type_sym.name` for field `$info_field.name`',
field.pos) field.pos)
} }
@ -535,8 +534,7 @@ pub fn (mut c Checker) infix_expr(mut infix_expr ast.InfixExpr) table.Type {
if infix_expr.op in [.div, .mod] { if infix_expr.op in [.div, .mod] {
if (infix_expr.right is ast.IntegerLiteral && if (infix_expr.right is ast.IntegerLiteral &&
infix_expr.right.str() == '0') || infix_expr.right.str() == '0') ||
(infix_expr.right is ast.FloatLiteral && (infix_expr.right is ast.FloatLiteral && infix_expr.right.str().f64() == 0.0) {
infix_expr.right.str().f64() == 0.0) {
oper := if infix_expr.op == .div { 'division' } else { 'modulo' } oper := if infix_expr.op == .div { 'division' } else { 'modulo' }
c.error('$oper by zero', right_pos) c.error('$oper by zero', right_pos)
} }
@ -825,7 +823,8 @@ pub fn (mut c Checker) call_method(mut call_expr ast.CallExpr) table.Type {
} }
if method := c.table.type_find_method(left_type_sym, method_name) { if method := c.table.type_find_method(left_type_sym, method_name) {
if !method.is_pub && !c.is_builtin_mod && !c.pref.is_test && if !method.is_pub && !c.is_builtin_mod && !c.pref.is_test &&
left_type_sym.mod != c.mod && left_type_sym.mod != '' { // method.mod != c.mod { left_type_sym.mod != c.mod &&
left_type_sym.mod != '' { // method.mod != c.mod {
// If a private method is called outside of the module // If a private method is called outside of the module
// its receiver type is defined in, show an error. // its receiver type is defined in, show an error.
// println('warn $method_name lef.mod=$left_type_sym.mod c.mod=$c.mod') // println('warn $method_name lef.mod=$left_type_sym.mod c.mod=$c.mod')
@ -836,7 +835,8 @@ pub fn (mut c Checker) call_method(mut call_expr ast.CallExpr) table.Type {
// call_expr.is_mut = true // call_expr.is_mut = true
} }
if method.return_type == table.void_type && if method.return_type == table.void_type &&
method.ctdefine.len > 0 && method.ctdefine !in c.pref.compile_defines { method.ctdefine.len > 0 &&
method.ctdefine !in c.pref.compile_defines {
call_expr.should_be_skipped = true call_expr.should_be_skipped = true
} }
nr_args := if method.args.len == 0 { 0 } else { method.args.len - 1 } nr_args := if method.args.len == 0 { 0 } else { method.args.len - 1 }
@ -1020,7 +1020,8 @@ pub fn (mut c Checker) call_fn(mut call_expr ast.CallExpr) table.Type {
} }
call_expr.return_type = f.return_type call_expr.return_type = f.return_type
if f.return_type == table.void_type && if f.return_type == table.void_type &&
f.ctdefine.len > 0 && f.ctdefine !in c.pref.compile_defines { f.ctdefine.len > 0 &&
f.ctdefine !in c.pref.compile_defines {
call_expr.should_be_skipped = true call_expr.should_be_skipped = true
} }
if f.language != .v || call_expr.language != .v { if f.language != .v || call_expr.language != .v {
@ -2339,7 +2340,7 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, type_sym table.TypeSymbol
ast.EnumVal { key = expr.val } ast.EnumVal { key = expr.val }
else { key = expr.str() } else { key = expr.str() }
} }
val := if key in branch_exprs { branch_exprs[key] } /**/ else { 0 } val := if key in branch_exprs { branch_exprs[key] } else { 0 }
if val == 1 { if val == 1 {
c.error('match case `$key` is handled more than once', branch.pos) c.error('match case `$key` is handled more than once', branch.pos)
} }
@ -2714,7 +2715,6 @@ fn (mut c Checker) sql_expr(mut node ast.SqlExpr) table.Type {
return node.typ return node.typ
} }
fn (mut c Checker) sql_stmt(mut node ast.SqlStmt) table.Type { fn (mut c Checker) sql_stmt(mut node ast.SqlStmt) table.Type {
sym := c.table.get_type_symbol(node.table_type) sym := c.table.get_type_symbol(node.table_type)
info := sym.info as table.Struct info := sym.info as table.Struct
@ -2736,8 +2736,6 @@ fn (c &Checker) fetch_and_verify_orm_fields(info table.Struct, pos token.Positio
return fields return fields
} }
fn (mut c Checker) fn_decl(it ast.FnDecl) { fn (mut c Checker) fn_decl(it ast.FnDecl) {
if it.is_generic && c.cur_generic_type == 0 { // need the cur_generic_type check to avoid inf. recursion if it.is_generic && c.cur_generic_type == 0 { // need the cur_generic_type check to avoid inf. recursion
// loop thru each generic type and generate a function // loop thru each generic type and generate a function
@ -2774,8 +2772,8 @@ fn (mut c Checker) fn_decl(it ast.FnDecl) {
} }
sym.methods.delete(idx) sym.methods.delete(idx)
// //
c.error('cannot define new methods on non-local `$sym.name` (' + c.error('cannot define new methods on non-local `$sym.name` (' + 'current module is `$c.mod`, `$sym.name` is from `$sym.mod`)',
'current module is `$c.mod`, `$sym.name` is from `$sym.mod`)', it.pos) it.pos)
} }
} }
if it.language == .v { if it.language == .v {

View File

@ -11,51 +11,84 @@ const (
dbtype = 'sqlite' dbtype = 'sqlite'
) )
enum SqlExprSide { left right } enum SqlExprSide {
left
right
}
fn (mut g Gen) sql_stmt(node ast.SqlStmt) { fn (mut g Gen) sql_stmt(node ast.SqlStmt) {
g.sql_i = 0
g.writeln('\n\t// sql insert') g.writeln('\n\t// sql insert')
db_name := g.new_tmp_var() db_name := g.new_tmp_var()
g.sql_stmt_name = g.new_tmp_var() g.sql_stmt_name = g.new_tmp_var()
g.write('${dbtype}__DB $db_name = ') g.write('${dbtype}__DB $db_name = ')
g.expr(node.db_expr) g.expr(node.db_expr)
g.writeln(';') g.writeln(';')
mut q := 'insert into $node.table_name (' g.write('sqlite3_stmt* $g.sql_stmt_name = ${dbtype}__DB_init_stmt($db_name, tos_lit("')
for i, field in node.fields { if node.kind == .insert {
if field.name == 'id' { g.write('insert into $node.table_name (')
continue } else {
g.write('update $node.table_name set ')
} }
q += '$field.name' if node.kind == .insert {
if i < node.fields.len - 1 { for i, field in node.fields {
q += ', ' if field.name == 'id' {
continue
}
g.write(field.name)
if i < node.fields.len - 1 {
g.write(', ')
}
}
g.write( ') values (')
for i, field in node.fields {
if field.name == 'id' {
continue
}
g.write('?${i+0}')
if i < node.fields.len - 1 {
g.write(', ')
}
}
g.write(')')
} else if node.kind == .update {
for i, col in node.updated_columns {
g.write(' $col = ')
g.expr_to_sql(node.update_exprs[i])
}
g.write(' where ')
}
if node.kind == .update {
g.expr_to_sql(node.where_expr)
}
g.writeln('"));')
if node.kind == .insert {
for i, field in node.fields {
if field.name == 'id' {
continue
}
x := '${node.object_var_name}.$field.name'
if field.typ == table.string_type {
g.writeln('sqlite3_bind_text($g.sql_stmt_name, ${i+0}, ${x}.str, ${x}.len, 0);')
} else {
g.writeln('sqlite3_bind_int($g.sql_stmt_name, ${i+0}, $x); // stmt')
}
} }
} }
q += ') values (' else if node.kind == .update {
for i, field in node.fields {
if field.name == 'id' {
continue
} }
q += '?${i+0}'
if i < node.fields.len - 1 { // Dump all sql parameters generated by our custom expr handler
q += ', ' binds := g.sql_buf.str()
} g.sql_buf = strings.new_builder(100)
} g.writeln(binds)
q += ')'
println(q)
g.writeln('sqlite3_stmt* $g.sql_stmt_name = ${dbtype}__DB_init_stmt($db_name, tos_lit("$q"));')
for i, field in node.fields {
if field.name == 'id' {
continue
}
x := '${node.object_var_name}.$field.name'
if field.typ == table.string_type {
g.writeln('sqlite3_bind_text($g.sql_stmt_name, ${i+0}, ${x}.str, ${x}.len, 0);')
} else {
g.writeln('sqlite3_bind_int($g.sql_stmt_name, ${i+0}, $x); //insertl')
}
}
g.writeln('sqlite3_step($g.sql_stmt_name);') g.writeln('sqlite3_step($g.sql_stmt_name);')
g.writeln('puts(sqlite3_errmsg(${db_name}.conn));') g.writeln('if (strcmp(sqlite3_errmsg(${db_name}.conn), "not an error") != 0) puts(sqlite3_errmsg(${db_name}.conn)); ')
g.writeln('sqlite3_finalize($g.sql_stmt_name);') g.writeln('sqlite3_finalize($g.sql_stmt_name);')
} }
@ -107,7 +140,7 @@ fn (mut g Gen) sql_select_expr(node ast.SqlExpr) {
binds := g.sql_buf.str() binds := g.sql_buf.str()
g.sql_buf = strings.new_builder(100) g.sql_buf = strings.new_builder(100)
g.writeln(binds) g.writeln(binds)
g.writeln('puts(sqlite3_errmsg(${db_name}.conn));') g.writeln('if (strcmp(sqlite3_errmsg(${db_name}.conn), "not an error") != 0) puts(sqlite3_errmsg(${db_name}.conn)); ')
// //
if node.is_count { if node.is_count {
g.writeln('$cur_line ${dbtype}__get_int_from_stmt($g.sql_stmt_name);') g.writeln('$cur_line ${dbtype}__get_int_from_stmt($g.sql_stmt_name);')
@ -127,7 +160,7 @@ fn (mut g Gen) sql_select_expr(node ast.SqlExpr) {
g.writeln('\t$elem_type_str $tmp = ($elem_type_str) {') g.writeln('\t$elem_type_str $tmp = ($elem_type_str) {')
// //
sym := g.table.get_type_symbol(array_info.elem_type) sym := g.table.get_type_symbol(array_info.elem_type)
info := sym.info as table.Struct info := sym.info as table.Struct
for field in info.fields { for field in info.fields {
g.zero_struct_field(field) g.zero_struct_field(field)
} }
@ -139,7 +172,7 @@ fn (mut g Gen) sql_select_expr(node ast.SqlExpr) {
// If we don't, string values are going to be nil etc for fields that are not returned // If we don't, string values are going to be nil etc for fields that are not returned
// by the db engine. // by the db engine.
sym := g.table.get_type_symbol(node.typ) sym := g.table.get_type_symbol(node.typ)
info := sym.info as table.Struct info := sym.info as table.Struct
for field in info.fields { for field in info.fields {
g.zero_struct_field(field) g.zero_struct_field(field)
} }
@ -148,7 +181,7 @@ fn (mut g Gen) sql_select_expr(node ast.SqlExpr) {
// //
g.writeln('int _step_res$tmp = sqlite3_step($g.sql_stmt_name);') g.writeln('int _step_res$tmp = sqlite3_step($g.sql_stmt_name);')
if node.is_array { if node.is_array {
//g.writeln('\tprintf("step res=%d\\n", _step_res$tmp);') // g.writeln('\tprintf("step res=%d\\n", _step_res$tmp);')
g.writeln('\tif (_step_res$tmp == SQLITE_DONE) break;') g.writeln('\tif (_step_res$tmp == SQLITE_DONE) break;')
g.writeln('\tif (_step_res$tmp == SQLITE_ROW) ;') // another row g.writeln('\tif (_step_res$tmp == SQLITE_ROW) ;') // another row
g.writeln('\telse if (_step_res$tmp != SQLITE_OK) break;') g.writeln('\telse if (_step_res$tmp != SQLITE_OK) break;')
@ -177,12 +210,10 @@ fn (mut g Gen) sql_select_expr(node ast.SqlExpr) {
fn (mut g Gen) sql_bind_int(val string) { fn (mut g Gen) sql_bind_int(val string) {
g.sql_buf.writeln('sqlite3_bind_int($g.sql_stmt_name, $g.sql_i, $val);') g.sql_buf.writeln('sqlite3_bind_int($g.sql_stmt_name, $g.sql_i, $val);')
} }
fn (mut g Gen) sql_bind_string(val string, len string) { fn (mut g Gen) sql_bind_string(val, len string) {
g.sql_buf.writeln('sqlite3_bind_text($g.sql_stmt_name, $g.sql_i, $val, $len, 0);') g.sql_buf.writeln('sqlite3_bind_text($g.sql_stmt_name, $g.sql_i, $val, $len, 0);')
} }
fn (mut g Gen) expr_to_sql(expr ast.Expr) { fn (mut g Gen) expr_to_sql(expr ast.Expr) {
@ -221,13 +252,17 @@ fn (mut g Gen) expr_to_sql(expr ast.Expr) {
// true/false literals were added to Sqlite 3.23 (2018-04-02) // true/false literals were added to Sqlite 3.23 (2018-04-02)
// but lots of apps/distros use older sqlite (e.g. Ubuntu 18.04 LTS ) // but lots of apps/distros use older sqlite (e.g. Ubuntu 18.04 LTS )
g.inc_sql_i() g.inc_sql_i()
g.sql_bind_int(if it.val { '1' } else { '0' }) g.sql_bind_int(if it.val {
'1'
} else {
'0'
})
} }
ast.Ident { ast.Ident {
// `name == user_name` => `name == ?1` // `name == user_name` => `name == ?1`
// for left sides just add a string, for right sides, generate the bindings // for left sides just add a string, for right sides, generate the bindings
if g.sql_side == .left { if g.sql_side == .left {
//println("sql gen left $expr.name") // println("sql gen left $expr.name")
g.write(expr.name) g.write(expr.name)
} else { } else {
g.inc_sql_i() g.inc_sql_i()
@ -235,11 +270,9 @@ fn (mut g Gen) expr_to_sql(expr ast.Expr) {
typ := info.typ typ := info.typ
if typ == table.string_type { if typ == table.string_type {
g.sql_bind_string('${expr.name}.str', '${expr.name}.len') g.sql_bind_string('${expr.name}.str', '${expr.name}.len')
} } else if typ == table.int_type {
else if typ == table.int_type {
g.sql_bind_int(expr.name) g.sql_bind_int(expr.name)
} } else {
else {
verror('bad sql type $typ') verror('bad sql type $typ')
} }
} }

View File

@ -60,11 +60,9 @@ fn (mut p Parser) sql_expr() ast.Expr {
is_count: is_count is_count: is_count
typ: typ typ: typ
db_expr: db_expr db_expr: db_expr
//table_name: table_name
table_type: table_type table_type: table_type
where_expr: where_expr where_expr: where_expr
has_where: has_where has_where: has_where
//fields: fields
is_array: !query_one is_array: !query_one
pos: pos pos: pos
} }
@ -109,6 +107,7 @@ fn (mut p Parser) sql_stmt() ast.SqlStmt {
} }
n = p.check_name() // into n = p.check_name() // into
mut updated_columns := []string{} mut updated_columns := []string{}
mut update_exprs := []ast.Expr{cap: 5}
if kind == .insert && n != 'into' { if kind == .insert && n != 'into' {
p.error('expecting `into`') p.error('expecting `into`')
} else if kind == .update { } else if kind == .update {
@ -118,9 +117,10 @@ fn (mut p Parser) sql_stmt() ast.SqlStmt {
column := p.check_name() column := p.check_name()
updated_columns << column updated_columns << column
p.check(.assign) p.check(.assign)
p.expr(0) update_exprs << p.expr(0)
} }
mut table_type := table.Type(0) mut table_type := table.Type(0)
mut where_expr := ast.Expr{}
if kind == .insert { if kind == .insert {
table_type = p.parse_type() // `User` table_type = p.parse_type() // `User`
sym := p.table.get_type_symbol(table_type) sym := p.table.get_type_symbol(table_type)
@ -131,7 +131,7 @@ fn (mut p Parser) sql_stmt() ast.SqlStmt {
idx := p.table.find_type_idx(table_name) idx := p.table.find_type_idx(table_name)
table_type = table.new_type(idx) table_type = table.new_type(idx)
p.check_sql_keyword('where') p.check_sql_keyword('where')
p.expr(0) where_expr = p.expr(0)
} }
p.check(.rcbr) p.check(.rcbr)
return ast.SqlStmt{ return ast.SqlStmt{
@ -140,6 +140,10 @@ fn (mut p Parser) sql_stmt() ast.SqlStmt {
table_type: table_type table_type: table_type
object_var_name: inserted_var_name object_var_name: inserted_var_name
pos: pos pos: pos
updated_columns: updated_columns
update_exprs: update_exprs
kind: kind
where_expr: where_expr
} }
} }