mirror of
https://github.com/kazhuravlev/database-gateway.git
synced 2025-06-26 17:53:16 +00:00
219 lines
6.5 KiB
Go
219 lines
6.5 KiB
Go
// Database Gateway provides access to servers with ACL for safe and restricted database interactions.
|
|
// Copyright (C) 2024 Kirill Zhuravlev
|
|
//
|
|
// This program is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU General Public License as published by
|
|
// the Free Software Foundation, either version 3 of the License, or
|
|
// (at your option) any later version.
|
|
//
|
|
// This program is distributed in the hope that it will be useful,
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
// GNU General Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU General Public License
|
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
package parser
|
|
|
|
import (
|
|
"fmt"
|
|
"slices"
|
|
|
|
"github.com/kazhuravlev/just"
|
|
pg "github.com/pganalyze/pg_query_go/v6"
|
|
)
|
|
|
|
type InsertVec struct { //nolint:recvcheck
|
|
Tbl string
|
|
Target []string
|
|
}
|
|
|
|
func (s *InsertVec) Columns() []string {
|
|
columns := just.SliceUniq(s.Target)
|
|
|
|
return columns
|
|
}
|
|
|
|
func (InsertVec) isVector() {}
|
|
|
|
func handleInsert(req *pg.InsertStmt) ([]Vector, error) { //nolint:gocyclo,gocognit,cyclop,funlen
|
|
if req.GetWithClause() != nil ||
|
|
req.GetOverride() != pg.OverridingKind_OVERRIDING_NOT_SET {
|
|
return nil, fmt.Errorf("unknown clause: %w", ErrNotImplemented)
|
|
}
|
|
|
|
rel := req.GetRelation()
|
|
tables := Tables{m: make(map[string]string)}
|
|
fqTableName := tables.Put(rel.GetCatalogname(), rel.GetSchemaname(), rel.GetRelname(), rel.GetAlias().GetAliasname())
|
|
tables.Finalize()
|
|
|
|
targetCols, err := pNodes2Columns(req.GetCols(), fqTableName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse target columns: %w", err)
|
|
}
|
|
|
|
switch node := req.GetSelectStmt().GetNode().(type) {
|
|
default:
|
|
return nil, fmt.Errorf("target select type (%T): %w", node, ErrNotImplemented)
|
|
case *pg.Node_SelectStmt:
|
|
sel := node.SelectStmt
|
|
|
|
if sel.DistinctClause != nil ||
|
|
sel.GetIntoClause() != nil ||
|
|
sel.TargetList != nil ||
|
|
sel.FromClause != nil ||
|
|
sel.GetWhereClause() != nil ||
|
|
sel.GroupClause != nil ||
|
|
sel.GetGroupDistinct() ||
|
|
sel.GetHavingClause() != nil ||
|
|
sel.WindowClause != nil ||
|
|
sel.SortClause != nil ||
|
|
sel.GetLimitOffset() != nil ||
|
|
sel.GetLimitCount() != nil ||
|
|
sel.GetLimitOption() != pg.LimitOption_LIMIT_OPTION_DEFAULT ||
|
|
sel.LockingClause != nil ||
|
|
sel.GetWithClause() != nil ||
|
|
sel.GetOp() != pg.SetOperation_SETOP_NONE ||
|
|
sel.GetAll() ||
|
|
sel.GetLarg() != nil ||
|
|
sel.GetRarg() != nil {
|
|
return nil, fmt.Errorf("unknown clause: %w", ErrNotImplemented)
|
|
}
|
|
|
|
if len(sel.GetFromClause()) != 0 {
|
|
return nil, fmt.Errorf("insert from select: %w", ErrNotImplemented)
|
|
}
|
|
|
|
for _, node := range sel.GetValuesLists() {
|
|
switch node := node.GetNode().(type) {
|
|
default:
|
|
return nil, fmt.Errorf("unknown node type (%T): %w", node, ErrNotImplemented)
|
|
case *pg.Node_List:
|
|
for _, node := range node.List.GetItems() {
|
|
switch node := node.GetNode().(type) {
|
|
default:
|
|
return nil, fmt.Errorf("unknown node list node type (%T): %w", node, ErrNotImplemented)
|
|
// Allow `insert ... values (DEFAULT)`
|
|
case *pg.Node_SetToDefault:
|
|
// Allow `insert ... values (1)`
|
|
// Allow `insert ... values ('1')`
|
|
case *pg.Node_AConst:
|
|
}
|
|
}
|
|
}
|
|
}
|
|
case nil: // This is `insert into t1 default values`
|
|
}
|
|
|
|
retCols, err := pNodes2Columns(req.GetReturningList(), fqTableName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse returning columns: %w", err)
|
|
}
|
|
|
|
allColumns := slices.Concat(targetCols, retCols)
|
|
|
|
if confl := req.GetOnConflictClause(); confl != nil {
|
|
switch confl.GetAction() { //nolint:exhaustive
|
|
default:
|
|
return nil, fmt.Errorf("unknonw action on conflict: %w", ErrNotImplemented)
|
|
case pg.OnConflictAction_ONCONFLICT_NONE, pg.OnConflictAction_ONCONFLICT_NOTHING, pg.OnConflictAction_ONCONFLICT_UPDATE:
|
|
}
|
|
|
|
if confl.GetWhereClause() != nil ||
|
|
confl.GetInfer().GetWhereClause() != nil ||
|
|
confl.GetInfer().GetConname() != "" {
|
|
return nil, fmt.Errorf("unknown on-conflict clause: %w", ErrNotImplemented)
|
|
}
|
|
|
|
conflictCols, err := pNodes2Columns(confl.GetInfer().GetIndexElems(), fqTableName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse conflict columns: %w", err)
|
|
}
|
|
|
|
targetCols, err := pNodes2Columns(confl.GetTargetList(), fqTableName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse target columns: %w", err)
|
|
}
|
|
|
|
allColumns = slices.Concat(allColumns, conflictCols, targetCols)
|
|
}
|
|
|
|
table2target := make(map[string]Columns, len(allColumns))
|
|
for _, column := range allColumns {
|
|
tbl, ok := tables.Get(column.Table())
|
|
if !ok {
|
|
return nil, fmt.Errorf("table not found: %s", column.Table()) //nolint:err113
|
|
}
|
|
|
|
table2target[tbl] = append(table2target[tbl], column)
|
|
}
|
|
|
|
vectors := make([]Vector, 0, len(table2target))
|
|
for tbl, cols := range table2target {
|
|
vectors = append(vectors, InsertVec{
|
|
Tbl: tbl,
|
|
Target: cols.ListNames(),
|
|
})
|
|
}
|
|
|
|
return vectors, nil
|
|
}
|
|
|
|
func pNodes2Columns(nodes []*pg.Node, defaultTable string) (Columns, error) { //nolint:cyclop
|
|
var allColumns Columns
|
|
|
|
for _, ret := range nodes {
|
|
switch node := ret.GetNode().(type) {
|
|
default:
|
|
return nil, fmt.Errorf("col type (%T): %w", ret.GetNode(), ErrNotImplemented)
|
|
case *pg.Node_IndexElem:
|
|
idx := node.IndexElem
|
|
if idx.GetExpr() != nil ||
|
|
idx.GetIndexcolname() != "" ||
|
|
idx.Collation != nil ||
|
|
idx.Opclass != nil ||
|
|
idx.Opclassopts != nil ||
|
|
idx.GetOrdering() != pg.SortByDir_SORTBY_DEFAULT ||
|
|
idx.GetNullsOrdering() != pg.SortByNulls_SORTBY_NULLS_DEFAULT {
|
|
return nil, fmt.Errorf("unknown index elem: %w", ErrNotImplemented)
|
|
}
|
|
|
|
col := Column{
|
|
table: []string{defaultTable},
|
|
column: idx.GetName(),
|
|
}
|
|
allColumns = append(allColumns, col)
|
|
case *pg.Node_ResTarget:
|
|
resTarget := node.ResTarget
|
|
if resTarget.Indirection != nil {
|
|
return nil, fmt.Errorf("resTarget field: %w", ErrNotImplemented)
|
|
}
|
|
|
|
if resTarget.GetName() != "" {
|
|
col := Column{
|
|
table: []string{defaultTable},
|
|
column: resTarget.GetName(),
|
|
}
|
|
allColumns = append(allColumns, col)
|
|
|
|
continue
|
|
}
|
|
|
|
switch node := resTarget.GetVal().GetNode().(type) {
|
|
default:
|
|
return nil, fmt.Errorf("resTarget type (%T): %w", resTarget.GetVal().GetNode(), ErrNotImplemented)
|
|
case *pg.Node_ColumnRef:
|
|
column, err := pNodeColumnRef(node)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse column: %w", err)
|
|
}
|
|
|
|
allColumns = append(allColumns, column)
|
|
}
|
|
}
|
|
}
|
|
|
|
return allColumns, nil
|
|
}
|