From ec5c378c05bdfe33b329cc463dc77f9d6a1aaa5c Mon Sep 17 00:00:00 2001 From: John Hopper Date: Mon, 30 Jun 2025 15:18:31 -0700 Subject: [PATCH 1/5] feat (v2): v2 --- cypher/models/cypher/model.go | 11 +- cypher/models/pgsql/format/format.go | 10 +- cypher/models/pgsql/test/testcase.go | 21 +- cypher/models/pgsql/test/translation_test.go | 2 +- .../pgsql/test/validation_integration_test.go | 39 +- .../pgsql/visualization/visualizer_test.go | 2 +- database/driver.go | 82 ++ database/neo4j/database.go | 163 +++ database/neo4j/database_integration_test.go | 1 + database/neo4j/driver.go | 202 +++ {drivers => database}/neo4j/mapper.go | 48 +- database/neo4j/neo4j.go | 48 + database/neo4j/schema.go | 334 +++++ database/pg/batch.go | 379 ++++++ database/pg/config.go | 7 + database/pg/database.go | 125 ++ database/pg/driver.go | 252 ++++ drivers/pg/facts.go => database/pg/error.go | 0 {drivers => database}/pg/manager.go | 253 ++-- {drivers => database}/pg/mapper.go | 8 +- {drivers => database}/pg/model/format.go | 6 +- database/pg/model/model.go | 68 + {drivers => database}/pg/pg.go | 16 +- {drivers => database}/pg/pgutil/kindmapper.go | 0 {drivers => database}/pg/query/definitions.go | 0 {drivers => database}/pg/query/format.go | 24 +- database/pg/query/query.go | 514 ++++++++ {drivers => database}/pg/query/sql.go | 2 +- .../pg/query/sql/insert_graph.sql | 0 .../pg/query/sql/insert_or_get_kind.sql | 0 .../pg/query/sql/schema_down.sql | 0 .../pg/query/sql/schema_up.sql | 0 .../pg/query/sql/select_graph_by_name.sql | 3 +- .../pg/query/sql/select_graphs.sql | 3 +- .../pg/query/sql/select_kind_id.sql | 0 .../pg/query/sql/select_kinds.sql | 0 .../pg/query/sql/select_table_indexes.sql | 0 {drivers => database}/pg/statements.go | 4 +- {drivers => database}/pg/types.go | 4 +- database/pg/util.go | 1 + database/schema.go | 68 + {drivers => database}/tooling.go | 2 +- database/v1compat/database.go | 153 +++ database/v1compat/errors.go | 23 + database/v1compat/graph.go | 111 ++ database/v1compat/node.go | 202 +++ {ops => database/v1compat/ops}/ops.go | 4 +- {ops => database/v1compat/ops}/parallel.go | 4 +- {ops => database/v1compat/ops}/traversal.go | 4 +- {graph => database/v1compat}/query.go | 81 +- {query => database/v1compat/query}/builder.go | 9 +- .../v1compat/query}/identifiers.go | 0 database/v1compat/query/model.go | 599 +++++++++ {query => database/v1compat/query}/rewrite.go | 0 {query => database/v1compat/query}/sort.go | 5 +- database/v1compat/relationship.go | 290 +++++ {graph => database/v1compat}/result.go | 34 +- {graph => database/v1compat}/switch.go | 2 +- .../v1compat/traversal}/collection.go | 24 +- database/v1compat/traversal/query.go | 45 + database/v1compat/traversal/traversal.go | 594 +++++++++ database/v1compat/traversal/traversal_test.go | 93 ++ database/v1compat/types.go | 193 +++ database/v1compat/wrapper.go | 434 +++++++ drivers/neo4j/batch.go | 254 ---- drivers/neo4j/const.go | 12 - drivers/neo4j/cypher.go | 319 ----- drivers/neo4j/cypher_internal_test.go | 89 -- drivers/neo4j/driver.go | 168 --- drivers/neo4j/index.go | 269 ---- drivers/neo4j/neo4j.go | 50 - drivers/neo4j/node.go | 219 ---- drivers/neo4j/relationship.go | 321 ----- drivers/neo4j/result.go | 56 - drivers/neo4j/result_internal_test.go | 61 - drivers/neo4j/transaction.go | 447 ------- drivers/neo4j/wrapper.go | 31 - drivers/pg/batch.go | 576 --------- drivers/pg/driver.go | 169 --- drivers/pg/model/model.go | 68 - drivers/pg/node.go | 139 --- drivers/pg/query.go | 63 - drivers/pg/query/query.go | 473 ------- drivers/pg/relationship.go | 238 ---- drivers/pg/result.go | 49 - drivers/pg/tooling.go | 125 -- drivers/pg/transaction.go | 299 ----- drivers/pg/util.go | 7 - graph/error.go | 1 - graph/graph.go | 147 --- graph/properties.go | 20 +- graph/{relationships.go => relationship.go} | 0 ...tionships_test.go => relationship_test.go} | 0 graph/schema.go | 45 - graphcache/cache.go | 106 -- query/model.go | 617 --------- query/neo4j/neo4j.go | 339 ----- query/neo4j/neo4j_test.go | 1100 ----------------- query/neo4j/rewrite.go | 153 --- query/query.go | 782 ++++++++++++ query/query_test.go | 62 + query/util.go | 174 +++ dawgs.go => registry.go | 21 +- registry_integration_test.go | 80 ++ traversal/traversal.go | 309 ++--- 105 files changed, 6556 insertions(+), 7508 deletions(-) create mode 100644 database/driver.go create mode 100644 database/neo4j/database.go create mode 100644 database/neo4j/database_integration_test.go create mode 100644 database/neo4j/driver.go rename {drivers => database}/neo4j/mapper.go (51%) create mode 100644 database/neo4j/neo4j.go create mode 100644 database/neo4j/schema.go create mode 100644 database/pg/batch.go create mode 100644 database/pg/config.go create mode 100644 database/pg/database.go create mode 100644 database/pg/driver.go rename drivers/pg/facts.go => database/pg/error.go (100%) rename {drivers => database}/pg/manager.go (59%) rename {drivers => database}/pg/mapper.go (90%) rename {drivers => database}/pg/model/format.go (83%) create mode 100644 database/pg/model/model.go rename {drivers => database}/pg/pg.go (86%) rename {drivers => database}/pg/pgutil/kindmapper.go (100%) rename {drivers => database}/pg/query/definitions.go (100%) rename {drivers => database}/pg/query/format.go (93%) create mode 100644 database/pg/query/query.go rename {drivers => database}/pg/query/sql.go (95%) rename {drivers => database}/pg/query/sql/insert_graph.sql (100%) rename {drivers => database}/pg/query/sql/insert_or_get_kind.sql (100%) rename {drivers => database}/pg/query/sql/schema_down.sql (100%) rename {drivers => database}/pg/query/sql/schema_up.sql (100%) rename {drivers => database}/pg/query/sql/select_graph_by_name.sql (71%) rename {drivers => database}/pg/query/sql/select_graphs.sql (70%) rename {drivers => database}/pg/query/sql/select_kind_id.sql (100%) rename {drivers => database}/pg/query/sql/select_kinds.sql (100%) rename {drivers => database}/pg/query/sql/select_table_indexes.sql (100%) rename {drivers => database}/pg/statements.go (92%) rename {drivers => database}/pg/types.go (97%) create mode 100644 database/pg/util.go create mode 100644 database/schema.go rename {drivers => database}/tooling.go (96%) create mode 100644 database/v1compat/database.go create mode 100644 database/v1compat/errors.go create mode 100644 database/v1compat/graph.go create mode 100644 database/v1compat/node.go rename {ops => database/v1compat/ops}/ops.go (99%) rename {ops => database/v1compat/ops}/parallel.go (99%) rename {ops => database/v1compat/ops}/traversal.go (98%) rename {graph => database/v1compat}/query.go (75%) rename {query => database/v1compat/query}/builder.go (97%) rename {query => database/v1compat/query}/identifiers.go (100%) create mode 100644 database/v1compat/query/model.go rename {query => database/v1compat/query}/rewrite.go (100%) rename {query => database/v1compat/query}/sort.go (87%) create mode 100644 database/v1compat/relationship.go rename {graph => database/v1compat}/result.go (73%) rename {graph => database/v1compat}/switch.go (99%) rename {traversal => database/v1compat/traversal}/collection.go (56%) create mode 100644 database/v1compat/traversal/query.go create mode 100644 database/v1compat/traversal/traversal.go create mode 100644 database/v1compat/traversal/traversal_test.go create mode 100644 database/v1compat/types.go create mode 100644 database/v1compat/wrapper.go delete mode 100644 drivers/neo4j/batch.go delete mode 100644 drivers/neo4j/const.go delete mode 100644 drivers/neo4j/cypher.go delete mode 100644 drivers/neo4j/cypher_internal_test.go delete mode 100644 drivers/neo4j/driver.go delete mode 100644 drivers/neo4j/index.go delete mode 100644 drivers/neo4j/neo4j.go delete mode 100644 drivers/neo4j/node.go delete mode 100644 drivers/neo4j/relationship.go delete mode 100644 drivers/neo4j/result.go delete mode 100644 drivers/neo4j/result_internal_test.go delete mode 100644 drivers/neo4j/transaction.go delete mode 100644 drivers/neo4j/wrapper.go delete mode 100644 drivers/pg/batch.go delete mode 100644 drivers/pg/driver.go delete mode 100644 drivers/pg/model/model.go delete mode 100644 drivers/pg/node.go delete mode 100644 drivers/pg/query.go delete mode 100644 drivers/pg/query/query.go delete mode 100644 drivers/pg/relationship.go delete mode 100644 drivers/pg/result.go delete mode 100644 drivers/pg/tooling.go delete mode 100644 drivers/pg/transaction.go delete mode 100644 drivers/pg/util.go delete mode 100644 graph/error.go rename graph/{relationships.go => relationship.go} (100%) rename graph/{relationships_test.go => relationship_test.go} (100%) delete mode 100644 graph/schema.go delete mode 100644 query/model.go delete mode 100644 query/neo4j/neo4j.go delete mode 100644 query/neo4j/neo4j_test.go delete mode 100644 query/neo4j/rewrite.go create mode 100644 query/query.go create mode 100644 query/query_test.go create mode 100644 query/util.go rename dawgs.go => registry.go (50%) create mode 100644 registry_integration_test.go diff --git a/cypher/models/cypher/model.go b/cypher/models/cypher/model.go index 32cccb4..edbe618 100644 --- a/cypher/models/cypher/model.go +++ b/cypher/models/cypher/model.go @@ -770,6 +770,12 @@ type Literal struct { } func NewLiteral(value any, null bool) *Literal { + if !null { + if strValue, typeOK := value.(string); typeOK { + return NewStringLiteral(strValue) + } + } + return &Literal{ Value: value, Null: null, @@ -777,7 +783,10 @@ func NewLiteral(value any, null bool) *Literal { } func NewStringLiteral(value string) *Literal { - return NewLiteral("'"+value+"'", false) + return &Literal{ + Value: "'" + value + "'", + Null: false, + } } func (s *Literal) copy() *Literal { diff --git a/cypher/models/pgsql/format/format.go b/cypher/models/pgsql/format/format.go index 5aebd67..8a1d30a 100644 --- a/cypher/models/pgsql/format/format.go +++ b/cypher/models/pgsql/format/format.go @@ -4,6 +4,7 @@ import ( "fmt" "strconv" "strings" + "time" "github.com/specterops/dawgs/cypher/models/pgsql" ) @@ -70,6 +71,9 @@ func formatSlice[T any, TS []T](builder *OutputBuilder, slice TS, dataType pgsql func formatValue(builder *OutputBuilder, value any) error { switch typedValue := value.(type) { + case time.Time: + builder.Write("'", typedValue.Format(time.RFC3339Nano), "'::timestamp with time zone") + case uint: builder.Write(strconv.FormatUint(uint64(typedValue), 10)) @@ -116,7 +120,11 @@ func formatValue(builder *OutputBuilder, value any) error { return formatSlice(builder, typedValue, pgsql.Int8Array) case string: - builder.Write("'", typedValue, "'") + // Double single quotes per SQL string literal rules + builder.Write("'", strings.ReplaceAll(typedValue, "'", "''"), "'") + + case []string: + return formatSlice(builder, typedValue, pgsql.TextArray) case bool: builder.Write(strconv.FormatBool(typedValue)) diff --git a/cypher/models/pgsql/test/testcase.go b/cypher/models/pgsql/test/testcase.go index f924318..f52bfc5 100644 --- a/cypher/models/pgsql/test/testcase.go +++ b/cypher/models/pgsql/test/testcase.go @@ -13,14 +13,13 @@ import ( "testing" "time" - "github.com/specterops/dawgs/drivers/pg" - "cuelang.org/go/pkg/regexp" "github.com/specterops/dawgs/cypher/frontend" "github.com/specterops/dawgs/cypher/models/cypher" "github.com/specterops/dawgs/cypher/models/pgsql" "github.com/specterops/dawgs/cypher/models/pgsql/translate" "github.com/specterops/dawgs/cypher/models/walk" + "github.com/specterops/dawgs/database" "github.com/stretchr/testify/require" ) @@ -183,7 +182,7 @@ func (s *TranslationTestCase) Assert(t *testing.T, expectedSQL string, kindMappe } } -func (s *TranslationTestCase) AssertLive(ctx context.Context, t *testing.T, driver *pg.Driver) { +func (s *TranslationTestCase) AssertLive(ctx context.Context, t *testing.T, db database.Instance) { if regularQuery, err := frontend.ParseCypher(frontend.NewContext(), s.Cypher); err != nil { t.Fatalf("Failed to compile cypher query: %s - %v", s.Cypher, err) } else { @@ -200,13 +199,15 @@ func (s *TranslationTestCase) AssertLive(ctx context.Context, t *testing.T, driv } } - if translation, err := translate.Translate(context.Background(), regularQuery, driver.KindMapper(), s.CypherParams); err != nil { - t.Fatalf("Failed to translate cypher query: %s - %v", s.Cypher, err) - } else if formattedQuery, err := translate.Translated(translation); err != nil { - t.Fatalf("Failed to format SQL translatedQuery: %v", err) - } else { - require.Nil(t, driver.Run(ctx, "explain "+formattedQuery, translation.Parameters)) - } + require.NoError(t, db.Session(ctx, func(ctx context.Context, driver database.Driver) error { + result := driver.Explain(ctx, regularQuery, s.CypherParams) + + if err := result.Close(ctx); err != nil { + return err + } + + return result.Error() + })) } } diff --git a/cypher/models/pgsql/test/translation_test.go b/cypher/models/pgsql/test/translation_test.go index 26aaa9a..143558e 100644 --- a/cypher/models/pgsql/test/translation_test.go +++ b/cypher/models/pgsql/test/translation_test.go @@ -7,7 +7,7 @@ import ( "strings" "testing" - "github.com/specterops/dawgs/drivers/pg/pgutil" + "github.com/specterops/dawgs/database/pg/pgutil" "github.com/specterops/dawgs/cypher/models/pgsql" "github.com/specterops/dawgs/graph" diff --git a/cypher/models/pgsql/test/validation_integration_test.go b/cypher/models/pgsql/test/validation_integration_test.go index fec493d..918ceb4 100644 --- a/cypher/models/pgsql/test/validation_integration_test.go +++ b/cypher/models/pgsql/test/validation_integration_test.go @@ -9,10 +9,10 @@ import ( "runtime/debug" "testing" - "github.com/jackc/pgx/v5/pgxpool" + "github.com/specterops/dawgs/database" "github.com/specterops/dawgs" - "github.com/specterops/dawgs/drivers/pg" + "github.com/specterops/dawgs/database/pg" "github.com/specterops/dawgs/graph" "github.com/specterops/dawgs/util/size" "github.com/stretchr/testify/require" @@ -32,34 +32,25 @@ func TestTranslationTestCases(t *testing.T) { require.NotEmpty(t, pgConnectionStr) - if pgxPool, err := pgxpool.New(testCtx, pgConnectionStr); err != nil { - t.Fatalf("Failed opening database connection: %v", err) - } else if connection, err := dawgs.Open(context.TODO(), pg.DriverName, dawgs.Config{ + if connection, err := dawgs.Open(context.TODO(), pg.DriverName, dawgs.Config{ GraphQueryMemoryLimit: size.Gibibyte, - Pool: pgxPool, + ConnectionString: pgConnectionStr, }); err != nil { t.Fatalf("Failed opening database connection: %v", err) - } else if pgConnection, typeOK := connection.(*pg.Driver); !typeOK { - t.Fatalf("Invalid connection type: %T", connection) } else { defer connection.Close(testCtx) - graphSchema := graph.Schema{ - Graphs: []graph.Graph{{ - Name: "test", - Nodes: graph.Kinds{ - graph.StringKind("NodeKind1"), - graph.StringKind("NodeKind2"), - }, - Edges: graph.Kinds{ - graph.StringKind("EdgeKind1"), - graph.StringKind("EdgeKind2"), - }, - }}, - DefaultGraph: graph.Graph{ - Name: "test", + graphSchema := database.NewSchema("test", database.Graph{ + Name: "test", + Nodes: graph.Kinds{ + graph.StringKind("NodeKind1"), + graph.StringKind("NodeKind2"), }, - } + Edges: graph.Kinds{ + graph.StringKind("EdgeKind1"), + graph.StringKind("EdgeKind2"), + }, + }) if err := connection.AssertSchema(testCtx, graphSchema); err != nil { t.Fatalf("Failed asserting graph schema: %v", err) @@ -79,7 +70,7 @@ func TestTranslationTestCases(t *testing.T) { } }() - testCase.AssertLive(testCtx, t, pgConnection) + testCase.AssertLive(testCtx, t, connection) }) casesRun += 1 diff --git a/cypher/models/pgsql/visualization/visualizer_test.go b/cypher/models/pgsql/visualization/visualizer_test.go index 08dc604..5291675 100644 --- a/cypher/models/pgsql/visualization/visualizer_test.go +++ b/cypher/models/pgsql/visualization/visualizer_test.go @@ -5,7 +5,7 @@ import ( "context" "testing" - "github.com/specterops/dawgs/drivers/pg/pgutil" + "github.com/specterops/dawgs/database/pg/pgutil" "github.com/specterops/dawgs/cypher/frontend" "github.com/specterops/dawgs/cypher/models/pgsql/translate" diff --git a/database/driver.go b/database/driver.go new file mode 100644 index 0000000..b05d0bd --- /dev/null +++ b/database/driver.go @@ -0,0 +1,82 @@ +package database + +import ( + "context" + + "github.com/specterops/dawgs/graph" + + "github.com/specterops/dawgs/cypher/models/cypher" +) + +type Option int + +const ( + OptionReadOnly Option = 0 + OptionReadWrite Option = 1 +) + +type Result interface { + HasNext(ctx context.Context) bool + Scan(scanTargets ...any) error + Error() error + Close(ctx context.Context) error + + // Values returns the next values array from the result. + // + // Deprecated: This function will be removed in future version. + Values() []any +} + +type Driver interface { + WithGraph(target Graph) Driver + + CreateNode(ctx context.Context, node *graph.Node) (graph.ID, error) + CreateRelationship(ctx context.Context, relationship *graph.Relationship) (graph.ID, error) + + Exec(ctx context.Context, query *cypher.RegularQuery, parameters map[string]any) Result + Explain(ctx context.Context, query *cypher.RegularQuery, parameters map[string]any) Result + Profile(ctx context.Context, query *cypher.RegularQuery, parameters map[string]any) Result + Mapper() graph.ValueMapper +} + +type QueryLogic func(ctx context.Context, driver Driver) error + +type Instance interface { + AssertSchema(ctx context.Context, schema Schema) error + Session(ctx context.Context, driverLogic QueryLogic, options ...Option) error + Transaction(ctx context.Context, driverLogic QueryLogic, options ...Option) error + Close(ctx context.Context) error + + // FetchKinds retrieves the complete list of kinds available to the database. + FetchKinds(ctx context.Context) (graph.Kinds, error) +} + +type errorResult struct { + err error +} + +func (s errorResult) HasNext(ctx context.Context) bool { + return false +} + +func (s errorResult) Scan(scanTargets ...any) error { + return s.err +} + +func (s errorResult) Error() error { + return s.err +} + +func (s errorResult) Values() []any { + return nil +} + +func (s errorResult) Close(ctx context.Context) error { + return nil +} + +func NewErrorResult(err error) Result { + return errorResult{ + err: err, + } +} diff --git a/database/neo4j/database.go b/database/neo4j/database.go new file mode 100644 index 0000000..96a7a1b --- /dev/null +++ b/database/neo4j/database.go @@ -0,0 +1,163 @@ +package neo4j + +import ( + "context" + "fmt" + "log/slog" + "time" + + "github.com/neo4j/neo4j-go-driver/v5/neo4j" + "github.com/specterops/dawgs" + "github.com/specterops/dawgs/database" + "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/util/channels" + "github.com/specterops/dawgs/util/size" +) + +type instance struct { + internalDriver neo4j.DriverWithContext + defaultTransactionTimeout time.Duration + limiter channels.ConcurrencyLimiter + graphQueryMemoryLimit size.Size + schemaManager *SchemaManager +} + +func New(internalDriver neo4j.DriverWithContext, cfg dawgs.Config) database.Instance { + return &instance{ + internalDriver: internalDriver, + defaultTransactionTimeout: DefaultNeo4jTransactionTimeout, + limiter: channels.NewConcurrencyLimiter(DefaultConcurrentConnections), + graphQueryMemoryLimit: cfg.GraphQueryMemoryLimit, + schemaManager: NewSchemaManager(internalDriver), + } +} + +func (s *instance) AssertSchema(ctx context.Context, schema database.Schema) error { + return s.schemaManager.AssertSchema(ctx, schema) +} + +func (s *instance) acquireInternalSession(ctx context.Context, options []database.Option) (neo4j.SessionWithContext, error) { + // Attempt to acquire a connection slot or wait for a bit until one becomes available + if !s.limiter.Acquire(ctx) { + return nil, graph.ErrConcurrentConnectionSlotTimeOut + } + + var ( + sessionCfg = neo4j.SessionConfig{ + // Default to a write enabled session if no options are supplied + AccessMode: neo4j.AccessModeWrite, + } + ) + + for _, option := range options { + if option == database.OptionReadOnly { + sessionCfg.AccessMode = neo4j.AccessModeRead + } + } + + return s.internalDriver.NewSession(ctx, sessionCfg), nil +} + +func (s *instance) Session(ctx context.Context, driverLogic database.QueryLogic, options ...database.Option) error { + if session, err := s.acquireInternalSession(ctx, options); err != nil { + return err + } else { + // Release the connection slot when this function exits + defer s.limiter.Release() + + defer func() { + if err := session.Close(ctx); err != nil { + slog.DebugContext(ctx, "failed to close session", slog.String("err", err.Error())) + } + }() + + return driverLogic(ctx, newInternalDriver(&sessionDriver{ + session: session, + })) + } +} + +func (s *instance) Transaction(ctx context.Context, driverLogic database.QueryLogic, options ...database.Option) error { + if session, err := s.acquireInternalSession(ctx, options); err != nil { + return err + } else { + // Release the connection slot when this function exits + defer s.limiter.Release() + + defer func() { + if err := session.Close(ctx); err != nil { + slog.DebugContext(ctx, "failed to close session", slog.String("err", err.Error())) + } + }() + + // Acquire a new transaction + if transaction, err := session.BeginTransaction(ctx); err != nil { + return err + } else { + defer func() { + if err := transaction.Rollback(ctx); err != nil { + slog.DebugContext(ctx, "failed to rollback transaction", slog.String("err", err.Error())) + } + }() + + if err := driverLogic(ctx, newInternalDriver(transaction)); err != nil { + return err + } + + return transaction.Commit(ctx) + } + } +} + +func (s *instance) FetchKinds(ctx context.Context) (graph.Kinds, error) { + var kinds graph.Kinds + + if session, err := s.acquireInternalSession(ctx, nil); err != nil { + return nil, err + } else { + // Release the connection slot when this function exits + defer s.limiter.Release() + + defer func() { + if err := session.Close(ctx); err != nil { + slog.DebugContext(ctx, "failed to close session", slog.String("err", err.Error())) + } + }() + + consumeKindResult := func(result neo4j.ResultWithContext) error { + defer result.Consume(ctx) + + for result.Next(ctx) { + values := result.Record().Values + + if len(values) == 0 { + return fmt.Errorf("expected at least one value for labels") + } else if kindStr, typeOK := values[0].(string); !typeOK { + return fmt.Errorf("unexpected label type from Neo4j: %T", values[0]) + } else { + kinds = append(kinds, graph.StringKind(kindStr)) + } + } + + return nil + } + + if result, err := session.Run(ctx, "call db.labels();", nil); err != nil { + return nil, err + } else if err := consumeKindResult(result); err != nil { + return nil, err + } + + if result, err := session.Run(ctx, "call db.relationshipTypes();", nil); err != nil { + return nil, err + } else if err := consumeKindResult(result); err != nil { + return nil, err + } + } + + return kinds, nil +} + +func (s *instance) Close(ctx context.Context) error { + return s.internalDriver.Close(ctx) +} diff --git a/database/neo4j/database_integration_test.go b/database/neo4j/database_integration_test.go new file mode 100644 index 0000000..7183b34 --- /dev/null +++ b/database/neo4j/database_integration_test.go @@ -0,0 +1 @@ +package neo4j_test diff --git a/database/neo4j/driver.go b/database/neo4j/driver.go new file mode 100644 index 0000000..c84b993 --- /dev/null +++ b/database/neo4j/driver.go @@ -0,0 +1,202 @@ +package neo4j + +import ( + "context" + "fmt" + + "github.com/neo4j/neo4j-go-driver/v5/neo4j" + "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/cypher/models/cypher/format" + "github.com/specterops/dawgs/database" + "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/query" +) + +var ( + resultValueMapper = newValueMapper() +) + +type sessionResult struct { + result neo4j.ResultWithContext + nextRecord *neo4j.Record + mapper graph.ValueMapper + err error +} + +func (s *sessionResult) Values() []any { + return s.nextRecord.Values +} + +func newResult(result neo4j.ResultWithContext, err error) database.Result { + return &sessionResult{ + result: result, + mapper: resultValueMapper, + err: err, + } +} + +func (s *sessionResult) HasNext(ctx context.Context) bool { + if s.err != nil { + return false + } + + hasNext := s.result.NextRecord(ctx, &s.nextRecord) + + if !hasNext { + s.err = s.result.Err() + } + + return hasNext +} + +func (s *sessionResult) Scan(scanTargets ...any) error { + if s.err != nil { + return s.err + } + + if len(scanTargets) != len(s.nextRecord.Values) { + return fmt.Errorf("expected to scan %d values but received %d to map to", len(s.nextRecord.Values), len(scanTargets)) + } + + for idx, nextTarget := range scanTargets { + nextValue := s.nextRecord.Values[idx] + + if !s.mapper.Map(nextValue, nextTarget) { + return fmt.Errorf("unable to scan type %T into target type %T", nextValue, nextTarget) + } + } + + return nil +} + +func (s *sessionResult) Close(ctx context.Context) error { + if s.err != nil { + return s.err + } + + _, err := s.result.Consume(ctx) + return err +} + +func (s *sessionResult) Error() error { + return s.err +} + +type neo4jDriver interface { + Run(ctx context.Context, cypher string, params map[string]any) (neo4j.ResultWithContext, error) +} + +type sessionDriver struct { + session neo4j.SessionWithContext +} + +func (s *sessionDriver) Run(ctx context.Context, cypher string, params map[string]any) (neo4j.ResultWithContext, error) { + return s.session.Run(ctx, cypher, params) +} + +type dawgsDriver struct { + internalDriver neo4jDriver +} + +func (s *dawgsDriver) Mapper() graph.ValueMapper { + return resultValueMapper +} + +func (s *dawgsDriver) CreateNode(ctx context.Context, node *graph.Node) (graph.ID, error) { + if builtQuery, err := query.New().Create( + query.Node().NodePattern(node.Kinds, node.Properties.MapOrEmpty()), + ).Return( + query.Node().ID(), + ).Build(); err != nil { + return 0, err + } else { + var ( + newEntityID graph.ID + result = s.Exec(ctx, builtQuery.Query, builtQuery.Parameters) + ) + + defer result.Close(ctx) + + if !result.HasNext(ctx) { + return 0, graph.ErrNoResultsFound + } + + if err := result.Scan(&newEntityID); err != nil { + return 0, err + } + + return newEntityID, result.Error() + } +} + +func (s *dawgsDriver) CreateRelationship(ctx context.Context, relationship *graph.Relationship) (graph.ID, error) { + if builtQuery, err := query.New().Where( + query.And( + query.Start().ID().Equals(relationship.StartID), + query.End().ID().Equals(relationship.EndID), + ), + ).Create( + query.Relationship().RelationshipPattern(relationship.Kind, relationship.Properties.MapOrEmpty(), graph.DirectionOutbound), + ).Return( + query.Relationship().ID(), + ).Build(); err != nil { + return 0, err + } else { + var ( + newEntityID graph.ID + result = s.Exec(ctx, builtQuery.Query, builtQuery.Parameters) + ) + + defer result.Close(ctx) + + if !result.HasNext(ctx) { + return 0, graph.ErrNoResultsFound + } + + if err := result.Scan(&newEntityID); err != nil { + return 0, err + } + + return newEntityID, result.Error() + } +} + +func newInternalDriver(internalDriver neo4jDriver) *dawgsDriver { + return &dawgsDriver{ + internalDriver: internalDriver, + } +} + +func (s *dawgsDriver) WithGraph(target database.Graph) database.Driver { + // NOOP for now + return s +} + +func (s *dawgsDriver) exec(ctx context.Context, cypherQuery string, parameters map[string]any) database.Result { + internalResult, err := s.internalDriver.Run(ctx, cypherQuery, parameters) + return newResult(internalResult, err) +} + +func (s *dawgsDriver) Exec(ctx context.Context, query *cypher.RegularQuery, parameters map[string]any) database.Result { + if cypherQuery, err := format.RegularQuery(query, false); err != nil { + return database.NewErrorResult(err) + } else { + return s.exec(ctx, cypherQuery, parameters) + } +} + +func (s *dawgsDriver) Explain(ctx context.Context, query *cypher.RegularQuery, parameters map[string]any) database.Result { + if cypherQuery, err := format.RegularQuery(query, false); err != nil { + return database.NewErrorResult(err) + } else { + return s.exec(ctx, "explain "+cypherQuery, parameters) + } +} + +func (s *dawgsDriver) Profile(ctx context.Context, query *cypher.RegularQuery, parameters map[string]any) database.Result { + if cypherQuery, err := format.RegularQuery(query, false); err != nil { + return database.NewErrorResult(err) + } else { + return s.exec(ctx, "profile "+cypherQuery, parameters) + } +} diff --git a/drivers/neo4j/mapper.go b/database/neo4j/mapper.go similarity index 51% rename from drivers/neo4j/mapper.go rename to database/neo4j/mapper.go index 401ae43..d27a58a 100644 --- a/drivers/neo4j/mapper.go +++ b/database/neo4j/mapper.go @@ -3,11 +3,13 @@ package neo4j import ( "time" + "github.com/neo4j/neo4j-go-driver/v5/neo4j" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/dbtype" "github.com/specterops/dawgs/graph" ) -func AsTime(value any) (time.Time, bool) { +func asTime(value any) (time.Time, bool) { switch typedValue := value.(type) { case dbtype.Time: return typedValue.Time(), true @@ -26,10 +28,50 @@ func AsTime(value any) (time.Time, bool) { } } +func newNode(internalNode neo4j.Node) *graph.Node { + var propertiesInst = internalNode.Props + + if propertiesInst == nil { + propertiesInst = make(map[string]any) + } + + return graph.NewNode(graph.ID(internalNode.Id), graph.AsProperties(propertiesInst), graph.StringsToKinds(internalNode.Labels)...) +} + +func newRelationship(internalRelationship neo4j.Relationship) *graph.Relationship { + propertiesInst := internalRelationship.Props + + if propertiesInst == nil { + propertiesInst = make(map[string]any) + } + + return graph.NewRelationship( + graph.ID(internalRelationship.Id), + graph.ID(internalRelationship.StartId), + graph.ID(internalRelationship.EndId), + graph.AsProperties(propertiesInst), + graph.StringKind(internalRelationship.Type), + ) +} + +func newPath(internalPath neo4j.Path) graph.Path { + path := graph.Path{} + + for _, node := range internalPath.Nodes { + path.Nodes = append(path.Nodes, newNode(node)) + } + + for _, relationship := range internalPath.Relationships { + path.Edges = append(path.Edges, newRelationship(relationship)) + } + + return path +} + func mapValue(rawValue, target any) bool { switch typedTarget := target.(type) { case *time.Time: - if value, typeOK := AsTime(rawValue); typeOK { + if value, typeOK := asTime(rawValue); typeOK { *typedTarget = value return true } @@ -68,6 +110,6 @@ func mapValue(rawValue, target any) bool { return false } -func NewValueMapper() graph.ValueMapper { +func newValueMapper() graph.ValueMapper { return graph.NewValueMapper(mapValue) } diff --git a/database/neo4j/neo4j.go b/database/neo4j/neo4j.go new file mode 100644 index 0000000..28768c2 --- /dev/null +++ b/database/neo4j/neo4j.go @@ -0,0 +1,48 @@ +package neo4j + +import ( + "context" + "fmt" + "net/url" + "time" + + "github.com/neo4j/neo4j-go-driver/v5/neo4j" + "github.com/specterops/dawgs" + "github.com/specterops/dawgs/database" +) + +const ( + DefaultNeo4jTransactionTimeout = time.Minute * 15 + DefaultBatchWriteSize = 20_000 + DefaultWriteFlushSize = DefaultBatchWriteSize * 5 + + // DefaultConcurrentConnections defines the default number of concurrent graph database connections allowed. + DefaultConcurrentConnections = 50 + + Neo4jConnectionScheme = "neo4j" + DriverName = "neo4j_v2" +) + +func newNeo4jDB(ctx context.Context, cfg dawgs.Config) (database.Instance, error) { + if connectionURL, err := url.Parse(cfg.ConnectionString); err != nil { + return nil, err + } else if connectionURL.Scheme != Neo4jConnectionScheme { + return nil, fmt.Errorf("expected connection URL scheme %s for Neo4J but got %s", Neo4jConnectionScheme, connectionURL.Scheme) + } else if password, isSet := connectionURL.User.Password(); !isSet { + return nil, fmt.Errorf("no password provided in connection URL") + } else { + boltURL := fmt.Sprintf("bolt://%s:%s", connectionURL.Hostname(), connectionURL.Port()) + + if internalDriver, err := neo4j.NewDriverWithContext(boltURL, neo4j.BasicAuth(connectionURL.User.Username(), password, "")); err != nil { + return nil, fmt.Errorf("unable to connect to Neo4J: %w", err) + } else { + return New(internalDriver, cfg), nil + } + } +} + +func init() { + dawgs.Register(DriverName, func(ctx context.Context, cfg dawgs.Config) (database.Instance, error) { + return newNeo4jDB(ctx, cfg) + }) +} diff --git a/database/neo4j/schema.go b/database/neo4j/schema.go new file mode 100644 index 0000000..3586f5a --- /dev/null +++ b/database/neo4j/schema.go @@ -0,0 +1,334 @@ +package neo4j + +import ( + "context" + "errors" + "fmt" + "log/slog" + "strings" + + "github.com/neo4j/neo4j-go-driver/v5/neo4j" + "github.com/specterops/dawgs/database" + + "github.com/specterops/dawgs/graph" +) + +const ( + nativeBTreeIndexProvider = "native-btree-1.0" + nativeLuceneIndexProvider = "lucene+native-3.0" + + dropPropertyIndexStatement = "drop index $name;" + dropPropertyConstraintStatement = "drop constraint $name;" + createPropertyIndexStatement = "call db.createIndex($name, $labels, $properties, $provider);" + createPropertyConstraintStatement = "call db.createUniquePropertyConstraint($name, $labels, $properties, $provider);" +) + +type neo4jIndex struct { + database.Index + + kind graph.Kind +} + +type neo4jConstraint struct { + database.Constraint + + kind graph.Kind +} + +type neo4jSchema struct { + Indexes map[string]neo4jIndex + Constraints map[string]neo4jConstraint +} + +func newNeo4jSchema() neo4jSchema { + return neo4jSchema{ + Indexes: map[string]neo4jIndex{}, + Constraints: map[string]neo4jConstraint{}, + } +} + +func toNeo4jSchema(dbSchema database.Schema) neo4jSchema { + neo4jSchemaInst := newNeo4jSchema() + + for _, graphSchema := range dbSchema.GraphSchemas { + for _, index := range graphSchema.NodeIndexes { + for _, kind := range graphSchema.Nodes { + indexName := strings.ToLower(kind.String()) + "_" + strings.ToLower(index.Field) + "_index" + + neo4jSchemaInst.Indexes[indexName] = neo4jIndex{ + Index: database.Index{ + Name: indexName, + Field: index.Field, + Type: index.Type, + }, + kind: kind, + } + } + } + + for _, constraint := range graphSchema.NodeConstraints { + for _, kind := range graphSchema.Nodes { + constraintName := strings.ToLower(kind.String()) + "_" + strings.ToLower(constraint.Field) + "_constraint" + + neo4jSchemaInst.Constraints[constraintName] = neo4jConstraint{ + Constraint: database.Constraint{ + Name: constraintName, + Field: constraint.Field, + Type: constraint.Type, + }, + kind: kind, + } + } + } + } + + return neo4jSchemaInst +} + +func parseProviderType(provider string) database.IndexType { + switch provider { + case nativeBTreeIndexProvider: + return database.IndexTypeBTree + case nativeLuceneIndexProvider: + return database.IndexTypeTextSearch + default: + return database.IndexTypeUnsupported + } +} + +func indexTypeProvider(indexType database.IndexType) string { + switch indexType { + case database.IndexTypeBTree: + return nativeBTreeIndexProvider + case database.IndexTypeTextSearch: + return nativeLuceneIndexProvider + default: + return "" + } +} + +type SchemaManager struct { + internalDriver neo4j.DriverWithContext +} + +func NewSchemaManager(internalDriver neo4j.DriverWithContext) *SchemaManager { + return &SchemaManager{ + internalDriver: internalDriver, + } +} + +func (s *SchemaManager) transaction(ctx context.Context, transactionHandler func(transaction neo4j.ExplicitTransaction) error) error { + session := s.internalDriver.NewSession(ctx, neo4j.SessionConfig{ + AccessMode: neo4j.AccessModeWrite, + }) + + defer func() { + if err := session.Close(ctx); err != nil { + slog.DebugContext(ctx, "failed to close session", slog.String("err", err.Error())) + } + }() + + // Acquire a new transaction + if transaction, err := session.BeginTransaction(ctx); err != nil { + return err + } else { + defer func() { + // Neo4j's error types make detecting if this is a rollback after close very difficult. Because this is a + // debug log output, accept the potential verbosity. + if err := transaction.Rollback(ctx); err != nil { + slog.DebugContext(ctx, "failed to rollback transaction", slog.String("err", err.Error())) + } + }() + + if err := transactionHandler(transaction); err != nil { + return err + } + + return transaction.Commit(ctx) + } +} + +func (s *SchemaManager) assertIndexes(ctx context.Context, indexesToRemove []string, indexesToAdd map[string]neo4jIndex) error { + if err := s.transaction(ctx, func(transaction neo4j.ExplicitTransaction) error { + for _, indexToRemove := range indexesToRemove { + slog.InfoContext(ctx, fmt.Sprintf("Removing index %s", indexToRemove)) + + if result, err := transaction.Run(ctx, strings.Replace(dropPropertyIndexStatement, "$name", indexToRemove, 1), nil); err != nil { + return err + } else if _, err := result.Consume(ctx); err != nil { + return err + } else if err := result.Err(); err != nil { + return err + } + } + + return nil + }); err != nil { + return err + } + + return s.transaction(ctx, func(transaction neo4j.ExplicitTransaction) error { + for indexName, indexToAdd := range indexesToAdd { + slog.InfoContext(ctx, fmt.Sprintf("Adding index %s to labels %s on properties %s using %s", indexName, indexToAdd.kind.String(), indexToAdd.Field, indexTypeProvider(indexToAdd.Type))) + + if result, err := transaction.Run(ctx, createPropertyIndexStatement, map[string]interface{}{ + "name": indexName, + "labels": []string{indexToAdd.kind.String()}, + "properties": []string{indexToAdd.Field}, + "provider": indexTypeProvider(indexToAdd.Type), + }); err != nil { + return err + } else if _, err := result.Consume(ctx); err != nil { + return err + } else if err := result.Err(); err != nil { + return err + } + } + + return nil + }) +} + +func (s *SchemaManager) assertConstraints(ctx context.Context, constraintsToRemove []string, constraintsToAdd map[string]neo4jConstraint) error { + return s.transaction(ctx, func(transaction neo4j.ExplicitTransaction) error { + for _, constraintToRemove := range constraintsToRemove { + if result, err := transaction.Run(ctx, strings.Replace(dropPropertyConstraintStatement, "$name", constraintToRemove, 1), nil); err != nil { + return err + } else if _, err := result.Consume(ctx); err != nil { + return err + } else if err := result.Err(); err != nil { + return err + } + } + + for constraintName, constraintToAdd := range constraintsToAdd { + if result, err := transaction.Run(ctx, createPropertyConstraintStatement, map[string]interface{}{ + "name": constraintName, + "labels": []string{constraintToAdd.kind.String()}, + "properties": []string{constraintToAdd.Field}, + "provider": indexTypeProvider(constraintToAdd.Type), + }); err != nil { + return err + } else if _, err := result.Consume(ctx); err != nil { + return err + } else if err := result.Err(); err != nil { + return err + } + } + + return nil + }) +} + +func (s *SchemaManager) fetchPresentSchema(ctx context.Context) (neo4jSchema, error) { + var ( + presentSchema = newNeo4jSchema() + err = s.transaction(ctx, func(transaction neo4j.ExplicitTransaction) error { + if result, err := transaction.Run(ctx, "call db.indexes() yield name, uniqueness, provider, labelsOrTypes, properties;", nil); err != nil { + return err + } else { + defer result.Consume(ctx) + + var ( + name string + uniqueness string + provider string + labels []string + properties []string + ) + + for result.Next(ctx) { + values := result.Record().Values + + if !newValueMapper().MapAll(values, []any{&name, &uniqueness, &provider, &labels, &properties}) { + return errors.New("failed to map present schema") + } + + // Need this for neo4j 4.4+ which creates a weird index by default + if len(labels) == 0 { + continue + } + + if len(labels) > 1 || len(properties) > 1 { + return fmt.Errorf("composite index types are currently not supported") + } + + if uniqueness == "UNIQUE" { + presentSchema.Constraints[name] = neo4jConstraint{ + Constraint: database.Constraint{ + Name: name, + Field: properties[0], + Type: parseProviderType(provider), + }, + kind: graph.StringKind(labels[0]), + } + } else { + presentSchema.Indexes[name] = neo4jIndex{ + Index: database.Index{ + Name: name, + Field: properties[0], + Type: parseProviderType(provider), + }, + kind: graph.StringKind(labels[0]), + } + } + } + + return result.Err() + } + }) + ) + + return presentSchema, err +} + +func (s *SchemaManager) AssertSchema(ctx context.Context, required database.Schema) error { + requiredNeo4jSchema := toNeo4jSchema(required) + + if presentNeo4jSchema, err := s.fetchPresentSchema(ctx); err != nil { + return err + } else { + var ( + indexesToRemove []string + constraintsToRemove []string + indexesToAdd = map[string]neo4jIndex{} + constraintsToAdd = map[string]neo4jConstraint{} + ) + + for presentIndexName := range presentNeo4jSchema.Indexes { + if _, hasMatchingDefinition := requiredNeo4jSchema.Indexes[presentIndexName]; !hasMatchingDefinition { + indexesToRemove = append(indexesToRemove, presentIndexName) + } + } + + for presentConstraintName := range presentNeo4jSchema.Constraints { + if _, hasMatchingDefinition := requiredNeo4jSchema.Constraints[presentConstraintName]; !hasMatchingDefinition { + constraintsToRemove = append(constraintsToRemove, presentConstraintName) + } + } + + for requiredIndexName, requiredIndex := range requiredNeo4jSchema.Indexes { + if presentIndex, hasMatchingDefinition := presentNeo4jSchema.Indexes[requiredIndexName]; !hasMatchingDefinition { + indexesToAdd[requiredIndexName] = requiredIndex + } else if requiredIndex.Type != presentIndex.Type { + indexesToRemove = append(indexesToRemove, requiredIndexName) + indexesToAdd[requiredIndexName] = requiredIndex + } + } + + for requiredConstraintName, requiredConstraint := range requiredNeo4jSchema.Constraints { + if presentConstraint, hasMatchingDefinition := presentNeo4jSchema.Constraints[requiredConstraintName]; !hasMatchingDefinition { + constraintsToAdd[requiredConstraintName] = requiredConstraint + } else if requiredConstraint.Type != presentConstraint.Type { + constraintsToRemove = append(constraintsToRemove, requiredConstraintName) + constraintsToAdd[requiredConstraintName] = requiredConstraint + } + } + + if err := s.assertConstraints(ctx, constraintsToRemove, constraintsToAdd); err != nil { + return err + } + + return s.assertIndexes(ctx, indexesToRemove, indexesToAdd) + } +} diff --git a/database/pg/batch.go b/database/pg/batch.go new file mode 100644 index 0000000..d5789f3 --- /dev/null +++ b/database/pg/batch.go @@ -0,0 +1,379 @@ +package pg + +import ( + "bytes" + "context" + "fmt" + "strconv" + "strings" + + "github.com/jackc/pgtype" + "github.com/specterops/dawgs/cypher/models/pgsql" + "github.com/specterops/dawgs/database/pg/model" + "github.com/specterops/dawgs/database/pg/query" + "github.com/specterops/dawgs/graph" +) + +type Int2ArrayEncoder struct { + buffer *bytes.Buffer +} + +func NewInt2ArrayEncoder() Int2ArrayEncoder { + return Int2ArrayEncoder{ + buffer: &bytes.Buffer{}, + } +} + +func (s *Int2ArrayEncoder) Encode(values []int16) string { + s.buffer.Reset() + s.buffer.WriteRune('{') + + for idx, value := range values { + if idx > 0 { + s.buffer.WriteRune(',') + } + + s.buffer.WriteString(strconv.Itoa(int(value))) + } + + s.buffer.WriteRune('}') + return s.buffer.String() +} + +type NodeUpsertParameters struct { + IDFutures []*query.Future[graph.ID] + KindIDSlices []string + Properties []pgtype.JSONB +} + +func NewNodeUpsertParameters(size int) *NodeUpsertParameters { + return &NodeUpsertParameters{ + IDFutures: make([]*query.Future[graph.ID], 0, size), + KindIDSlices: make([]string, 0, size), + Properties: make([]pgtype.JSONB, 0, size), + } +} + +func (s *NodeUpsertParameters) Format(graphTarget model.Graph) []any { + return []any{ + graphTarget.ID, + s.KindIDSlices, + s.Properties, + } +} + +func (s *NodeUpsertParameters) Append(ctx context.Context, update *query.NodeUpdate, schemaManager *SchemaManager, kindIDEncoder Int2ArrayEncoder) error { + s.IDFutures = append(s.IDFutures, update.IDFuture) + + if mappedKindIDs, err := schemaManager.AssertKinds(ctx, update.Node.Kinds); err != nil { + return fmt.Errorf("unable to map kinds %w", err) + } else { + s.KindIDSlices = append(s.KindIDSlices, kindIDEncoder.Encode(mappedKindIDs)) + } + + if propertiesJSONB, err := pgsql.PropertiesToJSONB(update.Node.Properties); err != nil { + return err + } else { + s.Properties = append(s.Properties, propertiesJSONB) + } + + return nil +} + +func (s *NodeUpsertParameters) AppendAll(ctx context.Context, updates *query.NodeUpdateBatch, schemaManager *SchemaManager, kindIDEncoder Int2ArrayEncoder) error { + for _, nextUpdate := range updates.Updates { + if err := s.Append(ctx, nextUpdate, schemaManager, kindIDEncoder); err != nil { + return err + } + } + + return nil +} + +type RelationshipUpdateByParameters struct { + StartIDs []graph.ID + EndIDs []graph.ID + KindIDs []int16 + Properties []pgtype.JSONB +} + +func NewRelationshipUpdateByParameters(size int) *RelationshipUpdateByParameters { + return &RelationshipUpdateByParameters{ + StartIDs: make([]graph.ID, 0, size), + EndIDs: make([]graph.ID, 0, size), + KindIDs: make([]int16, 0, size), + Properties: make([]pgtype.JSONB, 0, size), + } +} + +func (s *RelationshipUpdateByParameters) Format(graphTarget model.Graph) []any { + return []any{ + graphTarget.ID, + s.StartIDs, + s.EndIDs, + s.KindIDs, + s.Properties, + } +} + +func (s *RelationshipUpdateByParameters) Append(ctx context.Context, update *query.RelationshipUpdate, schemaManager *SchemaManager) error { + s.StartIDs = append(s.StartIDs, update.StartID.Value) + s.EndIDs = append(s.EndIDs, update.EndID.Value) + + if mappedKindIDs, err := schemaManager.AssertKinds(ctx, []graph.Kind{update.Relationship.Kind}); err != nil { + return err + } else { + s.KindIDs = append(s.KindIDs, mappedKindIDs...) + } + + if propertiesJSONB, err := pgsql.PropertiesToJSONB(update.Relationship.Properties); err != nil { + return err + } else { + s.Properties = append(s.Properties, propertiesJSONB) + } + return nil +} + +func (s *RelationshipUpdateByParameters) AppendAll(ctx context.Context, updates *query.RelationshipUpdateBatch, schemaManager *SchemaManager) error { + for _, nextUpdate := range updates.Updates { + if err := s.Append(ctx, nextUpdate, schemaManager); err != nil { + return err + } + } + + return nil +} + +type relationshipCreateBatch struct { + startIDs []uint64 + endIDs []uint64 + edgeKindIDs []int16 + edgePropertyBags []pgtype.JSONB +} + +func newRelationshipCreateBatch(size int) *relationshipCreateBatch { + return &relationshipCreateBatch{ + startIDs: make([]uint64, 0, size), + endIDs: make([]uint64, 0, size), + edgeKindIDs: make([]int16, 0, size), + edgePropertyBags: make([]pgtype.JSONB, 0, size), + } +} + +func (s *relationshipCreateBatch) Add(startID, endID uint64, edgeKindID int16) { + s.startIDs = append(s.startIDs, startID) + s.edgeKindIDs = append(s.edgeKindIDs, edgeKindID) + s.endIDs = append(s.endIDs, endID) +} + +func (s *relationshipCreateBatch) EncodeProperties(edgePropertiesBatch []*graph.Properties) error { + for _, edgeProperties := range edgePropertiesBatch { + if propertiesJSONB, err := pgsql.PropertiesToJSONB(edgeProperties); err != nil { + return err + } else { + s.edgePropertyBags = append(s.edgePropertyBags, propertiesJSONB) + } + } + + return nil +} + +type relationshipCreateBatchBuilder struct { + keyToEdgeID map[string]uint64 + relationshipUpdateBatch *relationshipCreateBatch + edgePropertiesIndex map[uint64]int + edgePropertiesBatch []*graph.Properties +} + +func newRelationshipCreateBatchBuilder(size int) *relationshipCreateBatchBuilder { + return &relationshipCreateBatchBuilder{ + keyToEdgeID: map[string]uint64{}, + relationshipUpdateBatch: newRelationshipCreateBatch(size), + edgePropertiesIndex: map[uint64]int{}, + } +} + +func (s *relationshipCreateBatchBuilder) Build() (*relationshipCreateBatch, error) { + return s.relationshipUpdateBatch, s.relationshipUpdateBatch.EncodeProperties(s.edgePropertiesBatch) +} + +func (s *relationshipCreateBatchBuilder) Add(ctx context.Context, kindMapper KindMapper, edge *graph.Relationship) error { + keyBuilder := strings.Builder{} + + keyBuilder.WriteString(edge.StartID.String()) + keyBuilder.WriteString(edge.EndID.String()) + keyBuilder.WriteString(edge.Kind.String()) + + key := keyBuilder.String() + + if existingPropertiesIdx, hasExisting := s.keyToEdgeID[key]; hasExisting { + s.edgePropertiesBatch[existingPropertiesIdx].Merge(edge.Properties) + } else { + var ( + startID = edge.StartID.Uint64() + edgeID = edge.ID.Uint64() + endID = edge.EndID.Uint64() + edgeProperties = edge.Properties.Clone() + ) + + if edgeKindID, err := kindMapper.MapKind(ctx, edge.Kind); err != nil { + return err + } else { + s.relationshipUpdateBatch.Add(startID, endID, edgeKindID) + } + + s.keyToEdgeID[key] = edgeID + + s.edgePropertiesBatch = append(s.edgePropertiesBatch, edgeProperties) + s.edgePropertiesIndex[edgeID] = len(s.edgePropertiesBatch) - 1 + } + + return nil +} + +func (s *dawgsDriver) updateNodes(ctx context.Context, validatedBatch *query.NodeUpdateBatch) error { + var ( + parameters = NewNodeUpsertParameters(len(validatedBatch.Updates)) + kindIDEncoder = NewInt2ArrayEncoder() + ) + + if err := parameters.AppendAll(ctx, validatedBatch, s.schemaManager, kindIDEncoder); err != nil { + return err + } + + if graphTarget, err := s.getTargetGraph(ctx); err != nil { + return err + } else { + nodeUpsertQuery := query.FormatNodeUpsert(graphTarget, validatedBatch.IdentityProperties) + + if result, err := s.internalConn.Query(ctx, nodeUpsertQuery, parameters.Format(graphTarget)...); err != nil { + return err + } else { + idFutureIndex := 0 + + for result.Next() { + if err := result.Scan(¶meters.IDFutures[idFutureIndex].Value); err != nil { + return err + } + + idFutureIndex++ + } + + result.Close() + return result.Err() + } + } +} + +func (s *dawgsDriver) UpdateNodes(ctx context.Context, batch []graph.NodeUpdate) error { + if validatedBatch, err := query.ValidateNodeUpdateByBatch(batch); err != nil { + return err + } else { + return s.updateNodes(ctx, validatedBatch) + } +} + +func (s *dawgsDriver) UpdateRelationships(ctx context.Context, batch []graph.RelationshipUpdate) error { + if validatedBatch, err := query.ValidateRelationshipUpdateByBatch(batch); err != nil { + return err + } else if err := s.updateNodes(ctx, validatedBatch.NodeUpdates); err != nil { + return err + } else { + parameters := NewRelationshipUpdateByParameters(len(validatedBatch.Updates)) + + if err := parameters.AppendAll(ctx, validatedBatch, s.schemaManager); err != nil { + return err + } + + if graphTarget, err := s.getTargetGraph(ctx); err != nil { + return err + } else { + relationshipUpsertQuery := query.FormatRelationshipPartitionUpsert(graphTarget, validatedBatch.IdentityProperties) + + if result, err := s.internalConn.Query(ctx, relationshipUpsertQuery, parameters.Format(graphTarget)...); err != nil { + return err + } else { + result.Close() + } + } + } + + return nil +} + +func (s *dawgsDriver) CreateNodes(ctx context.Context, batch []*graph.Node) error { + var ( + numCreates = len(batch) + kindIDSlices = make([]string, numCreates) + kindIDEncoder = Int2ArrayEncoder{ + buffer: &bytes.Buffer{}, + } + properties = make([]pgtype.JSONB, numCreates) + ) + + for idx, nextNode := range batch { + if mappedKindIDs, err := s.schemaManager.AssertKinds(ctx, nextNode.Kinds); err != nil { + return fmt.Errorf("unable to map kinds %w", err) + } else { + kindIDSlices[idx] = kindIDEncoder.Encode(mappedKindIDs) + } + + if propertiesJSONB, err := pgsql.PropertiesToJSONB(nextNode.Properties); err != nil { + return err + } else { + properties[idx] = propertiesJSONB + } + } + + if graphTarget, err := s.getTargetGraph(ctx); err != nil { + return err + } else if result, err := s.internalConn.Query(ctx, createNodeWithoutIDBatchStatement, graphTarget.ID, kindIDSlices, properties); err != nil { + return err + } else { + result.Close() + } + + return nil +} + +func (s *dawgsDriver) CreateRelationships(ctx context.Context, batch []*graph.Relationship) error { + batchBuilder := newRelationshipCreateBatchBuilder(len(batch)) + + for _, nextRel := range batch { + if err := batchBuilder.Add(ctx, s.schemaManager, nextRel); err != nil { + return err + } + } + + if createBatch, err := batchBuilder.Build(); err != nil { + return err + } else if graphTarget, err := s.getTargetGraph(ctx); err != nil { + return err + } else if result, err := s.internalConn.Query(ctx, createEdgeBatchStatement, graphTarget.ID, createBatch.startIDs, createBatch.endIDs, createBatch.edgeKindIDs, createBatch.edgePropertyBags); err != nil { + return err + } else { + result.Close() + } + + return nil +} + +func (s *dawgsDriver) DeleteNodes(ctx context.Context, batch []graph.ID) error { + if result, err := s.internalConn.Query(ctx, deleteNodeWithIDStatement, batch); err != nil { + return err + } else { + result.Close() + } + + return nil +} + +func (s *dawgsDriver) DeleteRelationships(ctx context.Context, batch []graph.ID) error { + if result, err := s.internalConn.Query(ctx, deleteEdgeWithIDStatement, batch); err != nil { + return err + } else { + result.Close() + } + + return nil +} diff --git a/database/pg/config.go b/database/pg/config.go new file mode 100644 index 0000000..7cc620f --- /dev/null +++ b/database/pg/config.go @@ -0,0 +1,7 @@ +package pg + +import "github.com/jackc/pgx/v5/pgxpool" + +type Config struct { + Pool *pgxpool.Pool +} diff --git a/database/pg/database.go b/database/pg/database.go new file mode 100644 index 0000000..a9c4165 --- /dev/null +++ b/database/pg/database.go @@ -0,0 +1,125 @@ +package pg + +import ( + "context" + "fmt" + "log/slog" + + "github.com/specterops/dawgs" + "github.com/specterops/dawgs/database" + "github.com/specterops/dawgs/database/v1compat" + "github.com/specterops/dawgs/graph" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/specterops/dawgs/util/size" +) + +func beginTx(ctx context.Context, conn *pgxpool.Conn, options []database.Option) (pgx.Tx, error) { + var ( + // Default to read-write + txAccessMode = pgx.ReadWrite + ) + + for _, option := range options { + if option == database.OptionReadOnly { + txAccessMode = pgx.ReadOnly + } + } + + return conn.BeginTx(ctx, pgx.TxOptions{ + AccessMode: txAccessMode, + }) +} + +type instance struct { + pool *pgxpool.Pool + graphQueryMemoryLimit size.Size + schemaManager *SchemaManager +} + +func (s *instance) RefreshKinds(ctx context.Context) error { + return s.schemaManager.Fetch(ctx) +} + +func (s *instance) Raw(ctx context.Context, query string, parameters map[string]any) error { + if acquiredConn, err := s.pool.Acquire(ctx); err != nil { + return err + } else { + defer acquiredConn.Release() + + _, err := acquiredConn.Exec(ctx, query, pgx.NamedArgs(parameters)) + return err + } +} + +func New(internalDriver *pgxpool.Pool, cfg dawgs.Config) v1compat.BackwardCompatibleInstance { + return &instance{ + pool: internalDriver, + graphQueryMemoryLimit: cfg.GraphQueryMemoryLimit, + schemaManager: NewSchemaManager(internalDriver), + } +} + +func (s *instance) AssertSchema(ctx context.Context, schema database.Schema) error { + return s.schemaManager.AssertSchema(ctx, schema) +} + +func (s *instance) Session(ctx context.Context, driverLogic database.QueryLogic, options ...database.Option) error { + if acquiredConn, err := s.pool.Acquire(ctx); err != nil { + return err + } else { + defer acquiredConn.Release() + return driverLogic(ctx, newInternalDriver(acquiredConn, s.schemaManager)) + } +} + +func (s *instance) Transaction(ctx context.Context, driverLogic database.QueryLogic, options ...database.Option) error { + if acquiredConn, err := s.pool.Acquire(ctx); err != nil { + return err + } else { + defer acquiredConn.Release() + + if transaction, err := beginTx(ctx, acquiredConn, options); err != nil { + return err + } else { + defer func() { + if err := transaction.Rollback(ctx); err != nil { + slog.DebugContext(ctx, "failed to rollback transaction", slog.String("err", err.Error())) + } + }() + + if err := driverLogic(ctx, newInternalDriver(transaction, s.schemaManager)); err != nil { + return err + } + + return transaction.Commit(ctx) + } + } +} + +func (s *instance) FetchKinds(ctx context.Context) (graph.Kinds, error) { + var ( + kindIDsByKind = s.schemaManager.GetKindIDsByKind() + kinds = make(graph.Kinds, 0, len(kindIDsByKind)) + ) + + for _, kind := range kindIDsByKind { + kinds = append(kinds, kind) + } + + return kinds, nil +} + +func (s *instance) Close(_ context.Context) error { + s.pool.Close() + return nil +} + +func SchemaManagerFromInstance(dbInst database.Instance) (*SchemaManager, error) { + if pgInstance, typeOK := dbInst.(*instance); !typeOK { + return nil, fmt.Errorf("dawgs pg: unable to get schema manager from instance type: %T", dbInst) + } else { + return pgInstance.schemaManager, nil + } +} diff --git a/database/pg/driver.go b/database/pg/driver.go new file mode 100644 index 0000000..b6349cd --- /dev/null +++ b/database/pg/driver.go @@ -0,0 +1,252 @@ +package pg + +import ( + "context" + "fmt" + + "github.com/jackc/pgx/v5" + "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/cypher/models/pgsql" + "github.com/specterops/dawgs/cypher/models/pgsql/translate" + "github.com/specterops/dawgs/database" + "github.com/specterops/dawgs/database/pg/model" + "github.com/specterops/dawgs/database/v1compat" + "github.com/specterops/dawgs/graph" +) + +type queryResult struct { + rows pgx.Rows + mapper graph.ValueMapper + values []any +} + +func newQueryResult(mapper graph.ValueMapper, rows pgx.Rows) database.Result { + return &queryResult{ + mapper: mapper, + rows: rows, + } +} + +func (s *queryResult) HasNext(ctx context.Context) bool { + if s.rows.Next() { + if values, err := s.rows.Values(); err != nil { + return false + } else { + s.values = values + return true + } + } + + return false +} + +func (s *queryResult) Values() []any { + return s.values +} + +func (s *queryResult) Scan(scanTargets ...any) error { + if s.values == nil { + return fmt.Errorf("no results to scan to; call HasNext()") + } + + if len(scanTargets) != len(s.values) { + return fmt.Errorf("expected to scan %d values but received %d to map to", len(s.values), len(scanTargets)) + } + + for idx, nextTarget := range scanTargets { + nextValue := s.values[idx] + + if !s.mapper.Map(nextValue, nextTarget) { + return fmt.Errorf("unable to scan type %T into target type %T", nextValue, nextTarget) + } + } + + return nil +} + +func (s *queryResult) Error() error { + return s.rows.Err() +} + +func (s *queryResult) Close(ctx context.Context) error { + s.rows.Close() + return nil +} + +type internalDriver interface { + Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) +} + +type translatedQuery struct { + SQL string + Parameters map[string]any +} + +type dawgsDriver struct { + internalConn internalDriver + queryExecMode pgx.QueryExecMode + queryResultFormats pgx.QueryResultFormats + schemaManager *SchemaManager + targetGraph *database.Graph +} + +func (s *dawgsDriver) Mapper() graph.ValueMapper { + return newValueMapper(context.TODO(), s.schemaManager) +} + +func newInternalDriver(internalConn internalDriver, schemaManager *SchemaManager) v1compat.BackwardCompatibleDriver { + return &dawgsDriver{ + internalConn: internalConn, + queryExecMode: pgx.QueryExecModeCacheStatement, + queryResultFormats: pgx.QueryResultFormats{pgx.BinaryFormatCode}, + schemaManager: schemaManager, + } +} + +func (s *dawgsDriver) getTargetGraph(ctx context.Context) (model.Graph, error) { + var targetGraph database.Graph + + if s.targetGraph != nil { + targetGraph = *s.targetGraph + } else { + if defaultGraph, hasDefaultGraph := s.schemaManager.DefaultGraph(); !hasDefaultGraph { + return model.Graph{}, fmt.Errorf("no graph target set for operation") + } else { + targetGraph = defaultGraph + } + } + + return s.schemaManager.AssertGraph(ctx, targetGraph) +} + +func (s *dawgsDriver) queryArgs(parameters map[string]any) []any { + queryArgs := []any{s.queryExecMode, s.queryResultFormats} + + if parameters != nil && len(parameters) > 0 { + queryArgs = append(queryArgs, pgx.NamedArgs(parameters)) + } + + return queryArgs +} + +func (s *dawgsDriver) executeTranslated(ctx context.Context, query translatedQuery) database.Result { + if internalResult, err := s.internalConn.Query(ctx, query.SQL, s.queryArgs(query.Parameters)...); err != nil { + return database.NewErrorResult(err) + } else { + return newQueryResult(newValueMapper(ctx, s.schemaManager), internalResult) + } +} + +func (s *dawgsDriver) translateCypherToPGSQL(ctx context.Context, query *cypher.RegularQuery, parameters map[string]any) (translatedQuery, error) { + if translated, err := translate.Translate(ctx, query, s.schemaManager, parameters); err != nil { + return translatedQuery{}, err + } else if sqlQuery, err := translate.Translated(translated); err != nil { + return translatedQuery{}, err + } else { + return translatedQuery{ + SQL: sqlQuery, + Parameters: parameters, + }, nil + } +} + +func (s *dawgsDriver) Exec(ctx context.Context, query *cypher.RegularQuery, parameters map[string]any) database.Result { + if translated, err := s.translateCypherToPGSQL(ctx, query, parameters); err != nil { + return database.NewErrorResult(err) + } else { + return s.executeTranslated(ctx, translated) + } +} + +func (s *dawgsDriver) Explain(ctx context.Context, query *cypher.RegularQuery, parameters map[string]any) database.Result { + if translated, err := s.translateCypherToPGSQL(ctx, query, parameters); err != nil { + return database.NewErrorResult(err) + } else { + translated.SQL = "explain " + translated.SQL + return s.executeTranslated(ctx, translated) + } +} + +func (s *dawgsDriver) Profile(ctx context.Context, query *cypher.RegularQuery, parameters map[string]any) database.Result { + if translated, err := s.translateCypherToPGSQL(ctx, query, parameters); err != nil { + return database.NewErrorResult(err) + } else { + translated.SQL = "explain (verbose, analyze, costs, buffers, format json) " + translated.SQL + return s.executeTranslated(ctx, translated) + } +} + +func (s *dawgsDriver) WithGraph(targetGraph database.Graph) database.Driver { + s.targetGraph = &targetGraph + return s +} + +func (s *dawgsDriver) CreateRelationship(ctx context.Context, relationship *graph.Relationship) (graph.ID, error) { + if targetGraph, err := s.getTargetGraph(ctx); err != nil { + return 0, err + } else if kindIDSlice, err := s.schemaManager.AssertKinds(ctx, graph.Kinds{relationship.Kind}); err != nil { + return 0, err + } else if propertiesJSONB, err := pgsql.PropertiesToJSONB(relationship.Properties); err != nil { + return 0, err + } else { + var ( + newEdgeID graph.ID + result = s.executeTranslated(ctx, translatedQuery{ + SQL: createEdgeStatement, + Parameters: map[string]any{ + "graph_id": targetGraph.ID, + "start_id": relationship.StartID, + "end_id": relationship.EndID, + "kind_id": kindIDSlice[0], + "properties": propertiesJSONB, + }, + }) + ) + + defer result.Close(ctx) + + if !result.HasNext(ctx) { + return 0, graph.ErrNoResultsFound + } + + if err := result.Scan(&newEdgeID); err != nil { + return 0, err + } + + return newEdgeID, result.Error() + } +} + +func (s *dawgsDriver) CreateNode(ctx context.Context, node *graph.Node) (graph.ID, error) { + if targetGraph, err := s.getTargetGraph(ctx); err != nil { + return 0, err + } else if kindIDSlice, err := s.schemaManager.AssertKinds(ctx, node.Kinds); err != nil { + return 0, err + } else if propertiesJSONB, err := pgsql.PropertiesToJSONB(node.Properties); err != nil { + return 0, err + } else { + var ( + newNodeID graph.ID + result = s.executeTranslated(ctx, translatedQuery{ + SQL: createNodeStatement, + Parameters: map[string]any{ + "graph_id": targetGraph.ID, + "kind_ids": kindIDSlice, + "properties": propertiesJSONB, + }, + }) + ) + + defer result.Close(ctx) + + if !result.HasNext(ctx) { + return 0, graph.ErrNoResultsFound + } + + if err := result.Scan(&newNodeID); err != nil { + return 0, err + } + + return newNodeID, result.Error() + } +} diff --git a/drivers/pg/facts.go b/database/pg/error.go similarity index 100% rename from drivers/pg/facts.go rename to database/pg/error.go diff --git a/drivers/pg/manager.go b/database/pg/manager.go similarity index 59% rename from drivers/pg/manager.go rename to database/pg/manager.go index ded8be6..3759301 100644 --- a/drivers/pg/manager.go +++ b/database/pg/manager.go @@ -4,13 +4,16 @@ import ( "context" "errors" "fmt" + "log/slog" "strings" "sync" + "github.com/specterops/dawgs/database" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" - "github.com/specterops/dawgs/drivers/pg/model" - "github.com/specterops/dawgs/drivers/pg/query" + "github.com/specterops/dawgs/database/pg/model" + "github.com/specterops/dawgs/database/pg/query" "github.com/specterops/dawgs/graph" ) @@ -22,17 +25,17 @@ type KindMapper interface { AssertKinds(ctx context.Context, kinds graph.Kinds) ([]int16, error) } -func KindMapperFromGraphDatabase(graphDB graph.Database) (KindMapper, error) { +func KindMapperFromGraphDatabase(graphDB database.Instance) (KindMapper, error) { switch typedGraphDB := graphDB.(type) { - case *Driver: - return typedGraphDB.SchemaManager, nil + case *instance: + return typedGraphDB.schemaManager, nil default: return nil, fmt.Errorf("unsupported graph database type: %T", typedGraphDB) } } type SchemaManager struct { - defaultGraph model.Graph + defaultGraph database.Graph pool *pgxpool.Pool hasDefaultGraph bool graphs map[string]model.Graph @@ -52,35 +55,40 @@ func NewSchemaManager(pool *pgxpool.Pool) *SchemaManager { } } -func (s *SchemaManager) WriteTransaction(ctx context.Context, txDelegate graph.TransactionDelegate, options ...graph.TransactionOption) error { - if cfg, err := renderConfig(batchWriteSize, readWriteTxOptions, options); err != nil { - return err - } else if conn, err := s.pool.Acquire(ctx); err != nil { +func (s *SchemaManager) transaction(ctx context.Context, transactionLogic func(transaction pgx.Tx) error) error { + if acquiredConn, err := s.pool.Acquire(ctx); err != nil { return err } else { - defer conn.Release() + defer acquiredConn.Release() - if tx, err := newTransactionWrapper(ctx, conn, s, cfg, true); err != nil { + if transaction, err := acquiredConn.BeginTx(ctx, pgx.TxOptions{ + AccessMode: pgx.ReadWrite, + }); err != nil { return err } else { - defer tx.Close() + defer func() { + if err := transaction.Rollback(ctx); err != nil && !errors.Is(err, pgx.ErrTxClosed) { + slog.DebugContext(ctx, "failed to rollback transaction", slog.String("err", err.Error())) + } + }() - if err := txDelegate(tx); err != nil { + if err := transactionLogic(transaction); err != nil { return err } - return tx.Commit() + return transaction.Commit(ctx) } } } -func (s *SchemaManager) fetch(tx graph.Transaction) error { - if kinds, err := query.On(tx).SelectKinds(); err != nil { +func (s *SchemaManager) fetch(ctx context.Context, tx pgx.Tx) error { + if kinds, err := query.On(tx).SelectKinds(ctx); err != nil { return err } else { s.kindsByID = kinds for kind, kindID := range s.kindsByID { + s.kindsByID[kind] = kindID s.kindIDsByKind[kindID] = kind } } @@ -91,18 +99,31 @@ func (s *SchemaManager) fetch(tx graph.Transaction) error { func (s *SchemaManager) GetKindIDsByKind() map[int16]graph.Kind { s.lock.RLock() defer s.lock.RUnlock() - return s.kindIDsByKind + + kindIDsByKindCopy := make(map[int16]graph.Kind, len(s.kindIDsByKind)) + + for k, v := range s.kindIDsByKind { + kindIDsByKindCopy[k] = v + } + + return kindIDsByKindCopy } func (s *SchemaManager) Fetch(ctx context.Context) error { - return s.WriteTransaction(ctx, func(tx graph.Transaction) error { - return s.fetch(tx) - }, OptionSetQueryExecMode(pgx.QueryExecModeSimpleProtocol)) + s.lock.Lock() + defer s.lock.Unlock() + + clear(s.kindIDsByKind) + clear(s.kindsByID) + + return s.transaction(ctx, func(transaction pgx.Tx) error { + return s.fetch(ctx, transaction) + }) } -func (s *SchemaManager) defineKinds(tx graph.Transaction, kinds graph.Kinds) error { +func (s *SchemaManager) defineKinds(ctx context.Context, tx pgx.Tx, kinds graph.Kinds) error { for _, kind := range kinds { - if kindID, err := query.On(tx).InsertOrGetKind(kind); err != nil { + if kindID, err := query.On(tx).InsertOrGetKind(ctx, kind); err != nil { return err } else { s.kindsByID[kind] = kindID @@ -139,13 +160,14 @@ func (s *SchemaManager) MapKind(ctx context.Context, kind graph.Kind) (int16, er } s.lock.RUnlock() - s.lock.Lock() - defer s.lock.Unlock() if err := s.Fetch(ctx); err != nil { return -1, err } + s.lock.RLock() + defer s.lock.RUnlock() + if id, hasID := s.kindsByID[kind]; hasID { return id, nil } else { @@ -162,36 +184,20 @@ func (s *SchemaManager) MapKinds(ctx context.Context, kinds graph.Kinds) ([]int1 } s.lock.RUnlock() - s.lock.Lock() - defer s.lock.Unlock() if err := s.Fetch(ctx); err != nil { return nil, err } + s.lock.RLock() + defer s.lock.RUnlock() + if mappedKinds, missingKinds := s.mapKinds(kinds); len(missingKinds) == 0 { return mappedKinds, nil } else { return nil, fmt.Errorf("unable to map kinds: %s", strings.Join(missingKinds.Strings(), ", ")) } } -func (s *SchemaManager) ReadTransaction(ctx context.Context, txDelegate graph.TransactionDelegate, options ...graph.TransactionOption) error { - if cfg, err := renderConfig(batchWriteSize, readOnlyTxOptions, options); err != nil { - return err - } else if conn, err := s.pool.Acquire(ctx); err != nil { - return err - } else { - defer conn.Release() - - return txDelegate(&transaction{ - schemaManager: s, - queryExecMode: cfg.QueryExecMode, - ctx: ctx, - conn: conn, - targetSchemaSet: false, - }) - } -} func (s *SchemaManager) mapKindIDs(kindIDs []int16) (graph.Kinds, []int16) { var ( @@ -227,13 +233,14 @@ func (s *SchemaManager) MapKindIDs(ctx context.Context, kindIDs []int16) (graph. } s.lock.RUnlock() - s.lock.Lock() - defer s.lock.Unlock() if err := s.Fetch(ctx); err != nil { return nil, err } + s.lock.RLock() + defer s.lock.RUnlock() + if kinds, missingKinds := s.mapKindIDs(kindIDs); len(missingKinds) == 0 { return kinds, nil } else { @@ -249,9 +256,11 @@ func (s *SchemaManager) assertKinds(ctx context.Context, kinds graph.Kinds) ([]i // We have to re-acquire the missing kinds since there's a potential for another writer to acquire the write-lock // in between release of the read-lock and acquisition of the write-lock for this operation if _, missingKinds := s.mapKinds(kinds); len(missingKinds) > 0 { - if err := s.WriteTransaction(ctx, func(tx graph.Transaction) error { - return s.defineKinds(tx, missingKinds) - }, OptionSetQueryExecMode(pgx.QueryExecModeSimpleProtocol)); err != nil { + if err := s.transaction(ctx, func(transaction pgx.Tx) error { + // Previously calls like this required - pgx.QueryExecModeSimpleProtocol while that seems to no longer be + // the case, this comment has been left here in case the issue reappears + return s.defineKinds(ctx, transaction, missingKinds) + }); err != nil { return nil, err } } @@ -276,80 +285,84 @@ func (s *SchemaManager) AssertKinds(ctx context.Context, kinds graph.Kinds) ([]i return s.assertKinds(ctx, kinds) } -func (s *SchemaManager) setDefaultGraph(defaultGraph model.Graph, schema graph.Graph) { - s.lock.Lock() - defer s.lock.Unlock() +func (s *SchemaManager) DefaultGraph() (database.Graph, bool) { + s.lock.RLock() + defer s.lock.RUnlock() - if s.hasDefaultGraph { - // Another actor has already asserted or otherwise set a default graph - return - } + return s.defaultGraph, s.hasDefaultGraph +} - s.graphs[schema.Name] = defaultGraph +func (s *SchemaManager) assertGraph(ctx context.Context, schema database.Graph) (model.Graph, error) { + var assertedGraph model.Graph - s.defaultGraph = defaultGraph - s.hasDefaultGraph = true -} + if err := s.transaction(ctx, func(transaction pgx.Tx) error { + queries := query.On(transaction) -func (s *SchemaManager) SetDefaultGraph(ctx context.Context, schema graph.Graph) error { - return s.ReadTransaction(ctx, func(tx graph.Transaction) error { // Validate the schema if the graph already exists in the database - if graphModel, err := query.On(tx).SelectGraphByName(schema.Name); err != nil { - return err - } else { - s.setDefaultGraph(graphModel, schema) - return nil - } - }) -} + if definition, err := queries.SelectGraphByName(ctx, schema.Name); err != nil { + // ErrNoRows is ignored as it signifies that this graph must be created + if !errors.Is(err, pgx.ErrNoRows) { + return err + } -func (s *SchemaManager) AssertDefaultGraph(ctx context.Context, schema graph.Graph) error { - return s.WriteTransaction(ctx, func(tx graph.Transaction) error { - if graphModel, err := s.AssertGraph(tx, schema); err != nil { + if newDefinition, err := queries.CreateGraph(ctx, schema); err != nil { + return err + } else { + assertedGraph = newDefinition + } + } else if assertedDefinition, err := queries.AssertGraph(ctx, schema, definition); err != nil { return err } else { - s.setDefaultGraph(graphModel, schema) + // Graph exists and may have been updated + assertedGraph = assertedDefinition } return nil - }) -} - -func (s *SchemaManager) DefaultGraph() (model.Graph, bool) { - s.lock.RLock() - defer s.lock.RUnlock() + }); err != nil { + return model.Graph{}, err + } - return s.defaultGraph, s.hasDefaultGraph + // Cache the graph definition and return it + s.graphs[schema.Name] = assertedGraph + return assertedGraph, nil } -func (s *SchemaManager) assertGraph(tx graph.Transaction, schema graph.Graph) (model.Graph, error) { - var assertedGraph model.Graph +func (s *SchemaManager) assertSchema(ctx context.Context, schema database.Schema) error { + if defaultGraph, hasDefaultGraph := schema.DefaultGraph(); !hasDefaultGraph { + return fmt.Errorf("no default graph specified in schema") + } else { + s.defaultGraph = defaultGraph + s.hasDefaultGraph = true + } - // Validate the schema if the graph already exists in the database - if definition, err := query.On(tx).SelectGraphByName(schema.Name); err != nil { - // ErrNoRows is ignored as it signifies that this graph must be created - if !errors.Is(err, pgx.ErrNoRows) { - return model.Graph{}, err + return s.transaction(ctx, func(transaction pgx.Tx) error { + if err := query.On(transaction).CreateSchema(ctx); err != nil { + return err } - if newDefinition, err := query.On(tx).CreateGraph(schema); err != nil { - return model.Graph{}, err - } else { - assertedGraph = newDefinition + if err := s.fetch(ctx, transaction); err != nil { + return err } - } else if assertedDefinition, err := query.On(tx).AssertGraph(schema, definition); err != nil { - return model.Graph{}, err - } else { - // Graph existed and may have been updated - assertedGraph = assertedDefinition - } - // Cache the graph definition and return it - s.graphs[schema.Name] = assertedGraph - return assertedGraph, nil + for _, graphSchema := range schema.GraphSchemas { + if _, missingKinds := s.mapKinds(graphSchema.Nodes); len(missingKinds) > 0 { + if err := s.defineKinds(ctx, transaction, missingKinds); err != nil { + return err + } + } + + if _, missingKinds := s.mapKinds(graphSchema.Edges); len(missingKinds) > 0 { + if err := s.defineKinds(ctx, transaction, missingKinds); err != nil { + return err + } + } + } + + return nil + }) } -func (s *SchemaManager) AssertGraph(tx graph.Transaction, schema graph.Graph) (model.Graph, error) { +func (s *SchemaManager) AssertGraph(ctx context.Context, schema database.Graph) (model.Graph, error) { // Acquire a read-lock first to fast-pass validate if we're missing the graph definitions s.lock.RLock() @@ -370,40 +383,14 @@ func (s *SchemaManager) AssertGraph(tx graph.Transaction, schema graph.Graph) (m return graphInstance, nil } - return s.assertGraph(tx, schema) -} - -func (s *SchemaManager) assertSchema(tx graph.Transaction, schema graph.Schema) error { - if err := query.On(tx).CreateSchema(); err != nil { - return err - } - - if err := s.fetch(tx); err != nil { - return err - } - - for _, graphSchema := range schema.Graphs { - if _, missingKinds := s.mapKinds(graphSchema.Nodes); len(missingKinds) > 0 { - if err := s.defineKinds(tx, missingKinds); err != nil { - return err - } - } - - if _, missingKinds := s.mapKinds(graphSchema.Edges); len(missingKinds) > 0 { - if err := s.defineKinds(tx, missingKinds); err != nil { - return err - } - } - } - - return nil + return s.assertGraph(ctx, schema) } -func (s *SchemaManager) AssertSchema(ctx context.Context, schema graph.Schema) error { +func (s *SchemaManager) AssertSchema(ctx context.Context, schema database.Schema) error { s.lock.Lock() defer s.lock.Unlock() - return s.WriteTransaction(ctx, func(tx graph.Transaction) error { - return s.assertSchema(tx, schema) - }, OptionSetQueryExecMode(pgx.QueryExecModeSimpleProtocol)) + // Previously calls like this required - pgx.QueryExecModeSimpleProtocol while that seems to no longer be + // the case, this comment has been left here in case the issue reappears + return s.assertSchema(ctx, schema) } diff --git a/drivers/pg/mapper.go b/database/pg/mapper.go similarity index 90% rename from drivers/pg/mapper.go rename to database/pg/mapper.go index bde8659..181d221 100644 --- a/drivers/pg/mapper.go +++ b/database/pg/mapper.go @@ -8,7 +8,7 @@ import ( func mapKinds(ctx context.Context, kindMapper KindMapper, untypedValue any) (graph.Kinds, bool) { var ( - // Default assumption is that the untyped value contains a type that can be mapped from + // The default assumption is that the untyped value contains a type that can be mapped from validType = true kindIDs []int16 ) @@ -20,6 +20,10 @@ func mapKinds(ctx context.Context, kindMapper KindMapper, untypedValue any) (gra for idx, untypedElement := range typedValue { if typedElement, typeOK := untypedElement.(int16); typeOK { kindIDs[idx] = typedElement + } else { + // Type assertion failed, mark as invalid type + validType = false + break } } @@ -99,6 +103,6 @@ func newMapFunc(ctx context.Context, kindMapper KindMapper) graph.MapFunc { } } -func NewValueMapper(ctx context.Context, kindMapper KindMapper) graph.ValueMapper { +func newValueMapper(ctx context.Context, kindMapper KindMapper) graph.ValueMapper { return graph.NewValueMapper(newMapFunc(ctx, kindMapper)) } diff --git a/drivers/pg/model/format.go b/database/pg/model/format.go similarity index 83% rename from drivers/pg/model/format.go rename to database/pg/model/format.go index 89fe23b..7759f60 100644 --- a/drivers/pg/model/format.go +++ b/database/pg/model/format.go @@ -4,7 +4,7 @@ import ( "strconv" "strings" - "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/database" ) const ( @@ -24,7 +24,7 @@ func EdgePartitionTableName(graphID int32) string { return partitionTableName(EdgeTable, graphID) } -func IndexName(table string, index graph.Index) string { +func IndexName(table string, index database.Index) string { stringBuilder := strings.Builder{} stringBuilder.WriteString(table) @@ -35,7 +35,7 @@ func IndexName(table string, index graph.Index) string { return stringBuilder.String() } -func ConstraintName(table string, constraint graph.Constraint) string { +func ConstraintName(table string, constraint database.Constraint) string { stringBuilder := strings.Builder{} stringBuilder.WriteString(table) diff --git a/database/pg/model/model.go b/database/pg/model/model.go new file mode 100644 index 0000000..0d41f15 --- /dev/null +++ b/database/pg/model/model.go @@ -0,0 +1,68 @@ +package model + +import ( + "github.com/specterops/dawgs/database" +) + +type IndexChangeSet struct { + NodeIndexesToRemove []string + EdgeIndexesToRemove []string + NodeConstraintsToRemove []string + EdgeConstraintsToRemove []string + NodeIndexesToAdd map[string]database.Index + EdgeIndexesToAdd map[string]database.Index + NodeConstraintsToAdd map[string]database.Constraint + EdgeConstraintsToAdd map[string]database.Constraint +} + +func NewIndexChangeSet() IndexChangeSet { + return IndexChangeSet{ + NodeIndexesToAdd: map[string]database.Index{}, + NodeConstraintsToAdd: map[string]database.Constraint{}, + EdgeIndexesToAdd: map[string]database.Index{}, + EdgeConstraintsToAdd: map[string]database.Constraint{}, + } +} + +type GraphPartition struct { + Name string + Indexes map[string]database.Index + Constraints map[string]database.Constraint +} + +func NewGraphPartition(name string) GraphPartition { + return GraphPartition{ + Name: name, + Indexes: map[string]database.Index{}, + Constraints: map[string]database.Constraint{}, + } +} + +func NewGraphPartitionFromSchema(name string, indexes []database.Index, constraints []database.Constraint) GraphPartition { + graphPartition := GraphPartition{ + Name: name, + Indexes: make(map[string]database.Index, len(indexes)), + Constraints: make(map[string]database.Constraint, len(constraints)), + } + + for _, index := range indexes { + graphPartition.Indexes[IndexName(name, index)] = index + } + + for _, constraint := range constraints { + graphPartition.Constraints[ConstraintName(name, constraint)] = constraint + } + + return graphPartition +} + +type GraphPartitions struct { + Node GraphPartition + Edge GraphPartition +} + +type Graph struct { + ID int32 + Name string + Partitions GraphPartitions +} diff --git a/drivers/pg/pg.go b/database/pg/pg.go similarity index 86% rename from drivers/pg/pg.go rename to database/pg/pg.go index 57dafeb..b586743 100644 --- a/drivers/pg/pg.go +++ b/database/pg/pg.go @@ -10,7 +10,7 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/specterops/dawgs" "github.com/specterops/dawgs/cypher/models/pgsql" - "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/database" ) const ( @@ -82,7 +82,17 @@ func NewPool(connectionString string) (*pgxpool.Pool, error) { } func init() { - dawgs.Register(DriverName, func(ctx context.Context, cfg dawgs.Config) (graph.Database, error) { - return NewDriver(cfg.Pool), nil + dawgs.Register(DriverName, func(ctx context.Context, cfg dawgs.Config) (database.Instance, error) { + if cfg.DriverConfig != nil { + if pgConfig, typeOK := cfg.DriverConfig.(Config); typeOK && pgConfig.Pool != nil { + return New(pgConfig.Pool, cfg), nil + } + } + + if pgxPool, err := NewPool(cfg.ConnectionString); err != nil { + return nil, err + } else { + return New(pgxPool, cfg), nil + } }) } diff --git a/drivers/pg/pgutil/kindmapper.go b/database/pg/pgutil/kindmapper.go similarity index 100% rename from drivers/pg/pgutil/kindmapper.go rename to database/pg/pgutil/kindmapper.go diff --git a/drivers/pg/query/definitions.go b/database/pg/query/definitions.go similarity index 100% rename from drivers/pg/query/definitions.go rename to database/pg/query/definitions.go diff --git a/drivers/pg/query/format.go b/database/pg/query/format.go similarity index 93% rename from drivers/pg/query/format.go rename to database/pg/query/format.go index db1ed24..63f2692 100644 --- a/drivers/pg/query/format.go +++ b/database/pg/query/format.go @@ -5,29 +5,31 @@ import ( "strconv" "strings" - "github.com/specterops/dawgs/drivers/pg/model" + "github.com/specterops/dawgs/database" + + "github.com/specterops/dawgs/database/pg/model" "github.com/specterops/dawgs/graph" ) -func postgresIndexType(indexType graph.IndexType) string { +func postgresIndexType(indexType database.IndexType) string { switch indexType { - case graph.BTreeIndex: + case database.IndexTypeBTree: return pgIndexTypeBTree - case graph.TextSearchIndex: + case database.IndexTypeTextSearch: return pgIndexTypeGIN default: return "NOT SUPPORTED" } } -func parsePostgresIndexType(pgType string) graph.IndexType { +func parsePostgresIndexType(pgType string) database.IndexType { switch strings.ToLower(pgType) { case pgIndexTypeBTree: - return graph.BTreeIndex + return database.IndexTypeBTree case pgIndexTypeGIN: - return graph.TextSearchIndex + return database.IndexTypeTextSearch default: - return graph.UnsupportedIndex + return database.IndexTypeUnsupported } } @@ -43,21 +45,21 @@ func formatDropPropertyConstraint(constraintName string) string { return join("drop index if exists ", constraintName, ";") } -func formatCreatePropertyConstraint(constraintName, tableName, fieldName string, indexType graph.IndexType) string { +func formatCreatePropertyConstraint(constraintName, tableName, fieldName string, indexType database.IndexType) string { pgIndexType := postgresIndexType(indexType) return join("create unique index ", constraintName, " on ", tableName, " using ", pgIndexType, " ((", tableName, ".", pgPropertiesColumn, " ->> '", fieldName, "'));") } -func formatCreatePropertyIndex(indexName, tableName, fieldName string, indexType graph.IndexType) string { +func formatCreatePropertyIndex(indexName, tableName, fieldName string, indexType database.IndexType) string { var ( pgIndexType = postgresIndexType(indexType) queryPartial = join("create index ", indexName, " on ", tableName, " using ", pgIndexType, " ((", tableName, ".", pgPropertiesColumn, " ->> '", fieldName) ) - if indexType == graph.TextSearchIndex { + if indexType == database.IndexTypeTextSearch { // GIN text search requires the column to be typed and to contain the tri-gram operation extension return join(queryPartial, "'::text) gin_trgm_ops);") } else { diff --git a/database/pg/query/query.go b/database/pg/query/query.go new file mode 100644 index 0000000..ae3750c --- /dev/null +++ b/database/pg/query/query.go @@ -0,0 +1,514 @@ +package query + +import ( + "context" + _ "embed" + "fmt" + "strings" + + "github.com/specterops/dawgs/database" + + "github.com/jackc/pgx/v5" + "github.com/specterops/dawgs/database/pg/model" + "github.com/specterops/dawgs/graph" +) + +type Query struct { + tx pgx.Tx +} + +func On(tx pgx.Tx) Query { + return Query{ + tx: tx, + } +} + +func (s Query) exec(ctx context.Context, statement string, args map[string]any, queryArgs ...any) error { + if args != nil && len(args) > 0 { + queryArgs = append(queryArgs, args) + } + + _, err := s.tx.Exec(ctx, statement, queryArgs...) + return err +} + +func (s Query) query(ctx context.Context, statement string, args map[string]any, queryArgs ...any) (pgx.Rows, error) { + if args != nil && len(args) > 0 { + queryArgs = append(queryArgs, pgx.NamedArgs(args)) + } + + return s.tx.Query(ctx, statement, queryArgs...) +} + +func (s Query) describeGraphPartition(ctx context.Context, name string) (model.GraphPartition, error) { + graphPartition := model.NewGraphPartition(name) + + if tableIndexDefinitions, err := s.SelectTableIndexDefinitions(ctx, name); err != nil { + return graphPartition, err + } else { + for _, tableIndexDefinition := range tableIndexDefinitions { + if captureGroups := pgPropertyIndexRegex.FindStringSubmatch(tableIndexDefinition); captureGroups == nil { + // If this index does not match our expected column index format then report it as a potential error + if !pgColumnIndexRegex.MatchString(tableIndexDefinition) { + return graphPartition, fmt.Errorf("regex mis-match on schema definition: %s", tableIndexDefinition) + } + } else { + indexName := captureGroups[pgIndexRegexGroupName] + + if captureGroups[pgIndexRegexGroupUnique] == pgIndexUniqueStr { + graphPartition.Constraints[indexName] = database.Constraint{ + Name: indexName, + Field: captureGroups[pgIndexRegexGroupFields], + Type: parsePostgresIndexType(captureGroups[pgIndexRegexGroupIndexType]), + } + } else { + graphPartition.Indexes[indexName] = database.Index{ + Name: indexName, + Field: captureGroups[pgIndexRegexGroupFields], + Type: parsePostgresIndexType(captureGroups[pgIndexRegexGroupIndexType]), + } + } + } + } + } + + return graphPartition, nil +} + +func (s Query) SelectKinds(ctx context.Context) (map[graph.Kind]int16, error) { + var ( + kindID int16 + kindName string + + kinds = map[graph.Kind]int16{} + ) + + if result, err := s.query(ctx, sqlSelectKinds, nil); err != nil { + return nil, err + } else { + defer result.Close() + + for result.Next() { + if err := result.Scan(&kindID, &kindName); err != nil { + return nil, err + } + + kinds[graph.StringKind(kindName)] = kindID + } + + return kinds, result.Err() + } +} + +func (s Query) selectGraphPartitions(ctx context.Context, graphID int32) (model.GraphPartitions, error) { + var ( + nodePartitionName = model.NodePartitionTableName(graphID) + edgePartitionName = model.EdgePartitionTableName(graphID) + ) + + if nodePartition, err := s.describeGraphPartition(ctx, nodePartitionName); err != nil { + return model.GraphPartitions{}, err + } else if edgePartition, err := s.describeGraphPartition(ctx, edgePartitionName); err != nil { + return model.GraphPartitions{}, err + } else { + return model.GraphPartitions{ + Node: nodePartition, + Edge: edgePartition, + }, nil + } +} + +func (s Query) selectGraphPartialByName(ctx context.Context, name string) (model.Graph, error) { + var graphID int32 + + if result, err := s.query(ctx, sqlSelectGraphByName, pgx.NamedArgs(map[string]any{ + "name": name, + })); err != nil { + return model.Graph{}, err + } else { + defer result.Close() + + if !result.Next() { + if err := result.Err(); err != nil { + return model.Graph{}, err + } + + return model.Graph{}, pgx.ErrNoRows + } + + if err := result.Scan(&graphID); err != nil { + return model.Graph{}, err + } + + return model.Graph{ + ID: graphID, + Name: name, + }, result.Err() + } +} + +func (s Query) SelectGraphByName(ctx context.Context, name string) (model.Graph, error) { + if definition, err := s.selectGraphPartialByName(ctx, name); err != nil { + return model.Graph{}, err + } else if graphPartitions, err := s.selectGraphPartitions(ctx, definition.ID); err != nil { + return model.Graph{}, err + } else { + definition.Partitions = graphPartitions + return definition, nil + } +} + +func (s Query) selectGraphPartials(ctx context.Context) ([]model.Graph, error) { + var ( + graphID int32 + graphName string + graphs []model.Graph + ) + + if result, err := s.query(ctx, sqlSelectGraphs, nil); err != nil { + return nil, err + } else { + defer result.Close() + + for result.Next() { + if err := result.Scan(&graphID, &graphName); err != nil { + return nil, err + } else { + graphs = append(graphs, model.Graph{ + ID: graphID, + Name: graphName, + }) + } + } + + return graphs, result.Err() + } +} + +func (s Query) SelectGraphs(ctx context.Context) (map[string]model.Graph, error) { + if definitions, err := s.selectGraphPartials(ctx); err != nil { + return nil, err + } else { + indexed := map[string]model.Graph{} + + for _, definition := range definitions { + if graphPartitions, err := s.selectGraphPartitions(ctx, definition.ID); err != nil { + return nil, err + } else { + definition.Partitions = graphPartitions + indexed[definition.Name] = definition + } + } + + return indexed, nil + } +} + +func (s Query) CreatePropertyIndex(ctx context.Context, indexName, tableName, fieldName string, indexType database.IndexType) error { + return s.exec(ctx, formatCreatePropertyIndex(indexName, tableName, fieldName, indexType), nil) +} + +func (s Query) CreatePropertyConstraint(ctx context.Context, indexName, tableName, fieldName string, indexType database.IndexType) error { + if indexType != database.IndexTypeBTree { + return fmt.Errorf("only b-tree indexing is supported for property constraints") + } + + return s.exec(ctx, formatCreatePropertyConstraint(indexName, tableName, fieldName, indexType), nil) +} + +func (s Query) DropIndex(ctx context.Context, indexName string) error { + return s.exec(ctx, formatDropPropertyIndex(indexName), nil) +} + +func (s Query) DropConstraint(ctx context.Context, constraintName string) error { + return s.exec(ctx, formatDropPropertyConstraint(constraintName), nil) +} + +func (s Query) CreateSchema(ctx context.Context) error { + if err := s.exec(ctx, sqlSchemaUp, nil); err != nil { + return err + } + + return nil +} + +func (s Query) DropSchema(ctx context.Context) error { + if err := s.exec(ctx, sqlSchemaDown, nil); err != nil { + return err + } + + return nil +} + +func (s Query) insertGraph(ctx context.Context, name string) (model.Graph, error) { + var graphID int32 + + if result, err := s.query(ctx, sqlInsertGraph, map[string]any{ + "name": name, + }); err != nil { + return model.Graph{}, err + } else { + defer result.Close() + + if !result.Next() { + if err := result.Err(); err != nil { + return model.Graph{}, err + } + + return model.Graph{}, pgx.ErrNoRows + } + + if err := result.Scan(&graphID); err != nil { + return model.Graph{}, fmt.Errorf("failed mapping ID from graph entry creation: %w", err) + } + + return model.Graph{ + ID: graphID, + Name: name, + }, nil + } +} + +func (s Query) CreatePartitionTable(ctx context.Context, name, parent string, graphID int32) (model.GraphPartition, error) { + if err := s.exec(ctx, formatCreatePartitionTable(name, parent, graphID), nil); err != nil { + return model.GraphPartition{}, err + } + + return model.GraphPartition{ + Name: name, + }, nil +} + +func (s Query) SelectTableIndexDefinitions(ctx context.Context, tableName string) ([]string, error) { + var ( + nextDefinition string + definitions []string + ) + + if result, err := s.query(ctx, sqlSelectTableIndexes, map[string]any{ + "tablename": tableName, + }); err != nil { + return nil, err + } else { + + defer result.Close() + + for result.Next() { + if err := result.Scan(&nextDefinition); err != nil { + return nil, err + } + + definitions = append(definitions, strings.ToLower(nextDefinition)) + } + + return definitions, result.Err() + } +} + +func (s Query) SelectKindID(ctx context.Context, kind graph.Kind) (int16, error) { + var kindID int16 + + if result, err := s.query(ctx, sqlSelectKindID, map[string]any{ + "name": kind.String(), + }); err != nil { + return -1, err + } else { + defer result.Close() + + if !result.Next() { + if err := result.Err(); err != nil { + return -1, err + } + + return -1, pgx.ErrNoRows + } + + if err := result.Scan(&kindID); err != nil { + return -1, err + } + + return kindID, result.Err() + } +} + +func (s Query) assertGraphPartitionIndexes(ctx context.Context, partitions model.GraphPartitions, indexChanges model.IndexChangeSet) error { + for _, indexToRemove := range append(indexChanges.NodeIndexesToRemove, indexChanges.EdgeIndexesToRemove...) { + if err := s.DropIndex(ctx, indexToRemove); err != nil { + return err + } + } + + for _, constraintToRemove := range append(indexChanges.NodeConstraintsToRemove, indexChanges.EdgeConstraintsToRemove...) { + if err := s.DropConstraint(ctx, constraintToRemove); err != nil { + return err + } + } + + for indexName, index := range indexChanges.NodeIndexesToAdd { + if err := s.CreatePropertyIndex(ctx, indexName, partitions.Node.Name, index.Field, index.Type); err != nil { + return err + } + } + + for constraintName, constraint := range indexChanges.NodeConstraintsToAdd { + if err := s.CreatePropertyConstraint(ctx, constraintName, partitions.Node.Name, constraint.Field, constraint.Type); err != nil { + return err + } + } + + for indexName, index := range indexChanges.EdgeIndexesToAdd { + if err := s.CreatePropertyIndex(ctx, indexName, partitions.Edge.Name, index.Field, index.Type); err != nil { + return err + } + } + + for constraintName, constraint := range indexChanges.EdgeConstraintsToAdd { + if err := s.CreatePropertyConstraint(ctx, constraintName, partitions.Edge.Name, constraint.Field, constraint.Type); err != nil { + return err + } + } + + return nil +} + +func (s Query) AssertGraph(ctx context.Context, schema database.Graph, definition model.Graph) (model.Graph, error) { + var ( + requiredNodePartition = model.NewGraphPartitionFromSchema(definition.Partitions.Node.Name, schema.NodeIndexes, schema.NodeConstraints) + requiredEdgePartition = model.NewGraphPartitionFromSchema(definition.Partitions.Edge.Name, schema.EdgeIndexes, schema.EdgeConstraints) + indexChangeSet = model.NewIndexChangeSet() + ) + + if presentNodePartition, err := s.describeGraphPartition(ctx, definition.Partitions.Node.Name); err != nil { + return model.Graph{}, err + } else { + for presentNodeIndexName := range presentNodePartition.Indexes { + if _, hasMatchingDefinition := requiredNodePartition.Indexes[presentNodeIndexName]; !hasMatchingDefinition { + indexChangeSet.NodeIndexesToRemove = append(indexChangeSet.NodeIndexesToRemove, presentNodeIndexName) + } + } + + for presentNodeConstraintName := range presentNodePartition.Constraints { + if _, hasMatchingDefinition := requiredNodePartition.Constraints[presentNodeConstraintName]; !hasMatchingDefinition { + indexChangeSet.NodeConstraintsToRemove = append(indexChangeSet.NodeConstraintsToRemove, presentNodeConstraintName) + } + } + + for requiredNodeIndexName, requiredNodeIndex := range requiredNodePartition.Indexes { + if presentNodeIndex, hasMatchingDefinition := presentNodePartition.Indexes[requiredNodeIndexName]; !hasMatchingDefinition { + indexChangeSet.NodeIndexesToAdd[requiredNodeIndexName] = requiredNodeIndex + } else if requiredNodeIndex.Type != presentNodeIndex.Type { + indexChangeSet.NodeIndexesToRemove = append(indexChangeSet.NodeIndexesToRemove, requiredNodeIndexName) + indexChangeSet.NodeIndexesToAdd[requiredNodeIndexName] = requiredNodeIndex + } + } + + for requiredNodeConstraintName, requiredNodeConstraint := range requiredNodePartition.Constraints { + if presentNodeConstraint, hasMatchingDefinition := presentNodePartition.Constraints[requiredNodeConstraintName]; !hasMatchingDefinition { + indexChangeSet.NodeConstraintsToAdd[requiredNodeConstraintName] = requiredNodeConstraint + } else if requiredNodeConstraint.Type != presentNodeConstraint.Type { + indexChangeSet.NodeConstraintsToRemove = append(indexChangeSet.NodeConstraintsToRemove, requiredNodeConstraintName) + indexChangeSet.NodeConstraintsToAdd[requiredNodeConstraintName] = requiredNodeConstraint + } + } + } + + if presentEdgePartition, err := s.describeGraphPartition(ctx, definition.Partitions.Edge.Name); err != nil { + return model.Graph{}, err + } else { + for presentEdgeIndexName := range presentEdgePartition.Indexes { + if _, hasMatchingDefinition := requiredEdgePartition.Indexes[presentEdgeIndexName]; !hasMatchingDefinition { + indexChangeSet.EdgeIndexesToRemove = append(indexChangeSet.EdgeIndexesToRemove, presentEdgeIndexName) + } + } + + for presentEdgeConstraintName := range presentEdgePartition.Constraints { + if _, hasMatchingDefinition := requiredEdgePartition.Constraints[presentEdgeConstraintName]; !hasMatchingDefinition { + indexChangeSet.EdgeConstraintsToRemove = append(indexChangeSet.EdgeConstraintsToRemove, presentEdgeConstraintName) + } + } + + for requiredEdgeIndexName, requiredEdgeIndex := range requiredEdgePartition.Indexes { + if presentEdgeIndex, hasMatchingDefinition := presentEdgePartition.Indexes[requiredEdgeIndexName]; !hasMatchingDefinition { + indexChangeSet.EdgeIndexesToAdd[requiredEdgeIndexName] = requiredEdgeIndex + } else if requiredEdgeIndex.Type != presentEdgeIndex.Type { + indexChangeSet.EdgeIndexesToRemove = append(indexChangeSet.EdgeIndexesToRemove, requiredEdgeIndexName) + indexChangeSet.EdgeIndexesToAdd[requiredEdgeIndexName] = requiredEdgeIndex + } + } + + for requiredEdgeConstraintName, requiredEdgeConstraint := range requiredEdgePartition.Constraints { + if presentEdgeConstraint, hasMatchingDefinition := presentEdgePartition.Constraints[requiredEdgeConstraintName]; !hasMatchingDefinition { + indexChangeSet.EdgeConstraintsToAdd[requiredEdgeConstraintName] = requiredEdgeConstraint + } else if requiredEdgeConstraint.Type != presentEdgeConstraint.Type { + indexChangeSet.EdgeConstraintsToRemove = append(indexChangeSet.EdgeConstraintsToRemove, requiredEdgeConstraintName) + indexChangeSet.EdgeConstraintsToAdd[requiredEdgeConstraintName] = requiredEdgeConstraint + } + } + } + + return model.Graph{ + ID: definition.ID, + Name: definition.Name, + Partitions: model.GraphPartitions{ + Node: requiredNodePartition, + Edge: requiredEdgePartition, + }, + }, s.assertGraphPartitionIndexes(ctx, definition.Partitions, indexChangeSet) +} + +func (s Query) createGraphPartitions(ctx context.Context, definition model.Graph) (model.Graph, error) { + var ( + nodePartitionName = model.NodePartitionTableName(definition.ID) + edgePartitionName = model.EdgePartitionTableName(definition.ID) + ) + + if nodePartition, err := s.CreatePartitionTable(ctx, nodePartitionName, model.NodeTable, definition.ID); err != nil { + return model.Graph{}, err + } else { + definition.Partitions.Node = nodePartition + } + + if edgePartition, err := s.CreatePartitionTable(ctx, edgePartitionName, model.EdgeTable, definition.ID); err != nil { + return model.Graph{}, err + } else { + definition.Partitions.Edge = edgePartition + } + + return definition, nil +} + +func (s Query) CreateGraph(ctx context.Context, schema database.Graph) (model.Graph, error) { + if definition, err := s.insertGraph(ctx, schema.Name); err != nil { + return model.Graph{}, err + } else if graphPartitions, err := s.createGraphPartitions(ctx, definition); err != nil { + return model.Graph{}, err + } else { + return s.AssertGraph(ctx, schema, graphPartitions) + } +} + +func (s Query) InsertOrGetKind(ctx context.Context, kind graph.Kind) (int16, error) { + var kindID int16 + + if result, err := s.query(ctx, sqlInsertKind, map[string]any{ + "name": kind.String(), + }); err != nil { + return -1, err + } else { + defer result.Close() + + if !result.Next() { + if err := result.Err(); err != nil { + return -1, err + } + + return -1, pgx.ErrNoRows + } + + if err := result.Scan(&kindID); err != nil { + return -1, err + } + + return kindID, result.Err() + } +} diff --git a/drivers/pg/query/sql.go b/database/pg/query/sql.go similarity index 95% rename from drivers/pg/query/sql.go rename to database/pg/query/sql.go index bec8b8f..44f3681 100644 --- a/drivers/pg/query/sql.go +++ b/database/pg/query/sql.go @@ -46,7 +46,7 @@ var ( sqlSchemaUp = loadSQL("schema_up.sql") sqlSchemaDown = loadSQL("schema_down.sql") sqlSelectTableIndexes = loadSQL("select_table_indexes.sql") - sqlSelectKindID = loadSQL("select_table_indexes.sql") + sqlSelectKindID = loadSQL("select_kind_id.sql") sqlSelectGraphs = loadSQL("select_graphs.sql") sqlInsertGraph = loadSQL("insert_graph.sql") sqlInsertKind = loadSQL("insert_or_get_kind.sql") diff --git a/drivers/pg/query/sql/insert_graph.sql b/database/pg/query/sql/insert_graph.sql similarity index 100% rename from drivers/pg/query/sql/insert_graph.sql rename to database/pg/query/sql/insert_graph.sql diff --git a/drivers/pg/query/sql/insert_or_get_kind.sql b/database/pg/query/sql/insert_or_get_kind.sql similarity index 100% rename from drivers/pg/query/sql/insert_or_get_kind.sql rename to database/pg/query/sql/insert_or_get_kind.sql diff --git a/drivers/pg/query/sql/schema_down.sql b/database/pg/query/sql/schema_down.sql similarity index 100% rename from drivers/pg/query/sql/schema_down.sql rename to database/pg/query/sql/schema_down.sql diff --git a/drivers/pg/query/sql/schema_up.sql b/database/pg/query/sql/schema_up.sql similarity index 100% rename from drivers/pg/query/sql/schema_up.sql rename to database/pg/query/sql/schema_up.sql diff --git a/drivers/pg/query/sql/select_graph_by_name.sql b/database/pg/query/sql/select_graph_by_name.sql similarity index 71% rename from drivers/pg/query/sql/select_graph_by_name.sql rename to database/pg/query/sql/select_graph_by_name.sql index cae41ee..e760837 100644 --- a/drivers/pg/query/sql/select_graph_by_name.sql +++ b/database/pg/query/sql/select_graph_by_name.sql @@ -1,4 +1,5 @@ -- Selects the ID of a graph with the given name. select id from graph -where name = @name; +where name = @name +limit 1; diff --git a/drivers/pg/query/sql/select_graphs.sql b/database/pg/query/sql/select_graphs.sql similarity index 70% rename from drivers/pg/query/sql/select_graphs.sql rename to database/pg/query/sql/select_graphs.sql index f1e4d78..b8e0eb6 100644 --- a/drivers/pg/query/sql/select_graphs.sql +++ b/database/pg/query/sql/select_graphs.sql @@ -1,3 +1,4 @@ -- Selects all defined graphs in the database. select id, name -from graph; +from graph +order by name; diff --git a/drivers/pg/query/sql/select_kind_id.sql b/database/pg/query/sql/select_kind_id.sql similarity index 100% rename from drivers/pg/query/sql/select_kind_id.sql rename to database/pg/query/sql/select_kind_id.sql diff --git a/drivers/pg/query/sql/select_kinds.sql b/database/pg/query/sql/select_kinds.sql similarity index 100% rename from drivers/pg/query/sql/select_kinds.sql rename to database/pg/query/sql/select_kinds.sql diff --git a/drivers/pg/query/sql/select_table_indexes.sql b/database/pg/query/sql/select_table_indexes.sql similarity index 100% rename from drivers/pg/query/sql/select_table_indexes.sql rename to database/pg/query/sql/select_table_indexes.sql diff --git a/drivers/pg/statements.go b/database/pg/statements.go similarity index 92% rename from drivers/pg/statements.go rename to database/pg/statements.go index a0fdcee..d28f084 100644 --- a/drivers/pg/statements.go +++ b/database/pg/statements.go @@ -1,12 +1,12 @@ package pg const ( - createNodeStatement = `insert into node (graph_id, kind_ids, properties) values (@graph_id, @kind_ids, @properties) returning (id, kind_ids, properties)::nodeComposite;` + createNodeStatement = `insert into node (graph_id, kind_ids, properties) values (@graph_id, @kind_ids, @properties) returning id;` createNodeWithoutIDBatchStatement = `insert into node (graph_id, kind_ids, properties) select $1, unnest($2::text[])::int2[], unnest($3::jsonb[])` createNodeWithIDBatchStatement = `insert into node (graph_id, id, kind_ids, properties) select $1, unnest($2::int8[]), unnest($3::text[])::int2[], unnest($4::jsonb[])` deleteNodeWithIDStatement = `delete from node where node.id = any($1)` - createEdgeStatement = `insert into edge (graph_id, start_id, end_id, kind_id, properties) values (@graph_id, @start_id, @end_id, @kind_id, @properties) returning (id, start_id, end_id, kind_id, properties)::edgeComposite;` + createEdgeStatement = `insert into edge (graph_id, start_id, end_id, kind_id, properties) values (@graph_id, @start_id, @end_id, @kind_id, @properties) returning id;` // TODO: The query below is not a pure creation statement as it contains an `on conflict` clause to dance around // Azure post-processing. This was done because Azure post will submit the same creation request hundreds of diff --git a/drivers/pg/types.go b/database/pg/types.go similarity index 97% rename from drivers/pg/types.go rename to database/pg/types.go index 9cd26cc..359f4a3 100644 --- a/drivers/pg/types.go +++ b/database/pg/types.go @@ -216,7 +216,7 @@ func (s *pathComposite) TryMap(compositeMap map[string]any) bool { func (s *pathComposite) FromMap(compositeMap map[string]any) error { if rawNodes, hasNodes := compositeMap["nodes"]; hasNodes { if typedRawNodes, typeOK := rawNodes.([]any); !typeOK { - return fmt.Errorf("") + return fmt.Errorf("expected nodes to be []any but got %T", rawNodes) } else { for _, rawNode := range typedRawNodes { switch typedNode := rawNode.(type) { @@ -238,7 +238,7 @@ func (s *pathComposite) FromMap(compositeMap map[string]any) error { if rawEdges, hasEdges := compositeMap["edges"]; hasEdges { if typedRawEdges, typeOK := rawEdges.([]any); !typeOK { - return fmt.Errorf("") + return fmt.Errorf("expected edges to be []any but got %T", rawEdges) } else { for _, rawEdge := range typedRawEdges { switch typedNode := rawEdge.(type) { diff --git a/database/pg/util.go b/database/pg/util.go new file mode 100644 index 0000000..ae78be8 --- /dev/null +++ b/database/pg/util.go @@ -0,0 +1 @@ +package pg diff --git a/database/schema.go b/database/schema.go new file mode 100644 index 0000000..1151541 --- /dev/null +++ b/database/schema.go @@ -0,0 +1,68 @@ +package database + +import "github.com/specterops/dawgs/graph" + +type IndexType int + +const ( + IndexTypeUnsupported IndexType = 0 + IndexTypeBTree IndexType = 1 + IndexTypeTextSearch IndexType = 2 +) + +func (s IndexType) String() string { + switch s { + case IndexTypeBTree: + return "btree" + + case IndexTypeTextSearch: + return "fts" + + case IndexTypeUnsupported: + return "unsupported" + + default: + return "invalid" + } +} + +type Index struct { + Name string + Field string + Type IndexType +} + +type Constraint Index + +type Graph struct { + Name string + Nodes graph.Kinds + Edges graph.Kinds + NodeConstraints []Constraint + EdgeConstraints []Constraint + NodeIndexes []Index + EdgeIndexes []Index +} + +type Schema struct { + GraphSchemas map[string]Graph + DefaultGraphName string +} + +func NewSchema(defaultGraphName string, graphSchemas ...Graph) Schema { + graphSchemaMap := map[string]Graph{} + + for _, graphSchema := range graphSchemas { + graphSchemaMap[graphSchema.Name] = graphSchema + } + + return Schema{ + GraphSchemas: graphSchemaMap, + DefaultGraphName: defaultGraphName, + } +} + +func (s *Schema) DefaultGraph() (Graph, bool) { + defaultGraph, hasDefaultGraph := s.GraphSchemas[s.DefaultGraphName] + return defaultGraph, hasDefaultGraph +} diff --git a/drivers/tooling.go b/database/tooling.go similarity index 96% rename from drivers/tooling.go rename to database/tooling.go index a822fa4..b0b8f45 100644 --- a/drivers/tooling.go +++ b/database/tooling.go @@ -1,4 +1,4 @@ -package drivers +package database import "sync/atomic" diff --git a/database/v1compat/database.go b/database/v1compat/database.go new file mode 100644 index 0000000..3b8d1f8 --- /dev/null +++ b/database/v1compat/database.go @@ -0,0 +1,153 @@ +package v1compat + +import ( + "context" + "time" + + "github.com/specterops/dawgs/database" + "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/util/size" +) + +type Batch interface { + // WithGraph scopes the transaction to a specific graph. If the driver for the transaction does not support + // multiple graphs the resulting transaction will target the default graph instead and this call becomes a no-op. + WithGraph(graphSchema database.Graph) Batch + + // CreateNode creates a new Node in the database and returns the creation as a NodeResult. + CreateNode(node *graph.Node) error + + // DeleteNode deletes a node by the given ID. + DeleteNode(id graph.ID) error + + // Nodes begins a batch query that can be used to update or delete nodes. + Nodes() NodeQuery + + // Relationships begins a batch query that can be used to update or delete relationships. + Relationships() RelationshipQuery + + // UpdateNodeBy is a stop-gap until the query interface can better support targeted batch create-update operations. + // Nodes identified by the NodeUpdate criteria will either be updated or in the case where the node does not yet + // exist, created. + UpdateNodeBy(update graph.NodeUpdate) error + + // TODO: Existing batch logic expects this to perform an upsert on conficts with (start_id, end_id, kind). This is incorrect and should be refactored + CreateRelationship(relationship *graph.Relationship) error + + // Deprecated: Use CreateRelationship Instead + // + // CreateRelationshipByIDs creates a new Relationship from the start Node to the end Node with the given Kind and + // Properties and returns the creation as a RelationshipResult. + CreateRelationshipByIDs(startNodeID, endNodeID graph.ID, kind graph.Kind, properties *graph.Properties) error + + // DeleteRelationship deletes a relationship by the given ID. + DeleteRelationship(id graph.ID) error + + // UpdateRelationshipBy is a stop-gap until the query interface can better support targeted batch create-update + // operations. Relationships identified by the RelationshipUpdate criteria will either be updated or in the case + // where the relationship does not yet exist, created. + UpdateRelationshipBy(update graph.RelationshipUpdate) error + + // Commit calls to commit this batch transaction right away. + Commit() error +} + +// Transaction is an interface that contains all operations that may be executed against a DAWGS driver. DAWGS drivers are +// expected to support all Transaction operations in-transaction. +type Transaction interface { + // WithGraph scopes the transaction to a specific graph. If the driver for the transaction does not support + // multiple graphs the resulting transaction will target the default graph instead and this call becomes a no-op. + WithGraph(graphSchema database.Graph) Transaction + + // CreateNode creates a new Node in the database and returns the creation as a NodeResult. + CreateNode(properties *graph.Properties, kinds ...graph.Kind) (*graph.Node, error) + + // UpdateNode updates a Node in the database with the given Node by ID. UpdateNode will not create missing Node + // entries in the database. Use CreateNode first to create a new Node. + UpdateNode(node *graph.Node) error + + // Nodes creates a new NodeQuery and returns it. + Nodes() NodeQuery + + // CreateRelationshipByIDs creates a new Relationship from the start Node to the end Node with the given Kind and + // Properties and returns the creation as a RelationshipResult. + CreateRelationshipByIDs(startNodeID, endNodeID graph.ID, kind graph.Kind, properties *graph.Properties) (*graph.Relationship, error) + + // UpdateRelationship updates a Relationship in the database with the given Relationship by ID. UpdateRelationship + // will not create missing Relationship entries in the database. Use CreateRelationship first to create a new + // Relationship. + UpdateRelationship(relationship *graph.Relationship) error + + // Relationships creates a new RelationshipQuery and returns it. + Relationships() RelationshipQuery + + // Query allows a user to execute a given cypher query that will be translated to the target database. + Query(query string, parameters map[string]any) Result + + // Commit calls to commit this transaction right away. + Commit() error + + GraphQueryMemoryLimit() size.Size +} + +// TransactionDelegate represents a transactional database context actor. Errors returned from a TransactionDelegate +// result in the rollback of write enabled transactions. Successful execution of a TransactionDelegate (nil error +// return value) results in a transactional commit of work done within the TransactionDelegate. +type TransactionDelegate func(tx Transaction) error + +// BatchDelegate represents a transactional database context actor. +type BatchDelegate func(batch Batch) error + +// TransactionConfig is a generic configuration that may apply to all supported databases. +type TransactionConfig struct { + Timeout time.Duration + DriverConfig any +} + +// TransactionOption is a function that represents a configuration setting for the underlying database transaction. +type TransactionOption func(config *TransactionConfig) + +// Database is a high-level interface representing transactional entry-points into DAWGS driver implementations. +type Database interface { + // SetWriteFlushSize sets a new write flush interval on the current driver + SetWriteFlushSize(interval int) + + // SetBatchWriteSize sets a new batch write interval on the current driver + SetBatchWriteSize(interval int) + + // ReadTransaction opens up a new read transactional context in the database and then defers the context to the + // given logic function. + ReadTransaction(ctx context.Context, txDelegate TransactionDelegate, options ...TransactionOption) error + + // WriteTransaction opens up a new write transactional context in the database and then defers the context to the + // given logic function. + WriteTransaction(ctx context.Context, txDelegate TransactionDelegate, options ...TransactionOption) error + + // BatchOperation opens up a new write transactional context in the database and then defers the context to the + // given logic function. Batch operations are fundamentally different between databases supported by DAWGS, + // necessitating a different interface that lacks many of the convenience features of a regular read or write + // transaction. + BatchOperation(ctx context.Context, batchDelegate BatchDelegate) error + + // AssertSchema will apply the given schema to the underlying database. + AssertSchema(ctx context.Context, dbSchema database.Schema) error + + // SetDefaultGraph sets the default graph namespace for the connection. + SetDefaultGraph(ctx context.Context, graphSchema database.Graph) error + + // Run allows a user to pass statements directly to the database. Since results may rely on a transactional context + // only an error is returned from this function + Run(ctx context.Context, query string, parameters map[string]any) error + + // Close closes the database context and releases any pooled resources held by the instance. + Close(ctx context.Context) error + + // FetchKinds retrieves the complete list of kinds available to the database. + FetchKinds(ctx context.Context) (graph.Kinds, error) + + // RefreshKinds refreshes the in memory kinds maps + RefreshKinds(ctx context.Context) error + + // V2 returns the V2 interface for this V1 database instance + V2() database.Instance +} diff --git a/database/v1compat/errors.go b/database/v1compat/errors.go new file mode 100644 index 0000000..72e9d8b --- /dev/null +++ b/database/v1compat/errors.go @@ -0,0 +1,23 @@ +package v1compat + +import ( + "errors" + "fmt" +) + +func IsErrNotFound(err error) bool { + return errors.Is(err, ErrNoResultsFound) +} + +func IsErrPropertyNotFound(err error) bool { + return errors.Is(err, ErrPropertyNotFound) +} + +func IsMissingResultExpectation(err error) bool { + return errors.Is(err, ErrMissingResultExpectation) +} + +// NewError returns an error that contains the given query context elements. +func NewError(query string, driverErr error) error { + return fmt.Errorf("driver error: %w - query: %s", driverErr, query) +} diff --git a/database/v1compat/graph.go b/database/v1compat/graph.go new file mode 100644 index 0000000..25b28bd --- /dev/null +++ b/database/v1compat/graph.go @@ -0,0 +1,111 @@ +package v1compat + +import ( + "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/database" + "github.com/specterops/dawgs/graph" +) + +type String = graph.String +type Kind = graph.Kind +type Kinds = graph.Kinds +type KindBitmaps = graph.KindBitmaps +type ThreadSafeKindBitmap = graph.ThreadSafeKindBitmap +type ID = graph.ID +type Node = graph.Node +type NodeSet = graph.NodeSet +type ThreadSafeNodeSet = graph.ThreadSafeNodeSet +type NodeKindSet = graph.NodeKindSet +type NodeUpdate = graph.NodeUpdate +type Relationship = graph.Relationship +type RelationshipSet = graph.RelationshipSet +type RelationshipUpdate = graph.RelationshipUpdate +type Path = graph.Path +type PathSegment = graph.PathSegment +type PathSet = graph.PathSet +type Criteria = cypher.SyntaxNode +type Properties = graph.Properties +type PropertyMap = graph.PropertyMap +type PropertyValue = graph.PropertyValue +type Direction = graph.Direction +type Tree = graph.Tree + +var ( + NewPathSet = graph.NewPathSet + StringKind = graph.StringKind + NewProperties = graph.NewProperties + NewNode = graph.NewNode + NewNodeSet = graph.NewNodeSet + NewRelationship = graph.NewRelationship + NewRelationshipSet = graph.NewRelationshipSet + Uint32SliceToIDs = graph.Uint32SliceToIDs + Uint64SliceToIDs = graph.Uint64SliceToIDs + NewRootPathSegment = graph.NewRootPathSegment + NewThreadSafeNodeSet = graph.NewThreadSafeNodeSet + PrepareNode = graph.PrepareNode + PrepareRelationship = graph.PrepareRelationship + NewNodeKindSet = graph.NewNodeKindSet + NodeSetToDuplex = graph.NodeSetToDuplex + NodeIDsToDuplex = graph.NodeIDsToDuplex + StringsToKinds = graph.StringsToKinds + SortAndSliceNodeSet = graph.SortAndSliceNodeSet + FormatPathSegment = graph.FormatPathSegment + NewThreadSafeKindBitmap = graph.NewThreadSafeKindBitmap + NewTree = graph.NewTree + + EmptyKind = graph.EmptyKind + + ErrNoResultsFound = graph.ErrNoResultsFound + ErrMissingResultExpectation = graph.ErrMissingResultExpectation + ErrUnsupportedDatabaseOperation = graph.ErrUnsupportedDatabaseOperation + ErrPropertyNotFound = graph.ErrPropertyNotFound + ErrContextTimedOut = graph.ErrContextTimedOut + ErrConcurrentConnectionSlotTimeOut = graph.ErrConcurrentConnectionSlotTimeOut +) + +type Constraint = database.Constraint +type Schema = database.Schema +type Graph = database.Graph +type Index = database.Index + +const ( + BTreeIndex = database.IndexTypeBTree + TextSearchIndex = database.IndexTypeTextSearch + + UnregisteredNodeID = graph.UnregisteredNodeID + + DirectionOutbound = graph.DirectionOutbound + DirectionInbound = graph.DirectionInbound + DirectionBoth = graph.DirectionBoth +) + +func symbolMapToStringMap(props map[String]any) map[string]any { + store := make(map[string]any, len(props)) + + for k, v := range props { + store[k.String()] = v + } + + return store +} + +func AsProperties[T PropertyMap | map[String]any | map[string]any](rawStore T) *Properties { + var store map[string]any + + switch typedStore := any(rawStore).(type) { + case PropertyMap: + store = symbolMapToStringMap(typedStore) + + case map[String]any: + store = symbolMapToStringMap(typedStore) + + case map[string]any: + store = typedStore + } + + return &Properties{ + Map: store, + Modified: make(map[string]struct{}), + Deleted: make(map[string]struct{}), + } +} diff --git a/database/v1compat/node.go b/database/v1compat/node.go new file mode 100644 index 0000000..14dfd48 --- /dev/null +++ b/database/v1compat/node.go @@ -0,0 +1,202 @@ +package v1compat + +import ( + "context" + + "github.com/specterops/dawgs/database" + "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/query" +) + +type nodeQuery struct { + ctx context.Context + driver database.Driver + builder query.QueryBuilder +} + +func newNodeQuery(ctx context.Context, driver database.Driver) NodeQuery { + return &nodeQuery{ + ctx: ctx, + driver: driver, + builder: query.New(), + } +} + +func (s nodeQuery) Filter(criteria Criteria) NodeQuery { + s.builder.Where(criteria) + return s +} + +func (s nodeQuery) Filterf(criteriaDelegate CriteriaProvider) NodeQuery { + s.builder.Where(criteriaDelegate()) + return s +} + +func (s nodeQuery) Query(delegate func(results Result) error, finalCriteria ...any) error { + s.builder.Return(finalCriteria...) + + if result, err := s.exec(); err != nil { + return err + } else { + defer result.Close() + return delegate(result) + } +} + +func (s nodeQuery) Delete() error { + s.builder.Delete(query.Node()) + + if result, err := s.exec(); err != nil { + return err + } else { + result.Close() + return nil + } +} + +func (s nodeQuery) Update(properties *graph.Properties) error { + s.builder.Update(query.Node().SetProperties(properties.MapOrEmpty())) + + if result, err := s.exec(); err != nil { + return err + } else { + result.Close() + return nil + } +} + +func (s nodeQuery) OrderBy(criteria ...Criteria) NodeQuery { + s.builder.OrderBy(criteria...) + return s +} + +func (s nodeQuery) Offset(skip int) NodeQuery { + s.builder.Skip(skip) + return s +} + +func (s nodeQuery) Limit(limit int) NodeQuery { + s.builder.Limit(limit) + return s +} + +func (s nodeQuery) exec() (Result, error) { + if builtQuery, err := s.builder.Build(); err != nil { + return nil, err + } else { + result := s.driver.Exec(s.ctx, builtQuery.Query, builtQuery.Parameters) + return wrapResult(s.ctx, result, s.driver.Mapper()), nil + } +} + +func (s nodeQuery) Count() (int64, error) { + s.builder.Return(query.Node().Count()) + + if result, err := s.exec(); err != nil { + return 0, err + } else { + defer result.Close() + + if result.Next() { + var count int64 + + if err := result.Scan(&count); err != nil { + return 0, err + } + + return count, result.Error() + } + + if result.Error() != nil { + return 0, result.Error() + } + + return 0, ErrNoResultsFound + } +} + +func (s nodeQuery) First() (*graph.Node, error) { + s.builder.Return(query.Node()).Limit(1) + + if result, err := s.exec(); err != nil { + return nil, err + } else { + defer result.Close() + + if result.Next() { + var node graph.Node + + if err := result.Scan(&node); err != nil { + return nil, err + } + + return &node, nil + } + + if result.Error() != nil { + return nil, result.Error() + } + + return nil, ErrNoResultsFound + } +} + +func (s nodeQuery) Fetch(delegate func(cursor Cursor[*graph.Node]) error, finalCriteria ...Criteria) error { + s.builder.Return(query.Node()) + + if builtQuery, err := s.builder.Build(); err != nil { + return err + } else { + resultIter := NewResultIterator(s.ctx, s.driver.Exec(s.ctx, builtQuery.Query, builtQuery.Parameters), func(result database.Result) (*graph.Node, error) { + var ( + node graph.Node + err = result.Scan(&node) + ) + + return &node, err + }) + + defer resultIter.Close() + return delegate(resultIter) + } +} + +func (s nodeQuery) FetchIDs(delegate func(cursor Cursor[graph.ID]) error) error { + s.builder.Return(query.Node().ID()) + + if builtQuery, err := s.builder.Build(); err != nil { + return err + } else { + resultIter := NewResultIterator(s.ctx, s.driver.Exec(s.ctx, builtQuery.Query, builtQuery.Parameters), func(result database.Result) (graph.ID, error) { + var nodeID graph.ID + return nodeID, result.Scan(&nodeID) + }) + + defer resultIter.Close() + return delegate(resultIter) + } +} + +func (s nodeQuery) FetchKinds(delegate func(cursor Cursor[KindsResult]) error) error { + s.builder.Return(query.Node().ID(), query.Node().Kinds()) + + if builtQuery, err := s.builder.Build(); err != nil { + return err + } else { + resultIter := NewResultIterator(s.ctx, s.driver.Exec(s.ctx, builtQuery.Query, builtQuery.Parameters), func(result database.Result) (KindsResult, error) { + var ( + nodeID graph.ID + nodeKinds graph.Kinds + err = result.Scan(&nodeID, &nodeKinds) + ) + + return KindsResult{ + ID: nodeID, + Kinds: nodeKinds, + }, err + }) + + defer resultIter.Close() + return delegate(resultIter) + } +} diff --git a/ops/ops.go b/database/v1compat/ops/ops.go similarity index 99% rename from ops/ops.go rename to database/v1compat/ops/ops.go index 44d3358..126c85e 100644 --- a/ops/ops.go +++ b/database/v1compat/ops/ops.go @@ -11,8 +11,8 @@ import ( "github.com/specterops/dawgs/util/channels" "github.com/specterops/dawgs/cardinality" - "github.com/specterops/dawgs/graph" - "github.com/specterops/dawgs/query" + graph "github.com/specterops/dawgs/database/v1compat" + "github.com/specterops/dawgs/database/v1compat/query" ) func FetchAllNodeProperties(tx graph.Transaction, nodes graph.NodeSet) error { diff --git a/ops/parallel.go b/database/v1compat/ops/parallel.go similarity index 99% rename from ops/parallel.go rename to database/v1compat/ops/parallel.go index 7345b56..9fbb35e 100644 --- a/ops/parallel.go +++ b/database/v1compat/ops/parallel.go @@ -5,8 +5,8 @@ import ( "errors" "sync" - "github.com/specterops/dawgs/graph" - "github.com/specterops/dawgs/query" + graph "github.com/specterops/dawgs/database/v1compat" + "github.com/specterops/dawgs/database/v1compat/query" "github.com/specterops/dawgs/util" "github.com/specterops/dawgs/util/channels" ) diff --git a/ops/traversal.go b/database/v1compat/ops/traversal.go similarity index 98% rename from ops/traversal.go rename to database/v1compat/ops/traversal.go index dd2927d..643a782 100644 --- a/ops/traversal.go +++ b/database/v1compat/ops/traversal.go @@ -6,8 +6,8 @@ import ( "github.com/specterops/dawgs/cardinality" - "github.com/specterops/dawgs/graph" - "github.com/specterops/dawgs/query" + graph "github.com/specterops/dawgs/database/v1compat" + "github.com/specterops/dawgs/database/v1compat/query" ) type LimitSkipTracker struct { diff --git a/graph/query.go b/database/v1compat/query.go similarity index 75% rename from graph/query.go rename to database/v1compat/query.go index 1f4114f..1b34dfd 100644 --- a/graph/query.go +++ b/database/v1compat/query.go @@ -1,11 +1,17 @@ -package graph +package v1compat -import "fmt" +import ( + "context" + "fmt" + + "github.com/specterops/dawgs/database" + "github.com/specterops/dawgs/graph" +) type Result interface { Next() bool Values() []any - Mapper() ValueMapper + Mapper() graph.ValueMapper // Scan takes a list of target any and attempts to map the next row from the result to the targets. This function // is semantically equivalent to calling graph.ScanNextResult(...) @@ -33,6 +39,44 @@ func ScanNextResult(result Result, targets ...any) error { return nil } +type resultWrapper struct { + result database.Result + ctx context.Context + mapper graph.ValueMapper +} + +func (s resultWrapper) Next() bool { + return s.result.HasNext(s.ctx) +} + +func (s resultWrapper) Values() []any { + return s.result.Values() +} + +func (s resultWrapper) Mapper() graph.ValueMapper { + return s.mapper +} + +func (s resultWrapper) Scan(targets ...any) error { + return s.result.Scan(targets...) +} + +func (s resultWrapper) Error() error { + return s.result.Error() +} + +func (s resultWrapper) Close() { + s.result.Close(s.ctx) +} + +func wrapResult(ctx context.Context, result database.Result, mapper graph.ValueMapper) Result { + return &resultWrapper{ + ctx: ctx, + result: result, + mapper: mapper, + } +} + type ErrorResult struct { err error } @@ -45,8 +89,8 @@ func (s ErrorResult) Next() bool { return false } -func (s ErrorResult) Mapper() ValueMapper { - return ValueMapper{} +func (s ErrorResult) Mapper() graph.ValueMapper { + return graph.ValueMapper{} } func (s ErrorResult) Scan(targets ...any) error { @@ -66,9 +110,6 @@ func NewErrorResult(err error) Result { } } -// Criteria is a top-level alias for communicating structured query filter criteria to a query generator. -type Criteria any - // CriteriaProvider is a function delegate that returns criteria. type CriteriaProvider func() Criteria @@ -82,13 +123,13 @@ type NodeQuery interface { Filterf(criteriaDelegate CriteriaProvider) NodeQuery // Query completes the query and hands the raw result to the given delegate for unmarshalling - Query(delegate func(results Result) error, finalCriteria ...Criteria) error + Query(delegate func(results Result) error, finalCriteria ...any) error // Delete deletes any candidate nodes that match the query criteria Delete() error // Update updates all candidate nodes with the given properties - Update(properties *Properties) error + Update(properties *graph.Properties) error // OrderBy sets the OrderBy clause of the NodeQuery. OrderBy(criteria ...Criteria) NodeQuery @@ -104,15 +145,15 @@ type NodeQuery interface { Count() (int64, error) // First completes the query and returns the result and any error encountered during execution. - First() (*Node, error) + First() (*graph.Node, error) // Fetch completes the query and captures a cursor for iterating the result set. This cursor is passed to the given // delegate. Errors from the delegate are returned upwards as the error result of this call. - Fetch(delegate func(cursor Cursor[*Node]) error, finalCriteria ...Criteria) error + Fetch(delegate func(cursor Cursor[*graph.Node]) error, finalCriteria ...Criteria) error // FetchIDs completes the query and captures a cursor for iterating the result set. This cursor is passed to the given // delegate. Errors from the delegate are returned upwards as the error result of this call. - FetchIDs(delegate func(cursor Cursor[ID]) error) error + FetchIDs(delegate func(cursor Cursor[graph.ID]) error) error // FetchKinds returns the ID and Kinds of matched nodes and omits property fetching FetchKinds(func(cursor Cursor[KindsResult]) error) error @@ -129,7 +170,7 @@ type RelationshipQuery interface { // Update replaces the properties of all candidate relationships that matches the query criteria with the // given properties - Update(properties *Properties) error + Update(properties *graph.Properties) error // Delete deletes any candidate relationships that match the query criteria Delete() error @@ -148,28 +189,28 @@ type RelationshipQuery interface { Count() (int64, error) // First completes the query and returns the result and any error encountered during execution. - First() (*Relationship, error) + First() (*graph.Relationship, error) // Query completes the query and hands the raw result to the given delegate for unmarshalling - Query(delegate func(results Result) error, finalCriteria ...Criteria) error + Query(delegate func(results Result) error, finalCriteria ...any) error // Fetch completes the query and captures a cursor for iterating the result set. This cursor is passed to the given // delegate. Errors from the delegate are returned upwards as the error result of this call. - Fetch(delegate func(cursor Cursor[*Relationship]) error) error + Fetch(delegate func(cursor Cursor[*graph.Relationship]) error) error // FetchDirection completes the query and captures a cursor for iterating through the relationship related nodes // for the given path direction - FetchDirection(direction Direction, delegate func(cursor Cursor[DirectionalResult]) error) error + FetchDirection(direction graph.Direction, delegate func(cursor Cursor[DirectionalResult]) error) error // FetchIDs completes the query and captures a cursor for iterating the result set. This cursor is passed to the given // delegate. Errors from the delegate are returned upwards as the error result of this call. - FetchIDs(delegate func(cursor Cursor[ID]) error) error + FetchIDs(delegate func(cursor Cursor[graph.ID]) error) error // FetchTriples(delegate func(cursor Cursor[RelationshipTripleResult]) error) error // - FetchAllShortestPaths(delegate func(cursor Cursor[Path]) error) error + FetchAllShortestPaths(delegate func(cursor Cursor[graph.Path]) error) error // FetchKinds returns the ID, Kind, Start ID and End ID of matched relationships and omits property fetching FetchKinds(delegate func(cursor Cursor[RelationshipKindsResult]) error) error diff --git a/query/builder.go b/database/v1compat/query/builder.go similarity index 97% rename from query/builder.go rename to database/v1compat/query/builder.go index d5e3cd4..a008cd2 100644 --- a/query/builder.go +++ b/database/v1compat/query/builder.go @@ -4,9 +4,8 @@ import ( "errors" "fmt" - "github.com/specterops/dawgs/cypher/models/walk" - "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/cypher/models/walk" "github.com/specterops/dawgs/graph" ) @@ -29,7 +28,7 @@ func NewBuilder(cache *Cache) *Builder { } } -func NewBuilderWithCriteria(criteria ...graph.Criteria) *Builder { +func NewBuilderWithCriteria(criteria ...cypher.SyntaxNode) *Builder { builder := NewBuilder(nil) builder.Apply(criteria...) @@ -220,10 +219,10 @@ func (s *Builder) prepareMatch(allShortestPaths bool) error { return nil } -func (s *Builder) Apply(criteria ...graph.Criteria) { +func (s *Builder) Apply(criteria ...cypher.SyntaxNode) { for _, nextCriteria := range criteria { switch typedCriteria := nextCriteria.(type) { - case []graph.Criteria: + case []cypher.SyntaxNode: s.Apply(typedCriteria...) case *cypher.Where: diff --git a/query/identifiers.go b/database/v1compat/query/identifiers.go similarity index 100% rename from query/identifiers.go rename to database/v1compat/query/identifiers.go diff --git a/database/v1compat/query/model.go b/database/v1compat/query/model.go new file mode 100644 index 0000000..954b155 --- /dev/null +++ b/database/v1compat/query/model.go @@ -0,0 +1,599 @@ +package query + +import ( + "fmt" + "strings" + "time" + + "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/graph" +) + +func convertCriteria[T any](criteria ...cypher.SyntaxNode) []T { + var ( + converted = make([]T, len(criteria)) + ) + + for idx, nextCriteria := range criteria { + converted[idx] = nextCriteria.(T) + } + + return converted +} + +func Update(clauses ...*cypher.UpdatingClause) []*cypher.UpdatingClause { + return clauses +} + +func AddKind(reference cypher.SyntaxNode, kind graph.Kind) *cypher.UpdatingClause { + return cypher.NewUpdatingClause(&cypher.Set{ + Items: []*cypher.SetItem{{ + Left: reference, + Operator: cypher.OperatorLabelAssignment, + Right: graph.Kinds{kind}, + }}, + }) +} + +func AddKinds(reference cypher.SyntaxNode, kinds graph.Kinds) *cypher.UpdatingClause { + return cypher.NewUpdatingClause(&cypher.Set{ + Items: []*cypher.SetItem{{ + Left: reference, + Operator: cypher.OperatorLabelAssignment, + Right: kinds, + }}, + }) +} + +func DeleteKind(reference cypher.SyntaxNode, kind graph.Kind) *cypher.UpdatingClause { + return cypher.NewUpdatingClause(&cypher.Remove{ + Items: []*cypher.RemoveItem{{ + KindMatcher: &cypher.KindMatcher{ + Reference: reference, + Kinds: graph.Kinds{kind}, + }, + }}, + }) +} + +func DeleteKinds(reference cypher.SyntaxNode, kinds graph.Kinds) *cypher.UpdatingClause { + return cypher.NewUpdatingClause(&cypher.Remove{ + Items: []*cypher.RemoveItem{{ + KindMatcher: &cypher.KindMatcher{ + Reference: reference, + Kinds: kinds, + }, + }}, + }) +} + +func SetProperty(reference cypher.SyntaxNode, value any) *cypher.UpdatingClause { + return cypher.NewUpdatingClause(&cypher.Set{ + Items: []*cypher.SetItem{{ + Left: reference, + Operator: cypher.OperatorAssignment, + Right: Parameter(value), + }}, + }) +} + +func SetProperties(reference cypher.SyntaxNode, properties map[string]any) *cypher.UpdatingClause { + set := &cypher.Set{} + + for key, value := range properties { + set.Items = append(set.Items, &cypher.SetItem{ + Left: Property(reference, key), + Operator: cypher.OperatorAssignment, + Right: Parameter(value), + }) + } + + return cypher.NewUpdatingClause(set) +} + +func DeleteProperty(reference *cypher.PropertyLookup) *cypher.UpdatingClause { + return cypher.NewUpdatingClause(&cypher.Remove{ + Items: []*cypher.RemoveItem{{ + Property: reference, + }}, + }) +} + +func DeleteProperties(reference cypher.SyntaxNode, propertyNames ...string) *cypher.UpdatingClause { + removeClause := &cypher.Remove{} + + for _, propertyName := range propertyNames { + removeClause.Items = append(removeClause.Items, &cypher.RemoveItem{ + Property: Property(reference, propertyName), + }) + } + + return cypher.NewUpdatingClause(removeClause) +} + +func Kind(reference cypher.SyntaxNode, kinds ...graph.Kind) *cypher.KindMatcher { + return &cypher.KindMatcher{ + Reference: reference, + Kinds: kinds, + } +} + +func KindIn(reference cypher.SyntaxNode, kinds ...graph.Kind) *cypher.KindMatcher { + return cypher.NewKindMatcher(reference, kinds) +} + +func NodeProperty(name string) *cypher.PropertyLookup { + return cypher.NewPropertyLookup(NodeSymbol, name) +} + +func RelationshipProperty(name string) *cypher.PropertyLookup { + return cypher.NewPropertyLookup(EdgeSymbol, name) +} + +func StartProperty(name string) *cypher.PropertyLookup { + return cypher.NewPropertyLookup(EdgeStartSymbol, name) +} + +func EndProperty(name string) *cypher.PropertyLookup { + return cypher.NewPropertyLookup(EdgeEndSymbol, name) +} + +func Property(qualifier cypher.SyntaxNode, name string) *cypher.PropertyLookup { + return &cypher.PropertyLookup{ + Atom: qualifier.(*cypher.Variable), + Symbol: name, + } +} + +func Count(reference cypher.SyntaxNode) *cypher.FunctionInvocation { + return &cypher.FunctionInvocation{ + Name: "count", + Arguments: []cypher.Expression{reference}, + } +} + +func CountDistinct(reference cypher.SyntaxNode) *cypher.FunctionInvocation { + return &cypher.FunctionInvocation{ + Name: "count", + Distinct: true, + Arguments: []cypher.Expression{reference}, + } +} + +func And(criteria ...cypher.SyntaxNode) *cypher.Conjunction { + return cypher.NewConjunction(convertCriteria[cypher.Expression](criteria...)...) +} + +func Or(criteria ...cypher.SyntaxNode) *cypher.Parenthetical { + return &cypher.Parenthetical{ + Expression: cypher.NewDisjunction(convertCriteria[cypher.Expression](criteria...)...), + } +} + +func Xor(criteria ...cypher.SyntaxNode) *cypher.ExclusiveDisjunction { + return cypher.NewExclusiveDisjunction(convertCriteria[cypher.Expression](criteria...)...) +} + +func Parameter(value any) *cypher.Parameter { + if parameter, isParameter := value.(*cypher.Parameter); isParameter { + return parameter + } + + return &cypher.Parameter{ + Value: value, + } +} + +func Literal(value any) *cypher.Literal { + return &cypher.Literal{ + Value: value, + Null: value == nil, + } +} + +func KindsOf(ref cypher.SyntaxNode) *cypher.FunctionInvocation { + switch typedRef := ref.(type) { + case *cypher.Variable: + switch typedRef.Symbol { + case NodeSymbol, EdgeStartSymbol, EdgeEndSymbol: + return &cypher.FunctionInvocation{ + Name: "labels", + Arguments: []cypher.Expression{ref}, + } + + case EdgeSymbol: + return &cypher.FunctionInvocation{ + Name: "type", + Arguments: []cypher.Expression{ref}, + } + + default: + return cypher.WithErrors(&cypher.FunctionInvocation{}, fmt.Errorf("invalid variable reference for KindsOf: %s", typedRef.Symbol)) + } + + default: + return cypher.WithErrors(&cypher.FunctionInvocation{}, fmt.Errorf("invalid reference type for KindsOf: %T", ref)) + } +} + +func Limit(limit int) *cypher.Limit { + return &cypher.Limit{ + Value: Literal(limit), + } +} + +func Offset(offset int) *cypher.Skip { + return &cypher.Skip{ + Value: Literal(offset), + } +} + +func StringContains(reference cypher.SyntaxNode, value string) *cypher.Comparison { + return cypher.NewComparison(reference, cypher.OperatorContains, Parameter(value)) +} + +func StringStartsWith(reference cypher.SyntaxNode, value string) *cypher.Comparison { + return cypher.NewComparison(reference, cypher.OperatorStartsWith, Parameter(value)) +} + +func StringEndsWith(reference cypher.SyntaxNode, value string) *cypher.Comparison { + return cypher.NewComparison(reference, cypher.OperatorEndsWith, Parameter(value)) +} + +func CaseInsensitiveStringContains(reference cypher.SyntaxNode, value string) *cypher.Comparison { + return cypher.NewComparison( + cypher.NewSimpleFunctionInvocation("toLower", convertCriteria[cypher.Expression](reference)...), + cypher.OperatorContains, + Parameter(strings.ToLower(value)), + ) +} + +func CaseInsensitiveStringStartsWith(reference cypher.SyntaxNode, value string) *cypher.Comparison { + return cypher.NewComparison( + cypher.NewSimpleFunctionInvocation("toLower", convertCriteria[cypher.Expression](reference)...), + cypher.OperatorStartsWith, + Parameter(strings.ToLower(value)), + ) +} + +func CaseInsensitiveStringEndsWith(reference cypher.SyntaxNode, value string) *cypher.Comparison { + return cypher.NewComparison( + cypher.NewSimpleFunctionInvocation("toLower", convertCriteria[cypher.Expression](reference)...), + cypher.OperatorEndsWith, + Parameter(strings.ToLower(value)), + ) +} + +func Equals(reference cypher.SyntaxNode, value any) *cypher.Comparison { + return cypher.NewComparison(reference, cypher.OperatorEquals, Parameter(value)) +} + +func GreaterThan(reference cypher.SyntaxNode, value any) *cypher.Comparison { + return cypher.NewComparison(reference, cypher.OperatorGreaterThan, Parameter(value)) +} + +func After(reference cypher.SyntaxNode, value any) *cypher.Comparison { + return GreaterThan(reference, value) +} + +func GreaterThanOrEquals(reference cypher.SyntaxNode, value any) *cypher.Comparison { + return cypher.NewComparison(reference, cypher.OperatorGreaterThanOrEqualTo, Parameter(value)) +} + +func LessThan(reference cypher.SyntaxNode, value any) *cypher.Comparison { + return cypher.NewComparison(reference, cypher.OperatorLessThan, Parameter(value)) +} + +func LessThanGraphQuery(reference1, reference2 cypher.SyntaxNode) *cypher.Comparison { + return cypher.NewComparison(reference1, cypher.OperatorLessThan, reference2) +} + +func Before(reference cypher.SyntaxNode, value time.Time) *cypher.Comparison { + return LessThan(reference, value) +} + +func BeforeGraphQuery(reference1, reference2 cypher.SyntaxNode) *cypher.Comparison { + return LessThanGraphQuery(reference1, reference2) +} + +func LessThanOrEquals(reference cypher.SyntaxNode, value any) *cypher.Comparison { + return cypher.NewComparison(reference, cypher.OperatorLessThanOrEqualTo, Parameter(value)) +} + +func Exists(reference cypher.SyntaxNode) *cypher.Comparison { + return cypher.NewComparison( + reference, + cypher.OperatorIsNot, + cypher.NewLiteral(nil, true), + ) +} + +func HasRelationships(reference *cypher.Variable) *cypher.PatternPredicate { + patternPredicate := cypher.NewPatternPredicate() + + patternPredicate.AddElement(&cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol(reference.Symbol), + }) + + patternPredicate.AddElement(&cypher.RelationshipPattern{ + Direction: graph.DirectionBoth, + }) + + patternPredicate.AddElement(&cypher.NodePattern{}) + + return patternPredicate +} + +func In(reference cypher.SyntaxNode, value any) *cypher.Comparison { + return cypher.NewComparison(reference, cypher.OperatorIn, Parameter(value)) +} + +func InInverted(reference cypher.SyntaxNode, value any) *cypher.Comparison { + return cypher.NewComparison(Parameter(value), cypher.OperatorIn, reference) +} + +func InIDs[T *cypher.FunctionInvocation | *cypher.Variable](reference T, ids ...graph.ID) *cypher.Comparison { + switch any(reference).(type) { + case *cypher.FunctionInvocation: + return cypher.NewComparison(reference, cypher.OperatorIn, Parameter(ids)) + + default: + return cypher.NewComparison(Identity(any(reference).(*cypher.Variable)), cypher.OperatorIn, Parameter(ids)) + } +} + +func Where(expression cypher.SyntaxNode) *cypher.Where { + whereClause := cypher.NewWhere() + whereClause.AddSlice(convertCriteria[cypher.Expression](expression)) + + return whereClause +} + +func OrderBy(leaves ...cypher.SyntaxNode) *cypher.Order { + return &cypher.Order{ + Items: convertCriteria[*cypher.SortItem](leaves...), + } +} + +func Order(reference, direction cypher.SyntaxNode) *cypher.SortItem { + switch direction { + case cypher.SortDescending: + return &cypher.SortItem{ + Ascending: false, + Expression: reference, + } + + default: + return &cypher.SortItem{ + Ascending: true, + Expression: reference, + } + } +} + +func Ascending() cypher.SortOrder { + return cypher.SortAscending +} + +func Descending() cypher.SortOrder { + return cypher.SortDescending +} + +func Delete(leaves ...cypher.SyntaxNode) *cypher.UpdatingClause { + deleteClause := &cypher.Delete{ + Detach: true, + } + + for _, leaf := range leaves { + switch leaf.(*cypher.Variable).Symbol { + case EdgeSymbol, EdgeStartSymbol, EdgeEndSymbol: + deleteClause.Detach = false + } + + deleteClause.Expressions = append(deleteClause.Expressions, leaf) + } + + return cypher.NewUpdatingClause(deleteClause) +} + +func NodePattern(kinds graph.Kinds, properties *cypher.Parameter) *cypher.NodePattern { + return &cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol(NodeSymbol), + Kinds: kinds, + Properties: properties, + } +} + +func StartNodePattern(kinds graph.Kinds, properties *cypher.Parameter) *cypher.NodePattern { + return &cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol(EdgeStartSymbol), + Kinds: kinds, + Properties: properties, + } +} + +func EndNodePattern(kinds graph.Kinds, properties *cypher.Parameter) *cypher.NodePattern { + return &cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol(EdgeEndSymbol), + Kinds: kinds, + Properties: properties, + } +} + +func RelationshipPattern(kind graph.Kind, properties *cypher.Parameter, direction graph.Direction) *cypher.RelationshipPattern { + return &cypher.RelationshipPattern{ + Variable: cypher.NewVariableWithSymbol(EdgeSymbol), + Kinds: graph.Kinds{kind}, + Properties: properties, + Direction: direction, + } +} + +func Create(elements ...cypher.SyntaxNode) *cypher.UpdatingClause { + var ( + pattern = &cypher.PatternPart{} + createClause = &cypher.Create{ + // Note: Unique is Neo4j specific and will not be supported here. Use of constraints for + // uniqueness is expected instead. + Unique: false, + Pattern: []*cypher.PatternPart{pattern}, + } + ) + + for _, element := range elements { + switch typedElement := element.(type) { + case *cypher.Variable: + switch typedElement.Symbol { + case NodeSymbol, EdgeStartSymbol, EdgeEndSymbol: + pattern.AddPatternElements(&cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol(typedElement.Symbol), + }) + + default: + createClause.AddError(fmt.Errorf("invalid variable reference create: %s", typedElement.Symbol)) + } + + case *cypher.NodePattern: + pattern.AddPatternElements(typedElement) + + case *cypher.RelationshipPattern: + pattern.AddPatternElements(typedElement) + + default: + createClause.AddError(fmt.Errorf("invalid type for create: %T", element)) + } + } + + return cypher.NewUpdatingClause(createClause) +} + +func ReturningDistinct(elements ...cypher.SyntaxNode) *cypher.Return { + returnCriteria := Returning(elements...) + returnCriteria.Projection.Distinct = true + + return returnCriteria +} + +func Returning(elements ...cypher.SyntaxNode) *cypher.Return { + projection := &cypher.Projection{} + + for _, element := range elements { + switch typedElement := element.(type) { + case *cypher.Order: + projection.Order = typedElement + + case *cypher.Limit: + projection.Limit = typedElement + + case *cypher.Skip: + projection.Skip = typedElement + + default: + projection.Items = append(projection.Items, &cypher.ProjectionItem{ + Expression: element, + }) + } + } + + return &cypher.Return{ + Projection: projection, + } +} + +func Size(expression cypher.SyntaxNode) *cypher.FunctionInvocation { + return cypher.NewSimpleFunctionInvocation("size", expression) +} + +func Not(expression cypher.SyntaxNode) *cypher.Negation { + return &cypher.Negation{ + Expression: &cypher.Parenthetical{ + Expression: expression, + }, + } +} + +func IsNull(reference cypher.SyntaxNode) *cypher.Comparison { + return cypher.NewComparison(reference, cypher.OperatorIs, Literal(nil)) +} + +func IsNotNull(reference cypher.SyntaxNode) *cypher.Comparison { + return cypher.NewComparison(reference, cypher.OperatorIsNot, Literal(nil)) +} + +func GetFirstReadingClause(query *cypher.RegularQuery) *cypher.ReadingClause { + if query.SingleQuery != nil && query.SingleQuery.SinglePartQuery != nil { + readingClauses := query.SingleQuery.SinglePartQuery.ReadingClauses + + if len(readingClauses) > 0 { + return readingClauses[0] + } + } + + return nil +} + +func SinglePartQuery(expressions ...cypher.SyntaxNode) *cypher.RegularQuery { + var ( + singlePartQuery = &cypher.SinglePartQuery{} + query = &cypher.RegularQuery{ + SingleQuery: &cypher.SingleQuery{ + SinglePartQuery: singlePartQuery, + }, + } + ) + + for _, expression := range expressions { + switch typedExpression := expression.(type) { + case *cypher.Where: + if firstReadingClause := GetFirstReadingClause(query); firstReadingClause != nil { + firstReadingClause.Match.Where = typedExpression + } else { + singlePartQuery.AddReadingClause(&cypher.ReadingClause{ + Match: &cypher.Match{ + Where: typedExpression, + }, + Unwind: nil, + }) + } + + case *cypher.Return: + singlePartQuery.Return = typedExpression + + case *cypher.Limit: + if singlePartQuery.Return != nil { + singlePartQuery.Return.Projection.Limit = typedExpression + } + + case *cypher.Skip: + if singlePartQuery.Return != nil { + singlePartQuery.Return.Projection.Skip = typedExpression + } + + case *cypher.Order: + if singlePartQuery.Return != nil { + singlePartQuery.Return.Projection.Order = typedExpression + } + + case *cypher.UpdatingClause: + singlePartQuery.AddUpdatingClause(typedExpression) + + case []*cypher.UpdatingClause: + for _, updatingClause := range typedExpression { + singlePartQuery.AddUpdatingClause(updatingClause) + } + + default: + singlePartQuery.AddError(fmt.Errorf("invalid type for dawgs query: %T %+v", expression, expression)) + } + } + + return query +} + +func EmptySinglePartQuery() *cypher.RegularQuery { + return &cypher.RegularQuery{ + SingleQuery: &cypher.SingleQuery{ + SinglePartQuery: &cypher.SinglePartQuery{}, + }, + } +} diff --git a/query/rewrite.go b/database/v1compat/query/rewrite.go similarity index 100% rename from query/rewrite.go rename to database/v1compat/query/rewrite.go diff --git a/query/sort.go b/database/v1compat/query/sort.go similarity index 87% rename from query/sort.go rename to database/v1compat/query/sort.go index 213db1d..fe8b8bf 100644 --- a/query/sort.go +++ b/database/v1compat/query/sort.go @@ -2,7 +2,6 @@ package query import ( "github.com/specterops/dawgs/cypher/models/cypher" - "github.com/specterops/dawgs/graph" ) type SortDirection string @@ -11,14 +10,14 @@ const SortDirectionAscending SortDirection = "asc" const SortDirectionDescending SortDirection = "desc" type SortItem struct { - SortCriteria graph.Criteria + SortCriteria cypher.SyntaxNode Direction SortDirection } type SortItems []SortItem func (s SortItems) FormatCypherOrder() *cypher.Order { - var orderCriteria []graph.Criteria + var orderCriteria []cypher.SyntaxNode for _, sortItem := range s { switch sortItem.Direction { diff --git a/database/v1compat/relationship.go b/database/v1compat/relationship.go new file mode 100644 index 0000000..b34e2cd --- /dev/null +++ b/database/v1compat/relationship.go @@ -0,0 +1,290 @@ +package v1compat + +import ( + "context" + "fmt" + + "github.com/specterops/dawgs/database" + "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/query" +) + +type relationshipQuery struct { + ctx context.Context + driver database.Driver + builder query.QueryBuilder +} + +func newRelationshipQuery(ctx context.Context, driver database.Driver) RelationshipQuery { + return &relationshipQuery{ + ctx: ctx, + driver: driver, + builder: query.New(), + } +} + +func (s relationshipQuery) exec() (Result, error) { + if builtQuery, err := s.builder.Build(); err != nil { + return nil, err + } else { + result := s.driver.Exec(s.ctx, builtQuery.Query, builtQuery.Parameters) + return wrapResult(s.ctx, result, s.driver.Mapper()), nil + } +} + +func (s relationshipQuery) Filter(criteria Criteria) RelationshipQuery { + s.builder.Where(criteria) + return s +} + +func (s relationshipQuery) Filterf(criteriaDelegate CriteriaProvider) RelationshipQuery { + s.builder.Where(criteriaDelegate()) + return s +} + +func (s relationshipQuery) Update(properties *graph.Properties) error { + s.builder.Update(query.Relationship().SetProperties(properties.MapOrEmpty())) + + if result, err := s.exec(); err != nil { + return err + } else { + result.Close() + return nil + } +} + +func (s relationshipQuery) Delete() error { + s.builder.Delete(query.Relationship()) + + if result, err := s.exec(); err != nil { + return err + } else { + result.Close() + return nil + } +} + +func (s relationshipQuery) OrderBy(criteria ...Criteria) RelationshipQuery { + s.builder.OrderBy(criteria...) + return s +} + +func (s relationshipQuery) Offset(skip int) RelationshipQuery { + s.builder.Skip(skip) + return s +} + +func (s relationshipQuery) Limit(limit int) RelationshipQuery { + s.builder.Limit(limit) + return s +} + +func (s relationshipQuery) Count() (int64, error) { + s.builder.Return(query.Relationship().Count()) + + if result, err := s.exec(); err != nil { + return 0, err + } else { + defer result.Close() + + if result.Next() { + var count int64 + + if err := result.Scan(&count); err != nil { + return 0, err + } + + return count, result.Error() + } + + if result.Error() != nil { + return 0, result.Error() + } + + return 0, ErrNoResultsFound + } +} + +func (s relationshipQuery) First() (*graph.Relationship, error) { + s.builder.Return(query.Relationship()).Limit(1) + + if result, err := s.exec(); err != nil { + return nil, err + } else { + defer result.Close() + + if result.Next() { + var relationship graph.Relationship + + if err := result.Scan(&relationship); err != nil { + return nil, err + } + + return &relationship, result.Error() + } + + if result.Error() != nil { + return nil, result.Error() + } + + return nil, ErrNoResultsFound + } +} + +func (s relationshipQuery) Query(delegate func(results Result) error, finalCriteria ...any) error { + s.builder.Return(finalCriteria...) + + if result, err := s.exec(); err != nil { + return err + } else { + defer result.Close() + return delegate(result) + } +} + +func (s relationshipQuery) Fetch(delegate func(cursor Cursor[*graph.Relationship]) error) error { + s.builder.Return(query.Relationship()) + + if builtQuery, err := s.builder.Build(); err != nil { + return err + } else { + resultIter := NewResultIterator(s.ctx, s.driver.Exec(s.ctx, builtQuery.Query, builtQuery.Parameters), func(result database.Result) (*graph.Relationship, error) { + var ( + relationship graph.Relationship + err = result.Scan(&relationship) + ) + + return &relationship, err + }) + + defer resultIter.Close() + return delegate(resultIter) + } +} + +func (s relationshipQuery) FetchDirection(direction graph.Direction, delegate func(cursor Cursor[DirectionalResult]) error) error { + switch direction { + case DirectionOutbound: + s.builder.Return(query.Relationship(), query.Start()) + case DirectionInbound: + s.builder.Return(query.Relationship(), query.End()) + default: + return fmt.Errorf("unsupported direction: %v", direction) + } + + if builtQuery, err := s.builder.Build(); err != nil { + return err + } else { + resultIter := NewResultIterator(s.ctx, s.driver.Exec(s.ctx, builtQuery.Query, builtQuery.Parameters), func(result database.Result) (DirectionalResult, error) { + var ( + relationship graph.Relationship + node graph.Node + ) + + if err := result.Scan(&relationship, &node); err != nil { + return DirectionalResult{}, err + } + + return DirectionalResult{ + Direction: direction, + Relationship: &relationship, + Node: &node, + }, nil + }) + + defer resultIter.Close() + return delegate(resultIter) + } +} + +func (s relationshipQuery) FetchIDs(delegate func(cursor Cursor[graph.ID]) error) error { + s.builder.Return(query.Relationship().ID()) + + if builtQuery, err := s.builder.Build(); err != nil { + return err + } else { + resultIter := NewResultIterator(s.ctx, s.driver.Exec(s.ctx, builtQuery.Query, builtQuery.Parameters), func(result database.Result) (graph.ID, error) { + var nodeID graph.ID + return nodeID, result.Scan(&nodeID) + }) + + defer resultIter.Close() + return delegate(resultIter) + } +} + +func (s relationshipQuery) FetchTriples(delegate func(cursor Cursor[RelationshipTripleResult]) error) error { + s.builder.Return(query.Start().ID(), query.End().ID(), query.Relationship().ID()) + + if builtQuery, err := s.builder.Build(); err != nil { + return err + } else { + resultIter := NewResultIterator(s.ctx, s.driver.Exec(s.ctx, builtQuery.Query, builtQuery.Parameters), func(result database.Result) (RelationshipTripleResult, error) { + var ( + startID graph.ID + endID graph.ID + edgeID graph.ID + err = result.Scan(&startID, &endID, &edgeID) + ) + + return RelationshipTripleResult{ + StartID: startID, + EndID: endID, + ID: edgeID, + }, err + }) + + defer resultIter.Close() + return delegate(resultIter) + } +} + +func (s relationshipQuery) FetchAllShortestPaths(delegate func(cursor Cursor[graph.Path]) error) error { + s.builder.WithAllShortestPaths().Return(query.Path()) + + if builtQuery, err := s.builder.Build(); err != nil { + return err + } else { + resultIter := NewResultIterator(s.ctx, s.driver.Exec(s.ctx, builtQuery.Query, builtQuery.Parameters), func(result database.Result) (graph.Path, error) { + var ( + path graph.Path + err = result.Scan(&path) + ) + + return path, err + }) + + defer resultIter.Close() + return delegate(resultIter) + } +} + +func (s relationshipQuery) FetchKinds(delegate func(cursor Cursor[RelationshipKindsResult]) error) error { + s.builder.Return(query.Start().ID(), query.End().ID(), query.Relationship().ID(), query.Relationship().Kind()) + + if builtQuery, err := s.builder.Build(); err != nil { + return err + } else { + resultIter := NewResultIterator(s.ctx, s.driver.Exec(s.ctx, builtQuery.Query, builtQuery.Parameters), func(result database.Result) (RelationshipKindsResult, error) { + var ( + startID graph.ID + endID graph.ID + edgeID graph.ID + edgeKind graph.Kind + err = result.Scan(&startID, &endID, &edgeID, &edgeKind) + ) + + return RelationshipKindsResult{ + RelationshipTripleResult: RelationshipTripleResult{ + StartID: startID, + EndID: endID, + ID: endID, + }, + Kind: edgeKind, + }, err + }) + + defer resultIter.Close() + return delegate(resultIter) + } +} diff --git a/graph/result.go b/database/v1compat/result.go similarity index 73% rename from graph/result.go rename to database/v1compat/result.go index e679caf..32e11d2 100644 --- a/graph/result.go +++ b/database/v1compat/result.go @@ -1,35 +1,37 @@ -package graph +package v1compat import ( "context" "errors" + "github.com/specterops/dawgs/database" + "github.com/specterops/dawgs/graph" "github.com/specterops/dawgs/util/channels" ) type KindsResult struct { - ID ID - Kinds Kinds + ID graph.ID + Kinds graph.Kinds } type RelationshipTripleResult struct { - ID ID - StartID ID - EndID ID + ID graph.ID + StartID graph.ID + EndID graph.ID } type RelationshipKindsResult struct { RelationshipTripleResult - Kind Kind + Kind graph.Kind } type DirectionalResult struct { - Direction Direction - Relationship *Relationship - Node *Node + Direction graph.Direction + Relationship *graph.Relationship + Node *graph.Node } -func NewDirectionalResult(direction Direction, relationship *Relationship, node *Node) DirectionalResult { +func NewDirectionalResult(direction graph.Direction, relationship *graph.Relationship, node *graph.Node) DirectionalResult { return DirectionalResult{ Direction: direction, Relationship: relationship, @@ -51,18 +53,18 @@ type Cursor[T any] interface { Chan() chan T } -type ResultMarshaller[T any] func(scanner Result) (T, error) +type ResultMarshaller[T any] func(scanner database.Result) (T, error) type ResultIterator[T any] struct { ctx context.Context - result Result + result database.Result cancelFunc func() valueC chan T marshaller ResultMarshaller[T] error error } -func NewResultIterator[T any](ctx context.Context, result Result, marshaller ResultMarshaller[T]) Cursor[T] { +func NewResultIterator[T any](ctx context.Context, result database.Result, marshaller ResultMarshaller[T]) Cursor[T] { var ( cursorCtx, cancelFunc = context.WithCancel(ctx) resultIterator = &ResultIterator[T]{ @@ -82,7 +84,7 @@ func (s *ResultIterator[T]) start() { go func() { defer close(s.valueC) - for s.result.Next() { + for s.result.HasNext(s.ctx) { if nextValue, err := s.marshaller(s.result); err != nil { s.error = err break @@ -108,7 +110,7 @@ func (s *ResultIterator[T]) Error() error { func (s *ResultIterator[T]) Close() { s.cancelFunc() - s.result.Close() + s.result.Close(s.ctx) } func (s *ResultIterator[T]) Chan() chan T { diff --git a/graph/switch.go b/database/v1compat/switch.go similarity index 99% rename from graph/switch.go rename to database/v1compat/switch.go index 2ecfb5a..8f4e127 100644 --- a/graph/switch.go +++ b/database/v1compat/switch.go @@ -1,4 +1,4 @@ -package graph +package v1compat import ( "context" diff --git a/traversal/collection.go b/database/v1compat/traversal/collection.go similarity index 56% rename from traversal/collection.go rename to database/v1compat/traversal/collection.go index dc4c78c..38c7ea2 100644 --- a/traversal/collection.go +++ b/database/v1compat/traversal/collection.go @@ -4,33 +4,33 @@ import ( "context" "sync" - "github.com/specterops/dawgs/graph" - "github.com/specterops/dawgs/ops" + "github.com/specterops/dawgs/database/v1compat" + "github.com/specterops/dawgs/database/v1compat/ops" ) type NodeCollector struct { - Nodes graph.NodeSet + Nodes v1compat.NodeSet lock *sync.Mutex } func NewNodeCollector() *NodeCollector { return &NodeCollector{ - Nodes: graph.NewNodeSet(), + Nodes: v1compat.NewNodeSet(), lock: &sync.Mutex{}, } } -func (s *NodeCollector) Collect(next *graph.PathSegment) { +func (s *NodeCollector) Collect(next *v1compat.PathSegment) { s.Add(next.Node) } -func (s *NodeCollector) PopulateProperties(ctx context.Context, db graph.Database, propertyNames ...string) error { - return db.ReadTransaction(ctx, func(tx graph.Transaction) error { +func (s *NodeCollector) PopulateProperties(ctx context.Context, db v1compat.Database, propertyNames ...string) error { + return db.ReadTransaction(ctx, func(tx v1compat.Transaction) error { return ops.FetchNodeProperties(tx, s.Nodes, propertyNames) }) } -func (s *NodeCollector) Add(node *graph.Node) { +func (s *NodeCollector) Add(node *v1compat.Node) { s.lock.Lock() defer s.lock.Unlock() @@ -38,7 +38,7 @@ func (s *NodeCollector) Add(node *graph.Node) { } type PathCollector struct { - Paths graph.PathSet + Paths v1compat.PathSet lock *sync.Mutex } @@ -48,13 +48,13 @@ func NewPathCollector() *PathCollector { } } -func (s *PathCollector) PopulateNodeProperties(ctx context.Context, db graph.Database, propertyNames ...string) error { - return db.ReadTransaction(ctx, func(tx graph.Transaction) error { +func (s *PathCollector) PopulateNodeProperties(ctx context.Context, db v1compat.Database, propertyNames ...string) error { + return db.ReadTransaction(ctx, func(tx v1compat.Transaction) error { return ops.FetchNodeProperties(tx, s.Paths.AllNodes(), propertyNames) }) } -func (s *PathCollector) Add(path graph.Path) { +func (s *PathCollector) Add(path v1compat.Path) { s.lock.Lock() defer s.lock.Unlock() diff --git a/database/v1compat/traversal/query.go b/database/v1compat/traversal/query.go new file mode 100644 index 0000000..8fd7c62 --- /dev/null +++ b/database/v1compat/traversal/query.go @@ -0,0 +1,45 @@ +package traversal + +import ( + "github.com/specterops/dawgs/database/v1compat" + "github.com/specterops/dawgs/database/v1compat/query" + "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/graphcache" +) + +func fetchNodesByIDQuery(tx v1compat.Transaction, ids []v1compat.ID) v1compat.NodeQuery { + return tx.Nodes().Filterf(func() v1compat.Criteria { + return query.InIDs(query.NodeID(), ids...) + }) +} + +func ShallowFetchNodesByID(tx v1compat.Transaction, cache graphcache.Cache, ids []v1compat.ID) ([]*v1compat.Node, error) { + cachedNodes, missingNodeIDs := cache.GetNodes(ids) + + if len(missingNodeIDs) > 0 { + newNodes := make([]*v1compat.Node, 0, len(missingNodeIDs)) + + if err := fetchNodesByIDQuery(tx, missingNodeIDs).FetchKinds(func(cursor v1compat.Cursor[v1compat.KindsResult]) error { + for next := range cursor.Chan() { + newNodes = append(newNodes, v1compat.NewNode(next.ID, nil, next.Kinds...)) + } + + return cursor.Error() + }); err != nil { + return nil, err + } + + // Put the fetched nodes into cache + cache.PutNodes(newNodes) + + // Verify all requested nodes were fetched + if len(newNodes) != len(missingNodeIDs) { + return nil, graph.ErrMissingResultExpectation + } + + // Append them to the end of the nodes being returned + cachedNodes = append(cachedNodes, newNodes...) + } + + return cachedNodes, nil +} diff --git a/database/v1compat/traversal/traversal.go b/database/v1compat/traversal/traversal.go new file mode 100644 index 0000000..9fc6e4b --- /dev/null +++ b/database/v1compat/traversal/traversal.go @@ -0,0 +1,594 @@ +package traversal + +import ( + "context" + "errors" + "fmt" + "log/slog" + "sync" + "sync/atomic" + + "github.com/specterops/dawgs/cardinality" + graph "github.com/specterops/dawgs/database/v1compat" + "github.com/specterops/dawgs/database/v1compat/ops" + "github.com/specterops/dawgs/database/v1compat/query" + "github.com/specterops/dawgs/graphcache" + "github.com/specterops/dawgs/util" + "github.com/specterops/dawgs/util/atomics" + "github.com/specterops/dawgs/util/channels" +) + +// Driver is a function that drives sending queries to the graph and retrieving vertexes and edges. Traversal +// drivers are expected to operate on a cactus tree representation of path space using the graph.PathSegment data +// structure. Path segments returned by a traversal driver are considered extensions of path space that require +// further expansion. If a traversal driver returns no descending path segments, then the given segment may be +// considered terminal. +type Driver = func(ctx context.Context, tx graph.Transaction, segment *graph.PathSegment) ([]*graph.PathSegment, error) + +type PatternMatchDelegate = func(terminal *graph.PathSegment) error + +// PatternContinuation is an openCypher inspired fluent pattern for defining parallel chained expansions. After +// building the pattern the user may call the Do(...) function and pass it a delegate for handling paths that match +// the pattern. +// +// The return value of the Do(...) function may be passed directly to a Traversal via a Plan as the Plan.Driver field. +type PatternContinuation interface { + Outbound(criteria ...graph.Criteria) PatternContinuation + OutboundWithDepth(min, max int, criteria ...graph.Criteria) PatternContinuation + Inbound(criteria ...graph.Criteria) PatternContinuation + InboundWithDepth(min, max int, criteria ...graph.Criteria) PatternContinuation + Do(delegate PatternMatchDelegate) Driver +} + +// expansion is an internal representation of a path expansion step. +type expansion struct { + criteria []graph.Criteria + direction graph.Direction + minDepth int + maxDepth int +} + +func (s expansion) PrepareCriteria(segment *graph.PathSegment) (graph.Criteria, error) { + var ( + criteria = s.criteria + ) + + switch s.direction { + case graph.DirectionOutbound: + criteria = append([]graph.Criteria{ + query.Equals(query.StartID(), segment.Node.ID), + }, criteria...) + + case graph.DirectionInbound: + criteria = append([]graph.Criteria{ + query.Equals(query.EndID(), segment.Node.ID), + }, criteria...) + + default: + return nil, fmt.Errorf("unsupported direction %v", s.direction) + } + + return query.And(criteria...), nil +} + +type patternTag struct { + patternIdx int + depth int +} + +func popSegmentPatternTag(segment *graph.PathSegment) *patternTag { + var tag *patternTag + + if typedTag, typeOK := segment.Tag.(*patternTag); typeOK && typedTag != nil { + tag = typedTag + segment.Tag = nil + } else { + tag = &patternTag{ + patternIdx: 0, + depth: 0, + } + } + + return tag +} + +type pattern struct { + expansions []expansion + delegate PatternMatchDelegate +} + +// Do assigns the PatterMatchDelegate internally before returning a function pointer to the Driver receiver function. +func (s *pattern) Do(delegate PatternMatchDelegate) Driver { + s.delegate = delegate + return s.Driver +} + +// OutboundWithDepth specifies the next outbound expansion step for this pattern with depth parameters. +func (s *pattern) OutboundWithDepth(min, max int, criteria ...graph.Criteria) PatternContinuation { + if min < 0 { + min = 1 + slog.Warn("Negative mindepth not allowed. Setting min depth for expansion to 1") + } + + if max < 0 { + max = 0 + slog.Warn("Negative maxdepth not allowed. Setting max depth for expansion to 0") + } + + s.expansions = append(s.expansions, expansion{ + criteria: criteria, + direction: graph.DirectionOutbound, + minDepth: min, + maxDepth: max, + }) + + return s +} + +// Outbound specifies the next outbound expansion step for this pattern. By default, this expansion will use a minimum +// depth of 1 to make the expansion required and a maximum depth of 0 to expand indefinitely. +func (s *pattern) Outbound(criteria ...graph.Criteria) PatternContinuation { + return s.OutboundWithDepth(1, 0, criteria...) +} + +// InboundWithDepth specifies the next inbound expansion step for this pattern with depth parameters. +func (s *pattern) InboundWithDepth(min, max int, criteria ...graph.Criteria) PatternContinuation { + if min < 0 { + min = 1 + slog.Warn("Negative mindepth not allowed. Setting min depth for expansion to 1") + } + + if max < 0 { + max = 0 + slog.Warn("Negative maxdepth not allowed. Setting max depth for expansion to 0") + } + + s.expansions = append(s.expansions, expansion{ + criteria: criteria, + direction: graph.DirectionInbound, + minDepth: min, + maxDepth: max, + }) + + return s +} + +// Inbound specifies the next inbound expansion step for this pattern. By default, this expansion will use a minimum +// depth of 1 to make the expansion required and a maximum depth of 0 to expand indefinitely. +func (s *pattern) Inbound(criteria ...graph.Criteria) PatternContinuation { + return s.InboundWithDepth(1, 0, criteria...) +} + +// NewPattern returns a new PatternContinuation for building a new pattern. +func NewPattern() PatternContinuation { + return &pattern{} +} + +func (s *pattern) Driver(ctx context.Context, tx graph.Transaction, segment *graph.PathSegment) ([]*graph.PathSegment, error) { + var ( + nextSegments []*graph.PathSegment + + // The patternTag lives on the current terminal segment of each path. Once popped the pointer reference for + // this segment is set to nil. + tag = popSegmentPatternTag(segment) + currentExpansion = s.expansions[tag.patternIdx] + + // fetchFunc handles directional results from the graph database and is called twice to fetch segment + // expansions. + fetchFunc = func(cursor graph.Cursor[graph.DirectionalResult]) error { + for next := range cursor.Chan() { + nextSegment := segment.Descend(next.Node, next.Relationship) + + // Don't emit cycles out of the fetch + if !nextSegment.IsCycle() { + nextSegment.Tag = &patternTag{ + // Use the tag's patternIdx and depth since this is a continuation of the expansions + patternIdx: tag.patternIdx, + depth: tag.depth + 1, + } + + nextSegments = append(nextSegments, nextSegment) + } + } + + return cursor.Error() + } + ) + + // The fetch direction is the reverse intent of the expansion direction + if fetchDirection, err := currentExpansion.direction.Reverse(); err != nil { + return nil, err + } else { + // If no max depth was set or if a max depth was set expand the current step further + if currentExpansion.maxDepth == 0 || tag.depth < currentExpansion.maxDepth { + // Perform the current expansion. + if criteria, err := currentExpansion.PrepareCriteria(segment); err != nil { + return nil, err + } else if err := tx.Relationships().Filter(criteria).FetchDirection(fetchDirection, fetchFunc); err != nil { + return nil, err + } + } + + // Check first if this current segment was fetched using the current expansion (i.e. non-optional) + if tag.depth > 0 && currentExpansion.minDepth == 0 || tag.depth >= currentExpansion.minDepth { + // No further expansions means this pattern segment is complete. Increment the pattern index to select the + // next pattern expansion. Additionally, set the depth back to zero for the tag since we are leaving the + // current expansion. + tag.patternIdx++ + tag.depth = 0 + + // Perform the next expansion if there is one. + if tag.patternIdx < len(s.expansions) { + nextExpansion := s.expansions[tag.patternIdx] + + // Expand the next segments + if criteria, err := nextExpansion.PrepareCriteria(segment); err != nil { + return nil, err + } else if err := tx.Relationships().Filter(criteria).FetchDirection(fetchDirection, fetchFunc); err != nil { + return nil, err + } + + // If the next expansion is optional, make sure to preserve the current traversal branch + if nextExpansion.minDepth == 0 { + // Reattach the tag to the segment before adding it to the returned segments for the next expansion + segment.Tag = tag + nextSegments = append(nextSegments, segment) + } + } else if len(nextSegments) == 0 { + // If there are no expanded segments and there are no remaining expansions, this is a terminal segment. + // Hand it off to the delegate and handle any returned error. + if err := s.delegate(segment); err != nil { + return nil, err + } + } + } + + // If the above condition does not match then this current expansion is non-terminal and non-continuable + } + + // Return any collected segments + return nextSegments, nil +} + +type Plan struct { + Root *graph.Node + RootSegment *graph.PathSegment + Driver Driver +} + +type Traversal struct { + db graph.Database + numWorkers int +} + +func New(db graph.Database, numParallelWorkers int) Traversal { + return Traversal{ + db: db, + numWorkers: numParallelWorkers, + } +} + +func (s Traversal) BreadthFirst(ctx context.Context, plan Plan) error { + var ( + // workerWG keeps count of background workers launched in goroutines + workerWG = &sync.WaitGroup{} + + // descentWG keeps count of in-flight traversal work. When this wait group reaches a count of 0 the traversal + // is considered complete. + completionC = make(chan struct{}, s.numWorkers*2) + descentCount = &atomic.Int64{} + errorCollector = util.NewErrorCollector() + traversalCtx, doneFunc = context.WithCancel(ctx) + segmentWriterC, segmentReaderC = channels.BufferedPipe[*graph.PathSegment](traversalCtx) + pathTree graph.Tree + ) + + // Defer calling the cancellation function of the context to ensure that all workers join, no matter what + defer doneFunc() + + // Close the writer channel to the buffered pipe + defer close(segmentWriterC) + + if plan.Root != nil { + pathTree = graph.NewTree(plan.Root) + } else if plan.RootSegment != nil { + pathTree = graph.Tree{ + Root: plan.RootSegment, + } + } else { + return fmt.Errorf("no root specified") + } + + // Launch the background traversal workers + for workerID := 0; workerID < s.numWorkers; workerID++ { + workerWG.Add(1) + + go func(workerID int) { + defer workerWG.Done() + + if err := s.db.ReadTransaction(ctx, func(tx graph.Transaction) error { + for { + if nextDescent, ok := channels.Receive(traversalCtx, segmentReaderC); !ok { + return nil + } else if tx.GraphQueryMemoryLimit() > 0 && pathTree.SizeOf() > tx.GraphQueryMemoryLimit() { + return fmt.Errorf("%w - Limit: %.2f MB - Memory In-Use: %.2f MB", ops.ErrGraphQueryMemoryLimit, tx.GraphQueryMemoryLimit().Mebibytes(), pathTree.SizeOf().Mebibytes()) + } else { + // Traverse the descending relationships of the current segment + if descendingSegments, err := plan.Driver(traversalCtx, tx, nextDescent); err != nil { + return err + } else { + for _, descendingSegment := range descendingSegments { + // Add to the descent count before submitting to the channel + descentCount.Add(1) + channels.Submit(traversalCtx, segmentWriterC, descendingSegment) + } + } + } + + // Mark descent for this segment as complete + descentCount.Add(-1) + + if !channels.Submit(traversalCtx, completionC, struct{}{}) { + return nil + } + } + }); err != nil && !errors.Is(err, graph.ErrContextTimedOut) && !errors.Is(err, context.Canceled) { + // A worker encountered a fatal error, kill the traversal context + doneFunc() + + errorCollector.Add(fmt.Errorf("reader %d failed: %w", workerID, err)) + } + }(workerID) + } + + // Add to the descent wait group and then queue the root of the path tree for traversal + descentCount.Add(1) + if channels.Submit(traversalCtx, segmentWriterC, pathTree.Root) { + for { + if _, ok := channels.Receive(traversalCtx, completionC); !ok || descentCount.Load() == 0 { + break + } + } + } + + // Actively cancel the traversal context to force any idle workers to join and exit + doneFunc() + + // Wait for all workers to exit + workerWG.Wait() + + return errorCollector.Combined() +} + +func newVisitorFilter(direction graph.Direction, userFilter graph.Criteria) func(segment *graph.PathSegment) graph.Criteria { + return func(segment *graph.PathSegment) graph.Criteria { + var filters []graph.Criteria + + if userFilter != nil { + filters = append(filters, userFilter) + } + + switch direction { + case graph.DirectionOutbound: + filters = append(filters, query.Equals(query.StartID(), segment.Node.ID)) + + case graph.DirectionInbound: + filters = append(filters, query.Equals(query.EndID(), segment.Node.ID)) + } + + return query.And(filters...) + } +} + +func shallowFetchRelationships(direction graph.Direction, segment *graph.PathSegment, graphQuery graph.RelationshipQuery) ([]*graph.Relationship, error) { + var ( + relationships []*graph.Relationship + returnCriteria graph.Criteria + ) + + switch direction { + case graph.DirectionOutbound: + returnCriteria = query.Returning( + query.EndID(), + query.KindsOf(query.End()), + query.RelationshipID(), + query.KindsOf(query.Relationship()), + ) + + case graph.DirectionInbound: + returnCriteria = query.Returning( + query.StartID(), + query.KindsOf(query.Start()), + query.RelationshipID(), + query.KindsOf(query.Relationship()), + ) + + default: + return nil, fmt.Errorf("bi-directional or non-directed edges are not supported") + } + + if err := graphQuery.Query(func(results graph.Result) error { + defer results.Close() + + var ( + nodeID graph.ID + nodeKinds graph.Kinds + edgeID graph.ID + edgeKind graph.Kind + ) + + for results.Next() { + if err := results.Scan(&nodeID, &nodeKinds, &edgeID, &edgeKind); err != nil { + return err + } + + switch direction { + case graph.DirectionOutbound: + relationships = append(relationships, graph.NewRelationship(edgeID, segment.Node.ID, nodeID, nil, edgeKind)) + + case graph.DirectionInbound: + relationships = append(relationships, graph.NewRelationship(edgeID, nodeID, segment.Node.ID, nil, edgeKind)) + } + } + + return results.Error() + }, returnCriteria); err != nil { + return nil, err + } + + return relationships, nil +} + +// SegmentFilter is a function type that takes a given path segment and returns true if further descent into the path +// is allowed. +type SegmentFilter = func(next *graph.PathSegment) bool + +// SegmentVisitor is a function that receives a path segment as part of certain traversal strategies. +type SegmentVisitor = func(next *graph.PathSegment) + +// UniquePathSegmentFilter is a SegmentFilter constructor that will allow a traversal to all unique paths. This is done +// by tracking edge IDs traversed in a bitmap. +func UniquePathSegmentFilter(delegate SegmentFilter) SegmentFilter { + traversalBitmap := cardinality.ThreadSafeDuplex(cardinality.NewBitmap64()) + + return func(next *graph.PathSegment) bool { + // Bail on cycles + if next.IsCycle() { + return false + } + + // Return if we've seen this edge before + if !traversalBitmap.CheckedAdd(next.Edge.ID.Uint64()) { + return false + } + + // Pass this segment to the delegate if we've never seen it before + return delegate(next) + } +} + +// AcyclicNodeFilter is a SegmentFilter constructor that will allow traversal to a node only once. It will ignore all +// but the first inbound or outbound edge that traverses to it. +func AcyclicNodeFilter(filter SegmentFilter) SegmentFilter { + return func(next *graph.PathSegment) bool { + // Bail on counting ourselves + if next.IsCycle() { + return false + } + + // Descend only if we've never seen this node before. + return filter(next) + } +} + +// A SkipLimitFilter is a function that represents a collection and descent filter for PathSegments. This function must +// return two boolean values: +// +// The first boolean value in the return tuple communicates to the FilteredSkipLimit SegmentFilter if the given +// PathSegment is eligible for collection and therefore should be counted when considering the traversal's skip and +// limit parameters. +// +// The second boolean value in the return tuple communicates to the FilteredSkipLimit SegmentFilter if the given +// PathSegment is eligible for further descent. When this value is true the path will be expanded further during +// traversal. +type SkipLimitFilter = func(next *graph.PathSegment) (bool, bool) + +// FilteredSkipLimit is a SegmentFilter constructor that allows a caller to inform the skip-limit algorithm when a +// result was collected and if the traversal should continue to descend further during traversal. +func FilteredSkipLimit(filter SkipLimitFilter, visitorFilter SegmentVisitor, skip, limit int) SegmentFilter { + var ( + shouldCollect = atomics.NewCounter(uint64(skip)) + atLimit = atomics.NewCounter(uint64(limit)) + ) + + return func(next *graph.PathSegment) bool { + canCollect, shouldDescend := filter(next) + + if canCollect { + // Check to see if this result should be skipped + if skip == 0 || shouldCollect() { + // If we should collect this result, check to see if we're already at a limit for the number of results + if limit > 0 && atLimit() { + slog.Debug(fmt.Sprintf("At collection limit, rejecting path: %s", graph.FormatPathSegment(next))) + return false + } + + slog.Debug(fmt.Sprintf("Collected path: %s", graph.FormatPathSegment(next))) + visitorFilter(next) + } else { + slog.Debug(fmt.Sprintf("Skipping path visit: %s", graph.FormatPathSegment(next))) + } + } + + if shouldDescend { + slog.Debug(fmt.Sprintf("Descending into path: %s", graph.FormatPathSegment(next))) + } else { + slog.Debug(fmt.Sprintf("Rejecting further descent into path: %s", graph.FormatPathSegment(next))) + } + + return shouldDescend + } +} + +// LightweightDriver is a Driver constructor that fetches only IDs and Kind information from vertexes and +// edges stored in the database. This cuts down on network transit and is appropriate for traversals that may involve +// a large number of or all vertexes within a target graph. +func LightweightDriver(direction graph.Direction, cache graphcache.Cache, criteria graph.Criteria, filter SegmentFilter, terminalVisitors ...SegmentVisitor) Driver { + filterProvider := newVisitorFilter(direction, criteria) + + return func(ctx context.Context, tx graph.Transaction, nextSegment *graph.PathSegment) ([]*graph.PathSegment, error) { + var ( + nextSegments []*graph.PathSegment + nextQuery = tx.Relationships().Filter(filterProvider(nextSegment)).OrderBy( + // Order by relationship ID so that skip and limit behave somewhat predictably - cost of this is pretty + // small even for large result sets + query.Order(query.Identity(query.Relationship()), query.Ascending()), + ) + ) + + if relationships, err := shallowFetchRelationships(direction, nextSegment, nextQuery); err != nil { + return nil, err + } else { + // Reconcile the start and end nodes of the fetched relationships with the graph cache + nodesToFetch := cardinality.NewBitmap64() + + for _, nextRelationship := range relationships { + if nextID, err := direction.PickReverse(nextRelationship); err != nil { + return nil, err + } else { + nodesToFetch.Add(nextID.Uint64()) + } + } + + // Shallow fetching the nodes achieves the same result as shallowFetchRelationships(...) but with the added + // benefit of interacting with the graph cache. Any nodes not already in the cache are fetched just-in-time + // from the database and stored back in the cache for later. + if cachedNodes, err := ShallowFetchNodesByID(tx, cache, graph.DuplexToGraphIDs(nodesToFetch)); err != nil { + return nil, err + } else { + cachedNodeSet := graph.NewNodeSet(cachedNodes...) + + for _, nextRelationship := range relationships { + if targetID, err := direction.PickReverse(nextRelationship); err != nil { + return nil, err + } else { + nextSegment := nextSegment.Descend(cachedNodeSet[targetID], nextRelationship) + + if filter(nextSegment) { + nextSegments = append(nextSegments, nextSegment) + } + } + } + } + } + + // If this segment has no further descent paths, render it as a path if we have a path visitor specified + if len(nextSegments) == 0 && len(terminalVisitors) > 0 { + for _, terminalVisitor := range terminalVisitors { + terminalVisitor(nextSegment) + } + } + + return nextSegments, nil + } +} diff --git a/database/v1compat/traversal/traversal_test.go b/database/v1compat/traversal/traversal_test.go new file mode 100644 index 0000000..3038165 --- /dev/null +++ b/database/v1compat/traversal/traversal_test.go @@ -0,0 +1,93 @@ +package traversal + +import ( + "testing" + + graph "github.com/specterops/dawgs/database/v1compat" + "github.com/stretchr/testify/require" +) + +var ( + kindA = graph.StringKind("a") + kindB = graph.StringKind("b") + kindR = graph.StringKind("r") + + node0 = graph.NewNode(0, nil, kindA) + node1 = graph.NewNode(1, nil, kindB) + node2 = graph.NewNode(2, nil, kindB) + node3 = graph.NewNode(3, nil, kindB) + + root = graph.NewRootPathSegment(node0) + + // node1Segment: (node0) <-[kindR]- (node1) + node1Segment = root.Descend(node1, graph.NewRelationship(100, 0, 1, nil, kindR)) + + // node2Segment: (node0) <-[kindR]- (node2) + node2Segment = root.Descend(node2, graph.NewRelationship(101, 0, 2, nil, kindR)) + + // node1Node3Segment: (node0) <-[kindR]- (node1) <-[kindR]- (node3) + node1Node3Segment = node1Segment.Descend(node3, graph.NewRelationship(102, 1, 3, nil, kindR)) + + // node2Node3Segment: (node0) <-[kindR]- (node2) <-[kindR]- (node3) + node2Node3Segment = node2Segment.Descend(node3, graph.NewRelationship(103, 2, 3, nil, kindR)) + + // cycleSegment: (node0) <-[kindR]- (node1) <-[kindR]- (node0) + cycleSegment = node1Segment.Descend(node0, graph.NewRelationship(104, 1, 0, nil, kindR)) +) + +func TestAcyclicSegmentVisitor(t *testing.T) { + visitor := AcyclicNodeFilter(func(next *graph.PathSegment) bool { + return true + }) + + // Disallow cycles + require.False(t, visitor(cycleSegment)) +} + +func TestUniquePathSegmentVisitor(t *testing.T) { + visitor := UniquePathSegmentFilter(func(next *graph.PathSegment) bool { + return true + }) + + // Visiting the segment for the first time should pass + require.True(t, visitor(node1Node3Segment)) + + // Allow traversal to the same node via different paths + require.True(t, visitor(node2Node3Segment)) + + // Disallow retraversal of the same path + require.False(t, visitor(node2Node3Segment)) + + // Disallow cycles + require.False(t, visitor(cycleSegment)) +} + +func TestFilteredSkipLimit(t *testing.T) { + var nodes []*graph.Node + + visitor := FilteredSkipLimit( + func(next *graph.PathSegment) (bool, bool) { + return next.Node.ID == 3, next.Node.ID == 3 + }, + func(next *graph.PathSegment) { + nodes = append(nodes, next.Node) + }, + 1, + 1) + + // Skip and descend + require.True(t, visitor(node1Node3Segment)) + + // Reject descent of node that doesn't match + require.False(t, visitor(node1Segment)) + + // Collect and descend + require.True(t, visitor(node2Node3Segment)) + + // At limit, reject descent + require.False(t, visitor(node1Node3Segment)) + + // Validate that we've collected exactly one node + require.Equal(t, 1, len(nodes)) + require.Equal(t, nodes[0].ID, graph.ID(3)) +} diff --git a/database/v1compat/types.go b/database/v1compat/types.go new file mode 100644 index 0000000..0f1e8bd --- /dev/null +++ b/database/v1compat/types.go @@ -0,0 +1,193 @@ +package v1compat + +import ( + "fmt" + + "github.com/specterops/dawgs/cardinality" + "github.com/specterops/dawgs/database" + "github.com/specterops/dawgs/util/size" +) + +// IndexedSlice is a structure maps a comparable key to a value that implements size.Sizable. +type IndexedSlice[K comparable, V any] struct { + index map[K]int + values []V + size size.Size +} + +func NewIndexedSlice[K comparable, V any]() *IndexedSlice[K, V] { + return &IndexedSlice[K, V]{ + index: make(map[K]int), + size: 0, + } +} + +func (s *IndexedSlice[K, V]) Keys() []K { + keys := make([]K, 0, len(s.index)) + + for key := range s.index { + keys = append(keys, key) + } + + return keys +} + +func (s *IndexedSlice[K, V]) Values() []V { + return s.values +} + +func (s *IndexedSlice[K, V]) Merge(other *IndexedSlice[K, V]) { + for key, idx := range other.index { + s.Put(key, other.values[idx]) + } +} + +// Len returns the number of values stored. +func (s *IndexedSlice[K, V]) Len() int { + return len(s.values) +} + +// SizeOf returns the relative size of the IndexedSlice instance. +func (s *IndexedSlice[K, V]) SizeOf() size.Size { + return s.size +} + +func (s *IndexedSlice[K, V]) Get(key K) V { + if valueIdx, hasValue := s.index[key]; hasValue { + return s.values[valueIdx] + } + + var empty V + return empty +} + +func (s *IndexedSlice[K, V]) Has(key K) bool { + _, hasValue := s.index[key] + return hasValue +} + +func (s *IndexedSlice[K, V]) GetOr(key K, defaultConstructor func() V) V { + if valueIdx, hasValue := s.index[key]; hasValue { + return s.values[valueIdx] + } + + defaultValue := defaultConstructor() + + s.Put(key, defaultValue) + return defaultValue +} + +// CheckedGet returns a tuple containing the value and a boolean representing if a value was found for the +// given key. +func (s *IndexedSlice[K, V]) CheckedGet(key K) (V, bool) { + if valueIdx, hasValue := s.index[key]; hasValue { + return s.values[valueIdx], true + } + + var empty V + return empty, false +} + +// GetAll returns all found values for a given slice of keys. Any keys that do not have stored values +// in this IndexedSlice are returned as the second value of the tuple return for this function. +func (s *IndexedSlice[K, V]) GetAll(keys []K) ([]V, []K) { + var ( + values = make([]V, 0, len(keys)) + missingKeys = make([]K, 0, len(keys)) + ) + + for _, key := range keys { + if valueIdx, hasValue := s.index[key]; hasValue { + values = append(values, s.values[valueIdx]) + } else { + missingKeys = append(missingKeys, key) + } + } + + return values, missingKeys +} + +// GetAllIndexed returns all found values for a given slice of keys. Any keys that do not have stored values +// in this IndexedSlice are returned as the second value of the tuple return for this function. +func (s *IndexedSlice[K, V]) GetAllIndexed(keys []K) (*IndexedSlice[K, V], []K) { + var ( + values = NewIndexedSlice[K, V]() + missingKeys = make([]K, 0, len(keys)) + ) + + for _, key := range keys { + if valueIdx, hasValue := s.index[key]; hasValue { + values.Put(key, s.values[valueIdx]) + } else { + missingKeys = append(missingKeys, key) + } + } + + return values, missingKeys +} + +func sizeOf(value any) size.Size { + if sizeable, typeOK := value.(size.Sizable); typeOK { + return sizeable.SizeOf() + } + + return size.Of(value) +} + +// Put inserts the given value with the given key. +func (s *IndexedSlice[K, V]) Put(key K, value V) { + s.size += sizeOf(value) + + if valueIdx, hasValue := s.index[key]; hasValue { + s.size -= sizeOf(s.values[valueIdx]) + s.values[valueIdx] = value + } else { + s.values = append(s.values, value) + s.index[key] = len(s.values) - 1 + } +} + +func (s *IndexedSlice[K, V]) Each(delegate func(key K, value V) bool) { + for id, idx := range s.index { + if !delegate(id, s.values[idx]) { + break + } + } +} + +// DuplexToGraphIDs takes a Duplex provider and returns a slice of graph IDs. +func DuplexToGraphIDs[T uint32 | uint64](provider cardinality.Duplex[T]) []ID { + ids := make([]ID, 0, provider.Cardinality()) + + provider.Each(func(value T) bool { + ids = append(ids, ID(value)) + return true + }) + + return ids + +} + +func AnyToV2DB(instance any) (database.Instance, error) { + switch typedInstance := instance.(type) { + case *v1Wrapper: + return typedInstance.v2DB, nil + case database.Instance: + return typedInstance, nil + default: + return nil, fmt.Errorf("unsupported instance type: %T", instance) + } +} + +func AnyToV2Driver(session any) (database.Driver, error) { + switch typedSession := session.(type) { + case *driverTransactionWrapper: + return typedSession.driver, nil + case *driverBatchWrapper: + return typedSession.driver, nil + case database.Driver: + return typedSession, nil + default: + return nil, fmt.Errorf("unsupported session type: %T", session) + } +} diff --git a/database/v1compat/wrapper.go b/database/v1compat/wrapper.go new file mode 100644 index 0000000..a652715 --- /dev/null +++ b/database/v1compat/wrapper.go @@ -0,0 +1,434 @@ +package v1compat + +import ( + "context" + "fmt" + "time" + + "github.com/specterops/dawgs/cypher/frontend" + "github.com/specterops/dawgs/database" + "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/query" + "github.com/specterops/dawgs/util/size" +) + +// typedSliceToAnySlice converts a given slice to a slice of type []any +func typedSliceToAnySlice[T any](slice []T) any { + anyCopy := make([]any, len(slice)) + + for idx := 0; idx < len(slice); idx++ { + anyCopy[idx] = slice[idx] + } + + return anyCopy +} + +// downcastPropertyFields is a dawgs version 1 compatibility tool that emulates the JSON encode/decode pass for an +// entity's properties. This is done via a type cast to avoid the compute cost of actually running the JSON +// encode/decode. +func downcastPropertyFields(props *Properties) { + for key, value := range props.MapOrEmpty() { + switch typedValue := value.(type) { + case []string: + props.Map[key] = typedSliceToAnySlice(typedValue) + case []time.Time: + props.Map[key] = typedSliceToAnySlice(typedValue) + case []bool: + props.Map[key] = typedSliceToAnySlice(typedValue) + case []uint: + props.Map[key] = typedSliceToAnySlice(typedValue) + case []uint8: + props.Map[key] = typedSliceToAnySlice(typedValue) + case []uint16: + props.Map[key] = typedSliceToAnySlice(typedValue) + case []uint32: + props.Map[key] = typedSliceToAnySlice(typedValue) + case []uint64: + props.Map[key] = typedSliceToAnySlice(typedValue) + case []int: + props.Map[key] = typedSliceToAnySlice(typedValue) + case []int8: + props.Map[key] = typedSliceToAnySlice(typedValue) + case []int16: + props.Map[key] = typedSliceToAnySlice(typedValue) + case []int32: + props.Map[key] = typedSliceToAnySlice(typedValue) + case []int64: + props.Map[key] = typedSliceToAnySlice(typedValue) + case []float32: + props.Map[key] = typedSliceToAnySlice(typedValue) + case []float64: + props.Map[key] = typedSliceToAnySlice(typedValue) + } + } +} + +type BackwardCompatibleInstance interface { + database.Instance + + RefreshKinds(ctx context.Context) error + Raw(ctx context.Context, query string, parameters map[string]any) error +} + +type BackwardCompatibleDriver interface { + database.Driver + + UpdateNodes(ctx context.Context, batch []graph.NodeUpdate) error + UpdateRelationships(ctx context.Context, batch []graph.RelationshipUpdate) error + CreateNodes(ctx context.Context, batch []*Node) error + CreateRelationships(ctx context.Context, batch []*Relationship) error + DeleteNodes(ctx context.Context, batch []graph.ID) error + DeleteRelationships(ctx context.Context, batch []graph.ID) error +} + +type v1Wrapper struct { + schema *database.Schema + v2DB BackwardCompatibleInstance + writeFlushSize int + batchWriteSize int +} + +func V1Wrapper(v2DB database.Instance) Database { + if v1CompatibleInstanceRef, typeOK := v2DB.(BackwardCompatibleInstance); !typeOK { + panic(fmt.Sprintf("type %T is not a v1CompatibleInstance", v2DB)) + } else { + return &v1Wrapper{ + v2DB: v1CompatibleInstanceRef, + } + } +} + +func (s *v1Wrapper) V2() database.Instance { + return s.v2DB +} + +func (s *v1Wrapper) ReadTransaction(ctx context.Context, txDelegate TransactionDelegate, options ...TransactionOption) error { + return s.v2DB.Session(ctx, func(ctx context.Context, driver database.Driver) error { + return txDelegate(wrapDriverToTransaction(ctx, driver)) + }, database.OptionReadOnly) +} + +func (s *v1Wrapper) WriteTransaction(ctx context.Context, txDelegate TransactionDelegate, options ...TransactionOption) error { + return s.v2DB.Session(ctx, func(ctx context.Context, driver database.Driver) error { + return txDelegate(wrapDriverToTransaction(ctx, driver)) + }) +} + +func (s *v1Wrapper) BatchOperation(ctx context.Context, batchDelegate BatchDelegate) error { + return s.v2DB.Session(ctx, func(ctx context.Context, driver database.Driver) error { + var ( + batchWrapper = wrapDriverToBatch(ctx, driver) + delegateErr = batchDelegate(batchWrapper) + ) + + if delegateErr != nil { + return delegateErr + } + + return batchWrapper.tryFlush(0) + }) +} + +func (s *v1Wrapper) AssertSchema(ctx context.Context, dbSchema database.Schema) error { + s.schema = &dbSchema + return s.v2DB.AssertSchema(ctx, dbSchema) +} + +func (s *v1Wrapper) SetDefaultGraph(ctx context.Context, graphSchema database.Graph) error { + if s.schema != nil { + s.schema.GraphSchemas[graphSchema.Name] = graphSchema + s.schema.DefaultGraphName = graphSchema.Name + } else { + s.schema = &database.Schema{ + GraphSchemas: map[string]database.Graph{ + graphSchema.Name: graphSchema, + }, + DefaultGraphName: graphSchema.Name, + } + } + + return s.v2DB.AssertSchema(ctx, *s.schema) +} + +func (s *v1Wrapper) Run(ctx context.Context, query string, parameters map[string]any) error { + return s.v2DB.Raw(ctx, query, parameters) +} + +func (s *v1Wrapper) Close(ctx context.Context) error { + return s.v2DB.Close(ctx) +} + +func (s *v1Wrapper) FetchKinds(ctx context.Context) (graph.Kinds, error) { + return s.v2DB.FetchKinds(ctx) +} + +func (s *v1Wrapper) RefreshKinds(ctx context.Context) error { + return s.v2DB.RefreshKinds(ctx) +} + +func (s *v1Wrapper) SetWriteFlushSize(interval int) { + s.writeFlushSize = interval +} + +func (s *v1Wrapper) SetBatchWriteSize(interval int) { + s.batchWriteSize = interval +} + +type driverBatchWrapper struct { + ctx context.Context + driver BackwardCompatibleDriver + + nodeDeletionBuffer []graph.ID + relationshipDeletionBuffer []graph.ID + nodeCreateBuffer []*graph.Node + nodeUpdateByBuffer []graph.NodeUpdate + relationshipCreateBuffer []*graph.Relationship + relationshipUpdateByBuffer []graph.RelationshipUpdate + batchWriteSize int +} + +func wrapDriverToBatch(ctx context.Context, driver database.Driver) *driverBatchWrapper { + if v1CompatibleDriverRef, typeOK := driver.(BackwardCompatibleDriver); !typeOK { + panic(fmt.Sprintf("type %T is not a v1CompatibleDriver", driver)) + } else { + return &driverBatchWrapper{ + ctx: ctx, + driver: v1CompatibleDriverRef, + batchWriteSize: 2000, + } + } +} + +func (s *driverBatchWrapper) tryFlush(batchWriteSize int) error { + if len(s.nodeUpdateByBuffer) >= batchWriteSize { + if err := s.driver.UpdateNodes(s.ctx, s.nodeUpdateByBuffer); err != nil { + return err + } + + s.nodeUpdateByBuffer = s.nodeUpdateByBuffer[:0] + } + + if len(s.relationshipUpdateByBuffer) >= batchWriteSize { + if err := s.driver.UpdateRelationships(s.ctx, s.relationshipUpdateByBuffer); err != nil { + return err + } + + s.relationshipUpdateByBuffer = s.relationshipUpdateByBuffer[:0] + } + + if len(s.relationshipCreateBuffer) >= batchWriteSize { + if err := s.driver.CreateRelationships(s.ctx, s.relationshipCreateBuffer); err != nil { + return err + } + + s.relationshipCreateBuffer = s.relationshipCreateBuffer[:0] + } + + if len(s.nodeCreateBuffer) >= batchWriteSize { + if err := s.driver.CreateNodes(s.ctx, s.nodeCreateBuffer); err != nil { + return err + } + + s.nodeCreateBuffer = s.nodeCreateBuffer[:0] + } + + if len(s.nodeDeletionBuffer) >= batchWriteSize { + if err := s.driver.DeleteNodes(s.ctx, s.nodeDeletionBuffer); err != nil { + return err + } + + s.nodeDeletionBuffer = s.nodeDeletionBuffer[:0] + } + + if len(s.relationshipDeletionBuffer) >= batchWriteSize { + if err := s.driver.DeleteRelationships(s.ctx, s.relationshipDeletionBuffer); err != nil { + return err + } + + s.relationshipDeletionBuffer = s.relationshipDeletionBuffer[:0] + } + + return nil +} + +func (s *driverBatchWrapper) WithGraph(graphSchema database.Graph) Batch { + s.driver.WithGraph(graphSchema) + return s +} + +func (s *driverBatchWrapper) CreateNode(node *graph.Node) error { + _, err := s.driver.CreateNode(s.ctx, node) + return err +} + +func (s *driverBatchWrapper) DeleteNode(id graph.ID) error { + s.nodeDeletionBuffer = append(s.nodeDeletionBuffer, id) + return s.tryFlush(s.batchWriteSize) +} + +func (s *driverBatchWrapper) Nodes() NodeQuery { + return newNodeQuery(s.ctx, s.driver) +} + +func (s *driverBatchWrapper) Relationships() RelationshipQuery { + return newRelationshipQuery(s.ctx, s.driver) +} + +func (s *driverBatchWrapper) UpdateNodeBy(update graph.NodeUpdate) error { + s.nodeUpdateByBuffer = append(s.nodeUpdateByBuffer, update) + return s.tryFlush(s.batchWriteSize) +} + +func (s *driverBatchWrapper) CreateRelationship(relationship *graph.Relationship) error { + s.relationshipCreateBuffer = append(s.relationshipCreateBuffer, relationship) + return s.tryFlush(s.batchWriteSize) +} + +func (s *driverBatchWrapper) CreateRelationshipByIDs(startNodeID, endNodeID graph.ID, kind graph.Kind, properties *graph.Properties) error { + return s.CreateRelationship(&graph.Relationship{ + StartID: startNodeID, + EndID: endNodeID, + Kind: kind, + Properties: properties, + }) +} + +func (s *driverBatchWrapper) DeleteRelationship(id graph.ID) error { + s.relationshipDeletionBuffer = append(s.relationshipDeletionBuffer, id) + return s.tryFlush(s.batchWriteSize) +} + +func (s *driverBatchWrapper) UpdateRelationshipBy(update graph.RelationshipUpdate) error { + s.relationshipUpdateByBuffer = append(s.relationshipUpdateByBuffer, update) + return s.tryFlush(s.batchWriteSize) +} + +func (s *driverBatchWrapper) Commit() error { + return s.tryFlush(0) +} + +type driverTransactionWrapper struct { + ctx context.Context + driver BackwardCompatibleDriver +} + +func (s driverTransactionWrapper) GraphQueryMemoryLimit() size.Size { + return size.Gibibyte +} + +func wrapDriverToTransaction(ctx context.Context, driver database.Driver) Transaction { + if v1CompatibleDriverRef, typeOK := driver.(BackwardCompatibleDriver); !typeOK { + panic(fmt.Sprintf("type %T is not a v1CompatibleDriver", driver)) + } else { + return &driverTransactionWrapper{ + ctx: ctx, + driver: v1CompatibleDriverRef, + } + } +} + +func (s driverTransactionWrapper) WithGraph(graphSchema database.Graph) Transaction { + s.driver.WithGraph(graphSchema) + return s +} + +func (s driverTransactionWrapper) CreateNode(properties *graph.Properties, kinds ...graph.Kind) (*graph.Node, error) { + newNode := graph.PrepareNode(properties, kinds...) + + if nodeID, err := s.driver.CreateNode(s.ctx, newNode); err != nil { + return nil, err + } else { + newNode.ID = nodeID + downcastPropertyFields(newNode.Properties) + + return newNode, nil + } +} + +func (s driverTransactionWrapper) UpdateNode(node *graph.Node) error { + updateQuery := query.New().Where(query.Node().ID().Equals(node.ID)) + + if len(node.AddedKinds) > 0 { + updateQuery.Update(query.Node().Kinds().Add(node.AddedKinds)) + } + + if len(node.DeletedKinds) > 0 { + updateQuery.Update(query.Node().Kinds().Remove(node.DeletedKinds)) + } + + if modifiedProperties := node.Properties.ModifiedProperties(); len(modifiedProperties) > 0 { + updateQuery.Update(query.Node().SetProperties(modifiedProperties)) + } + + if deletedProperties := node.Properties.DeletedProperties(); len(deletedProperties) > 0 { + updateQuery.Update(query.Node().RemoveProperties(deletedProperties)) + } + + if buildQuery, err := updateQuery.Build(); err != nil { + return err + } else { + result := s.driver.Exec(s.ctx, buildQuery.Query, buildQuery.Parameters) + defer result.Close(s.ctx) + + return result.Error() + } +} + +func (s driverTransactionWrapper) Nodes() NodeQuery { + return newNodeQuery(s.ctx, s.driver) +} + +func (s driverTransactionWrapper) CreateRelationshipByIDs(startNodeID, endNodeID graph.ID, kind graph.Kind, properties *graph.Properties) (*graph.Relationship, error) { + newRelationship := &graph.Relationship{ + StartID: startNodeID, + EndID: endNodeID, + Kind: kind, + Properties: properties, + } + + if newRelationshipID, err := s.driver.CreateRelationship(s.ctx, newRelationship); err != nil { + return nil, err + } else { + newRelationship.ID = newRelationshipID + downcastPropertyFields(newRelationship.Properties) + + return newRelationship, nil + } +} + +func (s driverTransactionWrapper) UpdateRelationship(relationship *graph.Relationship) error { + updateQuery := query.New().Where(query.Relationship().ID().Equals(relationship.ID)) + + if modifiedProperties := relationship.Properties.ModifiedProperties(); len(modifiedProperties) > 0 { + updateQuery.Update(query.Relationship().SetProperties(modifiedProperties)) + } + + if deletedProperties := relationship.Properties.DeletedProperties(); len(deletedProperties) > 0 { + updateQuery.Update(query.Relationship().RemoveProperties(deletedProperties)) + } + + if buildQuery, err := updateQuery.Build(); err != nil { + return err + } else { + result := s.driver.Exec(s.ctx, buildQuery.Query, buildQuery.Parameters) + defer result.Close(s.ctx) + + return result.Error() + } +} + +func (s driverTransactionWrapper) Relationships() RelationshipQuery { + return newRelationshipQuery(s.ctx, s.driver) +} + +func (s driverTransactionWrapper) Query(query string, parameters map[string]any) Result { + if cypherQuery, err := frontend.ParseCypher(frontend.NewContext(), query); err != nil { + return NewErrorResult(err) + } else { + return wrapResult(s.ctx, s.driver.Exec(s.ctx, cypherQuery, parameters), s.driver.Mapper()) + } +} + +func (s driverTransactionWrapper) Commit() error { + return nil +} diff --git a/drivers/neo4j/batch.go b/drivers/neo4j/batch.go deleted file mode 100644 index 3e8586c..0000000 --- a/drivers/neo4j/batch.go +++ /dev/null @@ -1,254 +0,0 @@ -package neo4j - -import ( - "context" - "strings" - - "github.com/neo4j/neo4j-go-driver/v5/neo4j" - "github.com/specterops/dawgs/graph" - "github.com/specterops/dawgs/util/size" -) - -type createRelationshipByIDs struct { - startID graph.ID - endID graph.ID - kind graph.Kind - properties *graph.Properties -} - -type batchTransaction struct { - innerTx *neo4jTransaction - nodeDeletionBuffer []graph.ID - relationshipDeletionBuffer []graph.ID - nodeUpdateByBuffer []graph.NodeUpdate - relationshipCreateBuffer []createRelationshipByIDs - relationshipUpdateByBuffer []graph.RelationshipUpdate - batchWriteSize int -} - -func (s *batchTransaction) CreateNode(node *graph.Node) error { - _, err := s.innerTx.CreateNode(node.Properties, node.Kinds...) - return err -} - -func (s *batchTransaction) CreateRelationship(relationship *graph.Relationship) error { - return s.CreateRelationshipByIDs(relationship.StartID, relationship.EndID, relationship.Kind, relationship.Properties) -} - -func (s *batchTransaction) WithGraph(graphSchema graph.Graph) graph.Batch { - return s -} - -func (s *batchTransaction) Nodes() graph.NodeQuery { - return NewNodeQuery(s.innerTx.ctx, s) -} - -func (s *batchTransaction) Relationships() graph.RelationshipQuery { - return NewRelationshipQuery(s.innerTx.ctx, s) -} - -func (s *batchTransaction) UpdateNodeBy(update graph.NodeUpdate) error { - if s.nodeUpdateByBuffer = append(s.nodeUpdateByBuffer, update); len(s.nodeUpdateByBuffer) >= s.batchWriteSize { - return s.flushNodeUpdates() - } - - return nil -} - -func (s *batchTransaction) UpdateRelationshipBy(update graph.RelationshipUpdate) error { - if s.relationshipUpdateByBuffer = append(s.relationshipUpdateByBuffer, update); len(s.relationshipUpdateByBuffer) >= s.batchWriteSize { - return s.flushRelationshipUpdates() - } - - return nil -} - -func (s *batchTransaction) DeleteNodes(ids []graph.ID) error { - return s.innerTx.DeleteNodesBySlice(ids) -} - -func (s *batchTransaction) DeleteRelationships(ids []graph.ID) error { - return s.innerTx.DeleteRelationshipsBySlice(ids) -} - -func (s *batchTransaction) Commit() error { - if len(s.nodeUpdateByBuffer) > 0 { - if err := s.flushNodeUpdates(); err != nil { - return err - } - } - - if len(s.relationshipCreateBuffer) > 0 { - if err := s.flushRelationshipCreation(); err != nil { - return err - } - } - - if len(s.relationshipUpdateByBuffer) > 0 { - if err := s.flushRelationshipUpdates(); err != nil { - return err - } - } - - if len(s.nodeDeletionBuffer) > 0 { - if err := s.flushNodeDeletions(); err != nil { - return err - } - } - - if len(s.relationshipDeletionBuffer) > 0 { - if err := s.flushRelationshipDeletions(); err != nil { - return err - } - } - - return s.innerTx.Commit() -} - -func (s *batchTransaction) Close() error { - return s.innerTx.Close() -} - -func (s *batchTransaction) UpdateNode(target *graph.Node) error { - return s.innerTx.UpdateNode(target) -} - -func (s *batchTransaction) CreateRelationshipByIDs(startNodeID, endNodeID graph.ID, kind graph.Kind, properties *graph.Properties) error { - nextUpdate := createRelationshipByIDs{ - startID: startNodeID, - endID: endNodeID, - kind: kind, - properties: properties, - } - - if s.relationshipCreateBuffer = append(s.relationshipCreateBuffer, nextUpdate); len(s.relationshipCreateBuffer) >= s.batchWriteSize { - return s.flushRelationshipCreation() - } - - return nil -} - -func (s *batchTransaction) DeleteNode(id graph.ID) error { - if s.nodeDeletionBuffer = append(s.nodeDeletionBuffer, id); len(s.nodeDeletionBuffer) >= s.batchWriteSize { - return s.flushNodeDeletions() - } - - return nil -} - -func (s *batchTransaction) DeleteRelationship(id graph.ID) error { - if s.relationshipDeletionBuffer = append(s.relationshipDeletionBuffer, id); len(s.relationshipDeletionBuffer) >= s.batchWriteSize { - return s.flushRelationshipDeletions() - } - - return nil -} - -func (s *batchTransaction) UpdateRelationship(relationship *graph.Relationship) error { - return s.innerTx.UpdateRelationship(relationship) -} - -func (s *batchTransaction) Raw(cypher string, params map[string]any) graph.Result { - return s.innerTx.Raw(cypher, params) -} - -type relationshipCreateByIDBatch struct { - numRelationships int - queryParameters map[string]any -} - -func cypherBuildRelationshipCreateByIDBatch(updates []createRelationshipByIDs) ([]string, []relationshipCreateByIDBatch) { - var ( - queries []string - queryParameters []relationshipCreateByIDBatch - - output = strings.Builder{} - updatesByRelKind = map[graph.Kind][]createRelationshipByIDs{} - ) - - for _, update := range updates { - updatesByRelKind[update.kind] = append(updatesByRelKind[update.kind], update) - } - - for kind, batchJobs := range updatesByRelKind { - output.WriteString("unwind $p as p match (s) where id(s) = p.s match(e) where id(e) = p.e merge (s)-[r:") - output.WriteString(kind.String()) - output.WriteString("]->(e) set r += p.p") - - nextQueryParameters := make([]map[string]any, len(batchJobs)) - - for idx, batchJob := range batchJobs { - nextQueryParameters[idx] = map[string]any{ - "s": batchJob.startID, - "e": batchJob.endID, - "p": batchJob.properties.Map, - } - } - - queries = append(queries, output.String()) - queryParameters = append(queryParameters, relationshipCreateByIDBatch{ - numRelationships: len(nextQueryParameters), - queryParameters: map[string]any{ - "p": nextQueryParameters, - }, - }) - - output.Reset() - } - - return queries, queryParameters -} - -func (s *batchTransaction) flushRelationshipCreation() error { - statements, batches := cypherBuildRelationshipCreateByIDBatch(s.relationshipCreateBuffer) - - for parameterIdx, statement := range statements { - nextBatch := batches[parameterIdx] - - if result := s.innerTx.runAndLog(statement, nextBatch.queryParameters, nextBatch.numRelationships); result.Error() != nil { - return result.Error() - } - } - - s.relationshipCreateBuffer = s.relationshipCreateBuffer[:0] - return nil -} - -func (s *batchTransaction) flushRelationshipDeletions() error { - buffer := s.relationshipDeletionBuffer - s.relationshipDeletionBuffer = s.relationshipDeletionBuffer[:0] - - return s.DeleteRelationships(buffer) -} - -func (s *batchTransaction) flushNodeUpdates() error { - buffer := s.nodeUpdateByBuffer - s.nodeUpdateByBuffer = s.nodeUpdateByBuffer[:0] - - return s.innerTx.updateNodesBy(buffer...) -} - -func (s *batchTransaction) flushRelationshipUpdates() error { - buffer := s.relationshipUpdateByBuffer - s.relationshipUpdateByBuffer = s.relationshipUpdateByBuffer[:0] - - return s.innerTx.updateRelationshipsBy(buffer...) -} - -func (s *batchTransaction) flushNodeDeletions() error { - buffer := s.nodeDeletionBuffer - s.nodeDeletionBuffer = s.nodeDeletionBuffer[:0] - - return s.innerTx.DeleteNodesBySlice(buffer) -} - -func newBatchOperation(ctx context.Context, session neo4j.Session, cfg graph.TransactionConfig, writeFlushSize int, batchWriteSize int, graphQueryMemoryLimit size.Size) *batchTransaction { - return &batchTransaction{ - innerTx: newTransaction(ctx, session, cfg, writeFlushSize, batchWriteSize, graphQueryMemoryLimit), - batchWriteSize: batchWriteSize, - nodeDeletionBuffer: make([]graph.ID, 0, batchWriteSize), - relationshipDeletionBuffer: make([]graph.ID, 0, batchWriteSize), - nodeUpdateByBuffer: make([]graph.NodeUpdate, 0, batchWriteSize), - relationshipUpdateByBuffer: make([]graph.RelationshipUpdate, 0, batchWriteSize), - } -} diff --git a/drivers/neo4j/const.go b/drivers/neo4j/const.go deleted file mode 100644 index ee8c258..0000000 --- a/drivers/neo4j/const.go +++ /dev/null @@ -1,12 +0,0 @@ -package neo4j - -// TODO: Deprecate these - -const ( - cypherDeleteNodeByID = `match (n) where id(n) = $id detach delete n` - cypherDeleteNodesByID = `match (n) where id(n) in $id_list detach delete n` - cypherDeleteRelationshipByID = `match ()-[r]->() where id(r) = $id delete r` - cypherDeleteRelationshipsByID = `unwind $p as rid match ()-[r]->() where id(r) = rid delete r` - idParameterName = "id" - idListParameterName = "id_list" -) diff --git a/drivers/neo4j/cypher.go b/drivers/neo4j/cypher.go deleted file mode 100644 index 378fc24..0000000 --- a/drivers/neo4j/cypher.go +++ /dev/null @@ -1,319 +0,0 @@ -package neo4j - -import ( - "bytes" - "fmt" - "log/slog" - "sort" - "strings" - - "github.com/specterops/dawgs/cypher/frontend" - "github.com/specterops/dawgs/cypher/models/cypher/format" - "github.com/specterops/dawgs/graph" -) - -func newUpdateKey(identityKind graph.Kind, identityProperties []string, updateKinds graph.Kinds) string { - var keys []string - - // Defensive check: identityKind may be nil or zero value - if identityKind != nil && !identityKind.Is(graph.EmptyKind) { - keys = append(keys, identityKind.String()) - } - - keys = append(keys, identityProperties...) - keys = append(keys, updateKinds.Strings()...) - - sort.Strings(keys) - - return strings.Join(keys, "") -} - -func relUpdateKey(update graph.RelationshipUpdate) string { - keys := []string{ - newUpdateKey(update.StartIdentityKind, update.StartIdentityProperties, update.Start.Kinds), - newUpdateKey(update.Relationship.Kind, update.IdentityProperties, nil), - newUpdateKey(update.EndIdentityKind, update.EndIdentityProperties, update.End.Kinds), - } - - return strings.Join(keys, "") -} - -type relUpdates struct { - identityKind graph.Kind - identityProperties []string - startIdentityKind graph.Kind - startIdentityProperties []string - startNodeKindsToAdd graph.Kinds - endIdentityKind graph.Kind - endIdentityProperties []string - endNodeKindsToAdd graph.Kinds - properties []map[string]any -} - -type relUpdateByMap map[string]*relUpdates - -func (s relUpdateByMap) add(update graph.RelationshipUpdate) { - var ( - updateKey = relUpdateKey(update) - updateProperties = map[string]any{ - "r": update.Relationship.Properties.Map, - "s": update.Start.Properties.Map, - "e": update.End.Properties.Map, - } - ) - - if updates, hasUpdates := s[updateKey]; hasUpdates { - updates.properties = append(updates.properties, updateProperties) - } else { - s[updateKey] = &relUpdates{ - identityKind: update.Relationship.Kind, - identityProperties: update.IdentityProperties, - startIdentityKind: update.StartIdentityKind, - startIdentityProperties: update.StartIdentityProperties, - startNodeKindsToAdd: update.Start.Kinds, - endIdentityKind: update.EndIdentityKind, - endIdentityProperties: update.EndIdentityProperties, - endNodeKindsToAdd: update.End.Kinds, - properties: []map[string]any{ - updateProperties, - }, - } - } -} - -func cypherBuildRelationshipUpdateQueryBatch(updates []graph.RelationshipUpdate) ([]string, [][]map[string]any) { - var ( - queries []string - queryParameters [][]map[string]any - - output = strings.Builder{} - batchedUpdates = relUpdateByMap{} - ) - - for _, update := range updates { - batchedUpdates.add(update) - } - - for _, batch := range batchedUpdates { - output.WriteString("unwind $p as p merge (s") - - if batch.startIdentityKind != nil && !batch.startIdentityKind.Is(graph.EmptyKind) { - output.WriteString(fmt.Sprintf(":%s", batch.startIdentityKind.String())) - } - - if len(batch.startIdentityProperties) > 0 { - output.WriteString(" {") - - firstIdentityProperty := true - for _, identityProperty := range batch.startIdentityProperties { - if firstIdentityProperty { - firstIdentityProperty = false - } else { - output.WriteString(",") - } - - output.WriteString(identityProperty) - output.WriteString(":p.s.") - output.WriteString(identityProperty) - } - - output.WriteString("}") - } - - output.WriteString(") merge (e") - if batch.endIdentityKind != nil && !batch.endIdentityKind.Is(graph.EmptyKind) { - output.WriteString(fmt.Sprintf(":%s", batch.endIdentityKind.String())) - } - - if len(batch.endIdentityProperties) > 0 { - output.WriteString(" {") - - firstIdentityProperty := true - for _, identityProperty := range batch.endIdentityProperties { - if firstIdentityProperty { - firstIdentityProperty = false - } else { - output.WriteString(",") - } - - output.WriteString(identityProperty) - output.WriteString(":p.e.") - output.WriteString(identityProperty) - } - - output.WriteString("}") - } - - output.WriteString(") merge (s)-[r:") - output.WriteString(batch.identityKind.String()) - - if len(batch.identityProperties) > 0 { - output.WriteString(" {") - - firstIdentityProperty := true - for _, identityProperty := range batch.identityProperties { - if firstIdentityProperty { - firstIdentityProperty = false - } else { - output.WriteString(",") - } - - output.WriteString(identityProperty) - output.WriteString(":p.r.") - output.WriteString(identityProperty) - } - - output.WriteString("}") - } - - output.WriteString("]->(e) set s += p.s, e += p.e, r += p.r") - - if len(batch.startNodeKindsToAdd) > 0 { - for _, kindToAdd := range batch.startNodeKindsToAdd { - if kindToAdd == graph.EmptyKind { - continue // skip empty kinds - } - output.WriteString(", s:") - output.WriteString(kindToAdd.String()) - } - } - - if len(batch.endNodeKindsToAdd) > 0 { - for _, kindToAdd := range batch.endNodeKindsToAdd { - if kindToAdd == graph.EmptyKind { - continue // skip empty kinds - } - output.WriteString(", e:") - output.WriteString(kindToAdd.String()) - } - } - - output.WriteString(", s.lastseen = datetime({timezone: 'UTC'}), e.lastseen = datetime({timezone: 'UTC'});") - - // Write out the query to be run - queries = append(queries, output.String()) - queryParameters = append(queryParameters, batch.properties) - - output.Reset() - } - - return queries, queryParameters -} - -type nodeUpdates struct { - identityKind graph.Kind - identityProperties []string - nodeKindsToAdd graph.Kinds - nodeKindsToRemove graph.Kinds - properties []map[string]any -} - -type nodeUpdateByMap map[string]*nodeUpdates - -func (s nodeUpdateByMap) add(update graph.NodeUpdate) { - updateKey := newUpdateKey(update.IdentityKind, update.IdentityProperties, update.Node.Kinds) - - if updates, hasUpdates := s[updateKey]; hasUpdates { - updates.properties = append(updates.properties, update.Node.Properties.Map) - } else { - s[updateKey] = &nodeUpdates{ - identityKind: update.IdentityKind, - identityProperties: update.IdentityProperties, - nodeKindsToAdd: update.Node.Kinds, - nodeKindsToRemove: update.Node.DeletedKinds, - properties: []map[string]any{ - update.Node.Properties.Map, - }, - } - } -} - -func cypherBuildNodeUpdateQueryBatch(updates []graph.NodeUpdate) ([]string, []map[string]any) { - var ( - queries []string - queryParameters []map[string]any - - output = strings.Builder{} - batchedUpdates = nodeUpdateByMap{} - ) - - for _, update := range updates { - batchedUpdates.add(update) - } - - for _, batch := range batchedUpdates { - output.WriteString("unwind $p as p merge (n") - - if batch.identityKind != nil && !batch.identityKind.Is(graph.EmptyKind) { - output.WriteString(fmt.Sprintf(":%s", batch.identityKind.String())) - } - - if len(batch.identityProperties) > 0 { - output.WriteString(" {") - - firstIdentityProperty := true - for _, identityProperty := range batch.identityProperties { - if firstIdentityProperty { - firstIdentityProperty = false - } else { - output.WriteString(",") - } - - output.WriteString(identityProperty) - output.WriteString(":p.") - output.WriteString(identityProperty) - } - - output.WriteString("}") - } - - output.WriteString(") set n += p") - - if len(batch.nodeKindsToAdd) > 0 { - for _, kindToAdd := range batch.nodeKindsToAdd { - output.WriteString(", n:") - output.WriteString(kindToAdd.String()) - } - } - - if len(batch.nodeKindsToRemove) > 0 { - output.WriteString(" remove ") - - for idx, kindToRemove := range batch.nodeKindsToRemove { - if idx > 0 { - output.WriteString(",") - } - - output.WriteString("n:") - output.WriteString(kindToRemove.String()) - } - } - - output.WriteString(";") - - // Write out the query to be run - queries = append(queries, output.String()) - queryParameters = append(queryParameters, map[string]any{ - "p": batch.properties, - }) - - output.Reset() - } - - return queries, queryParameters -} - -func stripCypherQuery(rawQuery string) string { - var ( - strippedEmitter = format.NewCypherEmitter(true) - buffer = &bytes.Buffer{} - ) - - if queryModel, err := frontend.ParseCypher(frontend.DefaultCypherContext(), rawQuery); err != nil { - slog.Error(fmt.Sprintf("Error occurred parsing cypher query during sanitization: %v", err)) - } else if err = strippedEmitter.Write(queryModel, buffer); err != nil { - slog.Error(fmt.Sprintf("Error occurred sanitizing cypher query: %v", err)) - } - - return buffer.String() -} diff --git a/drivers/neo4j/cypher_internal_test.go b/drivers/neo4j/cypher_internal_test.go deleted file mode 100644 index 1482ba1..0000000 --- a/drivers/neo4j/cypher_internal_test.go +++ /dev/null @@ -1,89 +0,0 @@ -package neo4j - -import ( - "strings" - "testing" - - "github.com/specterops/dawgs/graph" - - "github.com/stretchr/testify/require" -) - -func Test_relUpdateKey(t *testing.T) { - updateKey := relUpdateKey(graph.RelationshipUpdate{ - Relationship: &graph.Relationship{ - ID: 1, - StartID: 1, - EndID: 2, - Kind: graph.StringKind("MemberOf"), - Properties: graph.NewProperties(), - }, - Start: &graph.Node{ - ID: 1, - Kinds: graph.Kinds{graph.StringKind("User")}, - Properties: graph.AsProperties(map[string]any{ - "objectid": "OID-1", - }), - }, - StartIdentityKind: graph.StringKind("Base"), - StartIdentityProperties: []string{"objectid"}, - End: &graph.Node{ - ID: 2, - Kinds: graph.Kinds{graph.StringKind("Group")}, - Properties: graph.AsProperties(map[string]any{ - "objectid": "OID-2", - }), - }, - EndIdentityKind: graph.StringKind("Base"), - EndIdentityProperties: []string{"objectid"}, - }) - - // Order must be preserved to make each key unique. This is required as the batch insert is authored as an unwound - // merge statement. The update key groups like updates so that the generated query can address an entire batch of - // upsert entries at-once: - // - // unwind $p as p merge (s:Base {objectid: p.s.objectid}) merge (e:Base {objectid: p.e.objectid}) merge (s)-[r:MemberOf]->(e) set s += p.s, e += p.e, r += p.r, s:User, e:Group - require.Equal(t, "BaseUserobjectidMemberOfBaseGroupobjectid", updateKey) - - updateKey = relUpdateKey(graph.RelationshipUpdate{ - Relationship: &graph.Relationship{ - ID: 1, - StartID: 1, - EndID: 2, - Kind: graph.StringKind("GenericAll"), - Properties: graph.NewProperties(), - }, - Start: &graph.Node{ - ID: 1, - Kinds: graph.Kinds{graph.StringKind("User")}, - Properties: graph.AsProperties(map[string]any{ - "objectid": "OID-1", - }), - }, - StartIdentityKind: graph.StringKind("Base"), - StartIdentityProperties: []string{"objectid"}, - End: &graph.Node{ - ID: 2, - Kinds: graph.Kinds{graph.StringKind("Group")}, - Properties: graph.AsProperties(map[string]any{ - "objectid": "OID-2", - }), - }, - EndIdentityKind: graph.StringKind("Base"), - EndIdentityProperties: []string{"objectid"}, - }) - - // unwind $p as p merge (s:Base {objectid: p.s.objectid}) merge (e:Base {objectid: p.e.objectid}) merge (s)-[r:GenericAll]->(e) set s += p.s, e += p.e, r += p.r, s:User, e:Group - require.Equal(t, "BaseUserobjectidGenericAllBaseGroupobjectid", updateKey) -} - -func Test_StripCypher(t *testing.T) { - var ( - query = "match (u1:User {domain: \"DOMAIN1\"}), (u2:User {domain: \"DOMAIN2\"}) where u1.samaccountname <> \"krbtgt\" and u1.samaccountname = u2.samaccountname with u2 match p1 = (u2)-[*1..]->(g:Group) with p1 match p2 = (u2)-[*1..]->(g:Group) return p1, p2" - ) - - result := stripCypherQuery(query) - - require.Equalf(t, false, strings.Contains(result, "DOMAIN1"), "Cypher query not sanitized. Contains sensitive value: %s", result) - require.Equalf(t, false, strings.Contains(result, "DOMAIN2"), "Cypher query not sanitized. Contains sensitive value: %s", result) -} diff --git a/drivers/neo4j/driver.go b/drivers/neo4j/driver.go deleted file mode 100644 index 7498d9d..0000000 --- a/drivers/neo4j/driver.go +++ /dev/null @@ -1,168 +0,0 @@ -package neo4j - -import ( - "context" - "time" - - "github.com/neo4j/neo4j-go-driver/v5/neo4j" - "github.com/specterops/dawgs/graph" - "github.com/specterops/dawgs/util/channels" - "github.com/specterops/dawgs/util/size" -) - -const ( - DriverName = "neo4j" -) - -func readCfg() neo4j.SessionConfig { - return neo4j.SessionConfig{ - AccessMode: neo4j.AccessModeRead, - } -} - -func writeCfg() neo4j.SessionConfig { - return neo4j.SessionConfig{ - AccessMode: neo4j.AccessModeWrite, - } -} - -type driver struct { - driver neo4j.Driver - limiter channels.ConcurrencyLimiter - defaultTransactionTimeout time.Duration - batchWriteSize int - writeFlushSize int - graphQueryMemoryLimit size.Size -} - -func (s *driver) SetBatchWriteSize(size int) { - s.batchWriteSize = size -} - -func (s *driver) SetWriteFlushSize(size int) { - s.writeFlushSize = size -} - -func (s *driver) BatchOperation(ctx context.Context, batchDelegate graph.BatchDelegate) error { - // Attempt to acquire a connection slot or wait for a bit until one becomes available - if !s.limiter.Acquire(ctx) { - return graph.ErrContextTimedOut - } else { - defer s.limiter.Release() - } - - var ( - cfg = graph.TransactionConfig{ - Timeout: s.defaultTransactionTimeout, - } - - session = s.driver.NewSession(writeCfg()) - batch = newBatchOperation(ctx, session, cfg, s.writeFlushSize, s.batchWriteSize, s.graphQueryMemoryLimit) - ) - - defer session.Close() - defer batch.Close() - - if err := batchDelegate(batch); err != nil { - return err - } - - return batch.Commit() -} - -func (s *driver) Close(ctx context.Context) error { - return s.driver.Close() -} - -func (s *driver) transaction(ctx context.Context, txDelegate graph.TransactionDelegate, session neo4j.Session, options []graph.TransactionOption) error { - // Attempt to acquire a connection slot or wait for a bit until one becomes available - if !s.limiter.Acquire(ctx) { - return graph.ErrContextTimedOut - } else { - defer s.limiter.Release() - } - - cfg := graph.TransactionConfig{ - Timeout: s.defaultTransactionTimeout, - } - - // Apply the transaction options - for _, option := range options { - option(&cfg) - } - - tx := newTransaction(ctx, session, cfg, s.writeFlushSize, s.batchWriteSize, s.graphQueryMemoryLimit) - defer tx.Close() - - if err := txDelegate(tx); err != nil { - return err - } - - return tx.Commit() -} - -func (s *driver) ReadTransaction(ctx context.Context, txDelegate graph.TransactionDelegate, options ...graph.TransactionOption) error { - session := s.driver.NewSession(readCfg()) - defer session.Close() - - return s.transaction(ctx, txDelegate, session, options) -} - -func (s *driver) WriteTransaction(ctx context.Context, txDelegate graph.TransactionDelegate, options ...graph.TransactionOption) error { - session := s.driver.NewSession(writeCfg()) - defer session.Close() - - return s.transaction(ctx, txDelegate, session, options) -} - -func (s *driver) AssertSchema(ctx context.Context, schema graph.Schema) error { - return assertSchema(ctx, s, schema) -} - -func (s *driver) SetDefaultGraph(ctx context.Context, schema graph.Graph) error { - // Note: Neo4j does not support isolated physical graph namespaces. Namespacing can be emulated with Kinds but will - // not be supported for this driver since the fallback behavior is no different from storing all graph data in the - // same namespace. - // - // This is different for the PostgreSQL driver, specifically, since the driver in question supports on-disk - // isolation of graph namespaces. - return nil -} - -func (s *driver) Run(ctx context.Context, query string, parameters map[string]any) error { - return s.WriteTransaction(ctx, func(tx graph.Transaction) error { - result := tx.Raw(query, parameters) - defer result.Close() - - return result.Error() - }) -} - -func (s *driver) FetchKinds(ctx context.Context) (graph.Kinds, error) { - var kinds graph.Kinds - - if err := s.ReadTransaction(ctx, func(tx graph.Transaction) error { - if result := tx.Raw("CALL db.labels()", nil); result.Error() != nil { - return result.Error() - } else { - for result.Next() { - var kind string - if err := result.Scan(&kind); err != nil { - return err - } else { - kinds = append(kinds, graph.StringKind(kind)) - } - } - } - return nil - }); err != nil { - return nil, err - } - - return kinds, nil -} - -func (s *driver) RefreshKinds(_ context.Context) error { - // This isn't needed for neo4j - return nil -} diff --git a/drivers/neo4j/index.go b/drivers/neo4j/index.go deleted file mode 100644 index 16e33e9..0000000 --- a/drivers/neo4j/index.go +++ /dev/null @@ -1,269 +0,0 @@ -package neo4j - -import ( - "context" - "fmt" - "log/slog" - "strings" - - "github.com/specterops/dawgs/graph" -) - -const ( - nativeBTreeIndexProvider = "native-btree-1.0" - nativeLuceneIndexProvider = "lucene+native-3.0" - - dropPropertyIndexStatement = "drop index $name;" - dropPropertyConstraintStatement = "drop constraint $name;" - createPropertyIndexStatement = "call db.createIndex($name, $labels, $properties, $provider);" - createPropertyConstraintStatement = "call db.createUniquePropertyConstraint($name, $labels, $properties, $provider);" -) - -type neo4jIndex struct { - graph.Index - - kind graph.Kind -} - -type neo4jConstraint struct { - graph.Constraint - - kind graph.Kind -} - -type neo4jSchema struct { - Indexes map[string]neo4jIndex - Constraints map[string]neo4jConstraint -} - -func newNeo4jSchema() neo4jSchema { - return neo4jSchema{ - Indexes: map[string]neo4jIndex{}, - Constraints: map[string]neo4jConstraint{}, - } -} - -func toNeo4jSchema(dbSchema graph.Schema) neo4jSchema { - neo4jSchemaInst := newNeo4jSchema() - - for _, graphSchema := range dbSchema.Graphs { - for _, index := range graphSchema.NodeIndexes { - for _, kind := range graphSchema.Nodes { - indexName := strings.ToLower(kind.String()) + "_" + strings.ToLower(index.Field) + "_index" - - neo4jSchemaInst.Indexes[indexName] = neo4jIndex{ - Index: graph.Index{ - Name: indexName, - Field: index.Field, - Type: index.Type, - }, - kind: kind, - } - } - } - - for _, constraint := range graphSchema.NodeConstraints { - for _, kind := range graphSchema.Nodes { - constraintName := strings.ToLower(kind.String()) + "_" + strings.ToLower(constraint.Field) + "_constraint" - - neo4jSchemaInst.Constraints[constraintName] = neo4jConstraint{ - Constraint: graph.Constraint{ - Name: constraintName, - Field: constraint.Field, - Type: constraint.Type, - }, - kind: kind, - } - } - } - } - - return neo4jSchemaInst -} - -func parseProviderType(provider string) graph.IndexType { - switch provider { - case nativeBTreeIndexProvider: - return graph.BTreeIndex - case nativeLuceneIndexProvider: - return graph.TextSearchIndex - default: - return graph.UnsupportedIndex - } -} - -func indexTypeProvider(indexType graph.IndexType) string { - switch indexType { - case graph.BTreeIndex: - return nativeBTreeIndexProvider - case graph.TextSearchIndex: - return nativeLuceneIndexProvider - default: - return "" - } -} - -func assertIndexes(ctx context.Context, db graph.Database, indexesToRemove []string, indexesToAdd map[string]neo4jIndex) error { - if err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { - for _, indexToRemove := range indexesToRemove { - slog.InfoContext(ctx, fmt.Sprintf("Removing index %s", indexToRemove)) - - result := tx.Raw(strings.Replace(dropPropertyIndexStatement, "$name", indexToRemove, 1), nil) - result.Close() - - if err := result.Error(); err != nil { - return err - } - } - - return nil - }); err != nil { - return err - } - - return db.WriteTransaction(ctx, func(tx graph.Transaction) error { - for indexName, indexToAdd := range indexesToAdd { - slog.InfoContext(ctx, fmt.Sprintf("Adding index %s to labels %s on properties %s using %s", indexName, indexToAdd.kind.String(), indexToAdd.Field, indexTypeProvider(indexToAdd.Type))) - - if err := db.Run(ctx, createPropertyIndexStatement, map[string]interface{}{ - "name": indexName, - "labels": []string{indexToAdd.kind.String()}, - "properties": []string{indexToAdd.Field}, - "provider": indexTypeProvider(indexToAdd.Type), - }); err != nil { - return err - } - } - - return nil - }) -} - -func assertConstraints(ctx context.Context, db graph.Database, constraintsToRemove []string, constraintsToAdd map[string]neo4jConstraint) error { - for _, constraintToRemove := range constraintsToRemove { - if err := db.Run(ctx, strings.Replace(dropPropertyConstraintStatement, "$name", constraintToRemove, 1), nil); err != nil { - return err - } - } - - for constraintName, constraintToAdd := range constraintsToAdd { - if err := db.Run(ctx, createPropertyConstraintStatement, map[string]interface{}{ - "name": constraintName, - "labels": []string{constraintToAdd.kind.String()}, - "properties": []string{constraintToAdd.Field}, - "provider": indexTypeProvider(constraintToAdd.Type), - }); err != nil { - return err - } - } - - return nil -} - -func fetchPresentSchema(ctx context.Context, db graph.Database) (neo4jSchema, error) { - presentSchema := newNeo4jSchema() - - return presentSchema, db.ReadTransaction(ctx, func(tx graph.Transaction) error { - if result := tx.Raw("call db.indexes() yield name, uniqueness, provider, labelsOrTypes, properties;", nil); result.Error() != nil { - return result.Error() - } else { - defer result.Close() - - var ( - name string - uniqueness string - provider string - labels []string - properties []string - ) - - for result.Next() { - if err := result.Scan(&name, &uniqueness, &provider, &labels, &properties); err != nil { - return err - } - - // Need this for neo4j 4.4+ which creates a weird index by default - if len(labels) == 0 { - continue - } - - if len(labels) > 1 || len(properties) > 1 { - return fmt.Errorf("composite index types are currently not supported") - } - - if uniqueness == "UNIQUE" { - presentSchema.Constraints[name] = neo4jConstraint{ - Constraint: graph.Constraint{ - Name: name, - Field: properties[0], - Type: parseProviderType(provider), - }, - kind: graph.StringKind(labels[0]), - } - } else { - presentSchema.Indexes[name] = neo4jIndex{ - Index: graph.Index{ - Name: name, - Field: properties[0], - Type: parseProviderType(provider), - }, - kind: graph.StringKind(labels[0]), - } - } - } - - return result.Error() - } - }) -} - -func assertSchema(ctx context.Context, db graph.Database, required graph.Schema) error { - requiredNeo4jSchema := toNeo4jSchema(required) - - if presentNeo4jSchema, err := fetchPresentSchema(ctx, db); err != nil { - return err - } else { - var ( - indexesToRemove []string - constraintsToRemove []string - indexesToAdd = map[string]neo4jIndex{} - constraintsToAdd = map[string]neo4jConstraint{} - ) - - for presentIndexName := range presentNeo4jSchema.Indexes { - if _, hasMatchingDefinition := requiredNeo4jSchema.Indexes[presentIndexName]; !hasMatchingDefinition { - indexesToRemove = append(indexesToRemove, presentIndexName) - } - } - - for presentConstraintName := range presentNeo4jSchema.Constraints { - if _, hasMatchingDefinition := requiredNeo4jSchema.Constraints[presentConstraintName]; !hasMatchingDefinition { - constraintsToRemove = append(constraintsToRemove, presentConstraintName) - } - } - - for requiredIndexName, requiredIndex := range requiredNeo4jSchema.Indexes { - if presentIndex, hasMatchingDefinition := presentNeo4jSchema.Indexes[requiredIndexName]; !hasMatchingDefinition { - indexesToAdd[requiredIndexName] = requiredIndex - } else if requiredIndex.Type != presentIndex.Type { - indexesToRemove = append(indexesToRemove, requiredIndexName) - indexesToAdd[requiredIndexName] = requiredIndex - } - } - - for requiredConstraintName, requiredConstraint := range requiredNeo4jSchema.Constraints { - if presentConstraint, hasMatchingDefinition := presentNeo4jSchema.Constraints[requiredConstraintName]; !hasMatchingDefinition { - constraintsToAdd[requiredConstraintName] = requiredConstraint - } else if requiredConstraint.Type != presentConstraint.Type { - constraintsToRemove = append(constraintsToRemove, requiredConstraintName) - constraintsToAdd[requiredConstraintName] = requiredConstraint - } - } - - if err := assertConstraints(ctx, db, constraintsToRemove, constraintsToAdd); err != nil { - return err - } - - return assertIndexes(ctx, db, indexesToRemove, indexesToAdd) - } -} diff --git a/drivers/neo4j/neo4j.go b/drivers/neo4j/neo4j.go deleted file mode 100644 index c6a5ea2..0000000 --- a/drivers/neo4j/neo4j.go +++ /dev/null @@ -1,50 +0,0 @@ -package neo4j - -import ( - "context" - "fmt" - "math" - "net/url" - - "github.com/neo4j/neo4j-go-driver/v5/neo4j" - "github.com/specterops/dawgs" - "github.com/specterops/dawgs/graph" - "github.com/specterops/dawgs/util/channels" -) - -const ( - // defaultNeo4jTransactionTimeout is set to math.MinInt as this is what the core neo4j library defaults to when - // left unset. It is recommended that users set this for time-sensitive operations - defaultNeo4jTransactionTimeout = math.MinInt -) - -func newNeo4jDB(_ context.Context, cfg dawgs.Config) (graph.Database, error) { - if connectionURL, err := url.Parse(cfg.ConnectionString); err != nil { - return nil, err - } else if connectionURL.Scheme != DriverName { - return nil, fmt.Errorf("expected connection URL scheme %s for Neo4J but got %s", DriverName, connectionURL.Scheme) - } else if password, isSet := connectionURL.User.Password(); !isSet { - return nil, fmt.Errorf("no password provided in connection URL") - } else { - boltURL := fmt.Sprintf("bolt://%s:%s", connectionURL.Hostname(), connectionURL.Port()) - - if internalDriver, err := neo4j.NewDriver(boltURL, neo4j.BasicAuth(connectionURL.User.Username(), password, "")); err != nil { - return nil, fmt.Errorf("unable to connect to Neo4J: %w", err) - } else { - return &driver{ - driver: internalDriver, - defaultTransactionTimeout: defaultNeo4jTransactionTimeout, - limiter: channels.NewConcurrencyLimiter(DefaultConcurrentConnections), - writeFlushSize: DefaultWriteFlushSize, - batchWriteSize: DefaultBatchWriteSize, - graphQueryMemoryLimit: cfg.GraphQueryMemoryLimit, - }, nil - } - } -} - -func init() { - dawgs.Register(DriverName, func(ctx context.Context, cfg dawgs.Config) (graph.Database, error) { - return newNeo4jDB(ctx, cfg) - }) -} diff --git a/drivers/neo4j/node.go b/drivers/neo4j/node.go deleted file mode 100644 index 0484d28..0000000 --- a/drivers/neo4j/node.go +++ /dev/null @@ -1,219 +0,0 @@ -package neo4j - -import ( - "context" - - neo4j_core "github.com/neo4j/neo4j-go-driver/v5/neo4j" - "github.com/specterops/dawgs/graph" - "github.com/specterops/dawgs/query" - "github.com/specterops/dawgs/query/neo4j" -) - -func newPath(internalPath neo4j_core.Path) graph.Path { - path := graph.Path{} - - for _, node := range internalPath.Nodes { - path.Nodes = append(path.Nodes, newNode(node)) - } - - for _, relationship := range internalPath.Relationships { - path.Edges = append(path.Edges, newRelationship(relationship)) - } - - return path -} - -func newNode(internalNode neo4j_core.Node) *graph.Node { - var propertiesInst = internalNode.Props - - if propertiesInst == nil { - propertiesInst = make(map[string]any) - } - - return graph.NewNode(graph.ID(internalNode.Id), graph.AsProperties(propertiesInst), graph.StringsToKinds(internalNode.Labels)...) -} - -type NodeQuery struct { - ctx context.Context - tx innerTransaction - queryBuilder *neo4j.QueryBuilder -} - -func NewNodeQuery(ctx context.Context, tx innerTransaction) graph.NodeQuery { - return &NodeQuery{ - ctx: ctx, - tx: tx, - queryBuilder: neo4j.NewEmptyQueryBuilder(), - } -} - -func (s *NodeQuery) run(statement string, parameters map[string]any) graph.Result { - return s.tx.Raw(statement, parameters) -} - -func (s *NodeQuery) Query(delegate func(results graph.Result) error, finalCriteria ...graph.Criteria) error { - for _, criteria := range finalCriteria { - s.queryBuilder.Apply(criteria) - } - - if err := s.queryBuilder.Prepare(); err != nil { - return err - } else if statement, err := s.queryBuilder.Render(); err != nil { - return err - } else if result := s.run(statement, s.queryBuilder.Parameters); result.Error() != nil { - return result.Error() - } else { - defer result.Close() - return delegate(result) - } -} - -func (s *NodeQuery) Debug() (string, map[string]any) { - statement, _ := s.queryBuilder.Render() - return statement, s.queryBuilder.Parameters -} - -func (s *NodeQuery) Delete() error { - s.queryBuilder.Apply(query.Delete( - query.Node(), - )) - - if err := s.queryBuilder.Prepare(); err != nil { - return err - } else if statement, err := s.queryBuilder.Render(); err != nil { - return err - } else { - result := s.run(statement, s.queryBuilder.Parameters) - return result.Error() - } -} - -func (s *NodeQuery) OrderBy(criteria ...graph.Criteria) graph.NodeQuery { - s.queryBuilder.Apply(query.OrderBy(criteria...)) - return s -} - -func (s *NodeQuery) Offset(offset int) graph.NodeQuery { - s.queryBuilder.Apply(query.Offset(offset)) - return s -} - -func (s *NodeQuery) Limit(limit int) graph.NodeQuery { - s.queryBuilder.Apply(query.Limit(limit)) - return s -} - -func (s *NodeQuery) Filter(criteria graph.Criteria) graph.NodeQuery { - s.queryBuilder.Apply(query.Where(criteria)) - return s -} - -func (s *NodeQuery) Filterf(criteriaDelegate graph.CriteriaProvider) graph.NodeQuery { - return s.Filter(criteriaDelegate()) -} - -func (s *NodeQuery) Count() (int64, error) { - var count int64 - - return count, s.Query(func(results graph.Result) error { - if !results.Next() { - return graph.ErrNoResultsFound - } - - return results.Scan(&count) - }, query.Returning( - query.Count(query.Node()), - )) -} - -func (s *NodeQuery) Update(properties *graph.Properties) error { - s.queryBuilder.Apply(query.Updatef(func() graph.Criteria { - var updateStatements []graph.Criteria - - if modifiedProperties := properties.ModifiedProperties(); len(modifiedProperties) > 0 { - updateStatements = append(updateStatements, query.SetProperties(query.Node(), modifiedProperties)) - } - - if deletedProperties := properties.DeletedProperties(); len(deletedProperties) > 0 { - updateStatements = append(updateStatements, query.DeleteProperties(query.Node(), deletedProperties...)) - } - - return updateStatements - })) - - if err := s.queryBuilder.Prepare(); err != nil { - return err - } else if cypherQuery, err := s.queryBuilder.Render(); err != nil { - strippedQuery := stripCypherQuery(cypherQuery) - return graph.NewError(strippedQuery, err) - } else if result := s.run(cypherQuery, s.queryBuilder.Parameters); result.Error() != nil { - return result.Error() - } - - return nil -} - -func (s *NodeQuery) First() (*graph.Node, error) { - var node graph.Node - - return &node, s.Query(func(results graph.Result) error { - if !results.Next() { - return graph.ErrNoResultsFound - } - - return results.Scan(&node) - }, query.Returning( - query.Node(), - ), query.Limit(1)) -} - -func (s *NodeQuery) Fetch(delegate func(cursor graph.Cursor[*graph.Node]) error, finalCriteria ...graph.Criteria) error { - return s.Query(func(result graph.Result) error { - cursor := graph.NewResultIterator(s.ctx, result, func(result graph.Result) (*graph.Node, error) { - var node graph.Node - return &node, result.Scan(&node) - }) - - defer cursor.Close() - return delegate(cursor) - }, append([]graph.Criteria{query.Returning( - query.Node(), - )}, finalCriteria...)...) -} - -func (s *NodeQuery) FetchIDs(delegate func(cursor graph.Cursor[graph.ID]) error) error { - return s.Query(func(result graph.Result) error { - cursor := graph.NewResultIterator(s.ctx, result, func(result graph.Result) (graph.ID, error) { - var nodeID graph.ID - return nodeID, result.Scan(&nodeID) - }) - - defer cursor.Close() - return delegate(cursor) - }, query.Returning( - query.NodeID(), - )) -} - -func (s *NodeQuery) FetchKinds(delegate func(cursor graph.Cursor[graph.KindsResult]) error) error { - return s.Query(func(result graph.Result) error { - cursor := graph.NewResultIterator(s.ctx, result, func(result graph.Result) (graph.KindsResult, error) { - var ( - nodeID graph.ID - nodeKinds graph.Kinds - err = result.Scan(&nodeID, &nodeKinds) - ) - - return graph.KindsResult{ - ID: nodeID, - Kinds: nodeKinds, - }, err - }) - - defer cursor.Close() - return delegate(cursor) - }, query.Returning( - query.NodeID(), - query.KindsOf(query.Node()), - )) -} diff --git a/drivers/neo4j/relationship.go b/drivers/neo4j/relationship.go deleted file mode 100644 index 331a730..0000000 --- a/drivers/neo4j/relationship.go +++ /dev/null @@ -1,321 +0,0 @@ -package neo4j - -import ( - "context" - "fmt" - - neo4j_core "github.com/neo4j/neo4j-go-driver/v5/neo4j" - "github.com/specterops/dawgs/graph" - "github.com/specterops/dawgs/query" - "github.com/specterops/dawgs/query/neo4j" -) - -func directionToReturnCriteria(direction graph.Direction) (graph.Criteria, error) { - switch direction { - case graph.DirectionInbound: - // Select the relationship and the end node - return query.Returning( - query.Relationship(), - query.End(), - ), nil - - case graph.DirectionOutbound: - // Select the relationship and the start node - return query.Returning( - query.Relationship(), - query.Start(), - ), nil - - default: - return nil, fmt.Errorf("bad direction: %d", direction) - } -} - -func newRelationship(internalRelationship neo4j_core.Relationship) *graph.Relationship { - propertiesInst := internalRelationship.Props - - if propertiesInst == nil { - propertiesInst = make(map[string]any) - } - - return graph.NewRelationship( - graph.ID(internalRelationship.Id), - graph.ID(internalRelationship.StartId), - graph.ID(internalRelationship.EndId), - graph.AsProperties(propertiesInst), - graph.StringKind(internalRelationship.Type), - ) -} - -type RelationshipQuery struct { - ctx context.Context - tx innerTransaction - queryBuilder *neo4j.QueryBuilder -} - -func NewRelationshipQuery(ctx context.Context, tx innerTransaction) graph.RelationshipQuery { - return &RelationshipQuery{ - ctx: ctx, - tx: tx, - queryBuilder: neo4j.NewEmptyQueryBuilder(), - } -} - -func (s *RelationshipQuery) run(statement string, parameters map[string]any) graph.Result { - return s.tx.Raw(statement, parameters) -} - -func (s *RelationshipQuery) Query(delegate func(results graph.Result) error, finalCriteria ...graph.Criteria) error { - for _, criteria := range finalCriteria { - s.queryBuilder.Apply(criteria) - } - - if err := s.queryBuilder.Prepare(); err != nil { - return err - } else if statement, err := s.queryBuilder.Render(); err != nil { - return err - } else if result := s.run(statement, s.queryBuilder.Parameters); result.Error() != nil { - return result.Error() - } else { - defer result.Close() - return delegate(result) - } -} - -func (s *RelationshipQuery) Debug() (string, map[string]any) { - rendered, _ := s.queryBuilder.Render() - return rendered, s.queryBuilder.Parameters -} - -func (s *RelationshipQuery) Update(properties *graph.Properties) error { - s.queryBuilder.Apply(query.Updatef(func() graph.Criteria { - var updateStatements []graph.Criteria - - if modifiedProperties := properties.ModifiedProperties(); len(modifiedProperties) > 0 { - updateStatements = append(updateStatements, query.SetProperties(query.Relationship(), modifiedProperties)) - } - - if deletedProperties := properties.DeletedProperties(); len(deletedProperties) > 0 { - updateStatements = append(updateStatements, query.DeleteProperties(query.Relationship(), deletedProperties...)) - } - - return updateStatements - })) - - if err := s.queryBuilder.Prepare(); err != nil { - return err - } else if cypherQuery, err := s.queryBuilder.Render(); err != nil { - strippedQuery := stripCypherQuery(cypherQuery) - return graph.NewError(strippedQuery, err) - } else { - return s.run(cypherQuery, s.queryBuilder.Parameters).Error() - } -} - -func (s *RelationshipQuery) Delete() error { - s.queryBuilder.Apply(query.Delete( - query.Relationship(), - )) - - if err := s.queryBuilder.Prepare(); err != nil { - return err - } else if statement, err := s.queryBuilder.Render(); err != nil { - return err - } else { - return s.run(statement, s.queryBuilder.Parameters).Error() - } -} - -func (s *RelationshipQuery) OrderBy(criteria ...graph.Criteria) graph.RelationshipQuery { - s.queryBuilder.Apply(query.OrderBy(criteria...)) - return s -} - -func (s *RelationshipQuery) Offset(offset int) graph.RelationshipQuery { - s.queryBuilder.Apply(query.Offset(offset)) - return s -} - -func (s *RelationshipQuery) Limit(limit int) graph.RelationshipQuery { - s.queryBuilder.Apply(query.Limit(limit)) - return s -} - -func (s *RelationshipQuery) Filter(criteria graph.Criteria) graph.RelationshipQuery { - s.queryBuilder.Apply(query.Where(criteria)) - return s -} - -func (s *RelationshipQuery) Filterf(criteriaDelegate graph.CriteriaProvider) graph.RelationshipQuery { - s.queryBuilder.Apply(query.Where(criteriaDelegate())) - return s -} - -func (s *RelationshipQuery) Count() (int64, error) { - var count int64 - - return count, s.Query(func(results graph.Result) error { - if !results.Next() { - return graph.ErrNoResultsFound - } - - return results.Scan(&count) - }, query.Returning( - query.Count(query.Relationship()), - )) -} - -func (s *RelationshipQuery) FetchAllShortestPaths(delegate func(cursor graph.Cursor[graph.Path]) error) error { - s.queryBuilder.Apply(query.Returning( - query.Path(), - )) - - if err := s.queryBuilder.PrepareAllShortestPaths(); err != nil { - return err - } else if statement, err := s.queryBuilder.Render(); err != nil { - return err - } else if result := s.run(statement, s.queryBuilder.Parameters); result.Error() != nil { - return result.Error() - } else { - defer result.Close() - - cursor := graph.NewResultIterator(s.ctx, result, func(result graph.Result) (graph.Path, error) { - var ( - nextPath graph.Path - err = result.Scan(&nextPath) - ) - - return nextPath, err - }) - - defer cursor.Close() - return delegate(cursor) - } -} - -func (s *RelationshipQuery) FetchTriples(delegate func(cursor graph.Cursor[graph.RelationshipTripleResult]) error) error { - return s.Query(func(result graph.Result) error { - cursor := graph.NewResultIterator(s.ctx, result, func(result graph.Result) (graph.RelationshipTripleResult, error) { - var ( - startID graph.ID - relationshipID graph.ID - endID graph.ID - err = result.Scan(&startID, &relationshipID, &endID) - ) - - return graph.RelationshipTripleResult{ - ID: relationshipID, - StartID: startID, - EndID: endID, - }, err - }) - - defer cursor.Close() - return delegate(cursor) - }, query.ReturningDistinct( - query.StartID(), - query.RelationshipID(), - query.EndID(), - )) -} - -func (s *RelationshipQuery) FetchKinds(delegate func(cursor graph.Cursor[graph.RelationshipKindsResult]) error) error { - return s.Query(func(result graph.Result) error { - cursor := graph.NewResultIterator(s.ctx, result, func(result graph.Result) (graph.RelationshipKindsResult, error) { - var ( - startID graph.ID - relationshipID graph.ID - relationshipKind graph.Kind - endID graph.ID - err = result.Scan(&startID, &relationshipID, &relationshipKind, &endID) - ) - - return graph.RelationshipKindsResult{ - RelationshipTripleResult: graph.RelationshipTripleResult{ - ID: relationshipID, - StartID: startID, - EndID: endID, - }, - Kind: relationshipKind, - }, err - }) - - defer cursor.Close() - return delegate(cursor) - }, query.Returning( - query.StartID(), - query.RelationshipID(), - query.KindsOf(query.Relationship()), - query.EndID(), - )) -} - -func (s *RelationshipQuery) First() (*graph.Relationship, error) { - var relationship graph.Relationship - - return &relationship, s.Query(func(results graph.Result) error { - if !results.Next() { - return graph.ErrNoResultsFound - } - - return results.Scan(&relationship) - }, query.Returning( - query.Relationship(), - ), query.Limit(1)) -} - -func (s *RelationshipQuery) Fetch(delegate func(cursor graph.Cursor[*graph.Relationship]) error) error { - return s.Query(func(result graph.Result) error { - cursor := graph.NewResultIterator(s.ctx, result, func(result graph.Result) (*graph.Relationship, error) { - var relationship graph.Relationship - return &relationship, result.Scan(&relationship) - }) - - defer cursor.Close() - return delegate(cursor) - }, query.Returning( - query.Relationship(), - )) -} - -func (s *RelationshipQuery) FetchDirection(direction graph.Direction, delegate func(cursor graph.Cursor[graph.DirectionalResult]) error) error { - if returnCriteria, err := directionToReturnCriteria(direction); err != nil { - return err - } else { - return s.Query(func(result graph.Result) error { - cursor := graph.NewResultIterator(s.ctx, result, func(result graph.Result) (graph.DirectionalResult, error) { - var ( - relationship graph.Relationship - node graph.Node - ) - - if err := result.Scan(&relationship, &node); err != nil { - return graph.DirectionalResult{}, err - } - - return graph.DirectionalResult{ - Direction: direction, - Relationship: &relationship, - Node: &node, - }, nil - }) - - defer cursor.Close() - return delegate(cursor) - }, returnCriteria) - } -} - -func (s *RelationshipQuery) FetchIDs(delegate func(cursor graph.Cursor[graph.ID]) error) error { - return s.Query(func(result graph.Result) error { - cursor := graph.NewResultIterator(s.ctx, result, func(result graph.Result) (graph.ID, error) { - var relationshipID graph.ID - return relationshipID, result.Scan(&relationshipID) - }) - - defer cursor.Close() - return delegate(cursor) - }, query.Returning( - query.RelationshipID(), - )) -} diff --git a/drivers/neo4j/result.go b/drivers/neo4j/result.go deleted file mode 100644 index db7bb66..0000000 --- a/drivers/neo4j/result.go +++ /dev/null @@ -1,56 +0,0 @@ -package neo4j - -import ( - "github.com/neo4j/neo4j-go-driver/v5/neo4j" - "github.com/specterops/dawgs/graph" -) - -type internalResult struct { - query string - err error - driverResult neo4j.Result -} - -func NewResult(query string, err error, driverResult neo4j.Result) graph.Result { - return &internalResult{ - query: query, - err: err, - driverResult: driverResult, - } -} - -func (s *internalResult) Mapper() graph.ValueMapper { - return NewValueMapper() -} - -func (s *internalResult) Values() []any { - return s.driverResult.Record().Values -} - -func (s *internalResult) Scan(targets ...any) error { - return graph.ScanNextResult(s, targets...) -} - -func (s *internalResult) Next() bool { - return s.driverResult.Next() -} - -func (s *internalResult) Error() error { - if s.err != nil { - return s.err - } - - if s.driverResult != nil && s.driverResult.Err() != nil { - strippedQuery := stripCypherQuery(s.query) - return graph.NewError(strippedQuery, s.driverResult.Err()) - } - - return nil -} - -func (s *internalResult) Close() { - if s.driverResult != nil { - // Ignore the results of this call. This is called only as a best-effort attempt at a close - s.driverResult.Consume() - } -} diff --git a/drivers/neo4j/result_internal_test.go b/drivers/neo4j/result_internal_test.go deleted file mode 100644 index ae078f9..0000000 --- a/drivers/neo4j/result_internal_test.go +++ /dev/null @@ -1,61 +0,0 @@ -package neo4j - -import ( - "testing" - "time" - - "github.com/neo4j/neo4j-go-driver/v5/neo4j/dbtype" - "github.com/specterops/dawgs/graph" - "github.com/stretchr/testify/require" -) - -func mapTestCase[T, V any](t *testing.T, source T, expected V) { - var ( - value V - mapper = NewValueMapper() - ) - - require.True(t, mapper.Map(source, &value)) - require.Equalf(t, expected, value, "Mapping case for type %T to %T failed. Value is: %v", source, &value, value) -} - -func Test_mapValue(t *testing.T) { - var ( - utcNow = time.Now().UTC() - anyStringSlice = []any{"a", "b", "c"} - stringSlice = []string{"a", "b", "c"} - kindSlice = []graph.Kind{graph.StringKind("a"), graph.StringKind("b"), graph.StringKind("c")} - kinds = graph.Kinds{graph.StringKind("a"), graph.StringKind("b"), graph.StringKind("c")} - ) - - mapTestCase[uint, uint](t, 0, 0) - mapTestCase[uint8, uint8](t, 0, 0) - mapTestCase[uint16, uint16](t, 0, 0) - mapTestCase[uint32, uint32](t, 0, 0) - mapTestCase[uint64, uint64](t, 0, 0) - - mapTestCase(t, 0, 0) // Inferred int - mapTestCase[int8, int8](t, 0, 0) - mapTestCase[int16, int16](t, 0, 0) - mapTestCase[int32, int32](t, 0, 0) - mapTestCase[int64, int64](t, 0, 0) - mapTestCase[int64, graph.ID](t, 0, 0) - - mapTestCase[float32, float32](t, 1.5, 1.5) - mapTestCase(t, 1.5, 1.5) // Inferred float64 - - mapTestCase(t, true, true) - mapTestCase(t, "test", "test") - - mapTestCase(t, utcNow, utcNow) - mapTestCase(t, utcNow.Format(time.RFC3339Nano), utcNow) - mapTestCase(t, utcNow.Unix(), time.Unix(utcNow.Unix(), 0)) - mapTestCase(t, dbtype.Time(utcNow), utcNow) - mapTestCase(t, dbtype.LocalTime(utcNow), utcNow) - mapTestCase(t, dbtype.Date(utcNow), utcNow) - mapTestCase(t, dbtype.LocalDateTime(utcNow), utcNow) - - mapTestCase(t, anyStringSlice, stringSlice) - mapTestCase(t, anyStringSlice, kindSlice) - mapTestCase(t, anyStringSlice, kinds) -} diff --git a/drivers/neo4j/transaction.go b/drivers/neo4j/transaction.go deleted file mode 100644 index dd2f1a9..0000000 --- a/drivers/neo4j/transaction.go +++ /dev/null @@ -1,447 +0,0 @@ -package neo4j - -import ( - "context" - "encoding/json" - "fmt" - "log/slog" - "sort" - "strings" - - "github.com/specterops/dawgs/drivers" - "github.com/specterops/dawgs/query/neo4j" - "github.com/specterops/dawgs/util/size" - - neo4j_core "github.com/neo4j/neo4j-go-driver/v5/neo4j" - "github.com/specterops/dawgs/graph" - "github.com/specterops/dawgs/query" -) - -const ( - DefaultBatchWriteSize = 20_000 - DefaultWriteFlushSize = DefaultBatchWriteSize * 5 - - // DefaultConcurrentConnections defines the default number of concurrent graph database connections allowed. - DefaultConcurrentConnections = 50 -) - -type innerTransaction interface { - Raw(cypher string, params map[string]any) graph.Result -} - -type neo4jTransaction struct { - cfg graph.TransactionConfig - ctx context.Context - session neo4j_core.Session - innerTx neo4j_core.Transaction - writes int - writeFlushSize int - batchWriteSize int - graphQueryMemoryLimit size.Size -} - -func (s *neo4jTransaction) WithGraph(graphSchema graph.Graph) graph.Transaction { - // Neo4j does not support multiple graph namespaces within the same database. While Neo4j enterprise supports - // multiple databases this is not the same. Graph namespaces could be hacked using labels but this then requires - // a material change in how labels are applied and therefore was not plumbed. - // - // This has no material effect on the usage of the database: the schema is the same for all graph namespaces. - return s -} - -func (s *neo4jTransaction) Query(query string, parameters map[string]any) graph.Result { - return s.Raw(query, parameters) -} - -func (s *neo4jTransaction) updateRelationshipsBy(updates ...graph.RelationshipUpdate) error { - var ( - numUpdates = len(updates) - statements, queryParameterArrays = cypherBuildRelationshipUpdateQueryBatch(updates) - ) - - for parameterIdx, stmt := range statements { - propertyBags := queryParameterArrays[parameterIdx] - chunkMap := make([]map[string]any, 0, s.batchWriteSize) - - for _, val := range propertyBags { - chunkMap = append(chunkMap, val) - - if len(chunkMap) == s.batchWriteSize { - if result := s.Raw(stmt, map[string]any{ - "p": chunkMap, - }); result.Error() != nil { - return result.Error() - } - - chunkMap = chunkMap[:0] - } - } - - if len(chunkMap) > 0 { - if result := s.Raw(stmt, map[string]any{ - "p": chunkMap, - }); result.Error() != nil { - return result.Error() - } - } - } - - return s.logWrites(numUpdates) -} - -func (s *neo4jTransaction) UpdateRelationshipBy(update graph.RelationshipUpdate) error { - return s.updateRelationshipsBy(update) -} - -func (s *neo4jTransaction) updateNodesBy(updates ...graph.NodeUpdate) error { - var ( - numUpdates = len(updates) - statements, queryParameterMaps = cypherBuildNodeUpdateQueryBatch(updates) - ) - - for parameterIdx, stmt := range statements { - if result := s.Raw(stmt, queryParameterMaps[parameterIdx]); result.Error() != nil { - return fmt.Errorf("update nodes by error on statement (%s): %s", stmt, result.Error()) - } - } - - return s.logWrites(numUpdates) -} - -func (s *neo4jTransaction) UpdateNodeBy(update graph.NodeUpdate) error { - return s.updateNodesBy(update) -} - -func newTransaction(ctx context.Context, session neo4j_core.Session, cfg graph.TransactionConfig, writeFlushSize int, batchWriteSize int, graphQueryMemoryLimit size.Size) *neo4jTransaction { - return &neo4jTransaction{ - cfg: cfg, - ctx: ctx, - session: session, - writeFlushSize: writeFlushSize, - batchWriteSize: batchWriteSize, - graphQueryMemoryLimit: graphQueryMemoryLimit, - } -} - -func (s *neo4jTransaction) flushTx() error { - defer func() { - s.innerTx = nil - }() - - if err := s.innerTx.Commit(); err != nil { - return err - } - - return nil -} - -func (s *neo4jTransaction) currentTx() neo4j_core.Transaction { - if s.innerTx == nil { - if newTx, err := s.session.BeginTransaction(neo4j_core.WithTxTimeout(s.cfg.Timeout)); err != nil { - return newErrorTransactionWrapper(err) - } else { - s.innerTx = newTx - } - } - - return s.innerTx -} - -func (s *neo4jTransaction) logWrites(writes int) error { - if s.writes += writes; s.writes >= s.writeFlushSize { - if err := s.flushTx(); err != nil { - return err - } - - s.writes = 0 - } - - return nil -} - -func (s *neo4jTransaction) runAndLog(stmt string, params map[string]any, numWrites int) graph.Result { - result := s.Raw(stmt, params) - - if result.Error() == nil { - if err := s.logWrites(numWrites); err != nil { - return NewResult(stmt, err, nil) - } - } - - return result -} - -func (s *neo4jTransaction) updateNode(updatedNode *graph.Node) error { - queryBuilder := neo4j.NewQueryBuilder(query.SinglePartQuery( - query.Where( - query.Equals(query.NodeID(), updatedNode.ID), - ), - - query.Updatef(func() graph.Criteria { - var ( - properties = updatedNode.Properties - updateStatements []graph.Criteria - ) - - if addedKinds := updatedNode.AddedKinds; len(addedKinds) > 0 { - updateStatements = append(updateStatements, query.AddKinds(query.Node(), addedKinds)) - } - - if deletedKinds := updatedNode.DeletedKinds; len(deletedKinds) > 0 { - updateStatements = append(updateStatements, query.DeleteKinds(query.Node(), deletedKinds)) - } - - if modifiedProperties := properties.ModifiedProperties(); len(modifiedProperties) > 0 { - updateStatements = append(updateStatements, query.SetProperties(query.Node(), modifiedProperties)) - } - - if deletedProperties := properties.DeletedProperties(); len(deletedProperties) > 0 { - updateStatements = append(updateStatements, query.DeleteProperties(query.Node(), deletedProperties...)) - } - - return updateStatements - }), - )) - - if err := queryBuilder.Prepare(); err != nil { - return err - } else if cypherQuery, err := queryBuilder.Render(); err != nil { - strippedQuery := stripCypherQuery(cypherQuery) - return graph.NewError(strippedQuery, err) - } else if result := s.Raw(cypherQuery, queryBuilder.Parameters); result.Error() != nil { - return result.Error() - } - - return nil -} - -func (s *neo4jTransaction) createNode(properties *graph.Properties, kinds ...graph.Kind) (*graph.Node, error) { - queryBuilder := neo4j.NewQueryBuilder(query.SinglePartQuery( - query.Create( - query.NodePattern( - kinds, - query.Parameter(properties.Map), - ), - ), - - query.Returning( - query.Node(), - ), - )) - - if err := queryBuilder.Prepare(); err != nil { - return nil, err - } else if statement, err := queryBuilder.Render(); err != nil { - return nil, err - } else if result := s.Raw(statement, queryBuilder.Parameters); result.Error() != nil { - return nil, result.Error() - } else if !result.Next() { - return nil, fmt.Errorf("%w: %w", graph.ErrNoResultsFound, result.Error()) - } else { - var node graph.Node - return &node, result.Scan(&node) - } -} - -func (s *neo4jTransaction) createRelationshipByIDs(startNodeID, endNodeID graph.ID, kind graph.Kind, properties *graph.Properties) (*graph.Relationship, error) { - queryBuilder := neo4j.NewQueryBuilder(query.SinglePartQuery( - query.Where( - query.And( - query.Equals(query.StartID(), startNodeID), - query.Equals(query.EndID(), endNodeID), - ), - ), - query.Create( - query.Start(), - query.RelationshipPattern(kind, query.Parameter(properties.Map), graph.DirectionOutbound), - query.End(), - ), - - query.Returning( - query.Relationship(), - ), - )) - - if err := queryBuilder.Prepare(); err != nil { - return nil, err - } else if statement, err := queryBuilder.Render(); err != nil { - return nil, err - } else if result := s.Raw(statement, queryBuilder.Parameters); result.Error() != nil { - return nil, result.Error() - } else if !result.Next() { - return nil, graph.ErrNoResultsFound - } else { - var relationship graph.Relationship - return &relationship, result.Scan(&relationship) - } -} - -func (s *neo4jTransaction) Raw(stmt string, params map[string]any) graph.Result { - const maxParametersToRender = 12 - - if drivers.IsQueryAnalysisEnabled() { - var ( - parametersWritten = 0 - prettyParameters strings.Builder - sortedKeys []string - ) - - if len(params) > maxParametersToRender { - sortedKeys = make([]string, 0, maxParametersToRender) - } else { - sortedKeys = make([]string, 0, len(params)) - } - - for key := range params { - if sortedKeys = append(sortedKeys, key); len(sortedKeys) >= maxParametersToRender { - break - } - } - - sort.Strings(sortedKeys) - - for _, key := range sortedKeys { - value := params[key] - - if parametersWritten++; parametersWritten >= maxParametersToRender { - break - } else if parametersWritten > 1 { - prettyParameters.WriteString(", ") - } - - prettyParameters.WriteString(key) - prettyParameters.WriteString(":") - - if marshalledValue, err := json.Marshal(value); err != nil { - slog.Error(fmt.Sprintf("Unable to marshal query parameter %s", key)) - } else { - prettyParameters.Write(marshalledValue) - } - } - - slog.Info(fmt.Sprintf("%s - %s", stmt, prettyParameters.String()), "dawgs_db_driver", DriverName) - } - - driverResult, err := s.currentTx().Run(stmt, params) - return NewResult(stmt, err, driverResult) -} - -func (s *neo4jTransaction) Nodes() graph.NodeQuery { - return NewNodeQuery(s.ctx, s) -} - -func (s *neo4jTransaction) Relationships() graph.RelationshipQuery { - return NewRelationshipQuery(s.ctx, s) -} - -func (s *neo4jTransaction) DeleteNodesBySlice(ids []graph.ID) error { - return s.runAndLog(cypherDeleteNodesByID, map[string]any{ - idListParameterName: ids, - }, len(ids)).Error() -} - -func (s *neo4jTransaction) DeleteRelationshipsBySlice(ids []graph.ID) error { - return s.runAndLog(cypherDeleteRelationshipsByID, map[string]any{ - "p": ids, - }, len(ids)).Error() -} - -func (s *neo4jTransaction) Commit() error { - if s.innerTx != nil { - txRef := s.innerTx - s.innerTx = nil - - return txRef.Commit() - } - - return nil -} - -func (s *neo4jTransaction) Close() error { - if s.innerTx != nil { - txRef := s.innerTx - s.innerTx = nil - - return txRef.Close() - } - - return nil -} - -func (s *neo4jTransaction) CreateNode(properties *graph.Properties, kinds ...graph.Kind) (*graph.Node, error) { - if node, err := s.createNode(properties, kinds...); err != nil { - return nil, err - } else { - return node, s.logWrites(1) - } -} - -func (s *neo4jTransaction) UpdateNode(target *graph.Node) error { - if err := s.updateNode(target); err != nil { - return err - } - - return s.logWrites(1) -} - -func (s *neo4jTransaction) CreateRelationship(startNode, endNode *graph.Node, kind graph.Kind, properties *graph.Properties) (*graph.Relationship, error) { - return s.CreateRelationshipByIDs(startNode.ID, endNode.ID, kind, properties) -} - -func (s *neo4jTransaction) CreateRelationshipByIDs(startNodeID, endNodeID graph.ID, kind graph.Kind, properties *graph.Properties) (*graph.Relationship, error) { - if rel, err := s.createRelationshipByIDs(startNodeID, endNodeID, kind, properties); err != nil { - return nil, err - } else { - return rel, s.logWrites(1) - } -} - -func (s *neo4jTransaction) DeleteNode(id graph.ID) error { - return s.runAndLog(cypherDeleteNodeByID, map[string]any{ - idParameterName: id, - }, 1).Error() -} - -func (s *neo4jTransaction) DeleteRelationship(id graph.ID) error { - return s.runAndLog(cypherDeleteRelationshipByID, map[string]any{ - idParameterName: id, - }, 1).Error() -} - -func (s *neo4jTransaction) UpdateRelationship(relationship *graph.Relationship) error { - queryBuilder := neo4j.NewQueryBuilder(query.SinglePartQuery( - query.Where( - query.Equals(query.RelationshipID(), relationship.ID), - ), - - query.Updatef(func() graph.Criteria { - var ( - properties = relationship.Properties - updateStatements []graph.Criteria - ) - - if modifiedProperties := properties.ModifiedProperties(); len(modifiedProperties) > 0 { - updateStatements = append(updateStatements, query.SetProperties(query.Relationship(), modifiedProperties)) - } - - if deletedProperties := properties.DeletedProperties(); len(deletedProperties) > 0 { - updateStatements = append(updateStatements, query.DeleteProperties(query.Relationship(), deletedProperties...)) - } - - return updateStatements - }), - )) - - if err := queryBuilder.Prepare(); err != nil { - return err - } else if cypherQuery, err := queryBuilder.Render(); err != nil { - strippedQuery := stripCypherQuery(cypherQuery) - return graph.NewError(strippedQuery, err) - } else { - return s.runAndLog(cypherQuery, queryBuilder.Parameters, 1).Error() - } -} - -func (s *neo4jTransaction) GraphQueryMemoryLimit() size.Size { - return s.graphQueryMemoryLimit -} diff --git a/drivers/neo4j/wrapper.go b/drivers/neo4j/wrapper.go deleted file mode 100644 index e1ee168..0000000 --- a/drivers/neo4j/wrapper.go +++ /dev/null @@ -1,31 +0,0 @@ -package neo4j - -import ( - "github.com/neo4j/neo4j-go-driver/v5/neo4j" -) - -type errorTransactionWrapper struct { - err error -} - -func newErrorTransactionWrapper(err error) errorTransactionWrapper { - return errorTransactionWrapper{ - err: err, - } -} - -func (s errorTransactionWrapper) Run(cypher string, params map[string]any) (neo4j.Result, error) { - return nil, s.err -} - -func (s errorTransactionWrapper) Commit() error { - return s.err -} - -func (s errorTransactionWrapper) Rollback() error { - return s.err -} - -func (s errorTransactionWrapper) Close() error { - return s.err -} diff --git a/drivers/pg/batch.go b/drivers/pg/batch.go deleted file mode 100644 index a46a520..0000000 --- a/drivers/pg/batch.go +++ /dev/null @@ -1,576 +0,0 @@ -package pg - -import ( - "bytes" - "context" - "fmt" - "log/slog" - "strconv" - "strings" - - "github.com/jackc/pgtype" - "github.com/jackc/pgx/v5/pgxpool" - "github.com/specterops/dawgs/cypher/models/pgsql" - "github.com/specterops/dawgs/drivers/pg/model" - sql "github.com/specterops/dawgs/drivers/pg/query" - "github.com/specterops/dawgs/graph" -) - -type Int2ArrayEncoder struct { - buffer *bytes.Buffer -} - -func (s *Int2ArrayEncoder) Encode(values []int16) string { - s.buffer.Reset() - s.buffer.WriteRune('{') - - for idx, value := range values { - if idx > 0 { - s.buffer.WriteRune(',') - } - - s.buffer.WriteString(strconv.Itoa(int(value))) - } - - s.buffer.WriteRune('}') - return s.buffer.String() -} - -type batch struct { - ctx context.Context - innerTransaction *transaction - schemaManager *SchemaManager - nodeDeletionBuffer []graph.ID - relationshipDeletionBuffer []graph.ID - nodeCreateBuffer []*graph.Node - nodeUpdateByBuffer []graph.NodeUpdate - relationshipCreateBuffer []*graph.Relationship - relationshipUpdateByBuffer []graph.RelationshipUpdate - batchWriteSize int - kindIDEncoder Int2ArrayEncoder -} - -func newBatch(ctx context.Context, conn *pgxpool.Conn, schemaManager *SchemaManager, cfg *Config) (*batch, error) { - if tx, err := newTransactionWrapper(ctx, conn, schemaManager, cfg, false); err != nil { - return nil, err - } else { - return &batch{ - ctx: ctx, - schemaManager: schemaManager, - innerTransaction: tx, - batchWriteSize: cfg.BatchWriteSize, - kindIDEncoder: Int2ArrayEncoder{ - buffer: &bytes.Buffer{}, - }, - }, nil - } -} - -func (s *batch) WithGraph(schema graph.Graph) graph.Batch { - s.innerTransaction.WithGraph(schema) - return s -} - -func (s *batch) CreateNode(node *graph.Node) error { - s.nodeCreateBuffer = append(s.nodeCreateBuffer, node) - return s.tryFlush(s.batchWriteSize) -} - -func (s *batch) Nodes() graph.NodeQuery { - return s.innerTransaction.Nodes() -} - -func (s *batch) Relationships() graph.RelationshipQuery { - return s.innerTransaction.Relationships() -} - -func (s *batch) UpdateNodeBy(update graph.NodeUpdate) error { - s.nodeUpdateByBuffer = append(s.nodeUpdateByBuffer, update) - return s.tryFlush(s.batchWriteSize) -} - -func (s *batch) flushNodeDeleteBuffer() error { - if _, err := s.innerTransaction.conn.Exec(s.ctx, deleteNodeWithIDStatement, s.nodeDeletionBuffer); err != nil { - return err - } - - s.nodeDeletionBuffer = s.nodeDeletionBuffer[:0] - return nil -} - -func (s *batch) flushRelationshipDeleteBuffer() error { - if _, err := s.innerTransaction.conn.Exec(s.ctx, deleteEdgeWithIDStatement, s.relationshipDeletionBuffer); err != nil { - return err - } - - s.relationshipDeletionBuffer = s.relationshipDeletionBuffer[:0] - return nil -} - -func (s *batch) flushNodeCreateBuffer() error { - var ( - withoutIDs = false - withIDs = false - ) - - for _, node := range s.nodeCreateBuffer { - if node.ID == 0 || node.ID == graph.UnregisteredNodeID { - withoutIDs = true - } else { - withIDs = true - } - - if withIDs && withoutIDs { - return fmt.Errorf("batch may not mix preset node IDs with entries that require an auto-generated ID") - } - } - - if withoutIDs { - return s.flushNodeCreateBufferWithoutIDs() - } - - return s.flushNodeCreateBufferWithIDs() -} - -func (s *batch) flushNodeCreateBufferWithIDs() error { - var ( - numCreates = len(s.nodeCreateBuffer) - nodeIDs = make([]uint64, numCreates) - kindIDSlices = make([]string, numCreates) - kindIDEncoder = Int2ArrayEncoder{ - buffer: &bytes.Buffer{}, - } - properties = make([]pgtype.JSONB, numCreates) - ) - - for idx, nextNode := range s.nodeCreateBuffer { - nodeIDs[idx] = nextNode.ID.Uint64() - - if mappedKindIDs, err := s.schemaManager.AssertKinds(s.ctx, nextNode.Kinds); err != nil { - return fmt.Errorf("unable to map kinds %w", err) - } else { - kindIDSlices[idx] = kindIDEncoder.Encode(mappedKindIDs) - } - - if propertiesJSONB, err := pgsql.PropertiesToJSONB(nextNode.Properties); err != nil { - return err - } else { - properties[idx] = propertiesJSONB - } - } - - if graphTarget, err := s.innerTransaction.getTargetGraph(); err != nil { - return err - } else if _, err := s.innerTransaction.conn.Exec(s.ctx, createNodeWithIDBatchStatement, graphTarget.ID, nodeIDs, kindIDSlices, properties); err != nil { - return err - } - - s.nodeCreateBuffer = s.nodeCreateBuffer[:0] - return nil -} - -func (s *batch) flushNodeCreateBufferWithoutIDs() error { - var ( - numCreates = len(s.nodeCreateBuffer) - kindIDSlices = make([]string, numCreates) - kindIDEncoder = Int2ArrayEncoder{ - buffer: &bytes.Buffer{}, - } - properties = make([]pgtype.JSONB, numCreates) - ) - - for idx, nextNode := range s.nodeCreateBuffer { - if mappedKindIDs, err := s.schemaManager.AssertKinds(s.ctx, nextNode.Kinds); err != nil { - return fmt.Errorf("unable to map kinds %w", err) - } else { - kindIDSlices[idx] = kindIDEncoder.Encode(mappedKindIDs) - } - - if propertiesJSONB, err := pgsql.PropertiesToJSONB(nextNode.Properties); err != nil { - return err - } else { - properties[idx] = propertiesJSONB - } - } - - if graphTarget, err := s.innerTransaction.getTargetGraph(); err != nil { - return err - } else if _, err := s.innerTransaction.conn.Exec(s.ctx, createNodeWithoutIDBatchStatement, graphTarget.ID, kindIDSlices, properties); err != nil { - return err - } - - s.nodeCreateBuffer = s.nodeCreateBuffer[:0] - return nil -} - -func (s *batch) flushNodeUpsertBatch(updates *sql.NodeUpdateBatch) error { - parameters := NewNodeUpsertParameters(len(updates.Updates)) - - if err := parameters.AppendAll(s.ctx, updates, s.schemaManager, s.kindIDEncoder); err != nil { - return err - } - - if graphTarget, err := s.innerTransaction.getTargetGraph(); err != nil { - return err - } else { - query := sql.FormatNodeUpsert(graphTarget, updates.IdentityProperties) - - if rows, err := s.innerTransaction.conn.Query(s.ctx, query, parameters.Format(graphTarget)...); err != nil { - return err - } else { - defer rows.Close() - - idFutureIndex := 0 - - for rows.Next() { - if err := rows.Scan(¶meters.IDFutures[idFutureIndex].Value); err != nil { - return err - } - - idFutureIndex++ - } - } - } - - return nil -} - -func (s *batch) tryFlushNodeUpdateByBuffer() error { - if updates, err := sql.ValidateNodeUpdateByBatch(s.nodeUpdateByBuffer); err != nil { - return err - } else if err := s.flushNodeUpsertBatch(updates); err != nil { - return err - } - - s.nodeUpdateByBuffer = s.nodeUpdateByBuffer[:0] - return nil -} - -type NodeUpsertParameters struct { - IDFutures []*sql.Future[graph.ID] - KindIDSlices []string - Properties []pgtype.JSONB -} - -func NewNodeUpsertParameters(size int) *NodeUpsertParameters { - return &NodeUpsertParameters{ - IDFutures: make([]*sql.Future[graph.ID], 0, size), - KindIDSlices: make([]string, 0, size), - Properties: make([]pgtype.JSONB, 0, size), - } -} - -func (s *NodeUpsertParameters) Format(graphTarget model.Graph) []any { - return []any{ - graphTarget.ID, - s.KindIDSlices, - s.Properties, - } -} - -func (s *NodeUpsertParameters) Append(ctx context.Context, update *sql.NodeUpdate, schemaManager *SchemaManager, kindIDEncoder Int2ArrayEncoder) error { - s.IDFutures = append(s.IDFutures, update.IDFuture) - - if mappedKindIDs, err := schemaManager.AssertKinds(ctx, update.Node.Kinds); err != nil { - return fmt.Errorf("unable to map kinds %w", err) - } else { - s.KindIDSlices = append(s.KindIDSlices, kindIDEncoder.Encode(mappedKindIDs)) - } - - if propertiesJSONB, err := pgsql.PropertiesToJSONB(update.Node.Properties); err != nil { - return err - } else { - s.Properties = append(s.Properties, propertiesJSONB) - } - - return nil -} - -func (s *NodeUpsertParameters) AppendAll(ctx context.Context, updates *sql.NodeUpdateBatch, schemaManager *SchemaManager, kindIDEncoder Int2ArrayEncoder) error { - for _, nextUpdate := range updates.Updates { - if err := s.Append(ctx, nextUpdate, schemaManager, kindIDEncoder); err != nil { - return err - } - } - - return nil -} - -type RelationshipUpdateByParameters struct { - StartIDs []graph.ID - EndIDs []graph.ID - KindIDs []int16 - Properties []pgtype.JSONB -} - -func NewRelationshipUpdateByParameters(size int) *RelationshipUpdateByParameters { - return &RelationshipUpdateByParameters{ - StartIDs: make([]graph.ID, 0, size), - EndIDs: make([]graph.ID, 0, size), - KindIDs: make([]int16, 0, size), - Properties: make([]pgtype.JSONB, 0, size), - } -} - -func (s *RelationshipUpdateByParameters) Format(graphTarget model.Graph) []any { - return []any{ - graphTarget.ID, - s.StartIDs, - s.EndIDs, - s.KindIDs, - s.Properties, - } -} - -func (s *RelationshipUpdateByParameters) Append(ctx context.Context, update *sql.RelationshipUpdate, schemaManager *SchemaManager) error { - s.StartIDs = append(s.StartIDs, update.StartID.Value) - s.EndIDs = append(s.EndIDs, update.EndID.Value) - - if mappedKindIDs, err := schemaManager.AssertKinds(ctx, []graph.Kind{update.Relationship.Kind}); err != nil { - return err - } else { - s.KindIDs = append(s.KindIDs, mappedKindIDs...) - } - - if propertiesJSONB, err := pgsql.PropertiesToJSONB(update.Relationship.Properties); err != nil { - return err - } else { - s.Properties = append(s.Properties, propertiesJSONB) - } - return nil -} - -func (s *RelationshipUpdateByParameters) AppendAll(ctx context.Context, updates *sql.RelationshipUpdateBatch, schemaManager *SchemaManager) error { - for _, nextUpdate := range updates.Updates { - if err := s.Append(ctx, nextUpdate, schemaManager); err != nil { - return err - } - } - - return nil -} - -func (s *batch) flushRelationshipUpdateByBuffer(updates *sql.RelationshipUpdateBatch) error { - if err := s.flushNodeUpsertBatch(updates.NodeUpdates); err != nil { - return err - } - - parameters := NewRelationshipUpdateByParameters(len(updates.Updates)) - - if err := parameters.AppendAll(s.ctx, updates, s.schemaManager); err != nil { - return err - } - - if graphTarget, err := s.innerTransaction.getTargetGraph(); err != nil { - return err - } else { - query := sql.FormatRelationshipPartitionUpsert(graphTarget, updates.IdentityProperties) - - if _, err := s.innerTransaction.conn.Exec(s.ctx, query, parameters.Format(graphTarget)...); err != nil { - return err - } - } - - return nil -} - -func (s *batch) tryFlushRelationshipUpdateByBuffer() error { - if updateBatch, err := sql.ValidateRelationshipUpdateByBatch(s.relationshipUpdateByBuffer); err != nil { - return err - } else if err := s.flushRelationshipUpdateByBuffer(updateBatch); err != nil { - return err - } - - s.relationshipUpdateByBuffer = s.relationshipUpdateByBuffer[:0] - return nil -} - -type relationshipCreateBatch struct { - startIDs []uint64 - endIDs []uint64 - edgeKindIDs []int16 - edgePropertyBags []pgtype.JSONB -} - -func newRelationshipCreateBatch(size int) *relationshipCreateBatch { - return &relationshipCreateBatch{ - startIDs: make([]uint64, 0, size), - endIDs: make([]uint64, 0, size), - edgeKindIDs: make([]int16, 0, size), - edgePropertyBags: make([]pgtype.JSONB, 0, size), - } -} - -func (s *relationshipCreateBatch) Add(startID, endID uint64, edgeKindID int16) { - s.startIDs = append(s.startIDs, startID) - s.edgeKindIDs = append(s.edgeKindIDs, edgeKindID) - s.endIDs = append(s.endIDs, endID) -} - -func (s *relationshipCreateBatch) EncodeProperties(edgePropertiesBatch []*graph.Properties) error { - for _, edgeProperties := range edgePropertiesBatch { - if propertiesJSONB, err := pgsql.PropertiesToJSONB(edgeProperties); err != nil { - return err - } else { - s.edgePropertyBags = append(s.edgePropertyBags, propertiesJSONB) - } - } - - return nil -} - -type relationshipCreateBatchBuilder struct { - keyToEdgeID map[string]uint64 - relationshipUpdateBatch *relationshipCreateBatch - edgePropertiesIndex map[uint64]int - edgePropertiesBatch []*graph.Properties -} - -func newRelationshipCreateBatchBuilder(size int) *relationshipCreateBatchBuilder { - return &relationshipCreateBatchBuilder{ - keyToEdgeID: map[string]uint64{}, - relationshipUpdateBatch: newRelationshipCreateBatch(size), - edgePropertiesIndex: map[uint64]int{}, - } -} - -func (s *relationshipCreateBatchBuilder) Build() (*relationshipCreateBatch, error) { - return s.relationshipUpdateBatch, s.relationshipUpdateBatch.EncodeProperties(s.edgePropertiesBatch) -} - -func (s *relationshipCreateBatchBuilder) Add(ctx context.Context, kindMapper KindMapper, edge *graph.Relationship) error { - keyBuilder := strings.Builder{} - - keyBuilder.WriteString(edge.StartID.String()) - keyBuilder.WriteString(edge.EndID.String()) - keyBuilder.WriteString(edge.Kind.String()) - - key := keyBuilder.String() - - if existingPropertiesIdx, hasExisting := s.keyToEdgeID[key]; hasExisting { - s.edgePropertiesBatch[existingPropertiesIdx].Merge(edge.Properties) - } else { - var ( - startID = edge.StartID.Uint64() - edgeID = edge.ID.Uint64() - endID = edge.EndID.Uint64() - edgeProperties = edge.Properties.Clone() - ) - - if edgeKindID, err := kindMapper.MapKind(ctx, edge.Kind); err != nil { - return err - } else { - s.relationshipUpdateBatch.Add(startID, endID, edgeKindID) - } - - s.keyToEdgeID[key] = edgeID - - s.edgePropertiesBatch = append(s.edgePropertiesBatch, edgeProperties) - s.edgePropertiesIndex[edgeID] = len(s.edgePropertiesBatch) - 1 - } - - return nil -} - -func (s *batch) flushRelationshipCreateBuffer() error { - batchBuilder := newRelationshipCreateBatchBuilder(len(s.relationshipCreateBuffer)) - - for _, nextRel := range s.relationshipCreateBuffer { - if err := batchBuilder.Add(s.ctx, s.schemaManager, nextRel); err != nil { - return err - } - } - - if createBatch, err := batchBuilder.Build(); err != nil { - return err - } else if graphTarget, err := s.innerTransaction.getTargetGraph(); err != nil { - return err - } else if _, err := s.innerTransaction.conn.Exec(s.ctx, createEdgeBatchStatement, graphTarget.ID, createBatch.startIDs, createBatch.endIDs, createBatch.edgeKindIDs, createBatch.edgePropertyBags); err != nil { - slog.Info(fmt.Sprintf("Num merged property bags: %d - Num edge keys: %d - StartID batch size: %d", len(batchBuilder.edgePropertiesIndex), len(batchBuilder.keyToEdgeID), len(batchBuilder.relationshipUpdateBatch.startIDs))) - return err - } - - s.relationshipCreateBuffer = s.relationshipCreateBuffer[:0] - return nil -} - -func (s *batch) tryFlush(batchWriteSize int) error { - if len(s.nodeUpdateByBuffer) > batchWriteSize { - if err := s.tryFlushNodeUpdateByBuffer(); err != nil { - return err - } - } - - if len(s.relationshipUpdateByBuffer) > batchWriteSize { - if err := s.tryFlushRelationshipUpdateByBuffer(); err != nil { - return err - } - } - - if len(s.relationshipCreateBuffer) > batchWriteSize { - if err := s.flushRelationshipCreateBuffer(); err != nil { - return err - } - } - - if len(s.nodeCreateBuffer) > batchWriteSize { - if err := s.flushNodeCreateBuffer(); err != nil { - return err - } - } - - if len(s.nodeDeletionBuffer) > batchWriteSize { - if err := s.flushNodeDeleteBuffer(); err != nil { - return err - } - } - - if len(s.relationshipDeletionBuffer) > batchWriteSize { - if err := s.flushRelationshipDeleteBuffer(); err != nil { - return err - } - } - - return nil -} - -func (s *batch) CreateRelationship(relationship *graph.Relationship) error { - s.relationshipCreateBuffer = append(s.relationshipCreateBuffer, relationship) - return s.tryFlush(s.batchWriteSize) -} - -func (s *batch) CreateRelationshipByIDs(startNodeID, endNodeID graph.ID, kind graph.Kind, properties *graph.Properties) error { - return s.CreateRelationship(&graph.Relationship{ - StartID: startNodeID, - EndID: endNodeID, - Kind: kind, - Properties: properties, - }) -} - -func (s *batch) UpdateRelationshipBy(update graph.RelationshipUpdate) error { - s.relationshipUpdateByBuffer = append(s.relationshipUpdateByBuffer, update) - return s.tryFlush(s.batchWriteSize) -} - -func (s *batch) Commit() error { - if err := s.tryFlush(0); err != nil { - return err - } - - return s.innerTransaction.Commit() -} - -func (s *batch) DeleteNode(id graph.ID) error { - s.nodeDeletionBuffer = append(s.nodeDeletionBuffer, id) - return s.tryFlush(s.batchWriteSize) -} - -func (s *batch) DeleteRelationship(id graph.ID) error { - s.relationshipDeletionBuffer = append(s.relationshipDeletionBuffer, id) - return s.tryFlush(s.batchWriteSize) -} - -func (s *batch) Close() { - s.innerTransaction.Close() -} diff --git a/drivers/pg/driver.go b/drivers/pg/driver.go deleted file mode 100644 index 1367343..0000000 --- a/drivers/pg/driver.go +++ /dev/null @@ -1,169 +0,0 @@ -package pg - -import ( - "context" - "fmt" - - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgxpool" - "github.com/specterops/dawgs/graph" -) - -var ( - batchWriteSize = defaultBatchWriteSize - readOnlyTxOptions = pgx.TxOptions{ - AccessMode: pgx.ReadOnly, - } - - readWriteTxOptions = pgx.TxOptions{ - AccessMode: pgx.ReadWrite, - } -) - -type Config struct { - Options pgx.TxOptions - QueryExecMode pgx.QueryExecMode - QueryResultFormats pgx.QueryResultFormats - BatchWriteSize int -} - -func OptionSetQueryExecMode(queryExecMode pgx.QueryExecMode) graph.TransactionOption { - return func(config *graph.TransactionConfig) { - if pgCfg, typeOK := config.DriverConfig.(*Config); typeOK { - pgCfg.QueryExecMode = queryExecMode - } - } -} - -type Driver struct { - pool *pgxpool.Pool - *SchemaManager -} - -func NewDriver(pool *pgxpool.Pool) *Driver { - return &Driver{ - pool: pool, - SchemaManager: NewSchemaManager(pool), - } -} - -func (s *Driver) SetDefaultGraph(ctx context.Context, graphSchema graph.Graph) error { - return s.SchemaManager.SetDefaultGraph(ctx, graphSchema) -} - -func (s *Driver) KindMapper() KindMapper { - return s.SchemaManager -} - -func (s *Driver) SetBatchWriteSize(size int) { - batchWriteSize = size -} - -func (s *Driver) SetWriteFlushSize(size int) { - // THis is a no-op function since PostgreSQL does not require transaction rotation like Neo4j does -} - -func (s *Driver) BatchOperation(ctx context.Context, batchDelegate graph.BatchDelegate) error { - if cfg, err := renderConfig(batchWriteSize, readWriteTxOptions, nil); err != nil { - return err - } else if conn, err := s.pool.Acquire(ctx); err != nil { - return err - } else { - defer conn.Release() - - if batch, err := newBatch(ctx, conn, s.SchemaManager, cfg); err != nil { - return err - } else { - defer batch.Close() - - if err := batchDelegate(batch); err != nil { - return err - } - - return batch.Commit() - } - } -} - -func (s *Driver) Close(ctx context.Context) error { - s.pool.Close() - return nil -} - -func renderConfig(batchWriteSize int, pgxOptions pgx.TxOptions, userOptions []graph.TransactionOption) (*Config, error) { - graphCfg := graph.TransactionConfig{ - DriverConfig: &Config{ - Options: pgxOptions, - QueryExecMode: pgx.QueryExecModeCacheStatement, - QueryResultFormats: pgx.QueryResultFormats{pgx.BinaryFormatCode}, - BatchWriteSize: batchWriteSize, - }, - } - - for _, option := range userOptions { - option(&graphCfg) - } - - if graphCfg.DriverConfig != nil { - if pgCfg, typeOK := graphCfg.DriverConfig.(*Config); !typeOK { - return nil, fmt.Errorf("invalid driver config type %T", graphCfg.DriverConfig) - } else { - return pgCfg, nil - } - } - - return nil, fmt.Errorf("driver config is nil") -} - -func (s *Driver) FetchSchema(ctx context.Context) (graph.Schema, error) { - // TODO: This is not required for existing functionality as the SchemaManager type handles most of this negotiation - // however, in the future this function would make it easier to make schema management generic and should be - // implemented. - return graph.Schema{}, fmt.Errorf("not implemented") -} - -func (s *Driver) AssertSchema(ctx context.Context, schema graph.Schema) error { - // Resetting the pool must be done on every schema assertion as composite types may have changed OIDs - defer s.pool.Reset() - - // Assert that the base graph schema exists and has a matching schema definition - if err := s.SchemaManager.AssertSchema(ctx, schema); err != nil { - return err - } - - if schema.DefaultGraph.Name != "" { - // There's a default graph defined. Assert that it exists and has a matching schema - if err := s.SchemaManager.AssertDefaultGraph(ctx, schema.DefaultGraph); err != nil { - return err - } - } - - return nil -} - -func (s *Driver) Run(ctx context.Context, query string, parameters map[string]any) error { - return s.WriteTransaction(ctx, func(tx graph.Transaction) error { - result := tx.Raw(query, parameters) - defer result.Close() - - return result.Error() - }) -} - -func (s *Driver) FetchKinds(_ context.Context) (graph.Kinds, error) { - var kinds graph.Kinds - for _, kind := range s.SchemaManager.GetKindIDsByKind() { - kinds = append(kinds, kind) - } - - return kinds, nil -} - -func (s *Driver) RefreshKinds(ctx context.Context) error { - s.lock.Lock() - defer s.lock.Unlock() - - // Wipe this map to be rebuilt in the fetch call below - s.SchemaManager.kindIDsByKind = map[int16]graph.Kind{} - return s.SchemaManager.Fetch(ctx) -} diff --git a/drivers/pg/model/model.go b/drivers/pg/model/model.go deleted file mode 100644 index c0b3342..0000000 --- a/drivers/pg/model/model.go +++ /dev/null @@ -1,68 +0,0 @@ -package model - -import ( - "github.com/specterops/dawgs/graph" -) - -type IndexChangeSet struct { - NodeIndexesToRemove []string - EdgeIndexesToRemove []string - NodeConstraintsToRemove []string - EdgeConstraintsToRemove []string - NodeIndexesToAdd map[string]graph.Index - EdgeIndexesToAdd map[string]graph.Index - NodeConstraintsToAdd map[string]graph.Constraint - EdgeConstraintsToAdd map[string]graph.Constraint -} - -func NewIndexChangeSet() IndexChangeSet { - return IndexChangeSet{ - NodeIndexesToAdd: map[string]graph.Index{}, - NodeConstraintsToAdd: map[string]graph.Constraint{}, - EdgeIndexesToAdd: map[string]graph.Index{}, - EdgeConstraintsToAdd: map[string]graph.Constraint{}, - } -} - -type GraphPartition struct { - Name string - Indexes map[string]graph.Index - Constraints map[string]graph.Constraint -} - -func NewGraphPartition(name string) GraphPartition { - return GraphPartition{ - Name: name, - Indexes: map[string]graph.Index{}, - Constraints: map[string]graph.Constraint{}, - } -} - -func NewGraphPartitionFromSchema(name string, indexes []graph.Index, constraints []graph.Constraint) GraphPartition { - graphPartition := GraphPartition{ - Name: name, - Indexes: make(map[string]graph.Index, len(indexes)), - Constraints: make(map[string]graph.Constraint, len(constraints)), - } - - for _, index := range indexes { - graphPartition.Indexes[IndexName(name, index)] = index - } - - for _, constraint := range constraints { - graphPartition.Constraints[ConstraintName(name, constraint)] = constraint - } - - return graphPartition -} - -type GraphPartitions struct { - Node GraphPartition - Edge GraphPartition -} - -type Graph struct { - ID int32 - Name string - Partitions GraphPartitions -} diff --git a/drivers/pg/node.go b/drivers/pg/node.go deleted file mode 100644 index 08e1a72..0000000 --- a/drivers/pg/node.go +++ /dev/null @@ -1,139 +0,0 @@ -package pg - -import ( - "github.com/specterops/dawgs/graph" - "github.com/specterops/dawgs/query" -) - -type nodeQuery struct { - liveQuery -} - -func (s *nodeQuery) Filter(criteria graph.Criteria) graph.NodeQuery { - s.queryBuilder.Apply(query.Where(criteria)) - return s -} - -func (s *nodeQuery) Filterf(criteriaDelegate graph.CriteriaProvider) graph.NodeQuery { - return s.Filter(criteriaDelegate()) -} - -func (s *nodeQuery) Delete() error { - return s.exec(query.Delete( - query.Node(), - )) -} - -func (s *nodeQuery) Update(properties *graph.Properties) error { - return s.exec(query.Updatef(func() graph.Criteria { - var updateStatements []graph.Criteria - - if modifiedProperties := properties.ModifiedProperties(); len(modifiedProperties) > 0 { - updateStatements = append(updateStatements, query.SetProperties(query.Node(), modifiedProperties)) - } - - if deletedProperties := properties.DeletedProperties(); len(deletedProperties) > 0 { - updateStatements = append(updateStatements, query.DeleteProperties(query.Node(), deletedProperties...)) - } - - return updateStatements - })) -} - -func (s *nodeQuery) OrderBy(criteria ...graph.Criteria) graph.NodeQuery { - s.queryBuilder.Apply(query.OrderBy(criteria...)) - return s -} - -func (s *nodeQuery) Offset(offset int) graph.NodeQuery { - s.queryBuilder.Apply(query.Offset(offset)) - return s -} - -func (s *nodeQuery) Limit(limit int) graph.NodeQuery { - s.queryBuilder.Apply(query.Limit(limit)) - return s -} - -func (s *nodeQuery) Count() (int64, error) { - var count int64 - - return count, s.Query(func(results graph.Result) error { - if !results.Next() { - return graph.ErrNoResultsFound - } - - return results.Scan(&count) - }, query.Returning( - query.Count(query.Node()), - )) -} - -func (s *nodeQuery) First() (*graph.Node, error) { - var node graph.Node - - return &node, s.Query( - func(results graph.Result) error { - if !results.Next() { - return graph.ErrNoResultsFound - } - - return results.Scan(&node) - }, - query.Returning( - query.Node(), - ), - query.Limit(1), - ) -} - -func (s *nodeQuery) Fetch(delegate func(cursor graph.Cursor[*graph.Node]) error, finalCriteria ...graph.Criteria) error { - return s.Query(func(result graph.Result) error { - cursor := graph.NewResultIterator(s.ctx, result, func(result graph.Result) (*graph.Node, error) { - var node graph.Node - return &node, result.Scan(&node) - }) - - defer cursor.Close() - return delegate(cursor) - }, append([]graph.Criteria{query.Returning( - query.Node(), - )}, finalCriteria...)...) -} - -func (s *nodeQuery) FetchIDs(delegate func(cursor graph.Cursor[graph.ID]) error) error { - return s.Query(func(result graph.Result) error { - cursor := graph.NewResultIterator(s.ctx, result, func(result graph.Result) (graph.ID, error) { - var nodeID graph.ID - return nodeID, result.Scan(&nodeID) - }) - - defer cursor.Close() - return delegate(cursor) - }, query.Returning( - query.NodeID(), - )) -} - -func (s *nodeQuery) FetchKinds(delegate func(cursor graph.Cursor[graph.KindsResult]) error) error { - return s.Query(func(result graph.Result) error { - cursor := graph.NewResultIterator(s.ctx, result, func(result graph.Result) (graph.KindsResult, error) { - var ( - nodeID graph.ID - nodeKinds graph.Kinds - err = result.Scan(&nodeID, &nodeKinds) - ) - - return graph.KindsResult{ - ID: nodeID, - Kinds: nodeKinds, - }, err - }) - - defer cursor.Close() - return delegate(cursor) - }, query.Returning( - query.NodeID(), - query.KindsOf(query.Node()), - )) -} diff --git a/drivers/pg/query.go b/drivers/pg/query.go deleted file mode 100644 index afa5559..0000000 --- a/drivers/pg/query.go +++ /dev/null @@ -1,63 +0,0 @@ -package pg - -import ( - "context" - - "github.com/specterops/dawgs/cypher/models/pgsql/translate" - "github.com/specterops/dawgs/graph" - "github.com/specterops/dawgs/query" -) - -type liveQuery struct { - ctx context.Context - tx graph.Transaction - kindMapper KindMapper - queryBuilder *query.Builder -} - -func newLiveQuery(ctx context.Context, tx graph.Transaction, kindMapper KindMapper) liveQuery { - return liveQuery{ - ctx: ctx, - tx: tx, - kindMapper: kindMapper, - queryBuilder: query.NewBuilder(nil), - } -} - -func (s *liveQuery) runRegularQuery(allShortestPaths bool) graph.Result { - if regularQuery, err := s.queryBuilder.Build(allShortestPaths); err != nil { - return graph.NewErrorResult(err) - } else if translation, err := translate.FromCypher(s.ctx, regularQuery, s.kindMapper, false); err != nil { - return graph.NewErrorResult(err) - } else { - return s.tx.Raw(translation.Statement, translation.Parameters) - } -} - -func (s *liveQuery) Query(delegate func(results graph.Result) error, finalCriteria ...graph.Criteria) error { - s.queryBuilder.Apply(finalCriteria...) - - if result := s.runRegularQuery(false); result.Error() != nil { - return result.Error() - } else { - defer result.Close() - return delegate(result) - } -} - -func (s *liveQuery) QueryAllShortestPaths(delegate func(results graph.Result) error, finalCriteria ...graph.Criteria) error { - s.queryBuilder.Apply(finalCriteria...) - - if result := s.runRegularQuery(true); result.Error() != nil { - return result.Error() - } else { - defer result.Close() - return delegate(result) - } -} - -func (s *liveQuery) exec(finalCriteria ...graph.Criteria) error { - return s.Query(func(results graph.Result) error { - return results.Error() - }, finalCriteria...) -} diff --git a/drivers/pg/query/query.go b/drivers/pg/query/query.go deleted file mode 100644 index 85d3ac5..0000000 --- a/drivers/pg/query/query.go +++ /dev/null @@ -1,473 +0,0 @@ -package query - -import ( - _ "embed" - "fmt" - "strings" - - "github.com/jackc/pgx/v5" - "github.com/specterops/dawgs/drivers/pg/model" - "github.com/specterops/dawgs/graph" -) - -type Query struct { - tx graph.Transaction -} - -func On(tx graph.Transaction) Query { - return Query{ - tx: tx, - } -} - -func (s Query) exec(statement string, args map[string]any) error { - result := s.tx.Raw(statement, args) - defer result.Close() - - return result.Error() -} - -func (s Query) describeGraphPartition(name string) (model.GraphPartition, error) { - graphPartition := model.NewGraphPartition(name) - - if tableIndexDefinitions, err := s.SelectTableIndexDefinitions(name); err != nil { - return graphPartition, err - } else { - for _, tableIndexDefinition := range tableIndexDefinitions { - if captureGroups := pgPropertyIndexRegex.FindStringSubmatch(tableIndexDefinition); captureGroups == nil { - // If this index does not match our expected column index format then report it as a potential error - if !pgColumnIndexRegex.MatchString(tableIndexDefinition) { - return graphPartition, fmt.Errorf("regex mis-match on schema definition: %s", tableIndexDefinition) - } - } else { - indexName := captureGroups[pgIndexRegexGroupName] - - if captureGroups[pgIndexRegexGroupUnique] == pgIndexUniqueStr { - graphPartition.Constraints[indexName] = graph.Constraint{ - Name: indexName, - Field: captureGroups[pgIndexRegexGroupFields], - Type: parsePostgresIndexType(captureGroups[pgIndexRegexGroupIndexType]), - } - } else { - graphPartition.Indexes[indexName] = graph.Index{ - Name: indexName, - Field: captureGroups[pgIndexRegexGroupFields], - Type: parsePostgresIndexType(captureGroups[pgIndexRegexGroupIndexType]), - } - } - } - } - } - - return graphPartition, nil -} - -func (s Query) SelectKinds() (map[graph.Kind]int16, error) { - var ( - kindID int16 - kindName string - - kinds = map[graph.Kind]int16{} - result = s.tx.Raw(sqlSelectKinds, nil) - ) - - defer result.Close() - - for result.Next() { - if err := result.Scan(&kindID, &kindName); err != nil { - return nil, err - } - - kinds[graph.StringKind(kindName)] = kindID - } - - return kinds, result.Error() -} - -func (s Query) selectGraphPartitions(graphID int32) (model.GraphPartitions, error) { - var ( - nodePartitionName = model.NodePartitionTableName(graphID) - edgePartitionName = model.EdgePartitionTableName(graphID) - ) - - if nodePartition, err := s.describeGraphPartition(nodePartitionName); err != nil { - return model.GraphPartitions{}, err - } else if edgePartition, err := s.describeGraphPartition(edgePartitionName); err != nil { - return model.GraphPartitions{}, err - } else { - return model.GraphPartitions{ - Node: nodePartition, - Edge: edgePartition, - }, nil - } -} - -func (s Query) selectGraphPartialByName(name string) (model.Graph, error) { - var ( - graphID int32 - result = s.tx.Raw(sqlSelectGraphByName, map[string]any{ - "name": name, - }) - ) - - defer result.Close() - - if !result.Next() { - return model.Graph{}, pgx.ErrNoRows - } - - if err := result.Scan(&graphID); err != nil { - return model.Graph{}, err - } - - return model.Graph{ - ID: graphID, - Name: name, - }, result.Error() -} - -func (s Query) SelectGraphByName(name string) (model.Graph, error) { - if definition, err := s.selectGraphPartialByName(name); err != nil { - return model.Graph{}, err - } else if graphPartitions, err := s.selectGraphPartitions(definition.ID); err != nil { - return model.Graph{}, err - } else { - definition.Partitions = graphPartitions - return definition, nil - } -} - -func (s Query) selectGraphPartials() ([]model.Graph, error) { - var ( - graphID int32 - graphName string - graphs []model.Graph - - result = s.tx.Raw(sqlSelectGraphs, nil) - ) - - defer result.Close() - - for result.Next() { - if err := result.Scan(&graphID, &graphName); err != nil { - return nil, err - } else { - graphs = append(graphs, model.Graph{ - ID: graphID, - Name: graphName, - }) - } - } - - return graphs, result.Error() -} - -func (s Query) SelectGraphs() (map[string]model.Graph, error) { - if definitions, err := s.selectGraphPartials(); err != nil { - return nil, err - } else { - indexed := map[string]model.Graph{} - - for _, definition := range definitions { - if graphPartitions, err := s.selectGraphPartitions(definition.ID); err != nil { - return nil, err - } else { - definition.Partitions = graphPartitions - indexed[definition.Name] = definition - } - } - - return indexed, nil - } -} - -func (s Query) CreatePropertyIndex(indexName, tableName, fieldName string, indexType graph.IndexType) error { - return s.exec(formatCreatePropertyIndex(indexName, tableName, fieldName, indexType), nil) -} - -func (s Query) CreatePropertyConstraint(indexName, tableName, fieldName string, indexType graph.IndexType) error { - if indexType != graph.BTreeIndex { - return fmt.Errorf("only b-tree indexing is supported for property constraints") - } - - return s.exec(formatCreatePropertyConstraint(indexName, tableName, fieldName, indexType), nil) -} - -func (s Query) DropIndex(indexName string) error { - return s.exec(formatDropPropertyIndex(indexName), nil) -} - -func (s Query) DropConstraint(constraintName string) error { - return s.exec(formatDropPropertyConstraint(constraintName), nil) -} - -func (s Query) CreateSchema() error { - if err := s.exec(sqlSchemaUp, nil); err != nil { - return err - } - - return nil -} - -func (s Query) DropSchema() error { - if err := s.exec(sqlSchemaDown, nil); err != nil { - return err - } - - return nil -} - -func (s Query) insertGraph(name string) (model.Graph, error) { - var ( - graphID int32 - result = s.tx.Raw(sqlInsertGraph, map[string]any{ - "name": name, - }) - ) - - defer result.Close() - - if !result.Next() { - return model.Graph{}, result.Error() - } - - if err := result.Scan(&graphID); err != nil { - return model.Graph{}, fmt.Errorf("failed mapping ID from graph entry creation: %w", err) - } - - return model.Graph{ - ID: graphID, - Name: name, - }, nil -} - -func (s Query) CreatePartitionTable(name, parent string, graphID int32) (model.GraphPartition, error) { - if err := s.exec(formatCreatePartitionTable(name, parent, graphID), nil); err != nil { - return model.GraphPartition{}, err - } - - return model.GraphPartition{ - Name: name, - }, nil -} - -func (s Query) SelectTableIndexDefinitions(tableName string) ([]string, error) { - var ( - definition string - definitions []string - - result = s.tx.Raw(sqlSelectTableIndexes, map[string]any{ - "tablename": tableName, - }) - ) - - defer result.Close() - - for result.Next() { - if err := result.Scan(&definition); err != nil { - return nil, err - } - - definitions = append(definitions, strings.ToLower(definition)) - } - - return definitions, result.Error() -} - -func (s Query) SelectKindID(kind graph.Kind) (int16, error) { - var ( - kindID int16 - result = s.tx.Raw(sqlSelectKindID, map[string]any{ - "name": kind.String(), - }) - ) - - defer result.Close() - - if !result.Next() { - return -1, pgx.ErrNoRows - } - - if err := result.Scan(&kindID); err != nil { - return -1, err - } - - return kindID, result.Error() -} - -func (s Query) assertGraphPartitionIndexes(partitions model.GraphPartitions, indexChanges model.IndexChangeSet) error { - for _, indexToRemove := range append(indexChanges.NodeIndexesToRemove, indexChanges.EdgeIndexesToRemove...) { - if err := s.DropIndex(indexToRemove); err != nil { - return err - } - } - - for _, constraintToRemove := range append(indexChanges.NodeConstraintsToRemove, indexChanges.EdgeConstraintsToRemove...) { - if err := s.DropConstraint(constraintToRemove); err != nil { - return err - } - } - - for indexName, index := range indexChanges.NodeIndexesToAdd { - if err := s.CreatePropertyIndex(indexName, partitions.Node.Name, index.Field, index.Type); err != nil { - return err - } - } - - for constraintName, constraint := range indexChanges.NodeConstraintsToAdd { - if err := s.CreatePropertyConstraint(constraintName, partitions.Node.Name, constraint.Field, constraint.Type); err != nil { - return err - } - } - - for indexName, index := range indexChanges.EdgeIndexesToAdd { - if err := s.CreatePropertyIndex(indexName, partitions.Edge.Name, index.Field, index.Type); err != nil { - return err - } - } - - for constraintName, constraint := range indexChanges.EdgeConstraintsToAdd { - if err := s.CreatePropertyConstraint(constraintName, partitions.Edge.Name, constraint.Field, constraint.Type); err != nil { - return err - } - } - - return nil -} - -func (s Query) AssertGraph(schema graph.Graph, definition model.Graph) (model.Graph, error) { - var ( - requiredNodePartition = model.NewGraphPartitionFromSchema(definition.Partitions.Node.Name, schema.NodeIndexes, schema.NodeConstraints) - requiredEdgePartition = model.NewGraphPartitionFromSchema(definition.Partitions.Edge.Name, schema.EdgeIndexes, schema.EdgeConstraints) - indexChangeSet = model.NewIndexChangeSet() - ) - - if presentNodePartition, err := s.describeGraphPartition(definition.Partitions.Node.Name); err != nil { - return model.Graph{}, err - } else { - for presentNodeIndexName := range presentNodePartition.Indexes { - if _, hasMatchingDefinition := requiredNodePartition.Indexes[presentNodeIndexName]; !hasMatchingDefinition { - indexChangeSet.NodeIndexesToRemove = append(indexChangeSet.NodeIndexesToRemove, presentNodeIndexName) - } - } - - for presentNodeConstraintName := range presentNodePartition.Constraints { - if _, hasMatchingDefinition := requiredNodePartition.Constraints[presentNodeConstraintName]; !hasMatchingDefinition { - indexChangeSet.NodeConstraintsToRemove = append(indexChangeSet.NodeConstraintsToRemove, presentNodeConstraintName) - } - } - - for requiredNodeIndexName, requiredNodeIndex := range requiredNodePartition.Indexes { - if presentNodeIndex, hasMatchingDefinition := presentNodePartition.Indexes[requiredNodeIndexName]; !hasMatchingDefinition { - indexChangeSet.NodeIndexesToAdd[requiredNodeIndexName] = requiredNodeIndex - } else if requiredNodeIndex.Type != presentNodeIndex.Type { - indexChangeSet.NodeIndexesToRemove = append(indexChangeSet.NodeIndexesToRemove, requiredNodeIndexName) - indexChangeSet.NodeIndexesToAdd[requiredNodeIndexName] = requiredNodeIndex - } - } - - for requiredNodeConstraintName, requiredNodeConstraint := range requiredNodePartition.Constraints { - if presentNodeConstraint, hasMatchingDefinition := presentNodePartition.Constraints[requiredNodeConstraintName]; !hasMatchingDefinition { - indexChangeSet.NodeConstraintsToAdd[requiredNodeConstraintName] = requiredNodeConstraint - } else if requiredNodeConstraint.Type != presentNodeConstraint.Type { - indexChangeSet.NodeConstraintsToRemove = append(indexChangeSet.NodeConstraintsToRemove, requiredNodeConstraintName) - indexChangeSet.NodeConstraintsToAdd[requiredNodeConstraintName] = requiredNodeConstraint - } - } - } - - if presentEdgePartition, err := s.describeGraphPartition(definition.Partitions.Edge.Name); err != nil { - return model.Graph{}, err - } else { - for presentEdgeIndexName := range presentEdgePartition.Indexes { - if _, hasMatchingDefinition := requiredEdgePartition.Indexes[presentEdgeIndexName]; !hasMatchingDefinition { - indexChangeSet.EdgeIndexesToRemove = append(indexChangeSet.EdgeIndexesToRemove, presentEdgeIndexName) - } - } - - for presentEdgeConstraintName := range presentEdgePartition.Constraints { - if _, hasMatchingDefinition := requiredEdgePartition.Constraints[presentEdgeConstraintName]; !hasMatchingDefinition { - indexChangeSet.EdgeConstraintsToRemove = append(indexChangeSet.EdgeConstraintsToRemove, presentEdgeConstraintName) - } - } - - for requiredEdgeIndexName, requiredEdgeIndex := range requiredEdgePartition.Indexes { - if presentEdgeIndex, hasMatchingDefinition := presentEdgePartition.Indexes[requiredEdgeIndexName]; !hasMatchingDefinition { - indexChangeSet.EdgeIndexesToAdd[requiredEdgeIndexName] = requiredEdgeIndex - } else if requiredEdgeIndex.Type != presentEdgeIndex.Type { - indexChangeSet.EdgeIndexesToRemove = append(indexChangeSet.EdgeIndexesToRemove, requiredEdgeIndexName) - indexChangeSet.EdgeIndexesToAdd[requiredEdgeIndexName] = requiredEdgeIndex - } - } - - for requiredEdgeConstraintName, requiredEdgeConstraint := range requiredEdgePartition.Constraints { - if presentEdgeConstraint, hasMatchingDefinition := presentEdgePartition.Constraints[requiredEdgeConstraintName]; !hasMatchingDefinition { - indexChangeSet.EdgeConstraintsToAdd[requiredEdgeConstraintName] = requiredEdgeConstraint - } else if requiredEdgeConstraint.Type != presentEdgeConstraint.Type { - indexChangeSet.EdgeConstraintsToRemove = append(indexChangeSet.EdgeConstraintsToRemove, requiredEdgeConstraintName) - indexChangeSet.EdgeConstraintsToAdd[requiredEdgeConstraintName] = requiredEdgeConstraint - } - } - } - - return model.Graph{ - ID: definition.ID, - Name: definition.Name, - Partitions: model.GraphPartitions{ - Node: requiredNodePartition, - Edge: requiredEdgePartition, - }, - }, s.assertGraphPartitionIndexes(definition.Partitions, indexChangeSet) -} - -func (s Query) createGraphPartitions(definition model.Graph) (model.Graph, error) { - var ( - nodePartitionName = model.NodePartitionTableName(definition.ID) - edgePartitionName = model.EdgePartitionTableName(definition.ID) - ) - - if nodePartition, err := s.CreatePartitionTable(nodePartitionName, model.NodeTable, definition.ID); err != nil { - return model.Graph{}, err - } else { - definition.Partitions.Node = nodePartition - } - - if edgePartition, err := s.CreatePartitionTable(edgePartitionName, model.EdgeTable, definition.ID); err != nil { - return model.Graph{}, err - } else { - definition.Partitions.Edge = edgePartition - } - - return definition, nil -} - -func (s Query) CreateGraph(schema graph.Graph) (model.Graph, error) { - if definition, err := s.insertGraph(schema.Name); err != nil { - return model.Graph{}, err - } else if definition, err := s.createGraphPartitions(definition); err != nil { - return model.Graph{}, err - } else { - return s.AssertGraph(schema, definition) - } -} - -func (s Query) InsertOrGetKind(kind graph.Kind) (int16, error) { - var ( - kindID int16 - result = s.tx.Raw(sqlInsertKind, map[string]any{ - "name": kind.String(), - }) - ) - - defer result.Close() - - if !result.Next() { - return -1, pgx.ErrNoRows - } - - if err := result.Scan(&kindID); err != nil { - return -1, err - } - - return kindID, result.Error() -} diff --git a/drivers/pg/relationship.go b/drivers/pg/relationship.go deleted file mode 100644 index 72cd351..0000000 --- a/drivers/pg/relationship.go +++ /dev/null @@ -1,238 +0,0 @@ -package pg - -import ( - "fmt" - - "github.com/specterops/dawgs/graph" - "github.com/specterops/dawgs/query" -) - -func directionToReturnCriteria(direction graph.Direction) (graph.Criteria, error) { - switch direction { - case graph.DirectionInbound: - // Select the relationship and the end node - return query.Returning( - query.Relationship(), - query.End(), - ), nil - - case graph.DirectionOutbound: - // Select the relationship and the start node - return query.Returning( - query.Relationship(), - query.Start(), - ), nil - - default: - return nil, fmt.Errorf("bad direction: %d", direction) - } -} - -type relationshipQuery struct { - liveQuery -} - -func (s *relationshipQuery) Filter(criteria graph.Criteria) graph.RelationshipQuery { - s.queryBuilder.Apply(query.Where(criteria)) - return s -} - -func (s *relationshipQuery) Filterf(criteriaDelegate graph.CriteriaProvider) graph.RelationshipQuery { - return s.Filter(criteriaDelegate()) -} - -func (s *relationshipQuery) Delete() error { - return s.exec(query.Delete( - query.Relationship(), - )) -} - -func (s *relationshipQuery) Update(properties *graph.Properties) error { - return s.exec(query.Updatef(func() graph.Criteria { - var updateStatements []graph.Criteria - - if modifiedProperties := properties.ModifiedProperties(); len(modifiedProperties) > 0 { - updateStatements = append(updateStatements, query.SetProperties(query.Node(), modifiedProperties)) - } - - if deletedProperties := properties.DeletedProperties(); len(deletedProperties) > 0 { - updateStatements = append(updateStatements, query.DeleteProperties(query.Node(), deletedProperties...)) - } - - return updateStatements - })) -} - -func (s *relationshipQuery) OrderBy(criteria ...graph.Criteria) graph.RelationshipQuery { - s.queryBuilder.Apply(query.OrderBy(criteria...)) - return s -} - -func (s *relationshipQuery) Offset(offset int) graph.RelationshipQuery { - s.queryBuilder.Apply(query.Offset(offset)) - return s -} - -func (s *relationshipQuery) Limit(limit int) graph.RelationshipQuery { - s.queryBuilder.Apply(query.Limit(limit)) - return s -} - -func (s *relationshipQuery) Count() (int64, error) { - var count int64 - - return count, s.Query(func(results graph.Result) error { - if !results.Next() { - return graph.ErrNoResultsFound - } - - return results.Scan(&count) - }, query.Returning( - query.Count(query.Relationship()), - )) -} - -func (s *relationshipQuery) FetchAllShortestPaths(delegate func(cursor graph.Cursor[graph.Path]) error) error { - return s.QueryAllShortestPaths(func(results graph.Result) error { - cursor := graph.NewResultIterator(s.ctx, results, func(result graph.Result) (graph.Path, error) { - var path graph.Path - return path, result.Scan(&path) - }) - - defer cursor.Close() - return delegate(cursor) - }, query.Returning( - query.Path(), - )) -} - -func (s *relationshipQuery) FetchTriples(delegate func(cursor graph.Cursor[graph.RelationshipTripleResult]) error) error { - return s.Query(func(result graph.Result) error { - cursor := graph.NewResultIterator(s.ctx, result, func(result graph.Result) (graph.RelationshipTripleResult, error) { - var ( - startID graph.ID - relationshipID graph.ID - endID graph.ID - err = result.Scan(&startID, &relationshipID, &endID) - ) - - return graph.RelationshipTripleResult{ - ID: relationshipID, - StartID: startID, - EndID: endID, - }, err - }) - - defer cursor.Close() - return delegate(cursor) - }, query.ReturningDistinct( - query.StartID(), - query.RelationshipID(), - query.EndID(), - )) -} - -func (s *relationshipQuery) FetchKinds(delegate func(cursor graph.Cursor[graph.RelationshipKindsResult]) error) error { - return s.Query(func(result graph.Result) error { - cursor := graph.NewResultIterator(s.ctx, result, func(result graph.Result) (graph.RelationshipKindsResult, error) { - var ( - startID graph.ID - relationshipID graph.ID - relationshipKind graph.Kind - endID graph.ID - err = result.Scan(&startID, &relationshipID, &relationshipKind, &endID) - ) - - return graph.RelationshipKindsResult{ - RelationshipTripleResult: graph.RelationshipTripleResult{ - ID: relationshipID, - StartID: startID, - EndID: endID, - }, - Kind: relationshipKind, - }, err - }) - - defer cursor.Close() - return delegate(cursor) - }, query.Returning( - query.StartID(), - query.RelationshipID(), - query.KindsOf(query.Relationship()), - query.EndID(), - )) -} - -func (s *relationshipQuery) First() (*graph.Relationship, error) { - var relationship graph.Relationship - - return &relationship, s.Query( - func(results graph.Result) error { - if !results.Next() { - return graph.ErrNoResultsFound - } - - return results.Scan(&relationship) - }, - query.Returning( - query.Relationship(), - ), - query.Limit(1), - ) -} - -func (s *relationshipQuery) Fetch(delegate func(cursor graph.Cursor[*graph.Relationship]) error) error { - return s.Query(func(result graph.Result) error { - cursor := graph.NewResultIterator(s.ctx, result, func(result graph.Result) (*graph.Relationship, error) { - var relationship graph.Relationship - return &relationship, result.Scan(&relationship) - }) - - defer cursor.Close() - return delegate(cursor) - }, query.Returning( - query.Relationship(), - )) -} - -func (s *relationshipQuery) FetchDirection(direction graph.Direction, delegate func(cursor graph.Cursor[graph.DirectionalResult]) error) error { - if returnCriteria, err := directionToReturnCriteria(direction); err != nil { - return err - } else { - return s.Query(func(result graph.Result) error { - cursor := graph.NewResultIterator(s.ctx, result, func(result graph.Result) (graph.DirectionalResult, error) { - var ( - relationship graph.Relationship - node graph.Node - ) - - if err := result.Scan(&relationship, &node); err != nil { - return graph.DirectionalResult{}, err - } - - return graph.DirectionalResult{ - Direction: direction, - Relationship: &relationship, - Node: &node, - }, nil - }) - - defer cursor.Close() - return delegate(cursor) - }, returnCriteria) - } -} - -func (s *relationshipQuery) FetchIDs(delegate func(cursor graph.Cursor[graph.ID]) error) error { - return s.Query(func(result graph.Result) error { - cursor := graph.NewResultIterator(s.ctx, result, func(result graph.Result) (graph.ID, error) { - var relationshipID graph.ID - return relationshipID, result.Scan(&relationshipID) - }) - - defer cursor.Close() - return delegate(cursor) - }, query.Returning( - query.RelationshipID(), - )) -} diff --git a/drivers/pg/result.go b/drivers/pg/result.go deleted file mode 100644 index 412bdaf..0000000 --- a/drivers/pg/result.go +++ /dev/null @@ -1,49 +0,0 @@ -package pg - -import ( - "context" - - "github.com/jackc/pgx/v5" - "github.com/specterops/dawgs/graph" -) - -type queryResult struct { - ctx context.Context - rows pgx.Rows - values []any - kindMapper KindMapper -} - -func (s *queryResult) Values() []any { - return s.values -} - -func (s *queryResult) Next() bool { - if s.rows.Next() { - // This error check exists just as a guard for a successful return of this function. The expectation is that - // the pgx type will have error information attached to it which is reflected by the Error receiver function - // of this type - if values, err := s.rows.Values(); err == nil { - s.values = values - return true - } - } - - return false -} - -func (s *queryResult) Mapper() graph.ValueMapper { - return NewValueMapper(s.ctx, s.kindMapper) -} - -func (s *queryResult) Scan(targets ...any) error { - return graph.ScanNextResult(s, targets...) -} - -func (s *queryResult) Error() error { - return s.rows.Err() -} - -func (s *queryResult) Close() { - s.rows.Close() -} diff --git a/drivers/pg/tooling.go b/drivers/pg/tooling.go deleted file mode 100644 index 6fe6af9..0000000 --- a/drivers/pg/tooling.go +++ /dev/null @@ -1,125 +0,0 @@ -package pg - -import ( - "regexp" - "sync" - - "github.com/specterops/dawgs/drivers" -) - -type IterationOptions interface { - Once() -} - -type QueryHookOptions interface { - Trace() IterationOptions -} - -type QueryHook interface { - OnStatementMatch(statement string) QueryHookOptions - OnStatementRegex(re *regexp.Regexp) QueryHookOptions -} - -type actionType int - -const ( - actionTrace actionType = iota -) - -type queryHook struct { - statementMatch *string - statementRegex *regexp.Regexp - action actionType - actionIterations int -} - -func (s *queryHook) Execute(query string, arguments ...any) { - switch s.action { - case actionTrace: - } -} - -func (s *queryHook) Catches(query string, arguments ...any) bool { - if s.statementMatch != nil { - if query == *s.statementMatch { - return true - } - } - - if s.statementRegex != nil { - if s.statementRegex.MatchString(query) { - return true - } - } - - return false -} - -func (s *queryHook) Once() { - s.actionIterations = 1 -} - -func (s *queryHook) Times(actionIterations int) { - s.actionIterations = actionIterations -} - -func (s *queryHook) Trace() IterationOptions { - s.action = actionTrace - return s -} - -func (s *queryHook) OnStatementMatch(statement string) QueryHookOptions { - s.statementMatch = &statement - return s -} - -func (s *queryHook) OnStatementRegex(re *regexp.Regexp) QueryHookOptions { - s.statementRegex = re - return s -} - -type QueryPathInspector interface { - Hook() QueryHook -} - -type queryPathInspector struct { - hooks []*queryHook - lock *sync.RWMutex -} - -func (s *queryPathInspector) Inspect(query string, arguments ...any) { - if !drivers.IsQueryAnalysisEnabled() { - return - } - - s.lock.RLock() - defer s.lock.RUnlock() - - for _, hook := range s.hooks { - if hook.Catches(query, arguments) { - hook.Execute(query, arguments) - } - } -} - -func (s *queryPathInspector) Hook() QueryHook { - s.lock.Lock() - defer s.lock.Unlock() - - hook := &queryHook{} - s.hooks = append(s.hooks, hook) - - return hook -} - -var inspectorInst = &queryPathInspector{ - lock: &sync.RWMutex{}, -} - -func inspector() *queryPathInspector { - return inspectorInst -} - -func Inspector() QueryPathInspector { - return inspectorInst -} diff --git a/drivers/pg/transaction.go b/drivers/pg/transaction.go deleted file mode 100644 index aebd6e5..0000000 --- a/drivers/pg/transaction.go +++ /dev/null @@ -1,299 +0,0 @@ -package pg - -import ( - "context" - "fmt" - - "github.com/specterops/dawgs/cypher/models/pgsql" - "github.com/specterops/dawgs/cypher/models/pgsql/translate" - - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgxpool" - "github.com/specterops/dawgs/cypher/frontend" - "github.com/specterops/dawgs/drivers/pg/model" - "github.com/specterops/dawgs/graph" - "github.com/specterops/dawgs/query" - "github.com/specterops/dawgs/util/size" -) - -type driver interface { - Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) - Query(ctx context.Context, sql string, arguments ...any) (pgx.Rows, error) - QueryRow(ctx context.Context, sql string, arguments ...any) pgx.Row -} - -type inspectingDriver struct { - upstreamDriver driver -} - -func (s inspectingDriver) Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) { - inspector().Inspect(sql, arguments) - return s.upstreamDriver.Exec(ctx, sql, arguments...) -} - -func (s inspectingDriver) Query(ctx context.Context, sql string, arguments ...any) (pgx.Rows, error) { - inspector().Inspect(sql, arguments) - return s.upstreamDriver.Query(ctx, sql, arguments...) -} - -func (s inspectingDriver) QueryRow(ctx context.Context, sql string, arguments ...any) pgx.Row { - inspector().Inspect(sql, arguments) - return s.upstreamDriver.QueryRow(ctx, sql, arguments...) -} - -type transaction struct { - schemaManager *SchemaManager - queryExecMode pgx.QueryExecMode - queryResultsFormat pgx.QueryResultFormats - ctx context.Context - conn *pgxpool.Conn - tx pgx.Tx - targetSchema graph.Graph - targetSchemaSet bool -} - -func newTransactionWrapper(ctx context.Context, conn *pgxpool.Conn, schemaManager *SchemaManager, cfg *Config, allocateTransaction bool) (*transaction, error) { - wrapper := &transaction{ - schemaManager: schemaManager, - queryExecMode: cfg.QueryExecMode, - queryResultsFormat: cfg.QueryResultFormats, - ctx: ctx, - conn: conn, - targetSchemaSet: false, - } - - if allocateTransaction { - if pgxTx, err := conn.BeginTx(ctx, cfg.Options); err != nil { - return nil, err - } else { - wrapper.tx = pgxTx - } - } - - return wrapper, nil -} - -func (s *transaction) driver() driver { - if s.tx != nil { - return inspectingDriver{ - upstreamDriver: s.tx, - } - } - - return inspectingDriver{ - upstreamDriver: s.conn, - } -} - -func (s *transaction) GraphQueryMemoryLimit() size.Size { - return size.Gibibyte -} - -func (s *transaction) WithGraph(schema graph.Graph) graph.Transaction { - s.targetSchema = schema - s.targetSchemaSet = true - - return s -} - -func (s *transaction) Close() { - if s.tx != nil { - s.tx.Rollback(s.ctx) - s.tx = nil - } -} - -func (s *transaction) getTargetGraph() (model.Graph, error) { - if !s.targetSchemaSet { - // Look for a default graph target - if defaultGraph, hasDefaultGraph := s.schemaManager.DefaultGraph(); !hasDefaultGraph { - return model.Graph{}, fmt.Errorf("driver operation requires a graph target to be set") - } else { - return defaultGraph, nil - } - } - - return s.schemaManager.AssertGraph(s, s.targetSchema) -} - -func (s *transaction) CreateNode(properties *graph.Properties, kinds ...graph.Kind) (*graph.Node, error) { - if graphTarget, err := s.getTargetGraph(); err != nil { - return nil, err - } else if kindIDSlice, err := s.schemaManager.AssertKinds(s.ctx, kinds); err != nil { - return nil, err - } else if propertiesJSONB, err := pgsql.PropertiesToJSONB(properties); err != nil { - return nil, err - } else { - var ( - node graph.Node - result = s.Raw(createNodeStatement, map[string]any{ - "graph_id": graphTarget.ID, - "kind_ids": kindIDSlice, - "properties": propertiesJSONB, - }) - ) - - defer result.Close() - - if !result.Next() { - return nil, result.Error() - } - - return &node, result.Scan(&node) - } -} - -func (s *transaction) UpdateNode(node *graph.Node) error { - var ( - properties = node.Properties - updateStatements []graph.Criteria - ) - - if addedKinds := node.AddedKinds; len(addedKinds) > 0 { - updateStatements = append(updateStatements, query.AddKinds(query.Node(), addedKinds)) - } - - if deletedKinds := node.DeletedKinds; len(deletedKinds) > 0 { - updateStatements = append(updateStatements, query.DeleteKinds(query.Node(), deletedKinds)) - } - - if modifiedProperties := properties.ModifiedProperties(); len(modifiedProperties) > 0 { - updateStatements = append(updateStatements, query.SetProperties(query.Node(), modifiedProperties)) - } - - if deletedProperties := properties.DeletedProperties(); len(deletedProperties) > 0 { - updateStatements = append(updateStatements, query.DeleteProperties(query.Node(), deletedProperties...)) - } - - return s.Nodes().Filter(query.Equals(query.NodeID(), node.ID)).Query(func(results graph.Result) error { - // We don't need to exhaust the result set as the defered close with discard it for us - return results.Error() - }, updateStatements...) -} - -func (s *transaction) Nodes() graph.NodeQuery { - return &nodeQuery{ - liveQuery: newLiveQuery(s.ctx, s, s.schemaManager), - } -} - -func (s *transaction) CreateRelationshipByIDs(startNodeID, endNodeID graph.ID, kind graph.Kind, properties *graph.Properties) (*graph.Relationship, error) { - if graphTarget, err := s.getTargetGraph(); err != nil { - return nil, err - } else if kindIDSlice, err := s.schemaManager.AssertKinds(s.ctx, graph.Kinds{kind}); err != nil { - return nil, err - } else if propertiesJSONB, err := pgsql.PropertiesToJSONB(properties); err != nil { - return nil, err - } else { - var ( - edge graph.Relationship - result = s.Raw(createEdgeStatement, map[string]any{ - "graph_id": graphTarget.ID, - "start_id": startNodeID, - "end_id": endNodeID, - "kind_id": kindIDSlice[0], - "properties": propertiesJSONB, - }) - ) - - defer result.Close() - - if !result.Next() { - return nil, result.Error() - } - - return &edge, result.Scan(&edge) - } -} - -func (s *transaction) UpdateRelationship(relationship *graph.Relationship) error { - var ( - modifiedProperties = relationship.Properties.ModifiedProperties() - deletedProperties = relationship.Properties.DeletedProperties() - numModifiedProperties = len(modifiedProperties) - numDeletedProperties = len(deletedProperties) - - statement string - arguments []any - ) - - if numModifiedProperties > 0 { - if jsonbArgument, err := pgsql.ValueToJSONB(modifiedProperties); err != nil { - return err - } else { - arguments = append(arguments, jsonbArgument) - } - - if numDeletedProperties > 0 { - if textArrayArgument, err := pgsql.StringSliceToTextArray(deletedProperties); err != nil { - return err - } else { - arguments = append(arguments, textArrayArgument) - } - - statement = edgePropertySetAndDeleteStatement - } else { - statement = edgePropertySetOnlyStatement - } - } else if numDeletedProperties > 0 { - if textArrayArgument, err := pgsql.StringSliceToTextArray(deletedProperties); err != nil { - return err - } else { - arguments = append(arguments, textArrayArgument) - } - - statement = edgePropertyDeleteOnlyStatement - } - - _, err := s.driver().Exec(s.ctx, statement, append(arguments, relationship.ID)...) - return err -} - -func (s *transaction) Relationships() graph.RelationshipQuery { - return &relationshipQuery{ - liveQuery: newLiveQuery(s.ctx, s, s.schemaManager), - } -} - -func (s *transaction) query(query string, parameters map[string]any) (pgx.Rows, error) { - queryArgs := []any{s.queryExecMode, s.queryResultsFormat} - - if len(parameters) > 0 { - queryArgs = append(queryArgs, pgx.NamedArgs(parameters)) - } - - return s.driver().Query(s.ctx, query, queryArgs...) -} - -func (s *transaction) Query(query string, parameters map[string]any) graph.Result { - if parsedQuery, err := frontend.ParseCypher(frontend.NewContext(), query); err != nil { - return graph.NewErrorResult(err) - } else if translated, err := translate.Translate(s.ctx, parsedQuery, s.schemaManager, parameters); err != nil { - return graph.NewErrorResult(err) - } else if sqlQuery, err := translate.Translated(translated); err != nil { - return graph.NewErrorResult(err) - } else { - return s.Raw(sqlQuery, translated.Parameters) - } -} - -func (s *transaction) Raw(query string, parameters map[string]any) graph.Result { - if rows, err := s.query(query, parameters); err != nil { - return graph.NewErrorResult(err) - } else { - return &queryResult{ - ctx: s.ctx, - rows: rows, - kindMapper: s.schemaManager, - } - } -} - -func (s *transaction) Commit() error { - if s.tx != nil { - return s.tx.Commit(s.ctx) - } - - return nil -} diff --git a/drivers/pg/util.go b/drivers/pg/util.go deleted file mode 100644 index 2595323..0000000 --- a/drivers/pg/util.go +++ /dev/null @@ -1,7 +0,0 @@ -package pg - -import "github.com/specterops/dawgs/graph" - -func IsPostgreSQLGraph(db graph.Database) bool { - return graph.IsDriver[*Driver](db) -} diff --git a/graph/error.go b/graph/error.go deleted file mode 100644 index e103ddd..0000000 --- a/graph/error.go +++ /dev/null @@ -1 +0,0 @@ -package graph diff --git a/graph/graph.go b/graph/graph.go index 3cb13bc..c536f73 100644 --- a/graph/graph.go +++ b/graph/graph.go @@ -1,7 +1,6 @@ package graph import ( - "context" "errors" "slices" "strconv" @@ -16,8 +15,6 @@ const ( DirectionInbound Direction = 0 DirectionOutbound Direction = 1 DirectionBoth Direction = 2 - End = DirectionInbound - Start = DirectionOutbound ) var ErrInvalidDirection = errors.New("must be called with either an inbound or outbound direction") @@ -260,147 +257,3 @@ func (s RelationshipUpdate) EndIdentityPropertiesMap() map[string]any { return identityPropertiesMap } - -type Batch interface { - // WithGraph scopes the transaction to a specific graph. If the driver for the transaction does not support - // multiple graphs the resulting transaction will target the default graph instead and this call becomes a no-op. - WithGraph(graphSchema Graph) Batch - - // CreateNode creates a new Node in the database and returns the creation as a NodeResult. - CreateNode(node *Node) error - - // DeleteNode deletes a node by the given ID. - DeleteNode(id ID) error - - // Nodes begins a batch query that can be used to update or delete nodes. - Nodes() NodeQuery - - // Relationships begins a batch query that can be used to update or delete relationships. - Relationships() RelationshipQuery - - // UpdateNodeBy is a stop-gap until the query interface can better support targeted batch create-update operations. - // Nodes identified by the NodeUpdate criteria will either be updated or in the case where the node does not yet - // exist, created. - UpdateNodeBy(update NodeUpdate) error - - // TODO: Existing batch logic expects this to perform an upsert on conficts with (start_id, end_id, kind). This is incorrect and should be refactored - CreateRelationship(relationship *Relationship) error - - // Deprecated: Use CreateRelationship Instead - // - // CreateRelationshipByIDs creates a new Relationship from the start Node to the end Node with the given Kind and - // Properties and returns the creation as a RelationshipResult. - CreateRelationshipByIDs(startNodeID, endNodeID ID, kind Kind, properties *Properties) error - - // DeleteRelationship deletes a relationship by the given ID. - DeleteRelationship(id ID) error - - // UpdateRelationshipBy is a stop-gap until the query interface can better support targeted batch create-update - // operations. Relationships identified by the RelationshipUpdate criteria will either be updated or in the case - // where the relationship does not yet exist, created. - UpdateRelationshipBy(update RelationshipUpdate) error - - // Commit calls to commit this batch transaction right away. - Commit() error -} - -// Transaction is an interface that contains all operations that may be executed against a DAWGS driver. DAWGS drivers are -// expected to support all Transaction operations in-transaction. -type Transaction interface { - // WithGraph scopes the transaction to a specific graph. If the driver for the transaction does not support - // multiple graphs the resulting transaction will target the default graph instead and this call becomes a no-op. - WithGraph(graphSchema Graph) Transaction - - // CreateNode creates a new Node in the database and returns the creation as a NodeResult. - CreateNode(properties *Properties, kinds ...Kind) (*Node, error) - - // UpdateNode updates a Node in the database with the given Node by ID. UpdateNode will not create missing Node - // entries in the database. Use CreateNode first to create a new Node. - UpdateNode(node *Node) error - - // Nodes creates a new NodeQuery and returns it. - Nodes() NodeQuery - - // CreateRelationshipByIDs creates a new Relationship from the start Node to the end Node with the given Kind and - // Properties and returns the creation as a RelationshipResult. - CreateRelationshipByIDs(startNodeID, endNodeID ID, kind Kind, properties *Properties) (*Relationship, error) - - // UpdateRelationship updates a Relationship in the database with the given Relationship by ID. UpdateRelationship - // will not create missing Relationship entries in the database. Use CreateRelationship first to create a new - // Relationship. - UpdateRelationship(relationship *Relationship) error - - // Relationships creates a new RelationshipQuery and returns it. - Relationships() RelationshipQuery - - // Raw allows a user to pass raw queries directly to the database without translation. - Raw(query string, parameters map[string]any) Result - - // Query allows a user to execute a given cypher query that will be translated to the target database. - Query(query string, parameters map[string]any) Result - - // Commit calls to commit this transaction right away. - Commit() error - - // GraphQueryMemoryLimit returns the graph query memory limit of - GraphQueryMemoryLimit() size.Size -} - -// TransactionDelegate represents a transactional database context actor. Errors returned from a TransactionDelegate -// result in the rollback of write enabled transactions. Successful execution of a TransactionDelegate (nil error -// return value) results in a transactional commit of work done within the TransactionDelegate. -type TransactionDelegate func(tx Transaction) error - -// BatchDelegate represents a transactional database context actor. -type BatchDelegate func(batch Batch) error - -// TransactionConfig is a generic configuration that may apply to all supported databases. -type TransactionConfig struct { - Timeout time.Duration - DriverConfig any -} - -// TransactionOption is a function that represents a configuration setting for the underlying database transaction. -type TransactionOption func(config *TransactionConfig) - -// Database is a high-level interface representing transactional entry-points into DAWGS driver implementations. -type Database interface { - // SetWriteFlushSize sets a new write flush interval on the current driver - SetWriteFlushSize(interval int) - - // SetBatchWriteSize sets a new batch write interval on the current driver - SetBatchWriteSize(interval int) - - // ReadTransaction opens up a new read transactional context in the database and then defers the context to the - // given logic function. - ReadTransaction(ctx context.Context, txDelegate TransactionDelegate, options ...TransactionOption) error - - // WriteTransaction opens up a new write transactional context in the database and then defers the context to the - // given logic function. - WriteTransaction(ctx context.Context, txDelegate TransactionDelegate, options ...TransactionOption) error - - // BatchOperation opens up a new write transactional context in the database and then defers the context to the - // given logic function. Batch operations are fundamentally different between databases supported by DAWGS, - // necessitating a different interface that lacks many of the convenience features of a regular read or write - // transaction. - BatchOperation(ctx context.Context, batchDelegate BatchDelegate) error - - // AssertSchema will apply the given schema to the underlying database. - AssertSchema(ctx context.Context, dbSchema Schema) error - - // SetDefaultGraph sets the default graph namespace for the connection. - SetDefaultGraph(ctx context.Context, graphSchema Graph) error - - // Run allows a user to pass statements directly to the database. Since results may rely on a transactional context - // only an error is returned from this function - Run(ctx context.Context, query string, parameters map[string]any) error - - // Close closes the database context and releases any pooled resources held by the instance. - Close(ctx context.Context) error - - // FetchKinds retrieves the complete list of kinds available to the database. - FetchKinds(ctx context.Context) (Kinds, error) - - // RefreshKinds refreshes the in memory kinds maps - RefreshKinds(ctx context.Context) error -} diff --git a/graph/properties.go b/graph/properties.go index 5b563d2..603f18f 100644 --- a/graph/properties.go +++ b/graph/properties.go @@ -235,16 +235,28 @@ type Properties struct { } func (s *Properties) Merge(other *Properties) { + if s.Map == nil { + s.Map = make(map[string]any, len(other.Map)) + } + for otherKey, otherValue := range other.Map { s.Map[otherKey] = otherValue } + if s.Modified == nil { + s.Modified = make(map[string]struct{}, len(other.Modified)) + } + for otherModifiedKey := range other.Modified { s.Modified[otherModifiedKey] = struct{}{} delete(s.Deleted, otherModifiedKey) } + if s.Deleted == nil { + s.Deleted = make(map[string]struct{}, len(other.Deleted)) + } + for otherDeletedKey := range other.Deleted { s.Deleted[otherDeletedKey] = struct{}{} @@ -501,14 +513,6 @@ func NewProperties() *Properties { return &Properties{} } -func NewPropertiesRed() *Properties { - return &Properties{ - Map: map[string]any{}, - Modified: make(map[string]struct{}), - Deleted: make(map[string]struct{}), - } -} - type PropertyMap map[String]any func symbolMapToStringMap(props map[String]any) map[string]any { diff --git a/graph/relationships.go b/graph/relationship.go similarity index 100% rename from graph/relationships.go rename to graph/relationship.go diff --git a/graph/relationships_test.go b/graph/relationship_test.go similarity index 100% rename from graph/relationships_test.go rename to graph/relationship_test.go diff --git a/graph/schema.go b/graph/schema.go deleted file mode 100644 index 0a58da8..0000000 --- a/graph/schema.go +++ /dev/null @@ -1,45 +0,0 @@ -package graph - -type IndexType int - -const ( - UnsupportedIndex IndexType = 0 - BTreeIndex IndexType = 1 - TextSearchIndex IndexType = 2 -) - -func (s IndexType) String() string { - switch s { - case BTreeIndex: - return "btree" - - case TextSearchIndex: - return "fts" - - default: - return "invalid" - } -} - -type Index struct { - Name string - Field string - Type IndexType -} - -type Constraint Index - -type Graph struct { - Name string - Nodes Kinds - Edges Kinds - NodeConstraints []Constraint - EdgeConstraints []Constraint - NodeIndexes []Index - EdgeIndexes []Index -} - -type Schema struct { - Graphs []Graph - DefaultGraph Graph -} diff --git a/graphcache/cache.go b/graphcache/cache.go index 544d271..ed150a9 100644 --- a/graphcache/cache.go +++ b/graphcache/cache.go @@ -4,7 +4,6 @@ import ( "sync" "github.com/specterops/dawgs/graph" - "github.com/specterops/dawgs/query" "github.com/specterops/dawgs/util/size" ) @@ -121,108 +120,3 @@ func (s Cache) PutRelationshipSet(relationships graph.RelationshipSet) { s.relationships.Put(relationship.ID, relationship) } } - -func fetchNodesByIDQuery(tx graph.Transaction, ids []graph.ID) graph.NodeQuery { - return tx.Nodes().Filterf(func() graph.Criteria { - return query.InIDs(query.NodeID(), ids...) - }) -} - -func fetchRelationshipsByIDQuery(tx graph.Transaction, ids []graph.ID) graph.RelationshipQuery { - return tx.Relationships().Filterf(func() graph.Criteria { - return query.InIDs(query.RelationshipID(), ids...) - }) -} - -func FetchNodesByID(tx graph.Transaction, cache Cache, ids []graph.ID) ([]*graph.Node, error) { - var ( - cachedNodes, missingNodeIDs = cache.GetNodes(ids) - toBeCachedCount = 0 - ) - - if len(missingNodeIDs) > 0 { - if err := fetchNodesByIDQuery(tx, missingNodeIDs).Fetch(func(cursor graph.Cursor[*graph.Node]) error { - for next := range cursor.Chan() { - cachedNodes = append(cachedNodes, next) - toBeCachedCount++ - } - - return cursor.Error() - }); err != nil { - return nil, err - } - - if toBeCachedCount > 0 { - cache.PutNodes(cachedNodes[len(cachedNodes)-toBeCachedCount:]) - } - - if len(missingNodeIDs) != toBeCachedCount { - return nil, graph.ErrMissingResultExpectation - } - } - - return cachedNodes, nil -} - -func FetchRelationshipsByID(tx graph.Transaction, cache Cache, ids []graph.ID) (graph.RelationshipSet, error) { - cachedRelationships, missingRelationshipIDs := cache.GetRelationships(ids) - - if len(missingRelationshipIDs) > 0 { - if err := fetchRelationshipsByIDQuery(tx, ids).Fetch(func(cursor graph.Cursor[*graph.Relationship]) error { - for next := range cursor.Chan() { - cachedRelationships = append(cachedRelationships, next) - } - - return cursor.Error() - }); err != nil { - return nil, err - } - - cache.PutRelationships(cachedRelationships[len(cachedRelationships)-len(missingRelationshipIDs):]) - } - - return graph.NewRelationshipSet(cachedRelationships...), nil -} - -func ShallowFetchRelationships(cache Cache, graphQuery graph.RelationshipQuery) ([]*graph.Relationship, error) { - var relationships []*graph.Relationship - - if err := graphQuery.FetchKinds(func(cursor graph.Cursor[graph.RelationshipKindsResult]) error { - for next := range cursor.Chan() { - relationships = append(relationships, graph.NewRelationship(next.ID, next.StartID, next.EndID, nil, next.Kind)) - } - - return cursor.Error() - }); err != nil { - return nil, err - } - - cache.PutRelationships(relationships) - return relationships, nil -} - -func ShallowFetchNodesByID(tx graph.Transaction, cache Cache, ids []graph.ID) ([]*graph.Node, error) { - cachedNodes, missingNodeIDs := cache.GetNodes(ids) - - if len(missingNodeIDs) > 0 { - newNodes := make([]*graph.Node, 0, len(missingNodeIDs)) - - if err := fetchNodesByIDQuery(tx, missingNodeIDs).FetchKinds(func(cursor graph.Cursor[graph.KindsResult]) error { - for next := range cursor.Chan() { - newNodes = append(newNodes, graph.NewNode(next.ID, nil, next.Kinds...)) - } - - return cursor.Error() - }); err != nil { - return nil, err - } - - // Put the fetched nodes into cache - cache.PutNodes(newNodes) - - // Append them to the end of the nodes being returned - cachedNodes = append(cachedNodes, newNodes...) - } - - return cachedNodes, nil -} diff --git a/query/model.go b/query/model.go deleted file mode 100644 index a3b2233..0000000 --- a/query/model.go +++ /dev/null @@ -1,617 +0,0 @@ -package query - -import ( - "fmt" - "strings" - "time" - - cypherModel "github.com/specterops/dawgs/cypher/models/cypher" - "github.com/specterops/dawgs/graph" -) - -func convertCriteria[T any](criteria ...graph.Criteria) []T { - var ( - converted = make([]T, len(criteria)) - ) - - for idx, nextCriteria := range criteria { - converted[idx] = nextCriteria.(T) - } - - return converted -} - -func Update(clauses ...*cypherModel.UpdatingClause) []*cypherModel.UpdatingClause { - return clauses -} - -func Updatef(provider graph.CriteriaProvider) []*cypherModel.UpdatingClause { - switch typedCriteria := provider().(type) { - case []*cypherModel.UpdatingClause: - return typedCriteria - - case []graph.Criteria: - return convertCriteria[*cypherModel.UpdatingClause](typedCriteria...) - - case *cypherModel.UpdatingClause: - return []*cypherModel.UpdatingClause{typedCriteria} - - default: - return []*cypherModel.UpdatingClause{ - cypherModel.WithErrors(&cypherModel.UpdatingClause{}, fmt.Errorf("invalid type %T for update clause", typedCriteria)), - } - } -} - -func AddKind(reference graph.Criteria, kind graph.Kind) *cypherModel.UpdatingClause { - return cypherModel.NewUpdatingClause(&cypherModel.Set{ - Items: []*cypherModel.SetItem{{ - Left: reference, - Operator: cypherModel.OperatorLabelAssignment, - Right: graph.Kinds{kind}, - }}, - }) -} - -func AddKinds(reference graph.Criteria, kinds graph.Kinds) *cypherModel.UpdatingClause { - return cypherModel.NewUpdatingClause(&cypherModel.Set{ - Items: []*cypherModel.SetItem{{ - Left: reference, - Operator: cypherModel.OperatorLabelAssignment, - Right: kinds, - }}, - }) -} - -func DeleteKind(reference graph.Criteria, kind graph.Kind) *cypherModel.UpdatingClause { - return cypherModel.NewUpdatingClause(&cypherModel.Remove{ - Items: []*cypherModel.RemoveItem{{ - KindMatcher: &cypherModel.KindMatcher{ - Reference: reference, - Kinds: graph.Kinds{kind}, - }, - }}, - }) -} - -func DeleteKinds(reference graph.Criteria, kinds graph.Kinds) *cypherModel.UpdatingClause { - return cypherModel.NewUpdatingClause(&cypherModel.Remove{ - Items: []*cypherModel.RemoveItem{{ - KindMatcher: &cypherModel.KindMatcher{ - Reference: reference, - Kinds: kinds, - }, - }}, - }) -} - -func SetProperty(reference graph.Criteria, value any) *cypherModel.UpdatingClause { - return cypherModel.NewUpdatingClause(&cypherModel.Set{ - Items: []*cypherModel.SetItem{{ - Left: reference, - Operator: cypherModel.OperatorAssignment, - Right: Parameter(value), - }}, - }) -} - -func SetProperties(reference graph.Criteria, properties map[string]any) *cypherModel.UpdatingClause { - set := &cypherModel.Set{} - - for key, value := range properties { - set.Items = append(set.Items, &cypherModel.SetItem{ - Left: Property(reference, key), - Operator: cypherModel.OperatorAssignment, - Right: Parameter(value), - }) - } - - return cypherModel.NewUpdatingClause(set) -} - -func DeleteProperty(reference *cypherModel.PropertyLookup) *cypherModel.UpdatingClause { - return cypherModel.NewUpdatingClause(&cypherModel.Remove{ - Items: []*cypherModel.RemoveItem{{ - Property: reference, - }}, - }) -} - -func DeleteProperties(reference graph.Criteria, propertyNames ...string) *cypherModel.UpdatingClause { - removeClause := &cypherModel.Remove{} - - for _, propertyName := range propertyNames { - removeClause.Items = append(removeClause.Items, &cypherModel.RemoveItem{ - Property: Property(reference, propertyName), - }) - } - - return cypherModel.NewUpdatingClause(removeClause) -} - -func Kind(reference graph.Criteria, kinds ...graph.Kind) *cypherModel.KindMatcher { - return &cypherModel.KindMatcher{ - Reference: reference, - Kinds: kinds, - } -} - -func KindIn(reference graph.Criteria, kinds ...graph.Kind) *cypherModel.KindMatcher { - return cypherModel.NewKindMatcher(reference, kinds) -} - -func NodeProperty(name string) *cypherModel.PropertyLookup { - return cypherModel.NewPropertyLookup(NodeSymbol, name) -} - -func RelationshipProperty(name string) *cypherModel.PropertyLookup { - return cypherModel.NewPropertyLookup(EdgeSymbol, name) -} - -func StartProperty(name string) *cypherModel.PropertyLookup { - return cypherModel.NewPropertyLookup(EdgeStartSymbol, name) -} - -func EndProperty(name string) *cypherModel.PropertyLookup { - return cypherModel.NewPropertyLookup(EdgeEndSymbol, name) -} - -func Property(qualifier graph.Criteria, name string) *cypherModel.PropertyLookup { - return &cypherModel.PropertyLookup{ - Atom: qualifier.(*cypherModel.Variable), - Symbol: name, - } -} - -func Count(reference graph.Criteria) *cypherModel.FunctionInvocation { - return &cypherModel.FunctionInvocation{ - Name: "count", - Arguments: []cypherModel.Expression{reference}, - } -} - -func CountDistinct(reference graph.Criteria) *cypherModel.FunctionInvocation { - return &cypherModel.FunctionInvocation{ - Name: "count", - Distinct: true, - Arguments: []cypherModel.Expression{reference}, - } -} - -func And(criteria ...graph.Criteria) *cypherModel.Conjunction { - return cypherModel.NewConjunction(convertCriteria[cypherModel.Expression](criteria...)...) -} - -func Or(criteria ...graph.Criteria) *cypherModel.Parenthetical { - return &cypherModel.Parenthetical{ - Expression: cypherModel.NewDisjunction(convertCriteria[cypherModel.Expression](criteria...)...), - } -} - -func Xor(criteria ...graph.Criteria) *cypherModel.ExclusiveDisjunction { - return cypherModel.NewExclusiveDisjunction(convertCriteria[cypherModel.Expression](criteria...)...) -} - -func Parameter(value any) *cypherModel.Parameter { - if parameter, isParameter := value.(*cypherModel.Parameter); isParameter { - return parameter - } - - return &cypherModel.Parameter{ - Value: value, - } -} - -func Literal(value any) *cypherModel.Literal { - return &cypherModel.Literal{ - Value: value, - Null: value == nil, - } -} - -func KindsOf(ref graph.Criteria) *cypherModel.FunctionInvocation { - switch typedRef := ref.(type) { - case *cypherModel.Variable: - switch typedRef.Symbol { - case NodeSymbol, EdgeStartSymbol, EdgeEndSymbol: - return &cypherModel.FunctionInvocation{ - Name: "labels", - Arguments: []cypherModel.Expression{ref}, - } - - case EdgeSymbol: - return &cypherModel.FunctionInvocation{ - Name: "type", - Arguments: []cypherModel.Expression{ref}, - } - - default: - return cypherModel.WithErrors(&cypherModel.FunctionInvocation{}, fmt.Errorf("invalid variable reference for KindsOf: %s", typedRef.Symbol)) - } - - default: - return cypherModel.WithErrors(&cypherModel.FunctionInvocation{}, fmt.Errorf("invalid reference type for KindsOf: %T", ref)) - } -} - -func Limit(limit int) *cypherModel.Limit { - return &cypherModel.Limit{ - Value: Literal(limit), - } -} - -func Offset(offset int) *cypherModel.Skip { - return &cypherModel.Skip{ - Value: Literal(offset), - } -} - -func StringContains(reference graph.Criteria, value string) *cypherModel.Comparison { - return cypherModel.NewComparison(reference, cypherModel.OperatorContains, Parameter(value)) -} - -func StringStartsWith(reference graph.Criteria, value string) *cypherModel.Comparison { - return cypherModel.NewComparison(reference, cypherModel.OperatorStartsWith, Parameter(value)) -} - -func StringEndsWith(reference graph.Criteria, value string) *cypherModel.Comparison { - return cypherModel.NewComparison(reference, cypherModel.OperatorEndsWith, Parameter(value)) -} - -func CaseInsensitiveStringContains(reference graph.Criteria, value string) *cypherModel.Comparison { - return cypherModel.NewComparison( - cypherModel.NewSimpleFunctionInvocation("toLower", convertCriteria[cypherModel.Expression](reference)...), - cypherModel.OperatorContains, - Parameter(strings.ToLower(value)), - ) -} - -func CaseInsensitiveStringStartsWith(reference graph.Criteria, value string) *cypherModel.Comparison { - return cypherModel.NewComparison( - cypherModel.NewSimpleFunctionInvocation("toLower", convertCriteria[cypherModel.Expression](reference)...), - cypherModel.OperatorStartsWith, - Parameter(strings.ToLower(value)), - ) -} - -func CaseInsensitiveStringEndsWith(reference graph.Criteria, value string) *cypherModel.Comparison { - return cypherModel.NewComparison( - cypherModel.NewSimpleFunctionInvocation("toLower", convertCriteria[cypherModel.Expression](reference)...), - cypherModel.OperatorEndsWith, - Parameter(strings.ToLower(value)), - ) -} - -func Equals(reference graph.Criteria, value any) *cypherModel.Comparison { - return cypherModel.NewComparison(reference, cypherModel.OperatorEquals, Parameter(value)) -} - -func GreaterThan(reference graph.Criteria, value any) *cypherModel.Comparison { - return cypherModel.NewComparison(reference, cypherModel.OperatorGreaterThan, Parameter(value)) -} - -func After(reference graph.Criteria, value any) *cypherModel.Comparison { - return GreaterThan(reference, value) -} - -func GreaterThanOrEquals(reference graph.Criteria, value any) *cypherModel.Comparison { - return cypherModel.NewComparison(reference, cypherModel.OperatorGreaterThanOrEqualTo, Parameter(value)) -} - -func LessThan(reference graph.Criteria, value any) *cypherModel.Comparison { - return cypherModel.NewComparison(reference, cypherModel.OperatorLessThan, Parameter(value)) -} - -func LessThanGraphQuery(reference1, reference2 graph.Criteria) *cypherModel.Comparison { - return cypherModel.NewComparison(reference1, cypherModel.OperatorLessThan, reference2) -} - -func Before(reference graph.Criteria, value time.Time) *cypherModel.Comparison { - return LessThan(reference, value) -} - -func BeforeGraphQuery(reference1, reference2 graph.Criteria) *cypherModel.Comparison { - return LessThanGraphQuery(reference1, reference2) -} - -func LessThanOrEquals(reference graph.Criteria, value any) *cypherModel.Comparison { - return cypherModel.NewComparison(reference, cypherModel.OperatorLessThanOrEqualTo, Parameter(value)) -} - -func Exists(reference graph.Criteria) *cypherModel.Comparison { - return cypherModel.NewComparison( - reference, - cypherModel.OperatorIsNot, - cypherModel.NewLiteral(nil, true), - ) -} - -func HasRelationships(reference *cypherModel.Variable) *cypherModel.PatternPredicate { - patternPredicate := cypherModel.NewPatternPredicate() - - patternPredicate.AddElement(&cypherModel.NodePattern{ - Variable: cypherModel.NewVariableWithSymbol(reference.Symbol), - }) - - patternPredicate.AddElement(&cypherModel.RelationshipPattern{ - Direction: graph.DirectionBoth, - }) - - patternPredicate.AddElement(&cypherModel.NodePattern{}) - - return patternPredicate -} - -func In(reference graph.Criteria, value any) *cypherModel.Comparison { - return cypherModel.NewComparison(reference, cypherModel.OperatorIn, Parameter(value)) -} - -func InInverted(reference graph.Criteria, value any) *cypherModel.Comparison { - return cypherModel.NewComparison(Parameter(value), cypherModel.OperatorIn, reference) -} - -func InIDs[T *cypherModel.FunctionInvocation | *cypherModel.Variable](reference T, ids ...graph.ID) *cypherModel.Comparison { - switch any(reference).(type) { - case *cypherModel.FunctionInvocation: - return cypherModel.NewComparison(reference, cypherModel.OperatorIn, Parameter(ids)) - - default: - return cypherModel.NewComparison(Identity(any(reference).(*cypherModel.Variable)), cypherModel.OperatorIn, Parameter(ids)) - } -} - -func Where(expression graph.Criteria) *cypherModel.Where { - whereClause := cypherModel.NewWhere() - whereClause.AddSlice(convertCriteria[cypherModel.Expression](expression)) - - return whereClause -} - -func OrderBy(leaves ...graph.Criteria) *cypherModel.Order { - return &cypherModel.Order{ - Items: convertCriteria[*cypherModel.SortItem](leaves...), - } -} - -func Order(reference, direction graph.Criteria) *cypherModel.SortItem { - switch direction { - case cypherModel.SortDescending: - return &cypherModel.SortItem{ - Ascending: false, - Expression: reference, - } - - default: - return &cypherModel.SortItem{ - Ascending: true, - Expression: reference, - } - } -} - -func Ascending() cypherModel.SortOrder { - return cypherModel.SortAscending -} - -func Descending() cypherModel.SortOrder { - return cypherModel.SortDescending -} - -func Delete(leaves ...graph.Criteria) *cypherModel.UpdatingClause { - deleteClause := &cypherModel.Delete{ - Detach: true, - } - - for _, leaf := range leaves { - switch leaf.(*cypherModel.Variable).Symbol { - case EdgeSymbol, EdgeStartSymbol, EdgeEndSymbol: - deleteClause.Detach = false - } - - deleteClause.Expressions = append(deleteClause.Expressions, leaf) - } - - return cypherModel.NewUpdatingClause(deleteClause) -} - -func NodePattern(kinds graph.Kinds, properties *cypherModel.Parameter) *cypherModel.NodePattern { - return &cypherModel.NodePattern{ - Variable: cypherModel.NewVariableWithSymbol(NodeSymbol), - Kinds: kinds, - Properties: properties, - } -} - -func StartNodePattern(kinds graph.Kinds, properties *cypherModel.Parameter) *cypherModel.NodePattern { - return &cypherModel.NodePattern{ - Variable: cypherModel.NewVariableWithSymbol(EdgeStartSymbol), - Kinds: kinds, - Properties: properties, - } -} - -func EndNodePattern(kinds graph.Kinds, properties *cypherModel.Parameter) *cypherModel.NodePattern { - return &cypherModel.NodePattern{ - Variable: cypherModel.NewVariableWithSymbol(EdgeEndSymbol), - Kinds: kinds, - Properties: properties, - } -} - -func RelationshipPattern(kind graph.Kind, properties *cypherModel.Parameter, direction graph.Direction) *cypherModel.RelationshipPattern { - return &cypherModel.RelationshipPattern{ - Variable: cypherModel.NewVariableWithSymbol(EdgeSymbol), - Kinds: graph.Kinds{kind}, - Properties: properties, - Direction: direction, - } -} - -func Create(elements ...graph.Criteria) *cypherModel.UpdatingClause { - var ( - pattern = &cypherModel.PatternPart{} - createClause = &cypherModel.Create{ - // Note: Unique is Neo4j specific and will not be supported here. Use of constraints for - // uniqueness is expected instead. - Unique: false, - Pattern: []*cypherModel.PatternPart{pattern}, - } - ) - - for _, element := range elements { - switch typedElement := element.(type) { - case *cypherModel.Variable: - switch typedElement.Symbol { - case NodeSymbol, EdgeStartSymbol, EdgeEndSymbol: - pattern.AddPatternElements(&cypherModel.NodePattern{ - Variable: cypherModel.NewVariableWithSymbol(typedElement.Symbol), - }) - - default: - createClause.AddError(fmt.Errorf("invalid variable reference create: %s", typedElement.Symbol)) - } - - case *cypherModel.NodePattern: - pattern.AddPatternElements(typedElement) - - case *cypherModel.RelationshipPattern: - pattern.AddPatternElements(typedElement) - - default: - createClause.AddError(fmt.Errorf("invalid type for create: %T", element)) - } - } - - return cypherModel.NewUpdatingClause(createClause) -} - -func ReturningDistinct(elements ...graph.Criteria) *cypherModel.Return { - returnCriteria := Returning(elements...) - returnCriteria.Projection.Distinct = true - - return returnCriteria -} - -func Returning(elements ...graph.Criteria) *cypherModel.Return { - projection := &cypherModel.Projection{} - - for _, element := range elements { - switch typedElement := element.(type) { - case *cypherModel.Order: - projection.Order = typedElement - - case *cypherModel.Limit: - projection.Limit = typedElement - - case *cypherModel.Skip: - projection.Skip = typedElement - - default: - projection.Items = append(projection.Items, &cypherModel.ProjectionItem{ - Expression: element, - }) - } - } - - return &cypherModel.Return{ - Projection: projection, - } -} - -func Size(expression graph.Criteria) *cypherModel.FunctionInvocation { - return cypherModel.NewSimpleFunctionInvocation("size", expression) -} - -func Not(expression graph.Criteria) *cypherModel.Negation { - return &cypherModel.Negation{ - Expression: &cypherModel.Parenthetical{ - Expression: expression, - }, - } -} - -func IsNull(reference graph.Criteria) *cypherModel.Comparison { - return cypherModel.NewComparison(reference, cypherModel.OperatorIs, Literal(nil)) -} - -func IsNotNull(reference graph.Criteria) *cypherModel.Comparison { - return cypherModel.NewComparison(reference, cypherModel.OperatorIsNot, Literal(nil)) -} - -func GetFirstReadingClause(query *cypherModel.RegularQuery) *cypherModel.ReadingClause { - if query.SingleQuery != nil && query.SingleQuery.SinglePartQuery != nil { - readingClauses := query.SingleQuery.SinglePartQuery.ReadingClauses - - if len(readingClauses) > 0 { - return readingClauses[0] - } - } - - return nil -} - -func SinglePartQuery(expressions ...graph.Criteria) *cypherModel.RegularQuery { - var ( - singlePartQuery = &cypherModel.SinglePartQuery{} - query = &cypherModel.RegularQuery{ - SingleQuery: &cypherModel.SingleQuery{ - SinglePartQuery: singlePartQuery, - }, - } - ) - - for _, expression := range expressions { - switch typedExpression := expression.(type) { - case *cypherModel.Where: - if firstReadingClause := GetFirstReadingClause(query); firstReadingClause != nil { - firstReadingClause.Match.Where = typedExpression - } else { - singlePartQuery.AddReadingClause(&cypherModel.ReadingClause{ - Match: &cypherModel.Match{ - Where: typedExpression, - }, - Unwind: nil, - }) - } - - case *cypherModel.Return: - singlePartQuery.Return = typedExpression - - case *cypherModel.Limit: - if singlePartQuery.Return != nil { - singlePartQuery.Return.Projection.Limit = typedExpression - } - - case *cypherModel.Skip: - if singlePartQuery.Return != nil { - singlePartQuery.Return.Projection.Skip = typedExpression - } - - case *cypherModel.Order: - if singlePartQuery.Return != nil { - singlePartQuery.Return.Projection.Order = typedExpression - } - - case *cypherModel.UpdatingClause: - singlePartQuery.AddUpdatingClause(typedExpression) - - case []*cypherModel.UpdatingClause: - for _, updatingClause := range typedExpression { - singlePartQuery.AddUpdatingClause(updatingClause) - } - - default: - singlePartQuery.AddError(fmt.Errorf("invalid type for dawgs query: %T %+v", expression, expression)) - } - } - - return query -} - -func EmptySinglePartQuery() *cypherModel.RegularQuery { - return &cypherModel.RegularQuery{ - SingleQuery: &cypherModel.SingleQuery{ - SinglePartQuery: &cypherModel.SinglePartQuery{}, - }, - } -} diff --git a/query/neo4j/neo4j.go b/query/neo4j/neo4j.go deleted file mode 100644 index 0689f74..0000000 --- a/query/neo4j/neo4j.go +++ /dev/null @@ -1,339 +0,0 @@ -package neo4j - -import ( - "bytes" - "errors" - "fmt" - - "github.com/specterops/dawgs/cypher/models/walk" - - cypherBackend "github.com/specterops/dawgs/cypher/models/cypher/format" - - "github.com/specterops/dawgs/cypher/models/cypher" - "github.com/specterops/dawgs/graph" - "github.com/specterops/dawgs/query" -) - -var ( - ErrAmbiguousQueryVariables = errors.New("query mixes node and relationship query variables") -) - -type QueryBuilder struct { - Parameters map[string]any - - query *cypher.RegularQuery - order *cypher.Order - prepared bool -} - -func NewQueryBuilder(singleQuery *cypher.RegularQuery) *QueryBuilder { - return &QueryBuilder{ - query: cypher.Copy(singleQuery), - } -} - -func NewEmptyQueryBuilder() *QueryBuilder { - return &QueryBuilder{ - query: &cypher.RegularQuery{ - SingleQuery: &cypher.SingleQuery{ - SinglePartQuery: &cypher.SinglePartQuery{}, - }, - }, - } -} - -func (s *QueryBuilder) rewriteParameters() error { - parameterRewriter := query.NewParameterRewriter() - - if err := walk.Cypher(s.query, parameterRewriter); err != nil { - return err - } - - s.Parameters = parameterRewriter.Parameters - return nil -} - -func (s *QueryBuilder) Apply(criteria graph.Criteria) { - switch typedCriteria := criteria.(type) { - case *cypher.Where: - if query.GetFirstReadingClause(s.query) == nil { - s.query.SingleQuery.SinglePartQuery.AddReadingClause(&cypher.ReadingClause{ - Match: cypher.NewMatch(false), - }) - } - - query.GetFirstReadingClause(s.query).Match.Where = cypher.Copy(typedCriteria) - - case *cypher.Return: - s.query.SingleQuery.SinglePartQuery.Return = cypher.Copy(typedCriteria) - - case *cypher.Limit: - if s.query.SingleQuery.SinglePartQuery.Return != nil { - s.query.SingleQuery.SinglePartQuery.Return.Projection.Limit = cypher.Copy(typedCriteria) - } - - case *cypher.Skip: - if s.query.SingleQuery.SinglePartQuery.Return != nil { - s.query.SingleQuery.SinglePartQuery.Return.Projection.Skip = cypher.Copy(typedCriteria) - } - - case *cypher.Order: - s.order = cypher.Copy(typedCriteria) - - case []*cypher.UpdatingClause: - for _, updatingClause := range typedCriteria { - s.Apply(updatingClause) - } - - case *cypher.UpdatingClause: - s.query.SingleQuery.SinglePartQuery.AddUpdatingClause(cypher.Copy(typedCriteria)) - - default: - panic(fmt.Sprintf("invalid type for dawgs query: %T %+v", criteria, criteria)) - } -} - -func (s *QueryBuilder) prepareMatch() error { - var ( - patternPart = &cypher.PatternPart{} - - singleNodeBound = false - creatingSingleNode = false - - startNodeBound = false - creatingStartNode = false - endNodeBound = false - creatingEndNode = false - relationshipBound = false - creatingRelationship = false - - isRelationshipQuery = false - - bindWalk = walk.NewSimpleVisitor[cypher.SyntaxNode](func(node cypher.SyntaxNode, errorHandler walk.VisitorHandler) { - switch typedNode := node.(type) { - case *cypher.Variable: - switch typedNode.Symbol { - case query.NodeSymbol: - singleNodeBound = true - - case query.EdgeStartSymbol: - startNodeBound = true - isRelationshipQuery = true - - case query.EdgeEndSymbol: - endNodeBound = true - isRelationshipQuery = true - - case query.EdgeSymbol: - relationshipBound = true - isRelationshipQuery = true - } - } - }) - ) - - // Zip through updating clauses first - for _, updatingClause := range s.query.SingleQuery.SinglePartQuery.UpdatingClauses { - typedUpdatingClause, typeOK := updatingClause.(*cypher.UpdatingClause) - - if !typeOK { - return fmt.Errorf("unexpected updating clause type %T", typedUpdatingClause) - } - - switch typedClause := typedUpdatingClause.Clause.(type) { - case *cypher.Create: - if err := walk.Cypher(typedClause, walk.NewSimpleVisitor[cypher.SyntaxNode](func(node cypher.SyntaxNode, errorHandler walk.VisitorHandler) { - switch typedElement := node.(type) { - case *cypher.NodePattern: - switch typedElement.Variable.Symbol { - case query.NodeSymbol: - creatingSingleNode = true - - case query.EdgeStartSymbol: - creatingStartNode = true - - case query.EdgeEndSymbol: - creatingEndNode = true - } - - case *cypher.RelationshipPattern: - switch typedElement.Variable.Symbol { - case query.EdgeSymbol: - creatingRelationship = true - } - } - })); err != nil { - return err - } - - case *cypher.Delete: - if err := walk.Cypher(typedClause, bindWalk); err != nil { - return err - } - } - } - - // Is there a where clause? - if firstReadingClause := query.GetFirstReadingClause(s.query); firstReadingClause != nil && firstReadingClause.Match.Where != nil { - if err := walk.Cypher(firstReadingClause.Match.Where, bindWalk); err != nil { - return err - } - } - - // Is there a return clause - if spqReturn := s.query.SingleQuery.SinglePartQuery.Return; spqReturn != nil && spqReturn.Projection != nil { - // Did we have an order specified? - if s.order != nil { - if spqReturn.Projection.Order != nil { - return fmt.Errorf("order specified twice") - } - - s.query.SingleQuery.SinglePartQuery.Return.Projection.Order = s.order - } - - if err := walk.Cypher(s.query.SingleQuery.SinglePartQuery.Return, bindWalk); err != nil { - return err - } - } - - // Validate we're not mixing references - if isRelationshipQuery && singleNodeBound { - return ErrAmbiguousQueryVariables - } - - if singleNodeBound && !creatingSingleNode { - patternPart.AddPatternElements(&cypher.NodePattern{ - Variable: cypher.NewVariableWithSymbol(query.NodeSymbol), - }) - } - - if startNodeBound { - patternPart.AddPatternElements(&cypher.NodePattern{ - Variable: cypher.NewVariableWithSymbol(query.EdgeStartSymbol), - }) - } - - if isRelationshipQuery { - if !startNodeBound && !creatingStartNode { - patternPart.AddPatternElements(&cypher.NodePattern{}) - } - - if !creatingRelationship { - if relationshipBound { - patternPart.AddPatternElements(&cypher.RelationshipPattern{ - Variable: cypher.NewVariableWithSymbol(query.EdgeSymbol), - Direction: graph.DirectionOutbound, - }) - } else { - patternPart.AddPatternElements(&cypher.RelationshipPattern{ - Direction: graph.DirectionOutbound, - }) - } - } - - if !endNodeBound && !creatingEndNode { - patternPart.AddPatternElements(&cypher.NodePattern{}) - } - } - - if endNodeBound { - patternPart.AddPatternElements(&cypher.NodePattern{ - Variable: cypher.NewVariableWithSymbol(query.EdgeEndSymbol), - }) - } - - if firstReadingClause := query.GetFirstReadingClause(s.query); firstReadingClause != nil { - firstReadingClause.Match.Pattern = []*cypher.PatternPart{patternPart} - } else if len(patternPart.PatternElements) > 0 { - s.query.SingleQuery.SinglePartQuery.AddReadingClause(&cypher.ReadingClause{ - Match: &cypher.Match{ - Pattern: []*cypher.PatternPart{ - patternPart, - }, - }, - }) - } - - return nil -} - -func (s *QueryBuilder) compilationErrors() error { - var modelErrors []error - - walk.Cypher(s.query, walk.NewSimpleVisitor[cypher.SyntaxNode](func(node cypher.SyntaxNode, errorHandler walk.VisitorHandler) { - if errorNode, typeOK := node.(cypher.Fallible); typeOK { - if len(errorNode.Errors()) > 0 { - modelErrors = append(modelErrors, errorNode.Errors()...) - } - } - })) - - return errors.Join(modelErrors...) -} - -func (s *QueryBuilder) Prepare() error { - if s.prepared { - return nil - } - - s.prepared = true - - if s.query.SingleQuery.SinglePartQuery == nil { - return fmt.Errorf("single part query is nil") - } - - if err := s.compilationErrors(); err != nil { - return err - } - - if err := s.prepareMatch(); err != nil { - return err - } - - if err := s.rewriteParameters(); err != nil { - return err - } - - return walk.Cypher(s.query, NewExpressionListRewriter()) -} - -func (s *QueryBuilder) PrepareAllShortestPaths() error { - if err := s.Prepare(); err != nil { - return err - } else { - firstReadingClause := query.GetFirstReadingClause(s.query) - - // Set all pattern parts to search for the shortest paths and bind them - if len(firstReadingClause.Match.Pattern) > 1 { - return fmt.Errorf("only expected one pattern") - } - - // Grab the first pattern part - patternPart := firstReadingClause.Match.Pattern[0] - - // Bind the path - patternPart.Variable = cypher.NewVariableWithSymbol(query.PathSymbol) - - // Set the pattern to search for all shortest paths - patternPart.AllShortestPathsPattern = true - - // Update all relationship PatternElements to expand fully (*..) - for _, patternElement := range patternPart.PatternElements { - if relationshipPattern, isRelationshipPattern := patternElement.AsRelationshipPattern(); isRelationshipPattern { - relationshipPattern.Range = &cypher.PatternRange{} - } - } - - return nil - } -} - -func (s *QueryBuilder) Render() (string, error) { - buffer := &bytes.Buffer{} - - if err := cypherBackend.NewCypherEmitter(false).Write(s.query, buffer); err != nil { - return "", err - } else { - return buffer.String(), nil - } -} diff --git a/query/neo4j/neo4j_test.go b/query/neo4j/neo4j_test.go deleted file mode 100644 index 2efc343..0000000 --- a/query/neo4j/neo4j_test.go +++ /dev/null @@ -1,1100 +0,0 @@ -package neo4j_test - -import ( - "fmt" - "testing" - "time" - - "github.com/specterops/dawgs/cypher/models/cypher" - "github.com/specterops/dawgs/query" - "github.com/specterops/dawgs/query/neo4j" - - "github.com/specterops/dawgs/graph" - "github.com/stretchr/testify/require" -) - -var ( - SystemTags = "system_tags" - - User = graph.StringKind("User") - Domain = graph.StringKind("Domain") - Computer = graph.StringKind("Computer") - Group = graph.StringKind("Group") - HasSession = graph.StringKind("HasSession") - GenericWrite = graph.StringKind("GenericWrite") -) - -type QueryOutputAssertion struct { - Query string - Parameters map[string]any -} - -func expectAnalysisError(rawQuery *cypher.RegularQuery) func(t *testing.T) { - return func(t *testing.T) { - require.NotNil(t, neo4j.NewQueryBuilder(rawQuery).Prepare()) - } -} - -func assertQueryShortestPathResult(rawQuery *cypher.RegularQuery, expectedOutput string, expectedParameters ...map[string]any) func(t *testing.T) { - return func(t *testing.T) { - builder := neo4j.NewQueryBuilder(rawQuery) - - // Validate that building the query didn't throw an error - require.Nil(t, builder.PrepareAllShortestPaths()) - - if len(expectedParameters) == 1 { - require.Equal(t, expectedParameters[0], builder.Parameters) - } - - output, err := builder.Render() - - require.Nil(t, err) - require.Equal(t, expectedOutput, output) - } -} - -func assertQueryResult(rawQuery *cypher.RegularQuery, expectedOutput string, expectedParameters ...map[string]any) func(t *testing.T) { - return func(t *testing.T) { - var ( - builder = neo4j.NewQueryBuilder(rawQuery) - prepareErr = builder.Prepare() - ) - - // Validate that building the query didn't throw an error - if prepareErr != nil { - require.Nilf(t, prepareErr, prepareErr.Error()) - } - - if len(expectedParameters) == 1 { - require.Equal(t, expectedParameters[0], builder.Parameters) - } - - output, err := builder.Render() - - require.Nil(t, err) - require.Equal(t, expectedOutput, output) - } -} - -func assertOneOfQueryResult(rawQuery *cypher.RegularQuery, expectations []QueryOutputAssertion) func(t *testing.T) { - return func(t *testing.T) { - builder := neo4j.NewQueryBuilder(rawQuery) - - // Validate that building the query didn't throw an error - require.Nil(t, builder.Prepare()) - - output, err := builder.Render() - require.Nil(t, err) - - var matchingExpectation *QueryOutputAssertion - - for _, expectation := range expectations { - if expectation.Query == output { - matchingExpectation = &expectation - break - } - } - - if matchingExpectation == nil { - msg := fmt.Sprintf("Rendered query did not match any given options.\nActual:\n\t%s\nExpected one of: ", output) - - for _, expectation := range expectations { - msg += "\n\t" + expectation.Query - } - - t.Fatal(msg) - } else if matchingExpectation.Parameters != nil { - require.Equal(t, matchingExpectation.Parameters, builder.Parameters) - } - } -} - -func TestQueryBuilder_RenderShortestPaths(t *testing.T) { - t.Run("Shortest Paths with Unbound Relationship", assertQueryShortestPathResult(query.SinglePartQuery( - query.Where( - query.And( - query.Equals(query.StartProperty("objectid"), "12345"), - query.KindIn(query.Start(), graph.StringKind("A"), graph.StringKind("B")), - - query.Equals(query.EndProperty("objectid"), "56789"), - query.KindIn(query.End(), graph.StringKind("B")), - ), - ), - - query.Returning( - query.Path(), - ), - ), "match p = allShortestPaths((s)-[*]->(e)) where s.objectid = $p0 and (s:A or s:B) and e.objectid = $p1 and e:B return p", map[string]any{ - "p0": "12345", - "p1": "56789", - })) - - t.Run("Shortest Paths with Bound Relationship", assertQueryShortestPathResult(query.SinglePartQuery( - query.Where( - query.And( - query.Equals(query.StartProperty("objectid"), "12345"), - query.KindIn(query.Start(), graph.StringKind("A"), graph.StringKind("B")), - query.KindIn(query.Relationship(), graph.StringKind("R1"), graph.StringKind("R2")), - query.Equals(query.EndProperty("objectid"), "56789"), - query.KindIn(query.End(), graph.StringKind("B")), - ), - ), - - query.Returning( - query.Path(), - ), - ), "match p = allShortestPaths((s)-[r:R1|R2*]->(e)) where s.objectid = $p0 and (s:A or s:B) and e.objectid = $p1 and e:B return p", map[string]any{ - "p0": "12345", - "p1": "56789", - })) -} - -func TestQueryBuilder_Render(t *testing.T) { - // Node Queries - t.Run("Node Count", assertQueryResult(query.SinglePartQuery( - query.Where( - query.In(query.NodeID(), []graph.ID{1, 2, 3, 4}), - ), - - query.Returning( - query.Count(query.Node()), - ), - - query.Limit(10), - query.Offset(20), - ), "match (n) where id(n) in $p0 return count(n) skip 20 limit 10", map[string]any{ - "p0": []graph.ID{1, 2, 3, 4}, - })) - - t.Run("Node Item", assertQueryResult(query.SinglePartQuery( - query.Where( - query.In(query.NodeProperty("prop"), []int{1, 2, 3, 4}), - ), - - query.Returning( - query.Count(query.Node()), - ), - ), "match (n) where n.prop in $p0 return count(n)")) - - // TODO: Revisit parameter reuse - // - //reusedLiteral := query3.Literal([]int{1, 2, 3, 4}) - // - //t.Run("Node Item with Reused Literal", assertQueryResult(query3.Query( - // query3.Where( - // query3.And( - // query3.In(query3.NodeProperty("prop"), reusedLiteral), - // query3.In(query3.NodeProperty("other_prop"), reusedLiteral), - // ), - // ), - // - // query3.Returning( - // query3.Count(query3.Node()), - // ), - //), "match (n) where n.prop in $p0 and n.other_prop in $p0 return count(n)")) - - t.Run("Distinct Item", assertQueryResult(query.SinglePartQuery( - query.Where( - query.In(query.NodeProperty("prop"), []int{1, 2, 3, 4}), - ), - - query.ReturningDistinct( - query.NodeProperty("prop"), - ), - ), "match (n) where n.prop in $p0 return distinct n.prop")) - - t.Run("Count Distinct Item", assertQueryResult(query.SinglePartQuery( - query.Where( - query.In(query.NodeProperty("prop"), []int{1, 2, 3, 4}), - ), - - query.Returning( - query.CountDistinct(query.NodeProperty("prop")), - ), - ), "match (n) where n.prop in $p0 return count(distinct n.prop)")) - - t.Run("Set Node Labels", assertQueryResult(query.SinglePartQuery( - query.Where( - query.In(query.NodeProperty("prop"), []int{1, 2, 3, 4}), - ), - - query.Update( - query.AddKind(query.Node(), Domain), - query.AddKind(query.Node(), User), - ), - - query.Returning( - query.Count(query.Node()), - ), - ), "match (n) where n.prop in $p0 set n:Domain set n:User return count(n)")) - - t.Run("Remove Node Labels", assertQueryResult(query.SinglePartQuery( - query.Where( - query.In(query.NodeProperty("prop"), []int{1, 2, 3, 4}), - ), - - query.Update( - query.DeleteKind(query.Node(), Domain), - query.DeleteKind(query.Node(), User), - ), - - query.Returning( - query.Count(query.Node()), - ), - ), "match (n) where n.prop in $p0 remove n:Domain remove n:User return count(n)")) - - t.Run("Multiple Node ID References", assertQueryResult(query.SinglePartQuery( - query.Where( - query.And( - query.Equals(query.NodeProperty("name"), "name"), - query.In(query.NodeID(), []int{1, 2, 3, 4}), - ), - ), - - query.Returning( - query.Identity(query.Node()), - query.Property(query.Node(), "value"), - ), - - query.Limit(10), - query.Offset(20), - ), "match (n) where n.name = $p0 and id(n) in $p1 return id(n), n.value skip 20 limit 10")) - - // Create node - t.Run("Create Node", assertQueryResult(query.SinglePartQuery( - query.Create( - query.NodePattern( - graph.Kinds{Domain, Computer}, - query.Parameter(map[string]any{ - "prop1": 1234, - }), - ), - ), - - query.Returning( - query.Identity(query.Node()), - ), - ), - "create (n:Domain:Computer $p0) return id(n)", - map[string]any{ - "p0": map[string]any{ - "prop1": 1234, - }, - }, - )) - - // Set with node - - t.Run("DeleteProperty with Multiple Node ID References", assertQueryResult(query.SinglePartQuery( - query.Where( - query.And( - query.Equals(query.NodeProperty("name"), "name"), - query.In(query.NodeID(), []int{1, 2, 3, 4}), - ), - ), - - query.Update( - query.DeleteProperty(query.NodeProperty("other")), - query.DeleteProperty(query.NodeProperty("other2")), - ), - - query.Returning( - query.Identity(query.Node()), - query.Property(query.Node(), "value"), - ), - - query.Limit(10), - query.Offset(20), - ), "match (n) where n.name = $p0 and id(n) in $p1 remove n.other remove n.other2 return id(n), n.value skip 20 limit 10")) - - properties := graph.NewProperties() - properties.Set("test_1", "value_1") - properties.Set("test_2", "value_2") - - t.Run("Set from Map", assertOneOfQueryResult(query.SinglePartQuery( - query.Where( - query.Equals(query.NodeProperty("objectid"), "12345"), - ), - - query.Update( - query.SetProperties(query.Node(), properties.ModifiedProperties()), - ), - ), []QueryOutputAssertion{ - { - Query: "match (n) where n.objectid = $p0 set n.test_1 = $p1, n.test_2 = $p2", - Parameters: map[string]any{ - "p0": "12345", - "p1": "value_1", - "p2": "value_2", - }, - }, - { - Query: "match (n) where n.objectid = $p0 set n.test_2 = $p1, n.test_1 = $p2", - Parameters: map[string]any{ - "p0": "12345", - "p1": "value_2", - "p2": "value_1", - }, - }, - })) - - properties.Delete("test_1") - properties.Delete("test_2") - - t.Run("DeleteProperty from Map", assertOneOfQueryResult(query.SinglePartQuery( - query.Where( - query.Equals(query.NodeProperty("objectid"), "12345"), - ), - - query.Update( - query.DeleteProperties(query.Node(), properties.DeletedProperties()...), - ), - ), []QueryOutputAssertion{ - { - Query: "match (n) where n.objectid = $p0 remove n.test_2, n.test_1", - Parameters: map[string]any{ - "p0": "12345", - }, - }, - { - Query: "match (n) where n.objectid = $p0 remove n.test_1, n.test_2", - Parameters: map[string]any{ - "p0": "12345", - }, - }, - })) - - t.Run("Set with Multiple Node ID References", assertQueryResult(query.SinglePartQuery( - query.Where( - query.And( - query.Equals(query.NodeProperty("name"), "name"), - query.In(query.NodeID(), []int{1, 2, 3, 4}), - ), - ), - - query.Update( - query.SetProperty(query.NodeProperty("other"), "value"), - ), - - query.Returning( - query.Identity(query.Node()), - query.Property(query.Node(), "value"), - ), - - query.Limit(10), - query.Offset(20), - ), "match (n) where n.name = $p0 and id(n) in $p1 set n.other = $p2 return id(n), n.value skip 20 limit 10")) - - updatedNode := graph.NewNode(graph.ID(1), graph.NewProperties(), User, Domain, Computer) - updatedNode.Properties.Set("test_1", "value_1") - updatedNode.Properties.Delete("test_2") - - t.Run("Node Set and Remove Multiple Kinds and Properties", assertQueryResult(query.SinglePartQuery( - query.Where( - query.Equals(query.NodeID(), updatedNode.ID), - ), - - query.Updatef(func() graph.Criteria { - var ( - properties = updatedNode.Properties - updateStatements = []graph.Criteria{ - query.AddKinds(query.Node(), updatedNode.Kinds), - } - ) - - if modifiedProperties := properties.ModifiedProperties(); len(modifiedProperties) > 0 { - updateStatements = append(updateStatements, query.SetProperties(query.Node(), modifiedProperties)) - } - - if deletedProperties := properties.DeletedProperties(); len(deletedProperties) > 0 { - updateStatements = append(updateStatements, query.DeleteProperties(query.Node(), deletedProperties...)) - } - - return updateStatements - }), - ), "match (n) where id(n) = $p0 set n:User:Domain:Computer set n.test_1 = $p1 remove n.test_2")) - - t.Run("Node has Relationships", assertQueryResult(query.SinglePartQuery( - query.Where( - query.HasRelationships(query.Node()), - ), - - query.Returning( - query.Node(), - ), - ), "match (n) where (n)<-[]->() return n")) - - t.Run("Node has Relationships Order by Node Item", assertQueryResult(query.SinglePartQuery( - query.Where( - query.HasRelationships(query.Node()), - ), - - query.Returning( - query.Node(), - ), - - query.OrderBy( - query.Order(query.NodeProperty("value"), query.Ascending()), - ), - ), "match (n) where (n)<-[]->() return n order by n.value asc")) - - t.Run("Node has Relationships Order by Node Item", assertQueryResult(query.SinglePartQuery( - query.Where( - query.HasRelationships(query.Node()), - ), - - query.Returning( - query.Node(), - ), - - query.OrderBy( - query.Order(query.NodeProperty("value_1"), query.Ascending()), - query.Order(query.NodeProperty("value_2"), query.Descending()), - ), - ), "match (n) where (n)<-[]->() return n order by n.value_1 asc, n.value_2 desc")) - - t.Run("Node has Relationships Order by Node Item with Limit and Offset", assertQueryResult(query.SinglePartQuery( - query.Where( - query.HasRelationships(query.Node()), - ), - - query.Returning( - query.Node(), - ), - - query.OrderBy( - query.Order(query.NodeProperty("value_1"), query.Ascending()), - query.Order(query.NodeProperty("value_2"), query.Descending()), - ), - - query.Limit(10), - query.Offset(20), - ), "match (n) where (n)<-[]->() return n order by n.value_1 asc, n.value_2 desc skip 20 limit 10")) - - t.Run("Node has no Relationships", assertQueryResult(query.SinglePartQuery( - query.Where( - query.Not(query.HasRelationships(query.Node())), - ), - - query.Returning( - query.Node(), - ), - ), "match (n) where not ((n)<-[]->()) return n")) - - t.Run("Node Datetime Before", assertQueryResult(query.SinglePartQuery( - query.Where( - query.And( - query.Before(query.NodeProperty("lastseen"), time.Now().UTC()), - query.In(query.NodeID(), []int{1, 2, 3, 4}), - ), - ), - - query.Returning( - query.Node(), - ), - ), "match (n) where n.lastseen < $p0 and id(n) in $p1 return n")) - - t.Run("Node Datetime Before or Equal to", assertQueryResult(query.SinglePartQuery( - query.Where( - query.And( - query.LessThanOrEquals(query.NodeProperty("lastseen"), time.Now().UTC()), - query.In(query.NodeID(), []int{1, 2, 3, 4}), - ), - ), - - query.Returning( - query.Node(), - ), - ), "match (n) where n.lastseen <= $p0 and id(n) in $p1 return n")) - - t.Run("Node Datetime After", assertQueryResult(query.SinglePartQuery( - query.Where( - query.And( - query.After(query.NodeProperty("lastseen"), time.Now().UTC()), - query.In(query.NodeID(), []int{1, 2, 3, 4}), - ), - ), - - query.Returning( - query.Node(), - ), - ), "match (n) where n.lastseen > $p0 and id(n) in $p1 return n")) - - t.Run("Node Datetime After or Equal to", assertQueryResult(query.SinglePartQuery( - query.Where( - query.And( - query.GreaterThanOrEquals(query.NodeProperty("lastseen"), time.Now().UTC()), - query.In(query.NodeID(), []int{1, 2, 3, 4}), - ), - ), - - query.Returning( - query.Node(), - ), - ), "match (n) where n.lastseen >= $p0 and id(n) in $p1 return n")) - - t.Run("Node PropertyExists", assertQueryResult(query.SinglePartQuery( - query.Where( - query.And( - query.Exists(query.NodeProperty("lastseen")), - query.In(query.NodeID(), []int{1, 2, 3, 4}), - ), - ), - - query.Returning( - query.Node(), - ), - ), "match (n) where n.lastseen is not null and id(n) in $p0 return n")) - - t.Run("Select Node Kinds", assertQueryResult(query.SinglePartQuery( - query.Where( - query.And( - query.Kind(query.Node(), Domain), - ), - ), - - query.Returning( - query.KindsOf(query.Node()), - ), - ), "match (n) where n:Domain return labels(n)")) - - t.Run("Select Node ID and Kinds", assertQueryResult(query.SinglePartQuery( - query.Where( - query.And( - query.Kind(query.Node(), Domain), - ), - ), - - query.Returning( - query.NodeID(), - query.KindsOf(query.Node()), - ), - ), "match (n) where n:Domain return id(n), labels(n)")) - - t.Run("Node Kind Match", assertQueryResult(query.SinglePartQuery( - query.Where( - query.And( - query.Kind(query.Node(), Domain), - ), - ), - - query.Returning( - query.Node(), - ), - ), "match (n) where n:Domain return n")) - - t.Run("Node Kind In", assertQueryResult(query.SinglePartQuery( - query.Where( - query.And( - query.KindIn(query.Node(), Domain, User, Group), - ), - ), - - query.Returning( - query.Node(), - ), - ), "match (n) where (n:Domain or n:User or n:Group) return n")) - - t.Run("Node String Item Contains", assertQueryResult(query.SinglePartQuery( - query.Where( - query.StringContains(query.NodeProperty("tags"), "tag_1"), - ), - - query.Returning( - query.Node(), - ), - ), "match (n) where n.tags contains $p0 return n")) - - t.Run("Node String Item Starts With", assertQueryResult(query.SinglePartQuery( - query.Where( - query.StringStartsWith(query.NodeProperty("tags"), "tag_1"), - ), - - query.Returning( - query.Node(), - ), - ), "match (n) where n.tags starts with $p0 return n")) - - t.Run("Node String Item Ends With", assertQueryResult(query.SinglePartQuery( - query.Where( - query.StringEndsWith(query.NodeProperty("tags"), "tag_1"), - ), - - query.Returning( - query.Node(), - ), - ), "match (n) where n.tags ends with $p0 return n")) - - t.Run("Node String Item Case Insensitive Contains", assertQueryResult(query.SinglePartQuery( - query.Where( - query.CaseInsensitiveStringContains(query.NodeProperty("tags"), "tag_1"), - ), - - query.Returning( - query.Node(), - ), - ), "match (n) where toLower(n.tags) contains $p0 return n")) - - t.Run("Node String Item Case Insensitive Starts With", assertQueryResult(query.SinglePartQuery( - query.Where( - query.CaseInsensitiveStringStartsWith(query.NodeProperty("tags"), "tag_1"), - ), - - query.Returning( - query.Node(), - ), - ), "match (n) where toLower(n.tags) starts with $p0 return n")) - - t.Run("Node String Item Case Insensitive Ends With", assertQueryResult(query.SinglePartQuery( - query.Where( - query.CaseInsensitiveStringEndsWith(query.NodeProperty("tags"), "tag_1"), - ), - - query.Returning( - query.Node(), - ), - ), "match (n) where toLower(n.tags) ends with $p0 return n")) - - t.Run("Node Delete", assertQueryResult(query.SinglePartQuery( - query.Where( - query.In(query.Node(), []graph.ID{1, 2, 3}), - ), - - query.Delete( - query.Node(), - ), - ), "match (n) where n in $p0 detach delete n")) - - // Relationship Queries - t.Run("Empty Relationship Query", assertQueryResult(query.SinglePartQuery( - query.Returning( - query.RelationshipID(), - ), - ), "match ()-[r]->() return id(r)")) - - t.Run("Empty Start Node Query", assertQueryResult(query.SinglePartQuery( - query.Returning( - query.StartID(), - ), - ), "match (s)-[]->() return id(s)")) - - t.Run("Empty End Node Query", assertQueryResult(query.SinglePartQuery( - query.Returning( - query.EndID(), - ), - ), "match ()-[]->(e) return id(e)")) - - t.Run("Returning Relationship Kind Query", assertQueryResult(query.SinglePartQuery( - query.Returning( - query.RelationshipID(), - query.KindsOf(query.Relationship()), - ), - ), "match ()-[r]->() return id(r), type(r)")) - - t.Run("Returning Start and Relationship Kind Query", assertQueryResult(query.SinglePartQuery( - query.Returning( - query.RelationshipID(), - query.KindsOf(query.Relationship()), - query.KindsOf(query.Start()), - ), - ), "match (s)-[r]->() return id(r), type(r), labels(s)")) - - t.Run("Returning End and Relationship Kind Query", assertQueryResult(query.SinglePartQuery( - query.Returning( - query.RelationshipID(), - query.KindsOf(query.Relationship()), - query.KindsOf(query.End()), - ), - ), "match ()-[r]->(e) return id(r), type(r), labels(e)")) - - t.Run("Returning Start, End and Relationship Kind Query", assertQueryResult(query.SinglePartQuery( - query.Returning( - query.RelationshipID(), - query.KindsOf(query.Relationship()), - query.KindsOf(query.Start()), - query.KindsOf(query.End()), - ), - ), "match (s)-[r]->(e) return id(r), type(r), labels(s), labels(e)")) - - t.Run("Relationship Item and ID References", assertQueryResult(query.SinglePartQuery( - query.Where( - query.And( - query.Equals(query.RelationshipProperty("name"), "name"), - query.In(query.RelationshipID(), []int{1, 2, 3, 4}), - ), - ), - - query.Returning( - query.RelationshipID(), - query.Property(query.Relationship(), "value"), - ), - - query.Offset(20), - ), "match ()-[r]->() where r.name = $p0 and id(r) in $p1 return id(r), r.value skip 20")) - - t.Run("Relationship Select Start References", assertQueryResult(query.SinglePartQuery( - query.Where( - query.And( - query.Equals(query.RelationshipProperty("name"), "name"), - query.In(query.RelationshipID(), []int{1, 2, 3, 4}), - ), - ), - - query.Returning( - query.StartID(), - query.Property(query.Relationship(), "value"), - ), - - query.Offset(20), - ), "match (s)-[r]->() where r.name = $p0 and id(r) in $p1 return id(s), r.value skip 20")) - - t.Run("Relationship Start Node ID Reference", assertQueryResult(query.SinglePartQuery( - query.Where( - query.And( - query.Equals(query.StartID(), 1), - query.Equals(query.RelationshipProperty("name"), "name"), - query.In(query.RelationshipID(), []int{1, 2, 3, 4}), - ), - ), - - query.Returning( - query.RelationshipID(), - query.Property(query.Relationship(), "value"), - ), - - query.Offset(20), - ), "match (s)-[r]->() where id(s) = $p0 and r.name = $p1 and id(r) in $p2 return id(r), r.value skip 20")) - - t.Run("Relationship End Node ID Reference", assertQueryResult(query.SinglePartQuery( - query.Where( - query.And( - query.Equals(query.EndID(), 1), - query.Equals(query.RelationshipProperty("name"), "name"), - query.In(query.RelationshipID(), []int{1, 2, 3, 4}), - ), - ), - - query.Returning( - query.RelationshipID(), - query.Property(query.Relationship(), "value"), - ), - - query.Offset(20), - ), "match ()-[r]->(e) where id(e) = $p0 and r.name = $p1 and id(r) in $p2 return id(r), r.value skip 20")) - - t.Run("Relationship Start and End Node ID References", assertQueryResult(query.SinglePartQuery( - query.Where( - query.And( - query.Equals(query.StartID(), 1), - query.Equals(query.EndID(), 1), - query.Equals(query.RelationshipProperty("name"), "name"), - query.In(query.RelationshipID(), []int{1, 2, 3, 4}), - ), - ), - - query.Returning( - query.RelationshipID(), - query.Property(query.Relationship(), "value"), - ), - ), "match (s)-[r]->(e) where id(s) = $p0 and id(e) = $p1 and r.name = $p2 and id(r) in $p3 return id(r), r.value")) - - t.Run("Relationship Kind Match without Joining Expression", assertQueryResult(query.SinglePartQuery( - query.Where( - query.KindIn(query.Relationship(), Domain, User, GenericWrite, HasSession), - ), - - query.Returning( - query.RelationshipID(), - query.Property(query.Relationship(), "value"), - ), - ), "match ()-[r:Domain|User|GenericWrite|HasSession]->() return id(r), r.value")) - - t.Run("Relationship Kind Match", assertQueryResult(query.SinglePartQuery( - query.Where( - query.And( - query.Equals(query.StartID(), 1), - query.KindIn(query.Relationship(), HasSession), - query.Equals(query.RelationshipProperty("name"), "name"), - query.In(query.RelationshipID(), []int{1, 2, 3, 4}), - ), - ), - - query.Returning( - query.RelationshipID(), - query.Property(query.Relationship(), "value"), - ), - ), "match (s)-[r:HasSession]->() where id(s) = $p0 and r.name = $p1 and id(r) in $p2 return id(r), r.value")) - - updatedRelationship := graph.NewRelationship(graph.ID(1), graph.ID(2), graph.ID(3), graph.NewProperties(), HasSession) - updatedRelationship.Properties.Set("test_1", "value_1") - updatedRelationship.Properties.Delete("test_2") - - t.Run("Relationship Set and Remove Multiple Kinds and Properties", assertQueryResult(query.SinglePartQuery( - query.Where( - query.Equals(query.RelationshipID(), updatedRelationship.ID), - ), - - query.Updatef(func() graph.Criteria { - var ( - properties = updatedRelationship.Properties - updateStatements []graph.Criteria - ) - - if modifiedProperties := properties.ModifiedProperties(); len(modifiedProperties) > 0 { - updateStatements = append(updateStatements, query.SetProperties(query.Relationship(), modifiedProperties)) - } - - if deletedProperties := properties.DeletedProperties(); len(deletedProperties) > 0 { - updateStatements = append(updateStatements, query.DeleteProperties(query.Relationship(), deletedProperties...)) - } - - return updateStatements - }), - ), "match ()-[r]->() where id(r) = $p0 set r.test_1 = $p1 remove r.test_2")) - - t.Run("Relationship Kind Match in", assertQueryResult(query.SinglePartQuery( - query.Where( - query.And( - query.Equals(query.StartID(), 1), - query.KindIn(query.Relationship(), HasSession, GenericWrite), - query.Equals(query.RelationshipProperty("name"), "name"), - query.In(query.RelationshipID(), []int{1, 2, 3, 4}), - ), - ), - - query.Returning( - query.RelationshipID(), - query.Property(query.Relationship(), "value"), - ), - ), "match (s)-[r:HasSession|GenericWrite]->() where id(s) = $p0 and r.name = $p1 and id(r) in $p2 return id(r), r.value")) - - t.Run("Relationship Kind Match in and Start Node Kind Match in", assertQueryResult(query.SinglePartQuery( - query.Where( - query.And( - query.KindIn(query.Start(), User, Computer), - query.KindIn(query.Relationship(), HasSession, GenericWrite), - query.Equals(query.RelationshipProperty("name"), "name"), - query.In(query.RelationshipID(), []int{1, 2, 3, 4}), - ), - ), - - query.Returning( - query.RelationshipID(), - query.Property(query.Relationship(), "value"), - ), - ), "match (s)-[r:HasSession|GenericWrite]->() where (s:User or s:Computer) and r.name = $p0 and id(r) in $p1 return id(r), r.value")) - - t.Run("Relationship Kind Match in and Delete Start Node and Relationship", assertQueryResult(query.SinglePartQuery( - query.Where( - query.And( - query.KindIn(query.Relationship(), HasSession, GenericWrite), - ), - ), - - query.Delete( - query.Start(), - query.Relationship(), - ), - ), "match (s)-[r:HasSession|GenericWrite]->() delete s, r")) - - t.Run("Relationship Kind Match in and Delete Start Node and Relationship Returning Count Relationships Deleted", assertQueryResult(query.SinglePartQuery( - query.Where( - query.And( - query.KindIn(query.Relationship(), HasSession, GenericWrite), - ), - ), - - query.Delete( - query.Start(), - query.Relationship(), - ), - - query.Returning( - query.Count(query.Relationship()), - ), - ), "match (s)-[r:HasSession|GenericWrite]->() delete s, r return count(r)")) - - t.Run("Create Relationship", assertQueryResult(query.SinglePartQuery( - query.Create( - query.StartNodePattern( - graph.Kinds{Computer}, - query.Parameter(map[string]any{ - "prop1": 1234, - }), - ), - query.RelationshipPattern( - HasSession, - query.Parameter(map[string]any{ - "prop1": 1234, - }), - graph.DirectionOutbound, - ), - query.EndNodePattern( - graph.Kinds{User}, - query.Parameter(map[string]any{ - "prop1": 1234, - }), - ), - ), - - query.Returning( - query.Identity(query.Relationship()), - ), - ), - "create (s:Computer $p0)-[r:HasSession $p1]->(e:User $p2) return id(r)", - map[string]any{ - "p0": map[string]any{ - "prop1": 1234, - }, - "p1": map[string]any{ - "prop1": 1234, - }, - "p2": map[string]any{ - "prop1": 1234, - }, - }, - )) - - t.Run("Create Relationship with Match", assertQueryResult(query.SinglePartQuery( - query.Where( - query.And( - query.Equals(query.StartID(), 1), - query.Equals(query.EndID(), 2), - ), - ), - - query.Create( - query.Start(), - query.RelationshipPattern( - HasSession, - query.Parameter(map[string]any{ - "prop1": 1234, - }), - graph.DirectionOutbound, - ), - query.End(), - ), - - query.Returning( - query.Identity(query.Relationship()), - ), - ), - "match (s), (e) where id(s) = $p0 and id(e) = $p1 create (s)-[r:HasSession $p2]->(e) return id(r)", - map[string]any{ - "p0": 1, - "p1": 2, - "p2": map[string]any{ - "prop1": 1234, - }, - }, - )) - - t.Run("Not String Contains Operator Rewrite", assertQueryResult(query.SinglePartQuery( - query.Where( - query.Not( - query.StringContains(query.Property(query.Node(), SystemTags), "admin_tier_0"), - ), - ), - - query.Returning( - query.Count(query.Node()), - ), - ), "match (n) where (not (n.system_tags contains $p0) or n.system_tags is null) return count(n)")) - - t.Run("Is Not Null", assertQueryResult(query.SinglePartQuery( - query.Where( - query.IsNotNull( - query.Property(query.Node(), SystemTags), - ), - ), - query.Returning( - query.Count(query.Node()), - )), - "match (n) where n.system_tags is not null return count(n)")) - - t.Run("Is Null", assertQueryResult(query.SinglePartQuery( - query.Where( - query.IsNull( - query.Property(query.Node(), SystemTags), - ), - ), - query.Returning( - query.Count(query.Node()), - )), - "match (n) where n.system_tags is null return count(n)")) -} - -func TestQueryBuilder_Analyze(t *testing.T) { - // Don't allow node query references to intermingle with relationship query references - t.Run("Should Reject Mixing Query Type References", expectAnalysisError(query.SinglePartQuery( - query.Where( - query.And( - query.Equals(query.NodeID(), 1), - query.Equals(query.Property(query.Relationship(), "name"), "name"), - query.In(query.RelationshipID(), []int{1, 2, 3, 4}), - ), - ), - - query.Returning( - query.RelationshipID(), - query.Property(query.Relationship(), "value"), - ), - - query.Offset(20), - ))) - - t.Run("Should Reject Mixing Query Type References", expectAnalysisError(query.SinglePartQuery( - query.Where( - query.And( - query.Equals(query.NodeID(), 1), - query.Equals(query.Property(query.Relationship(), "name"), "name"), - query.In(query.RelationshipID(), []int{1, 2, 3, 4}), - ), - ), - - query.Returning( - query.RelationshipID(), - query.Property(query.Relationship(), "value"), - ), - - query.Offset(20), - ))) - - t.Run("Should fail on bad query criteria", expectAnalysisError(query.SinglePartQuery( - query.Node(), - ))) - - t.Run("Should fail on bad create criteria", expectAnalysisError(query.SinglePartQuery( - query.Create( - query.Where( - query.And(), - ), - ), - ))) - - t.Run("Should fail on bad variable reference types for KindOf", expectAnalysisError(query.SinglePartQuery( - query.Where( - query.KindsOf( - query.Create(), - ), - ), - ))) -} - -func Test_FormatCypherOrder(t *testing.T) { - var ( - sortItems = query.SortItems{ - {SortCriteria: query.NodeID(), Direction: query.SortDirectionAscending}, - {SortCriteria: query.Node(), Direction: query.SortDirectionDescending}, - {SortCriteria: query.Relationship(), Direction: query.SortDirectionAscending}, - } - ) - - require.Equal(t, true, sortItems.FormatCypherOrder().Items[0].Ascending) - require.Equal(t, false, sortItems.FormatCypherOrder().Items[1].Ascending) - require.Equal(t, true, sortItems.FormatCypherOrder().Items[2].Ascending) - - require.Equal(t, query.NodeID(), sortItems.FormatCypherOrder().Items[0].Expression) - require.Equal(t, query.Node(), sortItems.FormatCypherOrder().Items[1].Expression) - require.Equal(t, query.Relationship(), sortItems.FormatCypherOrder().Items[2].Expression) -} diff --git a/query/neo4j/rewrite.go b/query/neo4j/rewrite.go deleted file mode 100644 index 2e6f0ce..0000000 --- a/query/neo4j/rewrite.go +++ /dev/null @@ -1,153 +0,0 @@ -package neo4j - -import ( - "github.com/specterops/dawgs/cypher/models/cypher" - "github.com/specterops/dawgs/cypher/models/walk" - "github.com/specterops/dawgs/query" -) - -// ExpressionListRewriter contains rewriting logic related to folding Cypher syntax nodes along with additional -// guards for certain comparison checks. -type ExpressionListRewriter struct { - walk.Visitor[cypher.SyntaxNode] - - descentStack []cypher.SyntaxNode -} - -func NewExpressionListRewriter() walk.Visitor[cypher.SyntaxNode] { - return &ExpressionListRewriter{ - Visitor: walk.NewVisitor[cypher.SyntaxNode](), - } -} - -func (s *ExpressionListRewriter) pushExpression(expression cypher.SyntaxNode) { - s.descentStack = append(s.descentStack, expression) -} - -func (s *ExpressionListRewriter) peekExpression() (cypher.SyntaxNode, bool) { - if len(s.descentStack) == 0 { - return nil, false - } - - return s.descentStack[len(s.descentStack)-1], true -} - -func (s *ExpressionListRewriter) peekExpressionList() (cypher.ExpressionList, bool) { - if ancestorNode, hasPrevious := s.peekExpression(); hasPrevious { - ancestorExpressionList, isExpressionList := ancestorNode.(cypher.ExpressionList) - return ancestorExpressionList, isExpressionList - } - - return nil, false -} - -func (s *ExpressionListRewriter) popExpression() { - s.descentStack = s.descentStack[:len(s.descentStack)-1] -} - -func unwrapParenthetical(expression cypher.SyntaxNode) cypher.SyntaxNode { - cursor := expression - - for cursor != nil { - switch typedCursor := cursor.(type) { - case *cypher.Parenthetical: - cursor = typedCursor.Expression - continue - } - - break - } - - return cursor -} - -func (s *ExpressionListRewriter) rewriteStringNegation(negation *cypher.Negation) { - if ancestorExpressionList, isExpressionList := s.peekExpressionList(); isExpressionList { - switch typedNegatedExpression := unwrapParenthetical(negation.Expression).(type) { - case *cypher.Comparison: - firstPartial := typedNegatedExpression.FirstPartial() - - // If the negated expression is a comparison check to see if it's a string comparison. This is done since - // Neo4j comparison semantics for strings regarding `null` has edge cases that must be accounted for - switch firstPartial.Operator { - case cypher.OperatorStartsWith, cypher.OperatorEndsWith, cypher.OperatorContains: - // Rewrite this comparison is a disjunction of the negation and a follow-on comparison to handle null - // checks - ancestorExpressionList.Replace(ancestorExpressionList.IndexOf(negation), &cypher.Parenthetical{ - Expression: cypher.NewDisjunction( - negation, - cypher.NewComparison(typedNegatedExpression.Left, cypher.OperatorIs, query.Literal(nil)), - ), - }) - } - } - } -} - -func (s *ExpressionListRewriter) peekLastMatch() (*cypher.Match, bool) { - for idx := len(s.descentStack) - 1; idx >= 0; idx-- { - if lastMatch, typeOK := s.descentStack[idx].(*cypher.Match); typeOK { - return lastMatch, typeOK - } - } - - return nil, false -} - -func (s *ExpressionListRewriter) Enter(node cypher.SyntaxNode) { - // Push after visiting the node to avoid having ancestor references pointing to the currently visited node - s.pushExpression(node) -} - -func (s *ExpressionListRewriter) Exit(node cypher.SyntaxNode) { - attemptSelfRemoval := func() { - if ancestorNode, hasPrevious := s.peekExpression(); hasPrevious { - if ancestorExpressionList, isExpressionList := ancestorNode.(cypher.ExpressionList); isExpressionList { - ancestorExpressionList.Remove(node) - } - } - } - - s.popExpression() - - switch typedNode := node.(type) { - case cypher.ExpressionList: - if typedNode.Len() == 0 { - // Remove emtpy cypher expression lists - attemptSelfRemoval() - } - - case *cypher.KindMatcher: - if variable, typeOK := typedNode.Reference.(*cypher.Variable); !typeOK { - s.SetErrorf("expected a variable as the reference for a kind matcher but received: %T", node) - } else if variable.Symbol == query.EdgeSymbol { - // We need to remove this expression from the most recent expression list and tack it onto the - // relationship of the last match - if lastMatch, hasLastMatch := s.peekLastMatch(); !hasLastMatch { - s.SetErrorf("expected a match AST node") - } else if ancestorExpressionList, isExpressionList := s.peekExpressionList(); !isExpressionList { - s.SetErrorf("expected an expression list AST node") - } else { - firstRelationshipPattern := lastMatch.FirstRelationshipPattern() - firstRelationshipPattern.Kinds = append(firstRelationshipPattern.Kinds, typedNode.Kinds...) - - ancestorExpressionList.Remove(node) - } - } - - case *cypher.Negation: - s.rewriteStringNegation(typedNode) - - case *cypher.Parenthetical: - switch typedParentheticalElement := typedNode.Expression.(type) { - case cypher.ExpressionList: - if numExpressions := typedParentheticalElement.Len(); numExpressions == 0 { - attemptSelfRemoval() - } else if numExpressions == 1 { - // If the expression list has only one element, make it the only expression present in the - // parenthetical expression - typedNode.Expression = typedParentheticalElement.Get(0) - } - } - } -} diff --git a/query/query.go b/query/query.go new file mode 100644 index 0000000..7937972 --- /dev/null +++ b/query/query.go @@ -0,0 +1,782 @@ +package query + +import ( + "errors" + "fmt" + + "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/graph" +) + +type runtimeIdentifiers struct { + path string + node string + start string + relationship string + end string +} + +func (s runtimeIdentifiers) Path() *cypher.Variable { + return cypher.NewVariableWithSymbol(s.path) +} + +func (s runtimeIdentifiers) Node() *cypher.Variable { + return cypher.NewVariableWithSymbol(s.node) +} + +func (s runtimeIdentifiers) Start() *cypher.Variable { + return cypher.NewVariableWithSymbol(s.start) +} + +func (s runtimeIdentifiers) Relationship() *cypher.Variable { + return cypher.NewVariableWithSymbol(s.relationship) +} + +func (s runtimeIdentifiers) End() *cypher.Variable { + return cypher.NewVariableWithSymbol(s.end) +} + +var Identifiers = runtimeIdentifiers{ + path: "p", + node: "n", + start: "s", + relationship: "r", + end: "e", +} + +func joinedExpressionList(operator cypher.Operator, operands []cypher.SyntaxNode) cypher.SyntaxNode { + expressionList := &cypher.Comparison{} + + if len(operands) > 0 { + expressionList.Left = operands[0] + + for _, operand := range operands[1:] { + expressionList.NewPartialComparison(operator, operand) + } + } + + return expressionList +} + +func Not(operand cypher.Expression) cypher.Expression { + switch typedOperand := operand.(type) { + case *cypher.KindMatcher: + // If the type doesn't match, this code does not handle the error. This will be caught during query build time + // instead. + if identifier, typeOK := typedOperand.Reference.(*cypher.Variable); typeOK && identifier.Symbol == Identifiers.relationship { + if len(typedOperand.Kinds) == 1 { + return cypher.NewComparison( + cypher.NewSimpleFunctionInvocation(cypher.EdgeTypeFunction, identifier), + cypher.OperatorNotEquals, + cypher.NewStringLiteral(typedOperand.Kinds[0].String()), + ) + } else { + return cypher.NewNegation( + cypher.NewComparison( + cypher.NewSimpleFunctionInvocation(cypher.EdgeTypeFunction, identifier), + cypher.OperatorIn, + cypher.NewStringListLiteral(typedOperand.Kinds.Strings()), + ), + ) + } + } + } + + return cypher.NewNegation(operand) +} + +func And(operands ...cypher.SyntaxNode) cypher.SyntaxNode { + return joinedExpressionList(cypher.OperatorAnd, operands) +} + +func Or(operands ...cypher.SyntaxNode) cypher.SyntaxNode { + return joinedExpressionList(cypher.OperatorOr, operands) +} + +func Node() NodeContinuation { + return &entity[NodeContinuation]{ + identifier: Identifiers.Node(), + } +} + +func Path() PathContinuation { + return &entity[PathContinuation]{ + identifier: Identifiers.Path(), + } +} + +func Start() NodeContinuation { + return &entity[NodeContinuation]{ + identifier: Identifiers.Start(), + } +} + +func Relationship() RelationshipContinuation { + return &entity[RelationshipContinuation]{ + identifier: Identifiers.Relationship(), + } +} + +func End() NodeContinuation { + return &entity[NodeContinuation]{ + identifier: Identifiers.End(), + } +} + +type QualifiedExpression interface { + qualifier() cypher.Expression +} + +type EntityContinuation interface { + QualifiedExpression + + Count() cypher.Expression + ID() IdentityContinuation + Property(name string) PropertyContinuation +} + +type KindContinuation interface { + Is(kind graph.Kind) cypher.Expression + IsOneOf(kinds graph.Kinds) cypher.Expression +} + +type KindsContinuation interface { + Has(kind graph.Kind) cypher.Expression + HasOneOf(kinds graph.Kinds) cypher.Expression + Add(kinds graph.Kinds) cypher.Expression + Remove(kinds graph.Kinds) cypher.Expression +} + +type Comparable interface { + Contains(value any) cypher.Expression + Equals(value any) cypher.Expression + GreaterThan(value any) cypher.Expression + GreaterThanOrEqualTo(value any) cypher.Expression + LessThan(value any) cypher.Expression + LessThanOrEqualTo(value any) cypher.Expression +} + +type PropertyContinuation interface { + QualifiedExpression + Comparable + + Set(value any) *cypher.SetItem + Remove() *cypher.RemoveItem +} + +type IdentityContinuation interface { + QualifiedExpression + Comparable +} + +type comparisonContinuation struct { + qualifierExpression cypher.Expression +} + +func (s *comparisonContinuation) qualifier() cypher.Expression { + return s.qualifierExpression +} + +func (s *comparisonContinuation) asComparison(operator cypher.Operator, rOperand any) cypher.Expression { + return cypher.NewComparison( + s.qualifier(), + operator, + cypher.NewLiteral(rOperand, rOperand == nil), + ) +} + +func (s *comparisonContinuation) Contains(value any) cypher.Expression { + return s.asComparison(cypher.OperatorContains, value) +} + +func (s *comparisonContinuation) Equals(value any) cypher.Expression { + return s.asComparison(cypher.OperatorEquals, value) +} + +func (s *comparisonContinuation) GreaterThan(value any) cypher.Expression { + return s.asComparison(cypher.OperatorGreaterThan, value) +} + +func (s *comparisonContinuation) GreaterThanOrEqualTo(value any) cypher.Expression { + return s.asComparison(cypher.OperatorGreaterThanOrEqualTo, value) +} + +func (s *comparisonContinuation) LessThan(value any) cypher.Expression { + return s.asComparison(cypher.OperatorLessThan, value) +} + +func (s *comparisonContinuation) LessThanOrEqualTo(value any) cypher.Expression { + return s.asComparison(cypher.OperatorLessThanOrEqualTo, value) +} + +type propertyContinuation struct { + comparisonContinuation +} + +func (s *propertyContinuation) Set(value any) *cypher.SetItem { + return cypher.NewSetItem( + s.qualifier(), + cypher.OperatorAssignment, + cypher.NewLiteral(value, value == nil), + ) +} + +func (s *propertyContinuation) Remove() *cypher.RemoveItem { + return cypher.RemoveProperty(s.qualifier()) +} + +type entity[T any] struct { + identifier *cypher.Variable +} + +func (s *entity[T]) Kind() KindContinuation { + return kindContinuation{ + identifier: s.identifier, + } +} + +func (s *entity[T]) Kinds() KindsContinuation { + return kindsContinuation{ + identifier: s.identifier, + } +} + +func (s *entity[T]) Count() cypher.Expression { + return cypher.NewSimpleFunctionInvocation(cypher.CountFunction, s.identifier) +} + +func (s *entity[T]) SetProperties(properties map[string]any) cypher.Expression { + set := &cypher.Set{} + + for key, value := range properties { + set.Items = append(set.Items, s.Property(key).Set(value)) + } + + return set +} + +func (s *entity[T]) RemoveProperties(properties []string) cypher.Expression { + remove := &cypher.Remove{} + + for _, key := range properties { + remove.Items = append(remove.Items, s.Property(key).Remove()) + } + + return remove +} + +func (s *entity[T]) RelationshipPattern(kind graph.Kind, properties cypher.Expression, direction graph.Direction) cypher.Expression { + return &cypher.RelationshipPattern{ + Variable: s.identifier, + Kinds: graph.Kinds{kind}, + Direction: direction, + Properties: properties, + } +} + +func (s *entity[T]) NodePattern(kinds graph.Kinds, properties cypher.Expression) cypher.Expression { + return &cypher.NodePattern{ + Variable: s.identifier, + Kinds: kinds, + Properties: properties, + } +} + +func (s *entity[T]) qualifier() cypher.Expression { + return s.identifier +} + +func (s *entity[T]) ID() IdentityContinuation { + return &comparisonContinuation{ + qualifierExpression: &cypher.FunctionInvocation{ + Distinct: false, + Name: cypher.IdentityFunction, + Arguments: []cypher.Expression{s.identifier}, + }, + } +} + +func (s *entity[T]) Property(propertyName string) PropertyContinuation { + return &propertyContinuation{ + comparisonContinuation: comparisonContinuation{ + qualifierExpression: cypher.NewPropertyLookup(s.identifier.Symbol, propertyName), + }, + } +} + +type kindContinuation struct { + identifier *cypher.Variable +} + +func (s kindContinuation) Is(kind graph.Kind) cypher.Expression { + return s.IsOneOf(graph.Kinds{kind}) +} + +func (s kindContinuation) IsOneOf(kinds graph.Kinds) cypher.Expression { + return &cypher.KindMatcher{ + Reference: s.identifier, + Kinds: kinds, + } +} + +type kindsContinuation struct { + identifier *cypher.Variable +} + +func (s kindsContinuation) Has(kind graph.Kind) cypher.Expression { + return s.HasOneOf(graph.Kinds{kind}) +} + +func (s kindsContinuation) HasOneOf(kinds graph.Kinds) cypher.Expression { + return &cypher.KindMatcher{ + Reference: s.identifier, + Kinds: kinds, + } +} + +func (s kindsContinuation) Add(kinds graph.Kinds) cypher.Expression { + return cypher.NewSetItem( + s.identifier, + cypher.OperatorLabelAssignment, + kinds, + ) +} + +func (s kindsContinuation) Remove(kinds graph.Kinds) cypher.Expression { + return cypher.RemoveKindsByMatcher(cypher.NewKindMatcher(s.identifier, kinds)) +} + +type PathContinuation interface { + QualifiedExpression + + Count() cypher.Expression +} + +type RelationshipContinuation interface { + EntityContinuation + + RelationshipPattern(kind graph.Kind, properties cypher.Expression, direction graph.Direction) cypher.Expression + + Kind() KindContinuation + SetProperties(properties map[string]any) cypher.Expression + RemoveProperties(properties []string) cypher.Expression +} + +type NodeContinuation interface { + EntityContinuation + + NodePattern(kinds graph.Kinds, properties cypher.Expression) cypher.Expression + + Kinds() KindsContinuation + SetProperties(properties map[string]any) cypher.Expression + RemoveProperties(properties []string) cypher.Expression +} + +type QueryBuilder interface { + Where(constraints ...cypher.SyntaxNode) QueryBuilder + OrderBy(sortItems ...cypher.SyntaxNode) QueryBuilder + Skip(offset int) QueryBuilder + Limit(limit int) QueryBuilder + Return(projections ...any) QueryBuilder + Update(updatingClauses ...any) QueryBuilder + Create(creationClauses ...any) QueryBuilder + Delete(expressions ...any) QueryBuilder + WithShortestPaths() QueryBuilder + WithAllShortestPaths() QueryBuilder + Build() (*PreparedQuery, error) +} + +type builder struct { + errors []error + constraints []cypher.SyntaxNode + sortItems []cypher.SyntaxNode + projections []any + creates []any + setItems []*cypher.SetItem + removeItems []*cypher.RemoveItem + deleteItems []cypher.Expression + detachDelete bool + shortestPathQuery bool + allShorestPathsQuery bool + skip *int + limit *int +} + +func New() QueryBuilder { + return &builder{} +} + +func (s *builder) WithShortestPaths() QueryBuilder { + s.shortestPathQuery = true + return s +} + +func (s *builder) WithAllShortestPaths() QueryBuilder { + s.allShorestPathsQuery = true + return s +} + +func (s *builder) OrderBy(sortItems ...cypher.SyntaxNode) QueryBuilder { + s.sortItems = append(s.sortItems, sortItems...) + return s +} + +func (s *builder) Skip(skip int) QueryBuilder { + s.skip = &skip + return s +} + +func (s *builder) Limit(limit int) QueryBuilder { + s.limit = &limit + return s +} + +func (s *builder) Return(projections ...any) QueryBuilder { + s.projections = append(s.projections, projections...) + return s +} + +func (s *builder) Create(creationClauses ...any) QueryBuilder { + s.creates = append(s.creates, creationClauses...) + return s +} + +func (s *builder) Update(updates ...any) QueryBuilder { + for _, nextUpdate := range updates { + switch typedNextUpdate := nextUpdate.(type) { + case *cypher.Set: + s.setItems = append(s.setItems, typedNextUpdate.Items...) + + case *cypher.SetItem: + s.setItems = append(s.setItems, typedNextUpdate) + + case *cypher.Remove: + s.removeItems = append(s.removeItems, typedNextUpdate.Items...) + + case *cypher.RemoveItem: + s.removeItems = append(s.removeItems, typedNextUpdate) + + default: + s.trackError(fmt.Errorf("unknown update type: %T", nextUpdate)) + } + } + + return s +} + +func (s *builder) Delete(deleteItems ...any) QueryBuilder { + for _, nextDelete := range deleteItems { + switch typedNextUpdate := nextDelete.(type) { + case QualifiedExpression: + qualifier := typedNextUpdate.qualifier() + + switch qualifier { + case Identifiers.node, Identifiers.start, Identifiers.end: + s.detachDelete = true + } + + s.deleteItems = append(s.deleteItems, qualifier) + + case *cypher.Variable: + switch typedNextUpdate.Symbol { + case Identifiers.node, Identifiers.start, Identifiers.end: + s.detachDelete = true + } + + s.deleteItems = append(s.deleteItems, typedNextUpdate) + + default: + s.trackError(fmt.Errorf("unknown delete type: %T", nextDelete)) + } + } + + return s +} + +func (s *builder) trackError(err error) { + s.errors = append(s.errors, err) +} + +func (s *builder) Where(constraints ...cypher.SyntaxNode) QueryBuilder { + s.constraints = append(s.constraints, constraints...) + return s +} + +func (s *builder) buildCreates(singlePartQuery *cypher.SinglePartQuery) error { + // Early exit to hide this part of the business logic while handling queries with no create statements + if len(s.creates) == 0 { + return nil + } + + var ( + pattern = &cypher.PatternPart{} + createClause = &cypher.Create{ + // Note: Unique is Neo4j specific and will not be supported here. Use of constraints for + // uniqueness is expected instead. + Unique: false, + Pattern: []*cypher.PatternPart{pattern}, + } + ) + + for _, nextCreate := range s.creates { + switch typedNextCreate := nextCreate.(type) { + case QualifiedExpression: + switch typedExpression := typedNextCreate.qualifier().(type) { + case *cypher.Variable: + switch typedExpression.Symbol { + case Identifiers.node, Identifiers.start, Identifiers.end: + pattern.AddPatternElements(&cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol(typedExpression.Symbol), + }) + + default: + return fmt.Errorf("invalid variable reference for create: %s", typedExpression.Symbol) + } + } + + case *cypher.NodePattern: + pattern.AddPatternElements(typedNextCreate) + + case *cypher.RelationshipPattern: + pattern.AddPatternElements(&cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol(Identifiers.start), + }) + + pattern.AddPatternElements(typedNextCreate) + + pattern.AddPatternElements(&cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol(Identifiers.end), + }) + + default: + return fmt.Errorf("invalid type for create: %T", nextCreate) + } + } + + singlePartQuery.UpdatingClauses = append(singlePartQuery.UpdatingClauses, cypher.NewUpdatingClause(createClause)) + return nil +} + +func (s *builder) buildUpdatingClauses(singlePartQuery *cypher.SinglePartQuery) error { + if len(s.setItems) > 0 { + singlePartQuery.UpdatingClauses = append(singlePartQuery.UpdatingClauses, cypher.NewUpdatingClause( + cypher.NewSet(s.setItems), + )) + } + + if len(s.removeItems) > 0 { + singlePartQuery.UpdatingClauses = append(singlePartQuery.UpdatingClauses, cypher.NewUpdatingClause( + cypher.NewRemove(s.removeItems), + )) + } + + if len(s.deleteItems) > 0 { + singlePartQuery.UpdatingClauses = append(singlePartQuery.UpdatingClauses, cypher.NewUpdatingClause( + cypher.NewDelete( + s.detachDelete, + s.deleteItems, + ), + )) + } + + return s.buildCreates(singlePartQuery) +} + +func (s *builder) buildProjectionOrder() (*cypher.Order, error) { + var orderByNode *cypher.Order + + if len(s.sortItems) > 0 { + orderByNode = &cypher.Order{} + + for _, untypedSortItem := range s.sortItems { + switch typedSortItem := untypedSortItem.(type) { + case *cypher.Order: + for _, sortItem := range typedSortItem.Items { + orderByNode.Items = append(orderByNode.Items, sortItem) + } + + case *cypher.SortItem: + orderByNode.Items = append(orderByNode.Items, typedSortItem) + } + } + } + + return orderByNode, nil +} + +func (s *builder) buildProjection(singlePartQuery *cypher.SinglePartQuery) error { + var ( + hasProjectedItems = len(s.projections) > 0 + hasSkip = s.skip != nil && *s.skip > 0 + hasLimit = s.limit != nil && *s.limit > 0 + requiresProjection = hasProjectedItems || hasSkip || hasLimit + ) + + if requiresProjection { + if !hasProjectedItems { + return fmt.Errorf("query expected projected items") + } + + projection := singlePartQuery.NewProjection(false) + + for _, nextProjection := range s.projections { + switch typedNextProjection := nextProjection.(type) { + case *cypher.Return: + for _, returnItem := range typedNextProjection.Projection.Items { + if typedReturnItem, typeOK := returnItem.(*cypher.ProjectionItem); !typeOK { + return fmt.Errorf("invalid type for return: %T", returnItem) + } else { + projection.AddItem(typedReturnItem) + } + } + + case QualifiedExpression: + projection.AddItem(cypher.NewProjectionItemWithExpr(typedNextProjection.qualifier())) + + case kindContinuation: + var kindExpr cypher.Expression + + switch typedNextProjection.identifier.Symbol { + case Identifiers.node, Identifiers.start, Identifiers.end: + kindExpr = cypher.NewSimpleFunctionInvocation(cypher.NodeLabelsFunction, typedNextProjection.identifier) + + case Identifiers.relationship: + kindExpr = cypher.NewSimpleFunctionInvocation(cypher.EdgeTypeFunction, typedNextProjection.identifier) + } + + projection.AddItem(cypher.NewProjectionItemWithExpr(kindExpr)) + + case kindsContinuation: + var kindExpr cypher.Expression + + switch typedNextProjection.identifier.Symbol { + case Identifiers.node, Identifiers.start, Identifiers.end: + kindExpr = cypher.NewSimpleFunctionInvocation(cypher.NodeLabelsFunction, typedNextProjection.identifier) + + case Identifiers.relationship: + kindExpr = cypher.NewSimpleFunctionInvocation(cypher.EdgeTypeFunction, typedNextProjection.identifier) + } + + projection.AddItem(cypher.NewProjectionItemWithExpr(kindExpr)) + + default: + projection.AddItem(cypher.NewProjectionItemWithExpr(typedNextProjection)) + } + } + + if s.skip != nil && *s.skip > 0 { + projection.Skip = cypher.NewSkip(*s.skip) + } + + if s.limit != nil && *s.limit > 0 { + projection.Limit = cypher.NewLimit(*s.limit) + } + + if projectionOrder, err := s.buildProjectionOrder(); err != nil { + return err + } else if projectionOrder != nil { + projection.Order = projectionOrder + } + } + + return nil +} + +type PreparedQuery struct { + Query *cypher.RegularQuery + Parameters map[string]any +} + +func (s *builder) hasActions() bool { + return len(s.projections) > 0 || len(s.setItems) > 0 || len(s.removeItems) > 0 || len(s.creates) > 0 || len(s.deleteItems) > 0 +} + +func (s *builder) Build() (*PreparedQuery, error) { + if len(s.errors) > 0 { + return nil, errors.Join(s.errors...) + } + + if !s.hasActions() { + return nil, fmt.Errorf("query has no action specified") + } + + var ( + regularQuery, singlePartQuery = cypher.NewRegularQueryWithSingleQuery() + match = &cypher.Match{} + seenIdentifiers = newIdentifierSet() + relationshipKinds graph.Kinds + ) + + if err := s.buildUpdatingClauses(singlePartQuery); err != nil { + return nil, err + } + + if err := s.buildProjection(singlePartQuery); err != nil { + return nil, err + } + + // If there are constraints, add them to the match with a where clause + if len(s.constraints) > 0 { + var ( + whereClause = match.NewWhere() + constraints = &cypher.Comparison{} + ) + + for _, nextConstraint := range s.constraints { + switch typedNextConstraint := nextConstraint.(type) { + case *cypher.KindMatcher: + if identifier, typeOK := typedNextConstraint.Reference.(*cypher.Variable); !typeOK { + return nil, fmt.Errorf("expected type *cypher.Variable, got %T", typedNextConstraint) + } else if identifier.Symbol == Identifiers.relationship { + relationshipKinds = relationshipKinds.Add(typedNextConstraint.Kinds...) + continue + } + } + + if constraints.Left == nil { + constraints.Left = nextConstraint + } else { + constraints.NewPartialComparison(cypher.OperatorAnd, nextConstraint) + } + } + + if constraints.Left != nil { + whereClause.Add(constraints) + + if err := seenIdentifiers.CollectFromExpression(whereClause); err != nil { + return nil, err + } + } + } + + if err := seenIdentifiers.CollectFromExpression(singlePartQuery); err != nil { + return nil, err + } + + // Skip pattern preparation if there is a create clause with no constraints + if len(s.constraints) > 0 || len(s.creates) == 0 { + if isNodePattern(seenIdentifiers) { + if err := prepareNodePattern(match, seenIdentifiers); err != nil { + return nil, err + } + } else if isRelationshipPattern(seenIdentifiers) { + if err := prepareRelationshipPattern(match, seenIdentifiers, relationshipKinds, s.shortestPathQuery, s.allShorestPathsQuery); err != nil { + return nil, err + } + } else { + return nil, fmt.Errorf("query has no node and relationship query identifiers specified") + } + } + + if len(match.Pattern) > 0 { + newReadingClause := cypher.NewReadingClause() + newReadingClause.Match = match + + singlePartQuery.ReadingClauses = append(singlePartQuery.ReadingClauses, newReadingClause) + } + + return &PreparedQuery{ + Query: regularQuery, + Parameters: map[string]any{}, + }, nil +} diff --git a/query/query_test.go b/query/query_test.go new file mode 100644 index 0000000..72d067c --- /dev/null +++ b/query/query_test.go @@ -0,0 +1,62 @@ +package query_test + +import ( + "testing" + + "github.com/specterops/dawgs/cypher/models/cypher" + v2 "github.com/specterops/dawgs/query" + + "github.com/specterops/dawgs/cypher/models/cypher/format" + "github.com/specterops/dawgs/graph" + "github.com/stretchr/testify/require" +) + +func TestQuery(t *testing.T) { + preparedQuery, err := v2.New().Where( + v2.Not(v2.Relationship().Kind().Is(graph.StringKind("test"))), + v2.Not(v2.Relationship().Kind().IsOneOf(graph.Kinds{graph.StringKind("A"), graph.StringKind("B")})), + v2.Relationship().Property("rel_prop").LessThanOrEqualTo(1234), + v2.Relationship().Property("other_prop").Equals(5678), + v2.Start().Kinds().HasOneOf(graph.Kinds{graph.StringKind("test")}), + ).Update( + v2.Start().Property("this_prop").Set(1234), + v2.End().Kinds().Remove(graph.Kinds{graph.StringKind("A"), graph.StringKind("B")}), + ).Delete( + v2.Start(), + ).Return( + v2.Relationship(), + v2.Start().Property("node_prop"), + ).Skip(10).Limit(10).Build() + require.NoError(t, err) + + cypherQueryStr, err := format.RegularQuery(preparedQuery.Query, false) + require.NoError(t, err) + + require.Equal(t, "match (s)-[r]->() where type(r) <> 'test' and not type(r) in ['A', 'B'] and r.rel_prop <= 1234 and r.other_prop = 5678 and s:test set s.this_prop = 1234 remove e:A:B delete s return r, s.node_prop skip 10 limit 10", cypherQueryStr) + + preparedQuery, err = v2.New().Create( + v2.Node().NodePattern(graph.Kinds{graph.StringKind("A")}, cypher.NewParameter("props", map[string]any{})), + ).Build() + + require.NoError(t, err) + + cypherQueryStr, err = format.RegularQuery(preparedQuery.Query, false) + require.NoError(t, err) + + require.Equal(t, "create (n:A $props)", cypherQueryStr) + + // TODO: V1 compat wrecked the ergonomics experiment below. This should be revisited once V1-to-V2 is stable. + // + //preparedQuery, err = v2.New().Where( + // v2.Start().ID().Equals(1234), + //).Create( + // v2.Relationship().RelationshipPattern(graph.StringKind("A"), cypher.NewParameter("props", map[string]any{}), graph.DirectionOutbound), + //).Build() + // + //require.NoError(t, err) + // + //cypherQueryStr, err = format.RegularQuery(preparedQuery.Query, false) + //require.NoError(t, err) + // + //require.Equal(t, "match (s), (e) where id(s) = 1234 create (s)-[r:A $props]->(e)", cypherQueryStr) +} diff --git a/query/util.go b/query/util.go new file mode 100644 index 0000000..782bfb2 --- /dev/null +++ b/query/util.go @@ -0,0 +1,174 @@ +package query + +import ( + "errors" + "fmt" + + "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/cypher/models/walk" + "github.com/specterops/dawgs/graph" +) + +func isNodePattern(seen *identifierSet) bool { + return seen.Contains(Identifiers.node) +} + +func isRelationshipPattern(seen *identifierSet) bool { + var ( + hasStart = seen.Contains(Identifiers.start) + hasRelationship = seen.Contains(Identifiers.relationship) + hasEnd = seen.Contains(Identifiers.end) + ) + + return hasStart || hasRelationship || hasEnd +} + +func prepareNodePattern(match *cypher.Match, seen *identifierSet) error { + if isRelationshipPattern(seen) { + return fmt.Errorf("query mixes node and relationship query identifiers") + } + + match.NewPatternPart().AddPatternElements(&cypher.NodePattern{ + Variable: Identifiers.Node(), + }) + + return nil +} + +func prepareRelationshipPattern(match *cypher.Match, seen *identifierSet, relationshipKinds graph.Kinds, shortestPaths, allShortestPaths bool) error { + if shortestPaths && allShortestPaths { + return errors.New("query is requesting both all shortest paths and shortest paths") + } + + var ( + newPatternPart = match.NewPatternPart() + startNodeSeen = seen.Contains(Identifiers.start) + relationshipSeen = seen.Contains(Identifiers.relationship) + endNodeSeen = seen.Contains(Identifiers.end) + ) + + newPatternPart.ShortestPathPattern = shortestPaths + newPatternPart.AllShortestPathsPattern = allShortestPaths + + if startNodeSeen { + newPatternPart.AddPatternElements(&cypher.NodePattern{ + Variable: Identifiers.Start(), + }) + } else { + newPatternPart.AddPatternElements(&cypher.NodePattern{}) + } + + relationshipPattern := &cypher.RelationshipPattern{ + Kinds: relationshipKinds, + Direction: graph.DirectionOutbound, + } + + if relationshipSeen { + relationshipPattern.Variable = Identifiers.Relationship() + } + + if shortestPaths || allShortestPaths { + newPatternPart.Variable = Identifiers.Path() + relationshipPattern.Range = &cypher.PatternRange{} + } + + newPatternPart.AddPatternElements(relationshipPattern) + + if endNodeSeen { + newPatternPart.AddPatternElements(&cypher.NodePattern{ + Variable: Identifiers.End(), + }) + } else { + newPatternPart.AddPatternElements(&cypher.NodePattern{}) + } + + return nil +} + +type identifierSet struct { + identifiers map[string]struct{} +} + +func newIdentifierSet() *identifierSet { + return &identifierSet{ + identifiers: map[string]struct{}{}, + } +} + +func (s *identifierSet) Add(identifier string) { + s.identifiers[identifier] = struct{}{} +} + +func (s *identifierSet) Or(other *identifierSet) { + for otherIdentifier := range other.identifiers { + s.identifiers[otherIdentifier] = struct{}{} + } +} + +func (s *identifierSet) Contains(identifier string) bool { + _, containsIdentifier := s.identifiers[identifier] + return containsIdentifier +} + +func (s *identifierSet) CollectFromExpression(expr cypher.Expression) error { + if exprIdentifiers, err := extractCypherIdentifiers(expr); err != nil { + return err + } else { + s.Or(exprIdentifiers) + return nil + } +} + +type identifierExtractor struct { + walk.Visitor[cypher.SyntaxNode] + + seen *identifierSet + + inDelete bool + inUpdate bool + inCreate bool + inWhere bool +} + +func newIdentifierExtractor() *identifierExtractor { + return &identifierExtractor{ + Visitor: walk.NewVisitor[cypher.SyntaxNode](), + seen: newIdentifierSet(), + } +} + +func (s *identifierExtractor) Enter(node cypher.SyntaxNode) { + switch typedNode := node.(type) { + case *cypher.Variable: + s.seen.Add(typedNode.Symbol) + + case *cypher.NodePattern: + if typedNode.Variable != nil { + s.seen.Add(typedNode.Variable.Symbol) + } + + case *cypher.RelationshipPattern: + if typedNode.Variable != nil { + s.seen.Add(typedNode.Variable.Symbol) + } + + case *cypher.PatternPart: + if typedNode.Variable != nil { + s.seen.Add(typedNode.Variable.Symbol) + } + + case *cypher.ProjectionItem: + if typedNode.Alias != nil { + s.seen.Add(typedNode.Alias.Symbol) + } + } +} + +func extractCypherIdentifiers(expression cypher.Expression) (*identifierSet, error) { + var ( + identifierExtractorVisitor = newIdentifierExtractor() + err = walk.Cypher(expression, identifierExtractorVisitor) + ) + + return identifierExtractorVisitor.seen, err +} diff --git a/dawgs.go b/registry.go similarity index 50% rename from dawgs.go rename to registry.go index 40c1f07..cc93fcd 100644 --- a/dawgs.go +++ b/registry.go @@ -4,8 +4,8 @@ import ( "context" "errors" - "github.com/jackc/pgx/v5/pgxpool" - "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/database" + "github.com/specterops/dawgs/database/v1compat" "github.com/specterops/dawgs/util/size" ) @@ -13,7 +13,7 @@ var ( ErrDriverMissing = errors.New("driver missing") ) -type DriverConstructor func(ctx context.Context, cfg Config) (graph.Database, error) +type DriverConstructor func(ctx context.Context, cfg Config) (database.Instance, error) var availableDrivers = map[string]DriverConstructor{} @@ -24,13 +24,24 @@ func Register(driverName string, constructor DriverConstructor) { type Config struct { GraphQueryMemoryLimit size.Size ConnectionString string - Pool *pgxpool.Pool + + // DriverConfig holds driver-specific configuration data that will be passed to the driver constructor. The type + // and structure depend on the specific driver. + DriverConfig any } -func Open(ctx context.Context, driverName string, config Config) (graph.Database, error) { +func Open(ctx context.Context, driverName string, config Config) (database.Instance, error) { if driverConstructor, hasDriver := availableDrivers[driverName]; !hasDriver { return nil, ErrDriverMissing } else { return driverConstructor(ctx, config) } } + +func OpenV1(ctx context.Context, driverName string, config Config) (v1compat.Database, error) { + if driver, err := Open(ctx, driverName, config); err != nil { + return nil, err + } else { + return v1compat.V1Wrapper(driver), nil + } +} diff --git a/registry_integration_test.go b/registry_integration_test.go new file mode 100644 index 0000000..968b70c --- /dev/null +++ b/registry_integration_test.go @@ -0,0 +1,80 @@ +//go:build manual_integration + +package dawgs_test + +import ( + "context" + "fmt" + "log/slog" + "testing" + + //pg_v2 "github.com/specterops/dawgs/drivers/pg/v2" + "github.com/specterops/dawgs/database" + "github.com/specterops/dawgs/database/neo4j" + "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/query" + "github.com/specterops/dawgs/util/size" + "github.com/stretchr/testify/require" +) + +func Test(t *testing.T) { + ctx := context.Background() + + graphDB, err := database.Open(ctx, neo4j.DriverName, database.Config{ + GraphQueryMemoryLimit: size.Gibibyte * 1, + ConnectionString: "neo4j://neo4j:neo4jj@localhost:7687", + }) + + //graphDB, err := v2.Open(ctx, pg_v2.DriverName, v2.Config{ + // GraphQueryMemoryLimit: size.Gibibyte * 1, + // ConnectionString: "postgresql://postgres:postgres@localhost:5432/bhe", + //}) + + require.NoError(t, err) + + require.NoError(t, graphDB.AssertSchema(ctx, database.NewSchema( + "default", + database.Graph{ + Name: "default", + Nodes: graph.Kinds{graph.StringKind("Node")}, + Edges: graph.Kinds{graph.StringKind("Edge")}, + NodeIndexes: []database.Index{{ + Name: "node_label_name_index", + Field: "name", + Type: database.IndexTypeTextSearch, + }}, + }))) + + preparedQuery, err := query.New().Return(query.Node()).Limit(10).Build() + require.NoError(t, err) + + require.NoError(t, graphDB.Session(ctx, func(ctx context.Context, driver database.Driver) error { + return driver.CreateNode(ctx, graph.PrepareNode(graph.AsProperties(map[string]any{ + "name": "THAT NODE", + }), graph.StringKind("Node"))) + })) + + require.NoError(t, graphDB.Session(ctx, database.FetchNodes(preparedQuery, func(node *graph.Node) error { + slog.Info(fmt.Sprintf("Got result from DB: %v", node)) + return nil + }))) + + require.NoError(t, graphDB.Transaction(ctx, database.FetchNodes(preparedQuery, func(node *graph.Node) error { + slog.Info(fmt.Sprintf("Got result from DB: %v", node)) + return nil + }))) + + //require.NoError(t, graphDB.Transaction(ctx, func(ctx context.Context, driver v2.Driver) error { + // builder := v2.Query().Create( + // v2.Node().NodePattern(graph.Kinds{graph.StringKind("A")}, cypher.NewParameter("props", map[string]any{ + // "name": "1234", + // })), + // ) + // + // if preparedQuery, err := builder.Build(); err != nil { + // return err + // } else { + // return driver.CypherQuery(ctx, preparedQuery.Query, preparedQuery.Parameters).Close(ctx) + // } + //})) +} diff --git a/traversal/traversal.go b/traversal/traversal.go index a981519..f82ffa3 100644 --- a/traversal/traversal.go +++ b/traversal/traversal.go @@ -9,21 +9,18 @@ import ( "sync/atomic" "github.com/specterops/dawgs/cardinality" + "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/database" "github.com/specterops/dawgs/graph" - "github.com/specterops/dawgs/graphcache" - "github.com/specterops/dawgs/ops" "github.com/specterops/dawgs/query" "github.com/specterops/dawgs/util" "github.com/specterops/dawgs/util/atomics" "github.com/specterops/dawgs/util/channels" + "github.com/specterops/dawgs/util/size" ) -// Driver is a function that drives sending queries to the graph and retrieving vertexes and edges. Traversal -// drivers are expected to operate on a cactus tree representation of path space using the graph.PathSegment data -// structure. Path segments returned by a traversal driver are considered extensions of path space that require -// further expansion. If a traversal driver returns no descending path segments, then the given segment may be -// considered terminal. -type Driver = func(ctx context.Context, tx graph.Transaction, segment *graph.PathSegment) ([]*graph.PathSegment, error) +// Logic represents callable logic that drives sending queries to the graph +type Logic = func(ctx context.Context, tx database.Driver, segment *graph.PathSegment) ([]*graph.PathSegment, error) type PatternMatchDelegate = func(terminal *graph.PathSegment) error @@ -31,46 +28,23 @@ type PatternMatchDelegate = func(terminal *graph.PathSegment) error // building the pattern the user may call the Do(...) function and pass it a delegate for handling paths that match // the pattern. // -// The return value of the Do(...) function may be passed directly to a Traversal via a Plan as the Plan.Driver field. +// The return value of the Do(...) function may be passed directly to a Instance via a Plan as the Plan.Driver field. type PatternContinuation interface { - Outbound(criteria ...graph.Criteria) PatternContinuation - OutboundWithDepth(min, max int, criteria ...graph.Criteria) PatternContinuation - Inbound(criteria ...graph.Criteria) PatternContinuation - InboundWithDepth(min, max int, criteria ...graph.Criteria) PatternContinuation - Do(delegate PatternMatchDelegate) Driver + Outbound(criteria ...cypher.SyntaxNode) PatternContinuation + OutboundWithDepth(min, max int, criteria ...cypher.SyntaxNode) PatternContinuation + Inbound(criteria ...cypher.SyntaxNode) PatternContinuation + InboundWithDepth(min, max int, criteria ...cypher.SyntaxNode) PatternContinuation + Do(delegate PatternMatchDelegate) Logic } // expansion is an internal representation of a path expansion step. type expansion struct { - criteria []graph.Criteria + criteria []cypher.SyntaxNode direction graph.Direction minDepth int maxDepth int } -func (s expansion) PrepareCriteria(segment *graph.PathSegment) (graph.Criteria, error) { - var ( - criteria = s.criteria - ) - - switch s.direction { - case graph.DirectionOutbound: - criteria = append([]graph.Criteria{ - query.Equals(query.StartID(), segment.Node.ID), - }, criteria...) - - case graph.DirectionInbound: - criteria = append([]graph.Criteria{ - query.Equals(query.EndID(), segment.Node.ID), - }, criteria...) - - default: - return nil, fmt.Errorf("unsupported direction %v", s.direction) - } - - return query.And(criteria...), nil -} - type patternTag struct { patternIdx int depth int @@ -98,13 +72,13 @@ type pattern struct { } // Do assigns the PatterMatchDelegate internally before returning a function pointer to the Driver receiver function. -func (s *pattern) Do(delegate PatternMatchDelegate) Driver { +func (s *pattern) Do(delegate PatternMatchDelegate) Logic { s.delegate = delegate return s.Driver } // OutboundWithDepth specifies the next outbound expansion step for this pattern with depth parameters. -func (s *pattern) OutboundWithDepth(min, max int, criteria ...graph.Criteria) PatternContinuation { +func (s *pattern) OutboundWithDepth(min, max int, criteria ...cypher.SyntaxNode) PatternContinuation { if min < 0 { min = 1 slog.Warn("Negative mindepth not allowed. Setting min depth for expansion to 1") @@ -127,12 +101,12 @@ func (s *pattern) OutboundWithDepth(min, max int, criteria ...graph.Criteria) Pa // Outbound specifies the next outbound expansion step for this pattern. By default, this expansion will use a minimum // depth of 1 to make the expansion required and a maximum depth of 0 to expand indefinitely. -func (s *pattern) Outbound(criteria ...graph.Criteria) PatternContinuation { +func (s *pattern) Outbound(criteria ...cypher.SyntaxNode) PatternContinuation { return s.OutboundWithDepth(1, 0, criteria...) } // InboundWithDepth specifies the next inbound expansion step for this pattern with depth parameters. -func (s *pattern) InboundWithDepth(min, max int, criteria ...graph.Criteria) PatternContinuation { +func (s *pattern) InboundWithDepth(min, max int, criteria ...cypher.SyntaxNode) PatternContinuation { if min < 0 { min = 1 slog.Warn("Negative mindepth not allowed. Setting min depth for expansion to 1") @@ -155,7 +129,7 @@ func (s *pattern) InboundWithDepth(min, max int, criteria ...graph.Criteria) Pat // Inbound specifies the next inbound expansion step for this pattern. By default, this expansion will use a minimum // depth of 1 to make the expansion required and a maximum depth of 0 to expand indefinitely. -func (s *pattern) Inbound(criteria ...graph.Criteria) PatternContinuation { +func (s *pattern) Inbound(criteria ...cypher.SyntaxNode) PatternContinuation { return s.InboundWithDepth(1, 0, criteria...) } @@ -164,7 +138,7 @@ func NewPattern() PatternContinuation { return &pattern{} } -func (s *pattern) Driver(ctx context.Context, tx graph.Transaction, segment *graph.PathSegment) ([]*graph.PathSegment, error) { +func (s *pattern) Driver(ctx context.Context, dbDriver database.Driver, segment *graph.PathSegment) ([]*graph.PathSegment, error) { var ( nextSegments []*graph.PathSegment @@ -175,23 +149,61 @@ func (s *pattern) Driver(ctx context.Context, tx graph.Transaction, segment *gra // fetchFunc handles directional results from the graph database and is called twice to fetch segment // expansions. - fetchFunc = func(cursor graph.Cursor[graph.DirectionalResult]) error { - for next := range cursor.Chan() { - nextSegment := segment.Descend(next.Node, next.Relationship) - - // Don't emit cycles out of the fetch - if !nextSegment.IsCycle() { - nextSegment.Tag = &patternTag{ - // Use the tag's patternIdx and depth since this is a continuation of the expansions - patternIdx: tag.patternIdx, - depth: tag.depth + 1, + fetchFunc = func(criteria cypher.SyntaxNode, direction graph.Direction) error { + var ( + queryBuilder = query.New() + allCriteria = []cypher.SyntaxNode{criteria} + ) + + switch direction { + case graph.DirectionInbound: + queryBuilder.Where(append(allCriteria, query.Start().ID().Equals(segment.Node.ID))...).Return( + query.Relationship(), + query.End(), + ) + + case graph.DirectionOutbound: + queryBuilder.Where(append(allCriteria, query.End().ID().Equals(segment.Node.ID))...).Return( + query.Relationship(), + query.Start(), + ) + + default: + return fmt.Errorf("unsupported direction %v", direction) + } + + if preparedQuery, err := queryBuilder.Build(); err != nil { + return err + } else { + result := dbDriver.Exec(ctx, preparedQuery.Query, preparedQuery.Parameters) + defer result.Close(ctx) + + for result.HasNext(ctx) { + var ( + nextNode graph.Node + nextRelationship graph.Relationship + ) + + if err := result.Scan(&nextNode, &nextRelationship); err != nil { + return err } - nextSegments = append(nextSegments, nextSegment) + nextSegment := segment.Descend(&nextNode, &nextRelationship) + + // Don't emit cycles out of the fetch + if !nextSegment.IsCycle() { + nextSegment.Tag = &patternTag{ + // Use the tag's patternIdx and depth since this is a continuation of the expansions + patternIdx: tag.patternIdx, + depth: tag.depth + 1, + } + + nextSegments = append(nextSegments, nextSegment) + } } - } - return cursor.Error() + return result.Error() + } } ) @@ -202,15 +214,13 @@ func (s *pattern) Driver(ctx context.Context, tx graph.Transaction, segment *gra // If no max depth was set or if a max depth was set expand the current step further if currentExpansion.maxDepth == 0 || tag.depth < currentExpansion.maxDepth { // Perform the current expansion. - if criteria, err := currentExpansion.PrepareCriteria(segment); err != nil { - return nil, err - } else if err := tx.Relationships().Filter(criteria).FetchDirection(fetchDirection, fetchFunc); err != nil { + if err := fetchFunc(currentExpansion.criteria, fetchDirection); err != nil { return nil, err } } // Check first if this current segment was fetched using the current expansion (i.e. non-optional) - if tag.depth > 0 && currentExpansion.minDepth == 0 || tag.depth >= currentExpansion.minDepth { + if (tag.depth > 0 && currentExpansion.minDepth == 0) || tag.depth >= currentExpansion.minDepth { // No further expansions means this pattern segment is complete. Increment the pattern index to select the // next pattern expansion. Additionally, set the depth back to zero for the tag since we are leaving the // current expansion. @@ -222,9 +232,7 @@ func (s *pattern) Driver(ctx context.Context, tx graph.Transaction, segment *gra nextExpansion := s.expansions[tag.patternIdx] // Expand the next segments - if criteria, err := nextExpansion.PrepareCriteria(segment); err != nil { - return nil, err - } else if err := tx.Relationships().Filter(criteria).FetchDirection(fetchDirection, fetchFunc); err != nil { + if err := fetchFunc(nextExpansion.criteria, fetchDirection); err != nil { return nil, err } @@ -253,29 +261,30 @@ func (s *pattern) Driver(ctx context.Context, tx graph.Transaction, segment *gra type Plan struct { Root *graph.Node RootSegment *graph.PathSegment - Driver Driver + Logic Logic } -type Traversal struct { - db graph.Database - numWorkers int +type Instance struct { + db database.Instance + numParallelWorkers int + memoryLimit size.Size } -func New(db graph.Database, numParallelWorkers int) Traversal { - return Traversal{ - db: db, - numWorkers: numParallelWorkers, +func New(db database.Instance, numParallelWorkers int) Instance { + return Instance{ + db: db, + numParallelWorkers: numParallelWorkers, } } -func (s Traversal) BreadthFirst(ctx context.Context, plan Plan) error { +func (s Instance) BreadthFirst(ctx context.Context, plan Plan) error { var ( // workerWG keeps count of background workers launched in goroutines workerWG = &sync.WaitGroup{} // descentWG keeps count of in-flight traversal work. When this wait group reaches a count of 0 the traversal // is considered complete. - completionC = make(chan struct{}, s.numWorkers*2) + completionC = make(chan struct{}, s.numParallelWorkers) descentCount = &atomic.Int64{} errorCollector = util.NewErrorCollector() traversalCtx, doneFunc = context.WithCancel(ctx) @@ -300,21 +309,21 @@ func (s Traversal) BreadthFirst(ctx context.Context, plan Plan) error { } // Launch the background traversal workers - for workerID := 0; workerID < s.numWorkers; workerID++ { + for workerID := 0; workerID < s.numParallelWorkers; workerID++ { workerWG.Add(1) go func(workerID int) { defer workerWG.Done() - if err := s.db.ReadTransaction(ctx, func(tx graph.Transaction) error { + if err := s.db.Session(ctx, func(ctx context.Context, driver database.Driver) error { for { if nextDescent, ok := channels.Receive(traversalCtx, segmentReaderC); !ok { return nil - } else if tx.GraphQueryMemoryLimit() > 0 && pathTree.SizeOf() > tx.GraphQueryMemoryLimit() { - return fmt.Errorf("%w - Limit: %.2f MB - Memory In-Use: %.2f MB", ops.ErrGraphQueryMemoryLimit, tx.GraphQueryMemoryLimit().Mebibytes(), pathTree.SizeOf().Mebibytes()) + } else if s.memoryLimit > 0 && s.memoryLimit <= pathTree.SizeOf() { + return fmt.Errorf("traversal memory limit reached - Limit: %.2f MB - Memory In-Use: %.2f MB", s.memoryLimit.Mebibytes(), pathTree.SizeOf().Mebibytes()) } else { // Traverse the descending relationships of the current segment - if descendingSegments, err := plan.Driver(traversalCtx, tx, nextDescent); err != nil { + if descendingSegments, err := plan.Logic(traversalCtx, driver, nextDescent); err != nil { return err } else { for _, descendingSegment := range descendingSegments { @@ -360,85 +369,6 @@ func (s Traversal) BreadthFirst(ctx context.Context, plan Plan) error { return errorCollector.Combined() } -func newVisitorFilter(direction graph.Direction, userFilter graph.Criteria) func(segment *graph.PathSegment) graph.Criteria { - return func(segment *graph.PathSegment) graph.Criteria { - var filters []graph.Criteria - - if userFilter != nil { - filters = append(filters, userFilter) - } - - switch direction { - case graph.DirectionOutbound: - filters = append(filters, query.Equals(query.StartID(), segment.Node.ID)) - - case graph.DirectionInbound: - filters = append(filters, query.Equals(query.EndID(), segment.Node.ID)) - } - - return query.And(filters...) - } -} - -func shallowFetchRelationships(direction graph.Direction, segment *graph.PathSegment, graphQuery graph.RelationshipQuery) ([]*graph.Relationship, error) { - var ( - relationships []*graph.Relationship - returnCriteria graph.Criteria - ) - - switch direction { - case graph.DirectionOutbound: - returnCriteria = query.Returning( - query.EndID(), - query.KindsOf(query.End()), - query.RelationshipID(), - query.KindsOf(query.Relationship()), - ) - - case graph.DirectionInbound: - returnCriteria = query.Returning( - query.StartID(), - query.KindsOf(query.Start()), - query.RelationshipID(), - query.KindsOf(query.Relationship()), - ) - - default: - return nil, fmt.Errorf("bi-directional or non-directed edges are not supported") - } - - if err := graphQuery.Query(func(results graph.Result) error { - defer results.Close() - - var ( - nodeID graph.ID - nodeKinds graph.Kinds - edgeID graph.ID - edgeKind graph.Kind - ) - - for results.Next() { - if err := results.Scan(&nodeID, &nodeKinds, &edgeID, &edgeKind); err != nil { - return err - } - - switch direction { - case graph.DirectionOutbound: - relationships = append(relationships, graph.NewRelationship(edgeID, segment.Node.ID, nodeID, nil, edgeKind)) - - case graph.DirectionInbound: - relationships = append(relationships, graph.NewRelationship(edgeID, nodeID, segment.Node.ID, nil, edgeKind)) - } - } - - return results.Error() - }, returnCriteria); err != nil { - return nil, err - } - - return relationships, nil -} - // SegmentFilter is a function type that takes a given path segment and returns true if further descent into the path // is allowed. type SegmentFilter = func(next *graph.PathSegment) bool @@ -529,66 +459,3 @@ func FilteredSkipLimit(filter SkipLimitFilter, visitorFilter SegmentVisitor, ski return shouldDescend } } - -// LightweightDriver is a Driver constructor that fetches only IDs and Kind information from vertexes and -// edges stored in the database. This cuts down on network transit and is appropriate for traversals that may involve -// a large number of or all vertexes within a target graph. -func LightweightDriver(direction graph.Direction, cache graphcache.Cache, criteria graph.Criteria, filter SegmentFilter, terminalVisitors ...SegmentVisitor) Driver { - filterProvider := newVisitorFilter(direction, criteria) - - return func(ctx context.Context, tx graph.Transaction, nextSegment *graph.PathSegment) ([]*graph.PathSegment, error) { - var ( - nextSegments []*graph.PathSegment - nextQuery = tx.Relationships().Filter(filterProvider(nextSegment)).OrderBy( - // Order by relationship ID so that skip and limit behave somewhat predictably - cost of this is pretty - // small even for large result sets - query.Order(query.Identity(query.Relationship()), query.Ascending()), - ) - ) - - if relationships, err := shallowFetchRelationships(direction, nextSegment, nextQuery); err != nil { - return nil, err - } else { - // Reconcile the start and end nodes of the fetched relationships with the graph cache - nodesToFetch := cardinality.NewBitmap64() - - for _, nextRelationship := range relationships { - if nextID, err := direction.PickReverse(nextRelationship); err != nil { - return nil, err - } else { - nodesToFetch.Add(nextID.Uint64()) - } - } - - // Shallow fetching the nodes achieves the same result as shallowFetchRelationships(...) but with the added - // benefit of interacting with the graph cache. Any nodes not already in the cache are fetched just-in-time - // from the database and stored back in the cache for later. - if cachedNodes, err := graphcache.ShallowFetchNodesByID(tx, cache, graph.DuplexToGraphIDs(nodesToFetch)); err != nil { - return nil, err - } else { - cachedNodeSet := graph.NewNodeSet(cachedNodes...) - - for _, nextRelationship := range relationships { - if targetID, err := direction.PickReverse(nextRelationship); err != nil { - return nil, err - } else { - nextSegment := nextSegment.Descend(cachedNodeSet[targetID], nextRelationship) - - if filter(nextSegment) { - nextSegments = append(nextSegments, nextSegment) - } - } - } - } - } - - // If this segment has no further descent paths, render it as a path if we have a path visitor specified - if len(nextSegments) == 0 && len(terminalVisitors) > 0 { - for _, terminalVisitor := range terminalVisitors { - terminalVisitor(nextSegment) - } - } - - return nextSegments, nil - } -} From 1e956030e108d528693bedb8717387765bc86e27 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Tue, 19 Aug 2025 14:17:01 -0700 Subject: [PATCH 2/5] wip --- algo/closeness.go | 11 +- algo/katz.go | 22 ++- algo/sample.go | 6 + algo/scc.go | 205 ++++++++++++++++++++++++---- algo/weight.go | 58 ++++++++ cardinality/lock.go | 40 +++--- cmd/viz/graph.go | 120 +++++++++++++++++ cmd/viz/main.go | 43 ++++++ container/adjacencymap.go | 168 +++++++++++++++++++++++ container/bfs.go | 99 ++++++++++++++ container/digraph.go | 255 +++++++++-------------------------- container/fetch.go | 157 ++++++++++++++++----- container/gml.go | 46 +++++++ container/segment.go | 168 +++++++++++++++++++++++ container/segment_test.go | 55 ++++++++ container/triplestore.go | 228 +++++++++++++++++++++++++++++++ database/driver.go | 13 ++ database/v1compat/switch.go | 9 ++ go.mod | 6 +- go.sum | 56 ++------ graph/kind.go | 5 + graph/properties.go | 12 ++ query/query.go | 49 ++++--- registry.go | 7 + registry_integration_test.go | 74 +++++----- util/slog.go | 26 ++++ 26 files changed, 1553 insertions(+), 385 deletions(-) create mode 100644 cmd/viz/graph.go create mode 100644 cmd/viz/main.go create mode 100644 container/adjacencymap.go create mode 100644 container/bfs.go create mode 100644 container/gml.go create mode 100644 container/segment.go create mode 100644 container/segment_test.go create mode 100644 container/triplestore.go create mode 100644 util/slog.go diff --git a/algo/closeness.go b/algo/closeness.go index 769c395..900ceed 100644 --- a/algo/closeness.go +++ b/algo/closeness.go @@ -6,13 +6,14 @@ import ( "github.com/specterops/dawgs/container" "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/util" ) func ClosenessForDirectedUnweightedGraph(digraph container.DirectedGraph, direction graph.Direction, sampleFunc SampleFunc, nSamples int) map[uint64]Weight { scores := make(map[uint64]Weight, nSamples) for _, nodeID := range sampleFunc(digraph, nSamples) { - if shortestPathTerminals := digraph.BFSTree(nodeID, direction); len(shortestPathTerminals) > 0 { + if shortestPathTerminals := container.BFSTree(digraph, nodeID, direction); len(shortestPathTerminals) > 0 { var distanceSum Weight = 0 for _, shortestPathTerminal := range shortestPathTerminals { @@ -29,14 +30,16 @@ func ClosenessForDirectedUnweightedGraph(digraph container.DirectedGraph, direct return scores } -func ClosenessForDirectedUnweightedGraphParallel(digraph container.DirectedGraph, direction graph.Direction, sampleFunc SampleFunc, nSamples int) map[uint64]Weight { +func ClosenessForDirectedUnweightedGraphParallel(digraph container.DirectedGraph, direction graph.Direction, sampleFunc SampleFunc, nSamples int) WeightMap { var ( - scores = make(map[uint64]Weight, nSamples) + scores = make(WeightMap, nSamples) scoresLock = &sync.Mutex{} workerWG = &sync.WaitGroup{} nodeC = make(chan uint64) ) + defer util.SLogMeasure("ClosenessForDirectedUnweightedGraphParallel")() + for workerID := 0; workerID < runtime.NumCPU(); workerID++ { workerWG.Add(1) @@ -44,7 +47,7 @@ func ClosenessForDirectedUnweightedGraphParallel(digraph container.DirectedGraph defer workerWG.Done() for nodeID := range nodeC { - if shortestPathTerminals := digraph.BFSTree(nodeID, direction); len(shortestPathTerminals) > 0 { + if shortestPathTerminals := container.BFSTree(digraph, nodeID, direction); len(shortestPathTerminals) > 0 { var distanceSum Weight = 0 for _, shortestPathTerminal := range shortestPathTerminals { diff --git a/algo/katz.go b/algo/katz.go index 40a9405..5668420 100644 --- a/algo/katz.go +++ b/algo/katz.go @@ -1,11 +1,13 @@ package algo import ( + "log/slog" "maps" "math" "github.com/specterops/dawgs/container" "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/util" ) /* @@ -31,13 +33,15 @@ High katz centrality values indicate that a node has significant influence withi edges, as well as direct ones. The centrality score also accounts for the declining importance of more distant relationships. */ -func CalculateKatzCentrality(digraph container.DirectedGraph, alpha, beta, epsilon Weight, iterations int) (map[uint64]float64, bool) { +func CalculateKatzCentrality(digraph container.DirectedGraph, alpha, beta, epsilon Weight, iterations int, direction graph.Direction) (map[uint64]Weight, bool) { var ( numNodes = digraph.Nodes().Cardinality() - centrality = make(map[uint64]float64, numNodes) - prevCentrality = make(map[uint64]float64, numNodes) + centrality = make(map[uint64]Weight, numNodes) + prevCentrality = make(map[uint64]Weight, numNodes) ) + defer util.SLogMeasure("CalculateKatzCentrality", slog.String("direction", direction.String()))() + // Initialize centrality scores to baseline digraph.Nodes().Each(func(value uint64) bool { centrality[value] = beta @@ -52,15 +56,21 @@ func CalculateKatzCentrality(digraph container.DirectedGraph, alpha, beta, epsil digraph.Nodes().Each(func(sourceNode uint64) bool { sum := 0.0 - digraph.EachAdjacent(sourceNode, graph.DirectionBoth, func(adjacentNode uint64) bool { + digraph.EachAdjacentNode(sourceNode, direction, func(adjacentNode uint64) bool { sum += prevCentrality[adjacentNode] return true }) centrality[sourceNode] = beta + alpha*sum - if math.Abs(centrality[sourceNode]-prevCentrality[sourceNode]) > epsilon { - changed = true + // Only calculate epsilon tolerance if there is no tolerance violation yet detected + if !changed { + diff := math.Abs(centrality[sourceNode] - prevCentrality[sourceNode]) + changed = diff > epsilon + + if changed { + slog.Info("Tolerance Check Failure", slog.Float64("diff", diff), slog.Float64("epsilon", epsilon), slog.Uint64("src_node", sourceNode)) + } } return true diff --git a/algo/sample.go b/algo/sample.go index 665a529..a2e1610 100644 --- a/algo/sample.go +++ b/algo/sample.go @@ -46,6 +46,12 @@ func sampleHighestDegrees(digraph container.DirectedGraph, nSamples int, directi return nodeSamples } +func SampleExact(samples []uint64) SampleFunc { + return func(digraph container.DirectedGraph, nSamples int) []uint64 { + return samples + } +} + func SampleHighestDegrees(direction graph.Direction) SampleFunc { return func(digraph container.DirectedGraph, nSamples int) []uint64 { return sampleHighestDegrees(digraph, nSamples, direction) diff --git a/algo/scc.go b/algo/scc.go index b5d8360..830cc12 100644 --- a/algo/scc.go +++ b/algo/scc.go @@ -3,12 +3,16 @@ package algo import ( "math" + "github.com/gammazero/deque" "github.com/specterops/dawgs/cardinality" "github.com/specterops/dawgs/container" "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/util" ) -func StronglyConnectedComponents(digraph container.DirectedGraph, direction graph.Direction) ([]cardinality.Duplex[uint64], map[uint64]int) { +func StronglyConnectedComponents(digraph container.DirectedGraph) ([]cardinality.Duplex[uint64], map[uint64]uint64) { + defer util.SLogMeasure("StronglyConnectedComponents")() + type descentCursor struct { id uint64 branches []uint64 @@ -26,7 +30,7 @@ func StronglyConnectedComponents(digraph container.DirectedGraph, direction grap stack = make([]uint64, 0, initialAlloc) dfsDescentStack = make([]*descentCursor, 0, initialAlloc) stronglyConnectedComponents = make([]cardinality.Duplex[uint64], 0, initialAlloc) - nodeToSCCIndex = make(map[uint64]int, numNodes) + nodeToSCCIndex = make(map[uint64]uint64, numNodes) ) digraph.EachNode(func(node uint64) bool { @@ -36,7 +40,7 @@ func StronglyConnectedComponents(digraph container.DirectedGraph, direction grap dfsDescentStack = append(dfsDescentStack, &descentCursor{ id: node, - branches: digraph.Adjacent(node, direction), + branches: digraph.AdjacentNodes(node, graph.DirectionOutbound), branchIdx: 0, }) @@ -70,7 +74,7 @@ func StronglyConnectedComponents(digraph container.DirectedGraph, direction grap dfsDescentStack = append(dfsDescentStack, &descentCursor{ id: nextBranchID, - branches: digraph.Adjacent(nextBranchID, direction), + branches: digraph.AdjacentNodes(nextBranchID, graph.DirectionOutbound), branchIdx: 0, }) } else if onStack.Contains(nextBranchID) { @@ -84,7 +88,7 @@ func StronglyConnectedComponents(digraph container.DirectedGraph, direction grap if lowLinks[nextCursor.id] == visitedIndex[nextCursor.id] { var ( scc = cardinality.NewBitmap64() - sccID = len(stronglyConnectedComponents) + sccID = uint64(len(stronglyConnectedComponents)) ) for { @@ -115,34 +119,177 @@ func StronglyConnectedComponents(digraph container.DirectedGraph, direction grap return stronglyConnectedComponents, nodeToSCCIndex } -type ComponentDirectedGraph struct { - Components []cardinality.Duplex[uint64] - OriginNodeToComponentIndex map[uint64]int - Digraph container.DirectedGraph +type ComponentGraph struct { + componentMembers []cardinality.Duplex[uint64] + memberComponentLookup map[uint64]uint64 + digraph container.DirectedGraph +} + +func (s ComponentGraph) Digraph() container.DirectedGraph { + return s.digraph +} + +func (s ComponentGraph) ContainingComponent(memberID uint64) (uint64, bool) { + component, inComponentDigraph := s.memberComponentLookup[memberID] + return component, inComponentDigraph } -func NewComponentDirectedGraph(digraph container.DirectedGraph, direction graph.Direction) ComponentDirectedGraph { +func (s ComponentGraph) CollectComponentMembers(componentID uint64, members cardinality.Duplex[uint64]) { + members.Or(s.componentMembers[componentID]) +} + +func (s ComponentGraph) ComponentSearch(startComponent, endComponent uint64) bool { + if startComponent == endComponent { + return true + } + var ( - components, nodeToComponentIndex = StronglyConnectedComponents(digraph, direction) - componentDigraph = container.NewDirectedGraph() + traversals deque.Deque[uint64] + visited = cardinality.NewBitmap64() + reachable = false ) - // Ensure all components are present as vertices, even if they have no edges - for idx := range components { - componentDigraph.Nodes().CheckedAdd(uint64(idx)) + traversals.PushBack(startComponent) + + for remainingTraversals := traversals.Len(); !reachable && remainingTraversals > 0; remainingTraversals = traversals.Len() { + nextComponent := traversals.PopFront() + + s.digraph.EachAdjacentNode(nextComponent, graph.DirectionOutbound, func(adjacentComponent uint64) bool { + reachable = adjacentComponent == endComponent + + if !reachable { + if visited.CheckedAdd(adjacentComponent) { + traversals.PushBack(adjacentComponent) + } + } + + return !reachable + }) } - digraph.EachNode(func(node uint64) bool { - nodeComponent := graph.ID(nodeToComponentIndex[node]) - - digraph.EachAdjacent(node, direction, func(adjacent uint64) bool { - if adjacentComponent := graph.ID(nodeToComponentIndex[adjacent]); nodeComponent != adjacentComponent { - switch direction { - case graph.DirectionInbound: - componentDigraph.AddEdge(adjacentComponent, nodeComponent) - case graph.DirectionOutbound: - componentDigraph.AddEdge(nodeComponent, adjacentComponent) + return reachable +} + +func (s ComponentGraph) ComponentReachable(startComponent, endComponent uint64) bool { + if startComponent == endComponent { + return true + } + + var ( + outboundQueue deque.Deque[uint64] + inboundQueue deque.Deque[uint64] + outboundComponents = cardinality.NewBitmap64() + inboundComponents = cardinality.NewBitmap64() + visitedComponents = cardinality.NewBitmap64() + reachable = false + ) + + outboundQueue.PushBack(startComponent) + outboundComponents.Add(startComponent) + + inboundQueue.PushBack(endComponent) + inboundComponents.Add(endComponent) + + for !reachable { + var ( + outboundQueueLen = outboundQueue.Len() + inboundQueueLen = inboundQueue.Len() + ) + + if outboundQueueLen > 0 && outboundQueueLen <= inboundQueueLen { + nextComponent := outboundQueue.PopFront() + + if !visitedComponents.CheckedAdd(nextComponent) { + continue + } + + s.digraph.EachAdjacentNode(nextComponent, graph.DirectionOutbound, func(adjacentComponent uint64) bool { + if outboundComponents.CheckedAdd(adjacentComponent) { + // Haven't seen this component yet, append to the traversal queue and check for reachability + outboundQueue.PushBack(adjacentComponent) + reachable = inboundComponents.Contains(adjacentComponent) + } + + // Continue iterating if not reachable + return !reachable + }) + } else if inboundQueueLen > 0 { + nextComponent := inboundQueue.PopFront() + + s.digraph.EachAdjacentNode(nextComponent, graph.DirectionInbound, func(adjacentComponent uint64) bool { + if inboundComponents.CheckedAdd(adjacentComponent) { + // Haven't seen this component yet, append to the traversal queue and check for reachability + inboundQueue.PushBack(adjacentComponent) + reachable = outboundComponents.Contains(adjacentComponent) } + + // Continue iterating if not reachable + return !reachable + }) + } else { + // No more expansions remain + break + } + } + + return reachable +} + +func (s ComponentGraph) ComponentHistogram(originNodes []uint64) map[uint64]uint64 { + histogram := map[uint64]uint64{} + + for _, originNode := range originNodes { + if component, inComponent := s.ContainingComponent(originNode); inComponent { + histogram[component] += 1 + } + } + + return histogram +} + +func (s ComponentGraph) OriginReachable(startID, endID uint64) bool { + var ( + startComponent, startInComponent = s.ContainingComponent(startID) + endComponent, endInComponent = s.ContainingComponent(endID) + ) + + if !startInComponent || !endInComponent { + return false + } + + return s.ComponentReachable(startComponent, endComponent) +} + +func NewComponentGraph(originGraph container.DirectedGraph) ComponentGraph { + var ( + componentMembers, memberComponentLookup = StronglyConnectedComponents(originGraph) + componentDigraph = container.NewAdjacencyMapGraph() + nextEdgeID = uint64(1) + ) + + defer util.SLogMeasure("NewComponentGraph")() + + // Ensure all components are present as vertices, even if they have no edges + for componentID := range componentMembers { + componentDigraph.Nodes().Add(uint64(componentID)) + } + + originGraph.EachNode(func(node uint64) bool { + nodeComponent := memberComponentLookup[node] + + originGraph.EachAdjacentNode(node, graph.DirectionInbound, func(adjacent uint64) bool { + if adjacentComponent := memberComponentLookup[adjacent]; nodeComponent != adjacentComponent { + componentDigraph.AddEdge(nextEdgeID, adjacentComponent, nodeComponent) + nextEdgeID += 1 + } + + return true + }) + + originGraph.EachAdjacentNode(node, graph.DirectionOutbound, func(adjacent uint64) bool { + if adjacentComponent := memberComponentLookup[adjacent]; nodeComponent != adjacentComponent { + componentDigraph.AddEdge(nextEdgeID, nodeComponent, adjacentComponent) + nextEdgeID += 1 } return true @@ -151,9 +298,9 @@ func NewComponentDirectedGraph(digraph container.DirectedGraph, direction graph. return true }) - return ComponentDirectedGraph{ - Components: components, - OriginNodeToComponentIndex: nodeToComponentIndex, - Digraph: componentDigraph, + return ComponentGraph{ + componentMembers: componentMembers, + memberComponentLookup: memberComponentLookup, + digraph: componentDigraph, } } diff --git a/algo/weight.go b/algo/weight.go index b2be0d6..7af51cb 100644 --- a/algo/weight.go +++ b/algo/weight.go @@ -1,3 +1,61 @@ package algo +import ( + "maps" + "sort" +) + type Weight = float64 +type WeightMap map[uint64]Weight + +func (s WeightMap) Keys() []uint64 { + keys := make([]uint64, 0, len(s)) + + for key := range s { + keys = append(keys, key) + } + + return keys +} + +func (s WeightMap) Copy() WeightMap { + return maps.Clone(s) +} + +func (s WeightMap) MultiplyInclusiveOnly(other WeightMap) { + for k := range s { + s[k] *= other[k] + } +} + +func (s WeightMap) SumInclusiveOnly(other WeightMap) { + for k := range s { + s[k] += other[k] + } +} + +func (s WeightMap) VisitSorted(visitor func(k uint64, v Weight) bool) { + type tuple struct { + key uint64 + value Weight + } + + sorted := make([]tuple, 0, len(s)) + + for k, v := range s { + sorted = append(sorted, tuple{ + key: k, + value: v, + }) + } + + sort.Slice(sorted, func(i, j int) bool { + return sorted[i].value > sorted[j].value + }) + + for _, nextTuple := range sorted { + if !visitor(nextTuple.key, nextTuple.value) { + break + } + } +} diff --git a/cardinality/lock.go b/cardinality/lock.go index f6f19b3..6a4866f 100644 --- a/cardinality/lock.go +++ b/cardinality/lock.go @@ -6,13 +6,13 @@ import ( type threadSafeDuplex[T uint32 | uint64] struct { provider Duplex[T] - lock *sync.Mutex + lock *sync.RWMutex } func ThreadSafeDuplex[T uint32 | uint64](provider Duplex[T]) Duplex[T] { return threadSafeDuplex[T]{ provider: provider, - lock: &sync.Mutex{}, + lock: &sync.RWMutex{}, } } @@ -66,34 +66,38 @@ func (s threadSafeDuplex[T]) Or(other Provider[T]) { } func (s threadSafeDuplex[T]) Cardinality() uint64 { - s.lock.Lock() - defer s.lock.Unlock() + s.lock.RLock() + defer s.lock.RUnlock() return s.provider.Cardinality() } func (s threadSafeDuplex[T]) Slice() []T { - s.lock.Lock() - defer s.lock.Unlock() + s.lock.RLock() + defer s.lock.RUnlock() return s.provider.Slice() } func (s threadSafeDuplex[T]) Contains(value T) bool { - s.lock.Lock() - defer s.lock.Unlock() + s.lock.RLock() + defer s.lock.RUnlock() return s.provider.Contains(value) } func (s threadSafeDuplex[T]) Each(delegate func(value T) bool) { - s.lock.Lock() - defer s.lock.Unlock() + s.lock.RLock() + defer s.lock.RUnlock() s.provider.Each(delegate) } func (s threadSafeDuplex[T]) CheckedAdd(value T) bool { + if s.Contains(value) { + return false + } + s.lock.Lock() defer s.lock.Unlock() @@ -101,21 +105,21 @@ func (s threadSafeDuplex[T]) CheckedAdd(value T) bool { } func (s threadSafeDuplex[T]) Clone() Duplex[T] { - s.lock.Lock() - defer s.lock.Unlock() + s.lock.RLock() + defer s.lock.RUnlock() return ThreadSafeDuplex(s.provider.Clone()) } type threadSafeSimplex[T uint32 | uint64] struct { provider Simplex[T] - lock *sync.Mutex + lock *sync.RWMutex } func ThreadSafeSimplex[T uint32 | uint64](provider Simplex[T]) Simplex[T] { return threadSafeSimplex[T]{ provider: provider, - lock: &sync.Mutex{}, + lock: &sync.RWMutex{}, } } @@ -141,15 +145,15 @@ func (s threadSafeSimplex[T]) Or(other Provider[T]) { } func (s threadSafeSimplex[T]) Cardinality() uint64 { - s.lock.Lock() - defer s.lock.Unlock() + s.lock.RLock() + defer s.lock.RUnlock() return s.provider.Cardinality() } func (s threadSafeSimplex[T]) Clone() Simplex[T] { - s.lock.Lock() - defer s.lock.Unlock() + s.lock.RLock() + defer s.lock.RUnlock() return ThreadSafeSimplex(s.provider.Clone()) } diff --git a/cmd/viz/graph.go b/cmd/viz/graph.go new file mode 100644 index 0000000..1c4b542 --- /dev/null +++ b/cmd/viz/graph.go @@ -0,0 +1,120 @@ +package main + +import ( + "io" + "os" + "strconv" + + "github.com/go-echarts/go-echarts/v2/charts" + "github.com/go-echarts/go-echarts/v2/components" + "github.com/go-echarts/go-echarts/v2/opts" + "github.com/specterops/dawgs/container" + "github.com/specterops/dawgs/graph" +) + +var graphNodes = []opts.GraphNode{ + {Name: "Node1"}, + {Name: "Node2"}, + {Name: "Node3"}, + {Name: "Node4"}, + {Name: "Node5"}, + {Name: "Node6"}, + {Name: "Node7"}, + {Name: "Node8"}, +} + +func genLinks() []opts.GraphLink { + links := make([]opts.GraphLink, 0) + for i := 0; i < len(graphNodes); i++ { + for j := 0; j < len(graphNodes); j++ { + links = append(links, opts.GraphLink{Source: graphNodes[i].Name, Target: graphNodes[j].Name}) + } + } + return links +} + +func graphBase() *charts.Graph { + graph := charts.NewGraph() + graph.SetGlobalOptions( + charts.WithTitleOpts(opts.Title{Title: "basic graph example"}), + ) + graph.AddSeries("graph", graphNodes, genLinks(), + charts.WithGraphChartOpts( + opts.GraphChart{Force: &opts.GraphForce{Repulsion: 8000}}, + ), + ) + return graph +} + +func graphCircle() *charts.Graph { + graph := charts.NewGraph() + graph.SetGlobalOptions( + charts.WithTitleOpts(opts.Title{Title: "Circular layout"}), + ) + + graph.AddSeries("graph", graphNodes, genLinks()). + SetSeriesOptions( + charts.WithGraphChartOpts( + opts.GraphChart{ + Force: &opts.GraphForce{Repulsion: 8000}, + Layout: "circular", + }), + charts.WithLabelOpts(opts.Label{Show: opts.Bool(true), Position: "right"}), + ) + return graph +} + +func graphDigraph(digraph container.DirectedGraph, direction graph.Direction) *charts.Graph { + graph := charts.NewGraph() + graph.SetGlobalOptions( + charts.WithTitleOpts(opts.Title{ + Title: "demo", + })) + + var ( + nodes []opts.GraphNode + links []opts.GraphLink + ) + + digraph.EachNode(func(node uint64) bool { + sourceNode := strconv.FormatUint(node, 10) + + nodes = append(nodes, opts.GraphNode{ + Name: sourceNode, + }) + + digraph.EachAdjacentNode(node, direction, func(adjacent uint64) bool { + links = append(links, opts.GraphLink{ + Source: sourceNode, + Target: strconv.FormatUint(adjacent, 10), + }) + + return true + }) + + return true + }) + + graph.AddSeries("graph", nodes, links). + SetSeriesOptions( + charts.WithGraphChartOpts(opts.GraphChart{ + Force: &opts.GraphForce{Repulsion: 8000}, + }), + ) + return graph +} + +func doTheGraph(digraph container.DirectedGraph, direction graph.Direction) { + page := components.NewPage() + page.AddCharts( + graphDigraph(digraph, direction), + ) + + f, err := os.Create("graph.html") + if err != nil { + panic(err) + + } + page.Render(io.MultiWriter(f)) + f.Close() +} diff --git a/cmd/viz/main.go b/cmd/viz/main.go new file mode 100644 index 0000000..9c0086b --- /dev/null +++ b/cmd/viz/main.go @@ -0,0 +1,43 @@ +package main + +import ( + "context" + "fmt" + + "github.com/specterops/dawgs" + "github.com/specterops/dawgs/container" + "github.com/specterops/dawgs/database/pg" + "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/query" +) + +func main() { + ctx, done := context.WithCancel(context.Background()) + defer done() + + dbInst, err := dawgs.Open(ctx, pg.DriverName, dawgs.Config{ + ConnectionString: "user=postgres dbname=bhe password=bhe4eva host=localhost", + }) + + if err != nil { + panic(err) + } + + metaKinds := graph.Kinds{ + graph.StringKind("Meta"), + graph.StringKind("MetaDetail"), + } + + kindFilter := query.And( + query.Not(query.Start().Kinds().HasOneOf(metaKinds)), + query.Not(query.End().Kinds().HasOneOf(metaKinds)), + ) + + if digraph, err := container.FetchAdjacencyGraph(ctx, dbInst, kindFilter); err != nil { + panic(err) + } else { + fmt.Printf("Loaded %d nodes\n", digraph.Nodes().Cardinality()) + + // algo.CalculateKatzCentrality(digraph, 0.01, 1, 0.01, 1000) + } +} diff --git a/container/adjacencymap.go b/container/adjacencymap.go new file mode 100644 index 0000000..1d83303 --- /dev/null +++ b/container/adjacencymap.go @@ -0,0 +1,168 @@ +package container + +import ( + "github.com/specterops/dawgs/cardinality" + "github.com/specterops/dawgs/graph" +) + +type adjacencyMapDigraph struct { + inbound AdjacencyMap + outbound AdjacencyMap + nodes cardinality.Duplex[uint64] +} + +func NewAdjacencyMapGraph() DirectedGraph { + return &adjacencyMapDigraph{ + inbound: AdjacencyMap{}, + outbound: AdjacencyMap{}, + nodes: cardinality.NewBitmap64(), + } +} + +func (s *adjacencyMapDigraph) Normalize() ([]uint64, DirectedGraph) { + var ( + numNodes = s.NumNodes() + idIndex = make(map[uint64]uint64, numNodes) + reverseIDIndex = make([]uint64, numNodes) + newGraph = &adjacencyMapDigraph{ + inbound: AdjacencyMap{}, + outbound: AdjacencyMap{}, + nodes: cardinality.NewBitmap64(), + } + ) + + s.EachNode(func(node uint64) bool { + normalID := uint64(len(idIndex)) + + newGraph.nodes.Add(normalID) + idIndex[node] = normalID + reverseIDIndex[normalID] = node + + return true + }) + + for originID, originAdjacentBitmap := range s.inbound { + var ( + normalID = idIndex[originID] + normalAdjacent = cardinality.NewBitmap64() + ) + + originAdjacentBitmap.Each(func(originAdjacent uint64) bool { + normalAdjacent.Add(idIndex[originAdjacent]) + return true + }) + + newGraph.inbound[normalID] = normalAdjacent + } + + for originID, originAdjacentBitmap := range s.outbound { + var ( + normalID = idIndex[originID] + normalAdjacent = cardinality.NewBitmap64() + ) + + originAdjacentBitmap.Each(func(originAdjacent uint64) bool { + normalAdjacent.Add(idIndex[originAdjacent]) + return true + }) + + newGraph.outbound[normalID] = normalAdjacent + } + + return reverseIDIndex, newGraph +} + +func (s *adjacencyMapDigraph) NumNodes() uint64 { + return s.nodes.Cardinality() +} + +func (s *adjacencyMapDigraph) NumEdges() uint64 { + return s.nodes.Cardinality() +} + +func (s *adjacencyMapDigraph) getAdjacent(node uint64, direction graph.Direction) cardinality.Duplex[uint64] { + switch direction { + case graph.DirectionOutbound: + if adjacent, exists := s.outbound[node]; exists { + return adjacent + } + + case graph.DirectionInbound: + if adjacent, exists := s.inbound[node]; exists { + return adjacent + } + + case graph.DirectionBoth: + var ( + outboundAdjacent, hasOutbound = s.outbound[node] + inboundAdjacent, hasInbound = s.inbound[node] + ) + + if hasOutbound { + if hasInbound { + combinedAdjacent := outboundAdjacent.Clone() + combinedAdjacent.Or(inboundAdjacent) + + return combinedAdjacent + } + + return outboundAdjacent + } else if hasInbound { + return inboundAdjacent + } + } + + return nil +} + +func (s *adjacencyMapDigraph) AdjacentNodes(node uint64, direction graph.Direction) []uint64 { + if adjacent := s.getAdjacent(node, direction); adjacent != nil { + return adjacent.Slice() + } + + return nil +} + +func (s *adjacencyMapDigraph) Degrees(node uint64, direction graph.Direction) uint64 { + if adjacent := s.getAdjacent(node, direction); adjacent != nil { + return adjacent.Cardinality() + } + + return 0 +} + +func (s *adjacencyMapDigraph) Nodes() cardinality.Duplex[uint64] { + return s.nodes +} + +func (s *adjacencyMapDigraph) EachNode(delegate func(node uint64) bool) { + s.nodes.Each(delegate) +} + +func (s *adjacencyMapDigraph) EachAdjacentNode(node uint64, direction graph.Direction, delegate func(adjacent uint64) bool) { + if adjacent := s.getAdjacent(node, direction); adjacent != nil { + adjacent.Each(delegate) + } +} + +func (s *adjacencyMapDigraph) AddEdge(edge, start, end uint64) { + if edgeBitmap, exists := s.outbound[start]; exists { + edgeBitmap.Add(end) + } else { + edgeBitmap = cardinality.NewBitmap64() + edgeBitmap.Add(end) + + s.outbound[start] = edgeBitmap + } + + if edgeBitmap, exists := s.inbound[end]; exists { + edgeBitmap.Add(start) + } else { + edgeBitmap = cardinality.NewBitmap64() + edgeBitmap.Add(start) + + s.inbound[end] = edgeBitmap + } + + s.nodes.Add(start, end) +} diff --git a/container/bfs.go b/container/bfs.go new file mode 100644 index 0000000..19a095e --- /dev/null +++ b/container/bfs.go @@ -0,0 +1,99 @@ +package container + +import ( + "bufio" + "compress/gzip" + "context" + "os" + + "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/util" +) + +type BFSTreeFile struct { + Path string + NumPaths uint64 +} + +func (s BFSTreeFile) Remove() error { + return os.Remove(s.Path) +} + +func (s BFSTreeFile) ReadEach(ctx context.Context, delegate func(next *Segment) (bool, error)) error { + if fin, err := os.Open(s.Path); err != nil { + return err + } else { + defer fin.Close() + + if gzipReader, err := gzip.NewReader(fin); err != nil { + return err + } else { + defer gzipReader.Close() + + scanner := bufio.NewScanner(fin) + scanner.Split(bufio.ScanLines) + + for scanner.Scan() { + if err := scanner.Err(); err != nil { + return err + } + + if shouldContinue, err := delegate(UnmarshalSegment(scanner.Bytes())); err != nil { + return err + } else if !shouldContinue { + break + } + } + } + } + + return nil +} + +func WriteZoneBFSTree(zoneNodes graph.NodeSet, ts Triplestore, scratchPath string, maxDepth int) (BFSTreeFile, error) { + var ( + numPaths = uint64(0) + zoneNodeMembership = zoneNodes.IDBitmap() + ) + + defer util.SLogMeasure("WriteZoneBFSTree")() + + if scratchFile, err := os.CreateTemp(scratchPath, "zcp"); err != nil { + return BFSTreeFile{}, err + } else { + scratchFileWriter := gzip.NewWriter(scratchFile) + + defer scratchFile.Close() + defer scratchFileWriter.Close() + + for zoneNodeID, _ := range zoneNodes { + TSBFS( + ts, + zoneNodeID.Uint64(), + graph.DirectionInbound, + maxDepth, + func(edge Edge) bool { + return !zoneNodeMembership.Contains(edge.Start) + }, + func(segment *Segment) bool { + numPaths += 1 + + if err := MarshalSegment(segment, scratchFileWriter); err != nil { + panic(err) + } + + if _, err := scratchFileWriter.Write([]byte("\n")); err != nil { + panic(err) + } + + return true + }, + ) + } + + return BFSTreeFile{ + Path: scratchFile.Name(), + NumPaths: numPaths, + }, nil + } +} diff --git a/container/digraph.go b/container/digraph.go index 58a61c7..78cadc3 100644 --- a/container/digraph.go +++ b/container/digraph.go @@ -1,6 +1,9 @@ package container import ( + "fmt" + + "github.com/gammazero/deque" "github.com/specterops/dawgs/cardinality" "github.com/specterops/dawgs/graph" ) @@ -8,31 +11,51 @@ import ( type Weight = float64 type AdjacencyMap map[uint64]cardinality.Duplex[uint64] -type DirectedGraph interface { - AddEdge(start, end graph.ID) - NumNodes() uint64 - Nodes() cardinality.Duplex[uint64] - Degrees(node uint64, direction graph.Direction) uint64 - Adjacent(node uint64, direction graph.Direction) []uint64 - EachNode(delegate func(node uint64) bool) - EachAdjacent(node uint64, direction graph.Direction, delegate func(adjacent uint64) bool) - BFSTree(nodeID uint64, direction graph.Direction) []ShortestPathTerminal - Normalize() ([]uint64, DirectedGraph) - Dimensions(direction graph.Direction) (uint64, uint64) +type KindMap map[graph.Kind]cardinality.Duplex[uint64] + +func (s KindMap) Add(kind graph.Kind, member uint64) { + if members, hasMembers := s[kind]; hasMembers { + members.Add(member) + } else { + s[kind] = cardinality.NewBitmap64With(member) + } } -type directedGraph struct { - inbound AdjacencyMap - outbound AdjacencyMap - nodes cardinality.Duplex[uint64] +func (s KindMap) FindFirst(id uint64) graph.Kind { + for kind, membership := range s { + if membership.Contains(id) { + return kind + } + } + + panic(fmt.Sprintf("Can't find kind for edge ID %d", id)) + + return nil } -func NewDirectedGraph() DirectedGraph { - return &directedGraph{ - inbound: AdjacencyMap{}, - outbound: AdjacencyMap{}, - nodes: cardinality.NewBitmap64(), +func (s KindMap) FindAll(id uint64) graph.Kinds { + var matchedKinds graph.Kinds + + for kind, membership := range s { + if membership.Contains(id) { + matchedKinds = matchedKinds.Add(kind) + } } + + return matchedKinds +} + +type KindDatabase struct { + EdgeKindMap KindMap + NodeKindMap KindMap +} + +func (s KindDatabase) NodeKind(nodeID uint64) graph.Kinds { + return s.NodeKindMap.FindAll(nodeID) +} + +func (s KindDatabase) EdgeKind(edgeID uint64) graph.Kind { + return s.EdgeKindMap.FindFirst(edgeID) } type ShortestPathTerminal struct { @@ -40,41 +63,46 @@ type ShortestPathTerminal struct { Distance Weight } -func (s *directedGraph) Dimensions(direction graph.Direction) (uint64, uint64) { - var largestRow uint64 = 0 +type DirectedGraph interface { + AddEdge(edge, start, end uint64) + NumNodes() uint64 + Nodes() cardinality.Duplex[uint64] + EachNode(delegate func(node uint64) bool) + Degrees(node uint64, direction graph.Direction) uint64 + AdjacentNodes(node uint64, direction graph.Direction) []uint64 + EachAdjacentNode(node uint64, direction graph.Direction, delegate func(adjacent uint64) bool) +} - s.EachNode(func(node uint64) bool { - if adjacent := s.getAdjacent(node, direction); adjacent != nil { - count := adjacent.Cardinality() +func Dimensions(digraph DirectedGraph, direction graph.Direction) (uint64, uint64) { + var largestRow uint64 = 0 - if count > largestRow { - largestRow = count - } + digraph.EachNode(func(node uint64) bool { + if degrees := digraph.Degrees(node, direction); degrees > largestRow { + largestRow = degrees } return true }) - return s.Nodes().Cardinality(), largestRow + return digraph.Nodes().Cardinality(), largestRow } -func (s *directedGraph) BFSTree(nodeID uint64, direction graph.Direction) []ShortestPathTerminal { +func BFSTree(digraph DirectedGraph, nodeID uint64, direction graph.Direction) []ShortestPathTerminal { var ( visited = cardinality.NewBitmap64() - stack []ShortestPathTerminal + queue deque.Deque[ShortestPathTerminal] terminals []ShortestPathTerminal ) - stack = append(stack, ShortestPathTerminal{ + queue.PushBack(ShortestPathTerminal{ NodeID: nodeID, Distance: 0, }) - for len(stack) > 0 { - nextCursor := stack[len(stack)-1] - stack = stack[:len(stack)-1] + for queue.Len() > 0 { + nextCursor := queue.PopFront() - s.EachAdjacent(nextCursor.NodeID, direction, func(adjacentNodeID uint64) bool { + digraph.EachAdjacentNode(nextCursor.NodeID, direction, func(adjacentNodeID uint64) bool { if visited.CheckedAdd(adjacentNodeID) { terminalCursor := ShortestPathTerminal{ NodeID: adjacentNodeID, @@ -82,7 +110,7 @@ func (s *directedGraph) BFSTree(nodeID uint64, direction graph.Direction) []Shor } // If not visited, descend into this node next - stack = append(stack, terminalCursor) + queue.PushBack(terminalCursor) // This reachable node represents one of the shortest path terminals terminals = append(terminals, terminalCursor) @@ -94,156 +122,3 @@ func (s *directedGraph) BFSTree(nodeID uint64, direction graph.Direction) []Shor return terminals } - -func (s *directedGraph) Normalize() ([]uint64, DirectedGraph) { - var ( - numNodes = s.NumNodes() - idIndex = make(map[uint64]uint64, numNodes) - reverseIDIndex = make([]uint64, numNodes) - newGraph = &directedGraph{ - inbound: AdjacencyMap{}, - outbound: AdjacencyMap{}, - nodes: cardinality.NewBitmap64(), - } - ) - - s.EachNode(func(node uint64) bool { - normalID := uint64(len(idIndex)) - - newGraph.nodes.Add(normalID) - idIndex[node] = normalID - reverseIDIndex[normalID] = node - - return true - }) - - for originID, originAdjacentBitmap := range s.inbound { - var ( - normalID = idIndex[originID] - normalAdjacent = cardinality.NewBitmap64() - ) - - originAdjacentBitmap.Each(func(originAdjacent uint64) bool { - normalAdjacent.Add(idIndex[originAdjacent]) - return true - }) - - newGraph.inbound[normalID] = normalAdjacent - } - - for originID, originAdjacentBitmap := range s.outbound { - var ( - normalID = idIndex[originID] - normalAdjacent = cardinality.NewBitmap64() - ) - - originAdjacentBitmap.Each(func(originAdjacent uint64) bool { - normalAdjacent.Add(idIndex[originAdjacent]) - return true - }) - - newGraph.outbound[normalID] = normalAdjacent - } - - return reverseIDIndex, newGraph -} - -func (s *directedGraph) NumNodes() uint64 { - return s.nodes.Cardinality() -} - -func (s *directedGraph) NumEdges() uint64 { - return s.nodes.Cardinality() -} - -func (s *directedGraph) getAdjacent(node uint64, direction graph.Direction) cardinality.Duplex[uint64] { - switch direction { - case graph.DirectionOutbound: - if adjacent, exists := s.outbound[node]; exists { - return adjacent - } - - case graph.DirectionInbound: - if adjacent, exists := s.inbound[node]; exists { - return adjacent - } - - case graph.DirectionBoth: - var ( - outboundAdjacent, hasOutbound = s.outbound[node] - inboundAdjacent, hasInbound = s.inbound[node] - ) - - if hasOutbound { - if hasInbound { - combinedAdjacent := outboundAdjacent.Clone() - combinedAdjacent.Or(inboundAdjacent) - - return combinedAdjacent - } - - return outboundAdjacent - } else if hasInbound { - return inboundAdjacent - } - } - - return nil -} - -func (s *directedGraph) Adjacent(node uint64, direction graph.Direction) []uint64 { - if adjacent := s.getAdjacent(node, direction); adjacent != nil { - return adjacent.Slice() - } - - return nil -} - -func (s *directedGraph) Degrees(node uint64, direction graph.Direction) uint64 { - if adjacent := s.getAdjacent(node, direction); adjacent != nil { - return adjacent.Cardinality() - } - - return 0 -} - -func (s *directedGraph) Nodes() cardinality.Duplex[uint64] { - return s.nodes -} - -func (s *directedGraph) EachNode(delegate func(node uint64) bool) { - s.nodes.Each(delegate) -} - -func (s *directedGraph) EachAdjacent(node uint64, direction graph.Direction, delegate func(adjacent uint64) bool) { - if adjacent := s.getAdjacent(node, direction); adjacent != nil { - adjacent.Each(delegate) - } -} - -func (s *directedGraph) AddEdge(start, end graph.ID) { - var ( - startUint64 = start.Uint64() - endUint64 = end.Uint64() - ) - - if edgeBitmap, exists := s.outbound[startUint64]; exists { - edgeBitmap.Add(endUint64) - } else { - edgeBitmap = cardinality.NewBitmap64() - edgeBitmap.Add(endUint64) - - s.outbound[startUint64] = edgeBitmap - } - - if edgeBitmap, exists := s.inbound[endUint64]; exists { - edgeBitmap.Add(startUint64) - } else { - edgeBitmap = cardinality.NewBitmap64() - edgeBitmap.Add(startUint64) - - s.inbound[endUint64] = edgeBitmap - } - - s.nodes.Add(startUint64, endUint64) -} diff --git a/container/fetch.go b/container/fetch.go index ed1c1a1..fbb2b32 100644 --- a/container/fetch.go +++ b/container/fetch.go @@ -2,62 +2,151 @@ package container import ( "context" - "sync" + "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/database" "github.com/specterops/dawgs/graph" "github.com/specterops/dawgs/query" + "github.com/specterops/dawgs/util" ) -func FetchDirectedGraph(ctx context.Context, graphDB graph.Database, relationshipFilter graph.Criteria) (DirectedGraph, error) { - type anonymousEdge struct { - StartID graph.ID - EndID graph.ID - } +const ( + channelBufferLen = 4096 +) - var ( - edgeC = make(chan anonymousEdge, 4096) - mergeWG = &sync.WaitGroup{} - digraph = NewDirectedGraph() - ) +type anonymousEdge struct { + EdgeID uint64 + StartID uint64 + EndID uint64 +} - mergeWG.Add(1) +func FetchAdjacencyGraph(ctx context.Context, graphDB database.Instance, relationshipFilter cypher.SyntaxNode) (DirectedGraph, error) { + digraph := NewAdjacencyMapGraph() - go func() { - defer mergeWG.Done() + return digraph, graphDB.Session(ctx, func(ctx context.Context, driver database.Driver) error { + builder := query.New() - for edge := range edgeC { - digraph.AddEdge(edge.StartID, edge.EndID) + if relationshipFilter != nil { + builder.Where(relationshipFilter) } - }() - if err := graphDB.ReadTransaction(ctx, func(tx graph.Transaction) error { - return tx.Relationships().Filter(relationshipFilter).Query(func(results graph.Result) error { - for results.Next() { + builder.Return( + query.Relationship().ID(), + query.Start().ID(), + query.End().ID(), + ) + + if preparedQuery, err := builder.Build(); err != nil { + return err + } else { + result := driver.Exec(ctx, preparedQuery.Query, preparedQuery.Parameters) + defer result.Close(ctx) + + for result.HasNext(ctx) { var ( - startID graph.ID - endID graph.ID + edgeID uint64 + startID uint64 + endID uint64 ) - if err := results.Scan(&startID, &endID); err != nil { + if err := result.Scan(&edgeID, &startID, &endID); err != nil { return err } - edgeC <- anonymousEdge{ - StartID: startID, - EndID: endID, + digraph.AddEdge(edgeID, startID, endID) + } + + return result.Error() + } + }) +} + +func FetchKindDatabase(ctx context.Context, graphDB database.Instance) (KindDatabase, error) { + defer util.SLogMeasure("FetchKindDatabase")() + + edgeKinds := KindMap{} + + if err := graphDB.Session(ctx, func(ctx context.Context, driver database.Driver) error { + builder := query.New() + builder.Return(query.Relationship().ID(), query.Relationship().Kind()) + + if builtQuery, err := builder.Build(); err != nil { + return err + } else { + var ( + result = driver.Exec(ctx, builtQuery.Query, builtQuery.Parameters) + edgeID uint64 + kind graph.Kind + ) + + for result.HasNext(ctx) { + if err := result.Scan(&edgeID, &kind); err != nil { + return err } + + edgeKinds.Add(kind, edgeID) } - return results.Error() - }, query.Returning(query.StartID(), query.EndID())) + result.Close(ctx) + return result.Error() + } }); err != nil { - // Ensure that the edge channel is closed to prevent goroutine leaks - close(edgeC) - return nil, err + return KindDatabase{}, err + } + + return KindDatabase{ + EdgeKindMap: edgeKinds, + }, nil +} + +type TSDB struct { + Triplestore Triplestore + EdgeKinds KindMap +} + +func FetchTriplestore(ctx context.Context, graphDB database.Instance, filter cypher.SyntaxNode) (TSDB, error) { + tsdb := TSDB{ + Triplestore: NewTriplestore(), + EdgeKinds: KindMap{}, } - close(edgeC) - mergeWG.Wait() + defer util.SLogMeasure("FetchTriplestore")() + + return tsdb, graphDB.Session(ctx, func(ctx context.Context, driver database.Driver) error { + query := query.New().Return( + query.Start().ID(), + query.Relationship().ID(), + query.Relationship().Kind(), + query.End().ID(), + ) + + if filter != nil { + query.Where(filter) + } + + if builtQuery, err := query.Build(); err != nil { + return err + } else { + result := driver.Exec(ctx, builtQuery.Query, builtQuery.Parameters) + defer result.Close(ctx) - return digraph, nil + for result.HasNext(ctx) { + var ( + startID uint64 + relationshipID uint64 + relationshipKind graph.Kind + endID uint64 + ) + + if err := result.Scan(&startID, &relationshipID, &relationshipKind, &endID); err != nil { + return err + } + + tsdb.Triplestore.AddEdge(relationshipID, startID, endID) + tsdb.EdgeKinds.Add(relationshipKind, relationshipID) + } + + return result.Error() + } + }) } diff --git a/container/gml.go b/container/gml.go new file mode 100644 index 0000000..8179d8e --- /dev/null +++ b/container/gml.go @@ -0,0 +1,46 @@ +package container + +import ( + "os" + "strconv" + + "github.com/specterops/dawgs/graph" +) + +func WriteDigraphToGML(digraph DirectedGraph, path string) error { + if out, err := os.Create(path); err != nil { + return err + } else { + defer out.Close() + + out.WriteString("graph [\n") + out.WriteString("directed 1\n") + + digraph.EachNode(func(node uint64) bool { + out.WriteString("node [\n") + out.WriteString("id ") + out.WriteString(strconv.FormatUint(node, 10)) + out.WriteString("\n]\n") + + return true + }) + + digraph.EachNode(func(node uint64) bool { + digraph.EachAdjacentNode(node, graph.DirectionBoth, func(adjacent uint64) bool { + out.WriteString("edge [\n") + out.WriteString("source ") + out.WriteString(strconv.FormatUint(node, 10)) + out.WriteString("\ntarget ") + out.WriteString(strconv.FormatUint(adjacent, 10)) + out.WriteString("\n]\n") + + return true + }) + + return true + }) + + out.WriteString("]\n") + return nil + } +} diff --git a/container/segment.go b/container/segment.go new file mode 100644 index 0000000..5047cc1 --- /dev/null +++ b/container/segment.go @@ -0,0 +1,168 @@ +package container + +import ( + "encoding/binary" + "io" + "strconv" + "strings" +) + +func MarshalSegment(segment *Segment, writer io.Writer) error { + nextBytes := make([]byte, 8) + + for cursor := segment; cursor != nil; cursor = cursor.Previous { + binary.LittleEndian.PutUint64(nextBytes, cursor.Node) + + if _, err := writer.Write(nextBytes); err != nil { + return err + } + + if cursor.Previous != nil { + binary.LittleEndian.PutUint64(nextBytes, cursor.Edge) + + if _, err := writer.Write(nextBytes); err != nil { + return err + } + } + } + + return nil +} + +func UnmarshalSegment(segmentBytes []byte) *Segment { + var ( + nextSegment *Segment + terminalSegment = &Segment{} + ) + + for byteNum := 0; byteNum*8 < len(segmentBytes); byteNum += 1 { + var ( + startIdx = byteNum * 8 + nextID = binary.LittleEndian.Uint64(segmentBytes[startIdx : startIdx+8]) + ) + + if nextSegment == nil { + nextSegment = terminalSegment + nextSegment.Node = nextID + } else if nextSegment.Previous == nil { + nextSegment.Edge = nextID + nextSegment.Previous = &Segment{} + } else { + nextSegment = nextSegment.Previous + nextSegment.Node = nextID + } + } + + return terminalSegment +} + +type Segment struct { + Node uint64 + Edge uint64 + Previous *Segment +} + +func (s *Segment) Each(visitor func(cursor *Segment) bool) { + for cursor := s; cursor != nil; cursor = cursor.Previous { + if !visitor(cursor) { + break + } + } +} + +func (s *Segment) Root() uint64 { + var root uint64 + + for cursor := s; cursor != nil; cursor = cursor.Previous { + root = cursor.Node + } + + return root +} + +// Trunk returns the edge closest to the root node of the segment +func (s *Segment) Trunk() uint64 { + var trunk uint64 + + for cursor := s; cursor != nil; cursor = cursor.Previous { + if cursor.Previous != nil { + trunk = cursor.Edge + } + } + + return trunk +} + +func (s *Segment) Edges() []uint64 { + var edges []uint64 + + for cursor := s; cursor != nil; cursor = cursor.Previous { + if cursor.Previous != nil { + edges = append(edges, cursor.Edge) + } + } + + return edges +} + +func (s *Segment) Nodes() []uint64 { + var nodes []uint64 + + for cursor := s; cursor != nil; cursor = cursor.Previous { + nodes = append(nodes, cursor.Node) + } + + return nodes +} + +func (s *Segment) Depth() int { + depth := 0 + + for cursor := s; cursor != nil; cursor = cursor.Previous { + depth += 1 + } + + return depth +} + +type SerializedSegment struct { + Nodes []uint64 + Edges []uint64 +} + +func (s SerializedSegment) ToSegment() *Segment { + cursor := &Segment{} + + for nodeIndex := 0; nodeIndex < len(s.Nodes); nodeIndex += 1 { + cursor.Node = s.Nodes[nodeIndex] + + if nodeIndex < len(s.Edges) { + nextSegment := &Segment{ + Edge: s.Edges[nodeIndex-1], + Previous: cursor, + } + + cursor = nextSegment + } + } + + return cursor +} + +func (s *Segment) Format() string { + var builder strings.Builder + + for cursor := s; cursor != nil; cursor = cursor.Previous { + builder.WriteString("(") + builder.WriteString(strconv.FormatUint(cursor.Node, 10)) + builder.WriteString(")") + + if cursor.Previous != nil { + builder.WriteString("-[") + builder.WriteString(strconv.FormatUint(cursor.Edge, 10)) + builder.WriteString("]->") + } + } + + return builder.String() +} diff --git a/container/segment_test.go b/container/segment_test.go new file mode 100644 index 0000000..25557b2 --- /dev/null +++ b/container/segment_test.go @@ -0,0 +1,55 @@ +package container_test + +import ( + "bytes" + "testing" + + "github.com/specterops/dawgs/container" +) + +func Test_MarshalSegment(t *testing.T) { + var ( + buffer = &bytes.Buffer{} + segment = &container.Segment{ + Node: 5, + Edge: 14, + Previous: &container.Segment{ + Node: 4, + Edge: 13, + Previous: &container.Segment{ + Node: 3, + Edge: 12, + Previous: &container.Segment{ + Node: 2, + Edge: 11, + Previous: &container.Segment{ + Node: 1, + }, + }, + }, + }, + } + ) + + container.MarshalSegment(segment, buffer) + + readSegment := container.UnmarshalSegment(buffer.Bytes()) + + segment.Each(func(cursor *container.Segment) bool { + if readSegment.Node != cursor.Node { + t.Fatalf("Segments do not match - marshaled node %d - origin node %d", readSegment.Node, cursor.Node) + } + + if readSegment.Previous == nil { + if cursor.Previous != nil { + t.Fatal("Segments do not match") + } + } else if readSegment.Edge != cursor.Edge { + t.Fatal("Segments do not match") + } else { + readSegment = readSegment.Previous + } + + return true + }) +} diff --git a/container/triplestore.go b/container/triplestore.go new file mode 100644 index 0000000..f350fc4 --- /dev/null +++ b/container/triplestore.go @@ -0,0 +1,228 @@ +package container + +import ( + "github.com/gammazero/deque" + "github.com/specterops/dawgs/cardinality" + "github.com/specterops/dawgs/graph" +) + +type Edge struct { + ID uint64 + Start uint64 + End uint64 +} + +func (s Edge) Pick(direction graph.Direction) uint64 { + if direction == graph.DirectionOutbound { + return s.End + } + + return s.Start +} + +type Triplestore interface { + DirectedGraph + + NumEdges() uint64 + AdjacentEdges(node uint64, direction graph.Direction) []uint64 + EachAdjacentEdge(node uint64, direction graph.Direction, delegate func(next Edge) bool) +} + +type triplestore struct { + nodes cardinality.Duplex[uint64] + edges []Edge + deletedEdges cardinality.Duplex[uint64] + startIndex map[uint64]cardinality.Duplex[uint64] + endIndex map[uint64]cardinality.Duplex[uint64] +} + +func NewTriplestore() Triplestore { + return &triplestore{ + nodes: cardinality.NewBitmap64(), + deletedEdges: cardinality.NewBitmap64(), + startIndex: map[uint64]cardinality.Duplex[uint64]{}, + endIndex: map[uint64]cardinality.Duplex[uint64]{}, + } +} + +func (s *triplestore) DeleteEdge(id uint64) { + s.deletedEdges.Add(id) +} + +func (s *triplestore) NumNodes() uint64 { + return s.nodes.Cardinality() +} + +func (s *triplestore) Nodes() cardinality.Duplex[uint64] { + return s.nodes.Clone() +} + +func (s *triplestore) EachNode(delegate func(node uint64) bool) { + s.nodes.Each(delegate) +} + +func (s *triplestore) AddEdge(edge, start, end uint64) { + s.edges = append(s.edges, Edge{ + ID: edge, + Start: start, + End: end, + }) + + edgeIdx := len(s.edges) - 1 + + if edgeBitmap, exists := s.startIndex[start]; exists { + edgeBitmap.Add(uint64(edgeIdx)) + } else { + edgeBitmap = cardinality.NewBitmap64() + edgeBitmap.Add(uint64(edgeIdx)) + + s.startIndex[start] = edgeBitmap + } + + if edgeBitmap, exists := s.endIndex[end]; exists { + edgeBitmap.Add(uint64(edgeIdx)) + } else { + edgeBitmap = cardinality.NewBitmap64() + edgeBitmap.Add(uint64(edgeIdx)) + + s.endIndex[end] = edgeBitmap + } + + s.nodes.Add(start, end) +} + +func (s *triplestore) adjacentEdgeIndices(node uint64, direction graph.Direction) cardinality.Duplex[uint64] { + edgeIndices := cardinality.NewBitmap64() + + switch direction { + case graph.DirectionOutbound: + if outboundEdges, hasOutbound := s.startIndex[node]; hasOutbound { + edgeIndices.Or(outboundEdges) + } + + case graph.DirectionInbound: + if inboundEdges, hasInbound := s.endIndex[node]; hasInbound { + edgeIndices.Or(inboundEdges) + } + + default: + if outboundEdges, hasOutbound := s.startIndex[node]; hasOutbound { + edgeIndices.Or(outboundEdges) + } + + if inboundEdges, hasInbound := s.endIndex[node]; hasInbound { + edgeIndices.Or(inboundEdges) + } + } + + return edgeIndices +} + +func (s *triplestore) AdjacentEdges(node uint64, direction graph.Direction) []uint64 { + var ( + edgeIndices = s.adjacentEdgeIndices(node, direction) + edgeIDs = make([]uint64, 0, edgeIndices.Cardinality()) + ) + + edgeIndices.Each(func(value uint64) bool { + edgeIDs = append(edgeIDs, s.edges[value].ID) + return true + }) + + return edgeIDs +} + +func (s *triplestore) adjacent(node uint64, direction graph.Direction) cardinality.Duplex[uint64] { + nodes := cardinality.NewBitmap64() + + s.adjacentEdgeIndices(node, direction).Each(func(edgeIndex uint64) bool { + if edge := s.edges[edgeIndex]; !s.deletedEdges.Contains(edge.ID) { + switch direction { + case graph.DirectionOutbound: + nodes.Add(edge.End) + + case graph.DirectionInbound: + nodes.Add(edge.Start) + + default: + nodes.Add(edge.End) + nodes.Add(edge.Start) + } + } + + return true + }) + + return nodes +} + +func (s *triplestore) AdjacentNodes(node uint64, direction graph.Direction) []uint64 { + return s.adjacent(node, direction).Slice() +} + +func (s *triplestore) EachAdjacentNode(node uint64, direction graph.Direction, delegate func(adjacent uint64) bool) { + s.adjacent(node, direction).Each(delegate) +} + +func (s *triplestore) Degrees(node uint64, direction graph.Direction) uint64 { + if adjacent := s.adjacent(node, direction); adjacent != nil { + return adjacent.Cardinality() + } + + return 0 +} + +func (s *triplestore) NumEdges() uint64 { + return uint64(len(s.edges)) +} + +func (s *triplestore) EachAdjacentEdge(node uint64, direction graph.Direction, delegate func(next Edge) bool) { + s.adjacentEdgeIndices(node, direction).Each(func(edgeIndex uint64) bool { + return delegate(s.edges[edgeIndex]) + }) +} + +func TSBFS(ts Triplestore, nodeID uint64, direction graph.Direction, maxDepth int, descentFilter func(edge Edge) bool, handler func(segment *Segment) bool) int { + var ( + traversals deque.Deque[*Segment] + numImcompletePaths = 0 + ) + + traversals.PushBack(&Segment{ + Node: nodeID, + }) + + for remainingTraversals := traversals.Len(); remainingTraversals > 0; remainingTraversals = traversals.Len() { + var ( + nextSegment = traversals.PopFront() + segmentDepth = nextSegment.Depth() + depthExceded = maxDepth > 0 && maxDepth < segmentDepth + ) + + if !depthExceded { + ts.EachAdjacentEdge(nextSegment.Node, direction, func(nextEdge Edge) bool { + if descentFilter(nextEdge) { + traversals.PushBack(&Segment{ + Node: nextEdge.Pick(direction), + Edge: nextEdge.ID, + Previous: nextSegment, + }) + } + + return true + }) + } + + if segmentDepth > 1 && remainingTraversals-1 == traversals.Len() { + if depthExceded { + numImcompletePaths += 1 + } + + if !handler(nextSegment) { + break + } + } + } + + return numImcompletePaths +} diff --git a/database/driver.go b/database/driver.go index b05d0bd..0cc967e 100644 --- a/database/driver.go +++ b/database/driver.go @@ -30,12 +30,25 @@ type Result interface { type Driver interface { WithGraph(target Graph) Driver + // Eventually to be deprecated - this is currently a translation gap where cysql doesn't correctly + // marshal create statements + // + // Deprecated: This function will be removed in future version. CreateNode(ctx context.Context, node *graph.Node) (graph.ID, error) + + // Eventually to be deprecated - this is currently a translation gap where cysql doesn't correctly + // marshal create statements + // + // Deprecated: This function will be removed in future version. CreateRelationship(ctx context.Context, relationship *graph.Relationship) (graph.ID, error) Exec(ctx context.Context, query *cypher.RegularQuery, parameters map[string]any) Result Explain(ctx context.Context, query *cypher.RegularQuery, parameters map[string]any) Result Profile(ctx context.Context, query *cypher.RegularQuery, parameters map[string]any) Result + + // Mapper is supporting backward compat for v1 + // + // Deprecated: This function will be removed in future version. Mapper() graph.ValueMapper } diff --git a/database/v1compat/switch.go b/database/v1compat/switch.go index 8f4e127..9af81d9 100644 --- a/database/v1compat/switch.go +++ b/database/v1compat/switch.go @@ -4,6 +4,8 @@ import ( "context" "errors" "sync" + + "github.com/specterops/dawgs/database" ) var ( @@ -51,6 +53,13 @@ func NewDatabaseSwitch(ctx context.Context, initialDB Database) *DatabaseSwitch } } +func (s *DatabaseSwitch) V2() database.Instance { + s.currentDBLock.RLock() + defer s.currentDBLock.RUnlock() + + return s.currentDB.V2() +} + func (s *DatabaseSwitch) SetDefaultGraph(ctx context.Context, graphSchema Graph) error { s.currentDBLock.RLock() defer s.currentDBLock.RUnlock() diff --git a/go.mod b/go.mod index de0f1cd..d5f3c66 100644 --- a/go.mod +++ b/go.mod @@ -10,10 +10,13 @@ require ( github.com/bits-and-blooms/bitset v1.24.4 github.com/cespare/xxhash/v2 v2.3.0 github.com/gammazero/deque v1.2.0 + github.com/go-echarts/go-echarts/v2 v2.6.7 github.com/jackc/pgtype v1.14.4 github.com/jackc/pgx/v5 v5.7.6 github.com/neo4j/neo4j-go-driver/v5 v5.28.4 github.com/stretchr/testify v1.11.1 + golang.org/x/exp v0.0.0-20251209150349-8475f28825e9 + github.com/specterops/dawgs v0.3.1 ) require ( @@ -29,9 +32,6 @@ require ( github.com/mschoch/smat v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/crypto v0.46.0 // indirect - golang.org/x/exp v0.0.0-20251209150349-8475f28825e9 // indirect - golang.org/x/net v0.47.0 // indirect - golang.org/x/oauth2 v0.32.0 // indirect golang.org/x/sync v0.19.0 // indirect golang.org/x/text v0.32.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index b452d97..9417085 100644 --- a/go.sum +++ b/go.sum @@ -1,27 +1,19 @@ -cuelabs.dev/go/oci/ociregistry v0.0.0-20250304105642-27e071d2c9b1 h1:Dmbd5Q+ENb2C6carvwrMsrOUwJ9X9qfL5JdW32gYAHo= -cuelabs.dev/go/oci/ociregistry v0.0.0-20250304105642-27e071d2c9b1/go.mod h1:dqrnoZx62xbOZr11giMPrWbhlaV8euHwciXZEy3baT8= -cuelang.org/go v0.13.2 h1:SagzeEASX4E2FQnRbItsqa33sSelrJjQByLqH9uZCE8= -cuelang.org/go v0.13.2/go.mod h1:8MoQXu+RcXsa2s9mebJN1HJ1orVDc9aI9/yKi6Dzsi4= +cuelabs.dev/go/oci/ociregistry v0.0.0-20250722084951-074d06050084 h1:4k1yAtPvZJZQTu8DRY8muBo0LHv6TqtrE0AO5n6IPYs= +cuelabs.dev/go/oci/ociregistry v0.0.0-20250722084951-074d06050084/go.mod h1:4WWeZNxUO1vRoZWAHIG0KZOd6dA25ypyWuwD3ti0Tdc= cuelang.org/go v0.15.1 h1:MRnjc/KJE+K42rnJ3a+425f1jqXeOOgq9SK4tYRTtWw= cuelang.org/go v0.15.1/go.mod h1:NYw6n4akZcTjA7QQwJ1/gqWrrhsN4aZwhcAL0jv9rZE= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= -github.com/RoaringBitmap/roaring/v2 v2.10.0 h1:HbJ8Cs71lfCJyvmSptxeMX2PtvOC8yonlU0GQcy2Ak0= -github.com/RoaringBitmap/roaring/v2 v2.10.0/go.mod h1:FiJcsfkGje/nZBZgCu0ZxCPOKD/hVXDS2dXi7/eUFE0= github.com/RoaringBitmap/roaring/v2 v2.14.4 h1:4aKySrrg9G/5oRtJ3TrZLObVqxgQ9f1znCRBwEwjuVw= github.com/RoaringBitmap/roaring/v2 v2.14.4/go.mod h1:oMvV6omPWr+2ifRdeZvVJyaz+aoEUopyv5iH0u/+wbY= github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= github.com/axiomhq/hyperloglog v0.2.5 h1:Hefy3i8nAs8zAI/tDp+wE7N+Ltr8JnwiW3875pvl0N8= github.com/axiomhq/hyperloglog v0.2.5/go.mod h1:DLUK9yIzpU5B6YFLjxTIcbHu1g4Y1WQb1m5RH3radaM= -github.com/bits-and-blooms/bitset v1.12.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= -github.com/bits-and-blooms/bitset v1.24.0 h1:H4x4TuulnokZKvHLfzVRTHJfFfnHEeSYJizujEZvmAM= -github.com/bits-and-blooms/bitset v1.24.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/bits-and-blooms/bitset v1.24.4 h1:95H15Og1clikBrKr/DuzMXkQzECs1M6hhoGXLwLQOZE= github.com/bits-and-blooms/bitset v1.24.4/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/cockroachdb/apd/v3 v3.2.1 h1:U+8j7t0axsIgvQUqthuNm82HIrYXodOV2iWLWtEaIwg= github.com/cockroachdb/apd/v3 v3.2.1/go.mod h1:klXJcjp+FffLTHlhIG69tezTDvdP065naDsHzKhYSqc= @@ -33,12 +25,12 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-metro v0.0.0-20250106013310-edb8663e5e33 h1:ucRHb6/lvW/+mTEIGbvhcYU3S8+uSNkuMjx/qZFfhtM= github.com/dgryski/go-metro v0.0.0-20250106013310-edb8663e5e33/go.mod h1:c9O8+fpSOX1DM8cPNSkX/qsBWdkD4yd2dpciOWQjpBw= -github.com/emicklei/proto v1.14.0 h1:WYxC0OrBuuC+FUCTZvb8+fzEHdZMwLEF+OnVfZA3LXU= -github.com/emicklei/proto v1.14.0/go.mod h1:rn1FgRS/FANiZdD2djyH7TMA9jdRDcYQ9IEN9yvjX0A= -github.com/gammazero/deque v1.0.0 h1:LTmimT8H7bXkkCy6gZX7zNLtkbz4NdS2z8LZuor3j34= -github.com/gammazero/deque v1.0.0/go.mod h1:iflpYvtGfM3U8S8j+sZEKIak3SAKYpA5/SQewgfXDKo= +github.com/emicklei/proto v1.14.2 h1:wJPxPy2Xifja9cEMrcA/g08art5+7CGJNFNk35iXC1I= +github.com/emicklei/proto v1.14.2/go.mod h1:rn1FgRS/FANiZdD2djyH7TMA9jdRDcYQ9IEN9yvjX0A= github.com/gammazero/deque v1.2.0 h1:scEFO8Uidhw6KDU5qg1HA5fYwM0+us2qdeJqm43bitU= github.com/gammazero/deque v1.2.0/go.mod h1:JVrR+Bj1NMQbPnYclvDlvSX0nVGReLrQZ0aUMuWLctg= +github.com/go-echarts/go-echarts/v2 v2.6.7 h1:J9Y6/vVn06BBSGeoowPbdUWsxzHktwqF1uwOuSEUyTY= +github.com/go-echarts/go-echarts/v2 v2.6.7/go.mod h1:Z+spPygZRIEyqod69r0WMnkN5RV3MwhYDtw601w3G8w= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI= @@ -97,19 +89,14 @@ github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQ github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= github.com/jackc/pgx/v4 v4.18.2 h1:xVpYkNR5pk5bMCZGfClbO962UIqVABcAGt7ha1s/FeU= github.com/jackc/pgx/v4 v4.18.2/go.mod h1:Ey4Oru5tH5sB6tV7hDmfWFahwF15Eb7DNXlRKx2CkVw= -github.com/jackc/pgx/v5 v5.7.5 h1:JHGfMnQY+IEtGM63d+NGMjoRpysB2JBwDr5fsngwmJs= -github.com/jackc/pgx/v5 v5.7.5/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk= github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v1.3.0 h1:eHK/5clGOatcjX3oWGBO/MpxpbHzSwud5EWTSCI+MX0= github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= -github.com/kamstrup/intmap v0.5.1 h1:ENGAowczZA+PJPYYlreoqJvWgQVtAmX1l899WfYFVK0= -github.com/kamstrup/intmap v0.5.1/go.mod h1:gWUVWHKzWj8xpJVFf5GC0O26bWmv3GqdnIX/LMT6Aq4= github.com/kamstrup/intmap v0.5.2 h1:qnwBm1mh4XAnW9W9Ue9tZtTff8pS6+s6iKF6JRIV2Dk= github.com/kamstrup/intmap v0.5.2/go.mod h1:gWUVWHKzWj8xpJVFf5GC0O26bWmv3GqdnIX/LMT6Aq4= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= @@ -138,8 +125,6 @@ github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQ github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0= github.com/mschoch/smat v0.2.0 h1:8imxQsjDm8yFEAVBe7azKmKSgzSkZXDuKkSq9374khM= github.com/mschoch/smat v0.2.0/go.mod h1:kc9mz7DoBKqDyiRL7VZN8KvXQMWeTaVnttLRXOlotKw= -github.com/neo4j/neo4j-go-driver/v5 v5.28.1 h1:RKWQW7wTgYAY2fU9S+9LaJ9OwRPbRc0I17tlT7nDmAY= -github.com/neo4j/neo4j-go-driver/v5 v5.28.1/go.mod h1:Vff8OwT7QpLm7L2yYr85XNWe9Rbqlbeb9asNXJTHO4k= github.com/neo4j/neo4j-go-driver/v5 v5.28.4 h1:7toxehVcYkZbyxV4W3Ib9VcnyRBQPucF+VwNNmtSXi4= github.com/neo4j/neo4j-go-driver/v5 v5.28.4/go.mod h1:Vff8OwT7QpLm7L2yYr85XNWe9Rbqlbeb9asNXJTHO4k= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= @@ -151,9 +136,8 @@ github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8 github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/protocolbuffers/txtpbfmt v0.0.0-20250129171521-feedd8250727 h1:A8EM8fVuYc0qbVMw9D6EiKdKTIm1SmLvAWcCc2mipGY= -github.com/protocolbuffers/txtpbfmt v0.0.0-20250129171521-feedd8250727/go.mod h1:VmWrOlMnBZNtToCWzRlZlIXcJqjo0hS5dwQbRD62gL8= github.com/protocolbuffers/txtpbfmt v0.0.0-20251016062345-16587c79cd91 h1:s1LvMaU6mVwoFtbxv/rCZKE7/fwDmDY684FfUe4c1Io= +github.com/protocolbuffers/txtpbfmt v0.0.0-20251016062345-16587c79cd91/go.mod h1:JSbkp0BviKovYYt9XunS95M3mLPibE9bGg+Y95DsEEY= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= @@ -178,8 +162,6 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= @@ -195,6 +177,8 @@ go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9E go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= +go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -207,12 +191,8 @@ golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.20.0/go.mod h1:Xwo95rrVNIoSMx9wa1JroENMToLWn3RNVrTBpLHgZPQ= -golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= -golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= -golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= -golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= golang.org/x/exp v0.0.0-20251209150349-8475f28825e9 h1:MDfG8Cvcqlt9XXrmEiD4epKn7VJHZO84hejP9Jmp0MM= golang.org/x/exp v0.0.0-20251209150349-8475f28825e9/go.mod h1:EPRbTFwzwjXj9NpYyyrvenVh9Y+GFeEvMNh7Xuz7xgU= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= @@ -220,9 +200,8 @@ golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKG golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w= -golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= +golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -232,18 +211,13 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= -golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= -golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= -golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= -golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY= golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= -golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -278,8 +252,6 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= -golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -293,14 +265,15 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= -golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= +golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= @@ -309,7 +282,6 @@ gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/graph/kind.go b/graph/kind.go index 5dd9ceb..7e6fd22 100644 --- a/graph/kind.go +++ b/graph/kind.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "fmt" "sort" + "strings" "sync" "unsafe" @@ -111,6 +112,10 @@ func (s Kinds) Strings() []string { return kindStrings } +func (s Kinds) Formatted() string { + return strings.Join(s.Strings(), ",") +} + // ContainsOneOf returns true if the Kinds contains one of the given Kind types or false if it does not. func (s Kinds) ContainsOneOf(others ...Kind) bool { for _, kind := range s { diff --git a/graph/properties.go b/graph/properties.go index 603f18f..3ac3f41 100644 --- a/graph/properties.go +++ b/graph/properties.go @@ -425,6 +425,18 @@ func (s *Properties) GetOrDefault(key string, defaultValue any) PropertyValue { return s.GetWithFallback(key, defaultValue) } +func PropertiesMustGetOrDefault[T any](properties *Properties, key string, defaultValue T) T { + value := properties.GetWithFallback(key, defaultValue) + + if !value.IsNil() { + if typedValue, typeOK := value.Any().(T); typeOK { + return typedValue + } + } + + return defaultValue +} + func (s *Properties) GetWithFallback(key string, defaultValue any, fallbackKeys ...string) PropertyValue { value := defaultValue diff --git a/query/query.go b/query/query.go index 7937972..8080188 100644 --- a/query/query.go +++ b/query/query.go @@ -59,28 +59,28 @@ func joinedExpressionList(operator cypher.Operator, operands []cypher.SyntaxNode } func Not(operand cypher.Expression) cypher.Expression { - switch typedOperand := operand.(type) { - case *cypher.KindMatcher: - // If the type doesn't match, this code does not handle the error. This will be caught during query build time - // instead. - if identifier, typeOK := typedOperand.Reference.(*cypher.Variable); typeOK && identifier.Symbol == Identifiers.relationship { - if len(typedOperand.Kinds) == 1 { - return cypher.NewComparison( - cypher.NewSimpleFunctionInvocation(cypher.EdgeTypeFunction, identifier), - cypher.OperatorNotEquals, - cypher.NewStringLiteral(typedOperand.Kinds[0].String()), - ) - } else { - return cypher.NewNegation( - cypher.NewComparison( - cypher.NewSimpleFunctionInvocation(cypher.EdgeTypeFunction, identifier), - cypher.OperatorIn, - cypher.NewStringListLiteral(typedOperand.Kinds.Strings()), - ), - ) - } - } - } + // switch typedOperand := operand.(type) { + // case *cypher.KindMatcher: + // // If the type doesn't match, this code does not handle the error. This will be caught during query build time + // // instead. + // if identifier, typeOK := typedOperand.Reference.(*cypher.Variable); typeOK && identifier.Symbol == Identifiers.relationship { + // if len(typedOperand.Kinds) == 1 { + // return cypher.NewComparison( + // cypher.NewSimpleFunctionInvocation(cypher.EdgeTypeFunction, identifier), + // cypher.OperatorNotEquals, + // cypher.NewStringLiteral(typedOperand.Kinds[0].String()), + // ) + // } else { + // return cypher.NewNegation( + // cypher.NewComparison( + // cypher.NewSimpleFunctionInvocation(cypher.EdgeTypeFunction, identifier), + // cypher.OperatorIn, + // cypher.NewStringListLiteral(typedOperand.Kinds.Strings()), + // ), + // ) + // } + // } + // } return cypher.NewNegation(operand) } @@ -148,6 +148,7 @@ type KindsContinuation interface { } type Comparable interface { + In(value any) cypher.Expression Contains(value any) cypher.Expression Equals(value any) cypher.Expression GreaterThan(value any) cypher.Expression @@ -185,6 +186,10 @@ func (s *comparisonContinuation) asComparison(operator cypher.Operator, rOperand ) } +func (s *comparisonContinuation) In(value any) cypher.Expression { + return s.asComparison(cypher.OperatorIn, value) +} + func (s *comparisonContinuation) Contains(value any) cypher.Expression { return s.asComparison(cypher.OperatorContains, value) } diff --git a/registry.go b/registry.go index cc93fcd..436eacb 100644 --- a/registry.go +++ b/registry.go @@ -13,14 +13,18 @@ var ( ErrDriverMissing = errors.New("driver missing") ) +// DriverConstructor describes a function that takes a context and a dawgs configuration struct and returns either +// a valid `database.Instance` reference or the associated error that prevented instantiation. type DriverConstructor func(ctx context.Context, cfg Config) (database.Instance, error) var availableDrivers = map[string]DriverConstructor{} +// Register registers a dawgs driver under the given driverName func Register(driverName string, constructor DriverConstructor) { availableDrivers[driverName] = constructor } +// Config is the basic configuration struct for a dawgs connection type Config struct { GraphQueryMemoryLimit size.Size ConnectionString string @@ -30,6 +34,8 @@ type Config struct { DriverConfig any } +// Open creates a new dawgs graph database instance. This function expects the driver name, often imported to ensure +// that registration logic occurs. func Open(ctx context.Context, driverName string, config Config) (database.Instance, error) { if driverConstructor, hasDriver := availableDrivers[driverName]; !hasDriver { return nil, ErrDriverMissing @@ -38,6 +44,7 @@ func Open(ctx context.Context, driverName string, config Config) (database.Insta } } +// OpenV1 creates a new dawgs graph database instance but with a dawgs version 1 compatible interface. func OpenV1(ctx context.Context, driverName string, config Config) (v1compat.Database, error) { if driver, err := Open(ctx, driverName, config); err != nil { return nil, err diff --git a/registry_integration_test.go b/registry_integration_test.go index 968b70c..3f89b1e 100644 --- a/registry_integration_test.go +++ b/registry_integration_test.go @@ -4,13 +4,12 @@ package dawgs_test import ( "context" - "fmt" - "log/slog" "testing" - //pg_v2 "github.com/specterops/dawgs/drivers/pg/v2" + "github.com/specterops/dawgs" + "github.com/specterops/dawgs/database/pg" + "github.com/specterops/dawgs/database" - "github.com/specterops/dawgs/database/neo4j" "github.com/specterops/dawgs/graph" "github.com/specterops/dawgs/query" "github.com/specterops/dawgs/util/size" @@ -20,16 +19,11 @@ import ( func Test(t *testing.T) { ctx := context.Background() - graphDB, err := database.Open(ctx, neo4j.DriverName, database.Config{ + graphDB, err := dawgs.Open(ctx, pg.DriverName, dawgs.Config{ GraphQueryMemoryLimit: size.Gibibyte * 1, - ConnectionString: "neo4j://neo4j:neo4jj@localhost:7687", + ConnectionString: "postgresql://postgres:postgres@localhost:5432/bhe", }) - //graphDB, err := v2.Open(ctx, pg_v2.DriverName, v2.Config{ - // GraphQueryMemoryLimit: size.Gibibyte * 1, - // ConnectionString: "postgresql://postgres:postgres@localhost:5432/bhe", - //}) - require.NoError(t, err) require.NoError(t, graphDB.AssertSchema(ctx, database.NewSchema( @@ -39,42 +33,48 @@ func Test(t *testing.T) { Nodes: graph.Kinds{graph.StringKind("Node")}, Edges: graph.Kinds{graph.StringKind("Edge")}, NodeIndexes: []database.Index{{ - Name: "node_label_name_index", + Name: "node_name_index", Field: "name", Type: database.IndexTypeTextSearch, }}, }))) - preparedQuery, err := query.New().Return(query.Node()).Limit(10).Build() - require.NoError(t, err) - require.NoError(t, graphDB.Session(ctx, func(ctx context.Context, driver database.Driver) error { - return driver.CreateNode(ctx, graph.PrepareNode(graph.AsProperties(map[string]any{ + _, err := driver.CreateNode(ctx, graph.PrepareNode(graph.AsProperties(map[string]any{ "name": "THAT NODE", }), graph.StringKind("Node"))) + + return err })) - require.NoError(t, graphDB.Session(ctx, database.FetchNodes(preparedQuery, func(node *graph.Node) error { - slog.Info(fmt.Sprintf("Got result from DB: %v", node)) - return nil - }))) + require.NoError(t, graphDB.Session(ctx, func(ctx context.Context, driver database.Driver) error { + myQuery := query.New().Where( + query.Node().Property("a").Equals(1234), + ).OrderBy( + query.Node().Property("my_order"), + ) - require.NoError(t, graphDB.Transaction(ctx, database.FetchNodes(preparedQuery, func(node *graph.Node) error { - slog.Info(fmt.Sprintf("Got result from DB: %v", node)) - return nil - }))) + if builtQuery, err := myQuery.Build(); err != nil { + return err + } else { + var ( + node graph.Node + result = driver.Exec(ctx, builtQuery.Query, builtQuery.Parameters) + ) + + defer result.Close(ctx) - //require.NoError(t, graphDB.Transaction(ctx, func(ctx context.Context, driver v2.Driver) error { - // builder := v2.Query().Create( - // v2.Node().NodePattern(graph.Kinds{graph.StringKind("A")}, cypher.NewParameter("props", map[string]any{ - // "name": "1234", - // })), - // ) - // - // if preparedQuery, err := builder.Build(); err != nil { - // return err - // } else { - // return driver.CypherQuery(ctx, preparedQuery.Query, preparedQuery.Parameters).Close(ctx) - // } - //})) + if result.HasNext(ctx) { + if err := result.Scan(&node); err != nil { + return err + } else { + require.Equal(t, "THAT NODE", node.Properties.GetOrDefault("name", "").Any()) + } + } else { + t.Fatal("no node found") + } + } + + return nil + })) } diff --git a/util/slog.go b/util/slog.go new file mode 100644 index 0000000..fe8a511 --- /dev/null +++ b/util/slog.go @@ -0,0 +1,26 @@ +package util + +import ( + "log/slog" + "sync/atomic" + "time" +) + +var slogMeasureID = &atomic.Int64{} + +func SLogMeasure(msg string, args ...any) func(args ...any) { + var ( + then = time.Now() + measurementID = slogMeasureID.Add(1) + allArgs = append(args, slog.Int64("measurement_id", measurementID)) + ) + + slog.Info(msg, append(allArgs, slog.String("state", "enter"))...) + + return func(args ...any) { + exitArgs := append(allArgs, slog.Duration("elapsed", time.Since(then)), slog.String("state", "exit")) + exitArgs = append(exitArgs, args...) + + slog.Info(msg, exitArgs...) + } +} From 88f7343239bc75ca7098279d80b15c105cd00058 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Wed, 17 Dec 2025 13:16:52 -0800 Subject: [PATCH 3/5] wip --- algo/scc.go | 139 +++++++++++++++++++-------------------- container/digraph.go | 4 -- container/fetch.go | 7 ++ container/triplestore.go | 45 +++++++++++++ database/driver.go | 4 -- go.mod | 3 +- util/context.go | 7 ++ util/slog.go | 5 ++ 8 files changed, 134 insertions(+), 80 deletions(-) create mode 100644 util/context.go diff --git a/algo/scc.go b/algo/scc.go index 830cc12..146fa27 100644 --- a/algo/scc.go +++ b/algo/scc.go @@ -1,6 +1,7 @@ package algo import ( + "context" "math" "github.com/gammazero/deque" @@ -10,7 +11,7 @@ import ( "github.com/specterops/dawgs/util" ) -func StronglyConnectedComponents(digraph container.DirectedGraph) ([]cardinality.Duplex[uint64], map[uint64]uint64) { +func StronglyConnectedComponents(ctx context.Context, digraph container.DirectedGraph) ([]cardinality.Duplex[uint64], map[uint64]uint64) { defer util.SLogMeasure("StronglyConnectedComponents")() type descentCursor struct { @@ -34,86 +35,84 @@ func StronglyConnectedComponents(digraph container.DirectedGraph) ([]cardinality ) digraph.EachNode(func(node uint64) bool { - if _, visited := visitedIndex[node]; visited { - return true - } - - dfsDescentStack = append(dfsDescentStack, &descentCursor{ - id: node, - branches: digraph.AdjacentNodes(node, graph.DirectionOutbound), - branchIdx: 0, - }) - - for len(dfsDescentStack) > 0 { - nextCursor := dfsDescentStack[len(dfsDescentStack)-1] + if _, visited := visitedIndex[node]; !visited { + dfsDescentStack = append(dfsDescentStack, &descentCursor{ + id: node, + branches: digraph.AdjacentNodes(node, graph.DirectionOutbound), + branchIdx: 0, + }) - if nextCursor.branchIdx == 0 { - // First visit of this node - visitedIndex[nextCursor.id] = index - lowLinks[nextCursor.id] = index - index += 1 + for len(dfsDescentStack) > 0 { + nextCursor := dfsDescentStack[len(dfsDescentStack)-1] - stack = append(stack, nextCursor.id) - onStack.Add(nextCursor.id) - } else if lastSearchedNodeID != nextCursor.id { - // Revisiting this node from a descending DFS - lowLinks[nextCursor.id] = min(lowLinks[nextCursor.id], lowLinks[lastSearchedNodeID]) - } + if nextCursor.branchIdx == 0 { + // First visit of this node + visitedIndex[nextCursor.id] = index + lowLinks[nextCursor.id] = index + index += 1 - // Set to the current cursor ID for ascent - lastSearchedNodeID = nextCursor.id - - if nextCursor.branchIdx < len(nextCursor.branches) { - // Advance to the next branch - nextBranchID := nextCursor.branches[nextCursor.branchIdx] - nextCursor.branchIdx += 1 - - if _, visited := visitedIndex[nextBranchID]; !visited { - // This node has not been visited yet, run a DFS for it - lastSearchedNodeID = nextBranchID - - dfsDescentStack = append(dfsDescentStack, &descentCursor{ - id: nextBranchID, - branches: digraph.AdjacentNodes(nextBranchID, graph.DirectionOutbound), - branchIdx: 0, - }) - } else if onStack.Contains(nextBranchID) { - // Branch is on the traversal stack; hence it is also in the current SCC - lowLinks[nextCursor.id] = min(lowLinks[nextCursor.id], visitedIndex[nextBranchID]) + stack = append(stack, nextCursor.id) + onStack.Add(nextCursor.id) + } else if lastSearchedNodeID != nextCursor.id { + // Revisiting this node from a descending DFS + lowLinks[nextCursor.id] = min(lowLinks[nextCursor.id], lowLinks[lastSearchedNodeID]) } - } else { - // Finished visiting branches; exiting node - dfsDescentStack = dfsDescentStack[:len(dfsDescentStack)-1] - if lowLinks[nextCursor.id] == visitedIndex[nextCursor.id] { - var ( - scc = cardinality.NewBitmap64() - sccID = uint64(len(stronglyConnectedComponents)) - ) + // Set to the current cursor ID for ascent + lastSearchedNodeID = nextCursor.id + + if nextCursor.branchIdx < len(nextCursor.branches) { + // Advance to the next branch + nextBranchID := nextCursor.branches[nextCursor.branchIdx] + nextCursor.branchIdx += 1 + + if _, visited := visitedIndex[nextBranchID]; !visited { + // This node has not been visited yet, run a DFS for it + lastSearchedNodeID = nextBranchID + + dfsDescentStack = append(dfsDescentStack, &descentCursor{ + id: nextBranchID, + branches: digraph.AdjacentNodes(nextBranchID, graph.DirectionOutbound), + branchIdx: 0, + }) + } else if onStack.Contains(nextBranchID) { + // Branch is on the traversal stack; hence it is also in the current SCC + lowLinks[nextCursor.id] = min(lowLinks[nextCursor.id], visitedIndex[nextBranchID]) + } + } else { + // Finished visiting branches; exiting node + dfsDescentStack = dfsDescentStack[:len(dfsDescentStack)-1] + + if lowLinks[nextCursor.id] == visitedIndex[nextCursor.id] { + var ( + scc = cardinality.NewBitmap64() + sccID = uint64(len(stronglyConnectedComponents)) + ) - for { - // Unwind the stack to the root of the component - currentNode := stack[len(stack)-1] - stack = stack[:len(stack)-1] + for { + // Unwind the stack to the root of the component + currentNode := stack[len(stack)-1] + stack = stack[:len(stack)-1] - onStack.Remove(currentNode) + onStack.Remove(currentNode) - scc.Add(currentNode) + scc.Add(currentNode) - // Reverse index origin node to SCC - nodeToSCCIndex[currentNode] = sccID + // Reverse index origin node to SCC + nodeToSCCIndex[currentNode] = sccID - if currentNode == nextCursor.id { - break + if currentNode == nextCursor.id { + break + } } - } - stronglyConnectedComponents = append(stronglyConnectedComponents, scc) + stronglyConnectedComponents = append(stronglyConnectedComponents, scc) + } } } } - return true + return util.IsContextLive(ctx) }) return stronglyConnectedComponents, nodeToSCCIndex @@ -260,9 +259,9 @@ func (s ComponentGraph) OriginReachable(startID, endID uint64) bool { return s.ComponentReachable(startComponent, endComponent) } -func NewComponentGraph(originGraph container.DirectedGraph) ComponentGraph { +func NewComponentGraph(ctx context.Context, originGraph container.DirectedGraph) ComponentGraph { var ( - componentMembers, memberComponentLookup = StronglyConnectedComponents(originGraph) + componentMembers, memberComponentLookup = StronglyConnectedComponents(ctx, originGraph) componentDigraph = container.NewAdjacencyMapGraph() nextEdgeID = uint64(1) ) @@ -283,7 +282,7 @@ func NewComponentGraph(originGraph container.DirectedGraph) ComponentGraph { nextEdgeID += 1 } - return true + return util.IsContextLive(ctx) }) originGraph.EachAdjacentNode(node, graph.DirectionOutbound, func(adjacent uint64) bool { @@ -292,10 +291,10 @@ func NewComponentGraph(originGraph container.DirectedGraph) ComponentGraph { nextEdgeID += 1 } - return true + return util.IsContextLive(ctx) }) - return true + return util.IsContextLive(ctx) }) return ComponentGraph{ diff --git a/container/digraph.go b/container/digraph.go index 78cadc3..8dd72da 100644 --- a/container/digraph.go +++ b/container/digraph.go @@ -1,8 +1,6 @@ package container import ( - "fmt" - "github.com/gammazero/deque" "github.com/specterops/dawgs/cardinality" "github.com/specterops/dawgs/graph" @@ -28,8 +26,6 @@ func (s KindMap) FindFirst(id uint64) graph.Kind { } } - panic(fmt.Sprintf("Can't find kind for edge ID %d", id)) - return nil } diff --git a/container/fetch.go b/container/fetch.go index fbb2b32..7eab94e 100644 --- a/container/fetch.go +++ b/container/fetch.go @@ -104,6 +104,13 @@ type TSDB struct { EdgeKinds KindMap } +func NewTSDB() TSDB { + return TSDB{ + Triplestore: NewTriplestore(), + EdgeKinds: KindMap{}, + } +} + func FetchTriplestore(ctx context.Context, graphDB database.Instance, filter cypher.SyntaxNode) (TSDB, error) { tsdb := TSDB{ Triplestore: NewTriplestore(), diff --git a/container/triplestore.go b/container/triplestore.go index f350fc4..1d4193d 100644 --- a/container/triplestore.go +++ b/container/triplestore.go @@ -182,6 +182,51 @@ func (s *triplestore) EachAdjacentEdge(node uint64, direction graph.Direction, d }) } +func TSDFS(ts Triplestore, nodeID uint64, direction graph.Direction, maxDepth int, descentFilter func(edge Edge) bool, handler func(segment *Segment) bool) int { + var ( + traversals deque.Deque[*Segment] + numImcompletePaths = 0 + ) + + traversals.PushBack(&Segment{ + Node: nodeID, + }) + + for remainingTraversals := traversals.Len(); remainingTraversals > 0; remainingTraversals = traversals.Len() { + var ( + nextSegment = traversals.PopBack() + segmentDepth = nextSegment.Depth() + depthExceded = maxDepth > 0 && maxDepth < segmentDepth + ) + + if !depthExceded { + ts.EachAdjacentEdge(nextSegment.Node, direction, func(nextEdge Edge) bool { + if descentFilter(nextEdge) { + traversals.PushBack(&Segment{ + Node: nextEdge.Pick(direction), + Edge: nextEdge.ID, + Previous: nextSegment, + }) + } + + return true + }) + } + + if segmentDepth > 1 && remainingTraversals-1 == traversals.Len() { + if depthExceded { + numImcompletePaths += 1 + } + + if !handler(nextSegment) { + break + } + } + } + + return numImcompletePaths +} + func TSBFS(ts Triplestore, nodeID uint64, direction graph.Direction, maxDepth int, descentFilter func(edge Edge) bool, handler func(segment *Segment) bool) int { var ( traversals deque.Deque[*Segment] diff --git a/database/driver.go b/database/driver.go index 0cc967e..0d01386 100644 --- a/database/driver.go +++ b/database/driver.go @@ -20,10 +20,6 @@ type Result interface { Scan(scanTargets ...any) error Error() error Close(ctx context.Context) error - - // Values returns the next values array from the result. - // - // Deprecated: This function will be removed in future version. Values() []any } diff --git a/go.mod b/go.mod index d5f3c66..6ef87e6 100644 --- a/go.mod +++ b/go.mod @@ -15,8 +15,6 @@ require ( github.com/jackc/pgx/v5 v5.7.6 github.com/neo4j/neo4j-go-driver/v5 v5.28.4 github.com/stretchr/testify v1.11.1 - golang.org/x/exp v0.0.0-20251209150349-8475f28825e9 - github.com/specterops/dawgs v0.3.1 ) require ( @@ -32,6 +30,7 @@ require ( github.com/mschoch/smat v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/crypto v0.46.0 // indirect + golang.org/x/exp v0.0.0-20251209150349-8475f28825e9 // indirect golang.org/x/sync v0.19.0 // indirect golang.org/x/text v0.32.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/util/context.go b/util/context.go new file mode 100644 index 0000000..88539ac --- /dev/null +++ b/util/context.go @@ -0,0 +1,7 @@ +package util + +import "context" + +func IsContextLive(ctx context.Context) bool { + return ctx.Err() == nil +} diff --git a/util/slog.go b/util/slog.go index fe8a511..55bee45 100644 --- a/util/slog.go +++ b/util/slog.go @@ -24,3 +24,8 @@ func SLogMeasure(msg string, args ...any) func(args ...any) { slog.Info(msg, exitArgs...) } } + +func SLogError(msg string, err error, args ...any) { + allArgs := append([]any{slog.String("err", err.Error())}, args...) + slog.Error(msg, allArgs...) +} From 73df06b7d471e7e2deeff65653ec8310585ecaf7 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Wed, 17 Dec 2025 21:38:38 -0800 Subject: [PATCH 4/5] wip --- algo/closeness.go | 4 +- container/digraph.go | 23 +++++----- container/triplestore.go | 96 ++++++++++++++++++++++++++++++++++++++++ database/pg/mapper.go | 44 ++++++++++++++++++ graph/node.go | 34 ++++++-------- graph/properties.go | 14 ++++-- graph/relationship.go | 24 ++++++++++ 7 files changed, 203 insertions(+), 36 deletions(-) diff --git a/algo/closeness.go b/algo/closeness.go index 900ceed..298ab2e 100644 --- a/algo/closeness.go +++ b/algo/closeness.go @@ -17,7 +17,7 @@ func ClosenessForDirectedUnweightedGraph(digraph container.DirectedGraph, direct var distanceSum Weight = 0 for _, shortestPathTerminal := range shortestPathTerminals { - distanceSum += shortestPathTerminal.Distance + distanceSum += Weight(shortestPathTerminal.Distance) } if distanceSum > 0 { @@ -51,7 +51,7 @@ func ClosenessForDirectedUnweightedGraphParallel(digraph container.DirectedGraph var distanceSum Weight = 0 for _, shortestPathTerminal := range shortestPathTerminals { - distanceSum += shortestPathTerminal.Distance + distanceSum += Weight(shortestPathTerminal.Distance) } if distanceSum > 0 { diff --git a/container/digraph.go b/container/digraph.go index 8dd72da..81f9099 100644 --- a/container/digraph.go +++ b/container/digraph.go @@ -54,9 +54,10 @@ func (s KindDatabase) EdgeKind(edgeID uint64) graph.Kind { return s.EdgeKindMap.FindFirst(edgeID) } -type ShortestPathTerminal struct { - NodeID uint64 - Distance Weight +type PathTerminal struct { + Node uint64 + Weight Weight + Distance int } type DirectedGraph interface { @@ -83,25 +84,25 @@ func Dimensions(digraph DirectedGraph, direction graph.Direction) (uint64, uint6 return digraph.Nodes().Cardinality(), largestRow } -func BFSTree(digraph DirectedGraph, nodeID uint64, direction graph.Direction) []ShortestPathTerminal { +func BFSTree(digraph DirectedGraph, nodeID uint64, direction graph.Direction) []PathTerminal { var ( visited = cardinality.NewBitmap64() - queue deque.Deque[ShortestPathTerminal] - terminals []ShortestPathTerminal + queue deque.Deque[PathTerminal] + terminals []PathTerminal ) - queue.PushBack(ShortestPathTerminal{ - NodeID: nodeID, + queue.PushBack(PathTerminal{ + Node: nodeID, Distance: 0, }) for queue.Len() > 0 { nextCursor := queue.PopFront() - digraph.EachAdjacentNode(nextCursor.NodeID, direction, func(adjacentNodeID uint64) bool { + digraph.EachAdjacentNode(nextCursor.Node, direction, func(adjacentNodeID uint64) bool { if visited.CheckedAdd(adjacentNodeID) { - terminalCursor := ShortestPathTerminal{ - NodeID: adjacentNodeID, + terminalCursor := PathTerminal{ + Node: adjacentNodeID, Distance: nextCursor.Distance + 1, } diff --git a/container/triplestore.go b/container/triplestore.go index 1d4193d..112e0dd 100644 --- a/container/triplestore.go +++ b/container/triplestore.go @@ -1,6 +1,8 @@ package container import ( + "sync" + "github.com/gammazero/deque" "github.com/specterops/dawgs/cardinality" "github.com/specterops/dawgs/graph" @@ -271,3 +273,97 @@ func TSBFS(ts Triplestore, nodeID uint64, direction graph.Direction, maxDepth in return numImcompletePaths } + +type ThreadSafeDeque[T any] struct { + lock *sync.RWMutex + container deque.Deque[T] +} + +func (s *ThreadSafeDeque[T]) PushFront(elem T) { + s.lock.Lock() + s.container.PushFront(elem) + s.lock.Unlock() +} + +func (s *ThreadSafeDeque[T]) PushBack(elem T) { + s.lock.Lock() + s.container.PushBack(elem) + s.lock.Unlock() +} + +func (s *ThreadSafeDeque[T]) PopFront() T { + s.lock.Lock() + value := s.container.PopFront() + s.lock.Unlock() + return value +} + +func (s *ThreadSafeDeque[T]) PopBack() T { + s.lock.Lock() + value := s.container.PopBack() + s.lock.Unlock() + return value +} + +func (s *ThreadSafeDeque[T]) Len() int { + s.lock.RLock() + numElements := s.container.Len() + s.lock.RUnlock() + return numElements +} + +// SSPBFS is a parallel stateless shortest-path breadth-first search. +func TSStatelessBFS(ts Triplestore, rootNode uint64, direction graph.Direction, maxDepth int, descentFilter func(edge Edge) (Weight, bool), terminalHandler func(terminal PathTerminal) bool, numWorkers int) int { + var ( + traversals deque.Deque[PathTerminal] + numImcompletePaths = 0 + ) + + traversals.PushBack(PathTerminal{ + Node: rootNode, + Weight: 0, + Distance: 0, + }) + + for traversals.Len() > 0 { + var ( + nextSegment = traversals.PopFront() + hasExpansions = false + depthExceded = maxDepth > 0 && maxDepth < nextSegment.Distance + ) + + if !depthExceded { + nextDistance := nextSegment.Distance + 1 + + ts.EachAdjacentEdge(nextSegment.Node, direction, func(nextEdge Edge) bool { + if weight, shouldDescend := descentFilter(nextEdge); shouldDescend { + hasExpansions = true + + if nextSegment.Distance > 0 { + weight *= nextSegment.Weight + } + + traversals.PushBack(PathTerminal{ + Node: nextEdge.Pick(direction), + Distance: nextDistance, + Weight: weight, + }) + } + + return true + }) + } + + if nextSegment.Distance >= 1 && !hasExpansions { + if depthExceded { + numImcompletePaths += 1 + } + + if !terminalHandler(nextSegment) { + break + } + } + } + + return numImcompletePaths +} diff --git a/database/pg/mapper.go b/database/pg/mapper.go index 181d221..6b4b17a 100644 --- a/database/pg/mapper.go +++ b/database/pg/mapper.go @@ -48,6 +48,50 @@ func mapKinds(ctx context.Context, kindMapper KindMapper, untypedValue any) (gra return nil, validType } +func TryMapNode(ctx context.Context, values map[string]any, kindMapper KindMapper) (*graph.Node, bool) { + var node nodeComposite + + if node.TryMap(values) { + var graphNode graph.Node + + if err := node.ToNode(ctx, kindMapper, &graphNode); err == nil { + return &graphNode, true + } + } + + return nil, false +} + +func TryMapRelationship(ctx context.Context, values map[string]any, kindMapper KindMapper) (*graph.Relationship, bool) { + var edge edgeComposite + + if edge.TryMap(values) { + var graphRelationship graph.Relationship + + if err := edge.ToRelationship(ctx, kindMapper, &graphRelationship); err == nil { + return &graphRelationship, true + } + } + + return nil, false +} + +func TryMapToGraphType(ctx context.Context, value any, kindMapper KindMapper) any { + switch typedValue := value.(type) { + case map[string]any: + if node, mapped := TryMapNode(ctx, typedValue, kindMapper); mapped { + return node + } + + if relationship, mapped := TryMapRelationship(ctx, typedValue, kindMapper); mapped { + return relationship + } + + } + + return value +} + func newMapFunc(ctx context.Context, kindMapper KindMapper) graph.MapFunc { return func(value, target any) bool { switch typedTarget := target.(type) { diff --git a/graph/node.go b/graph/node.go index e6db3c8..d379782 100644 --- a/graph/node.go +++ b/graph/node.go @@ -32,19 +32,17 @@ func NewNode(id ID, properties *Properties, kinds ...Kind) *Node { } type serializableNode struct { - ID ID `json:"id"` - Kinds []string `json:"kinds"` - AddedKinds []string `json:"added_kinds"` - DeletedKinds []string `json:"deleted_kinds"` - Properties *Properties `json:"properties"` + ID ID `json:"id"` + Kinds []string `json:"kinds"` + Properties *Properties `json:"properties"` } type Node struct { - ID ID `json:"id"` - Kinds Kinds `json:"kinds"` - AddedKinds Kinds `json:"added_kinds"` - DeletedKinds Kinds `json:"deleted_kinds"` - Properties *Properties `json:"properties"` + ID ID + Kinds Kinds + AddedKinds Kinds + DeletedKinds Kinds + Properties *Properties } func (s *Node) Merge(other *Node) { @@ -101,11 +99,9 @@ func (s *Node) DeleteKinds(kinds ...Kind) { func (s *Node) MarshalJSON() ([]byte, error) { var ( jsonNode = serializableNode{ - ID: s.ID, - Kinds: s.Kinds.Strings(), - AddedKinds: s.AddedKinds.Strings(), - DeletedKinds: s.DeletedKinds.Strings(), - Properties: s.Properties, + ID: s.ID, + Kinds: s.Kinds.Strings(), + Properties: s.Properties, } ) @@ -253,11 +249,9 @@ func (s *NodeSet) UnmarshalJSON(input []byte) error { nodeSet := make(NodeSet, len(tmpMap)) for key, value := range tmpMap { nodeSet[key] = &Node{ - ID: value.ID, - Kinds: StringsToKinds(value.Kinds), - AddedKinds: StringsToKinds(value.AddedKinds), - DeletedKinds: StringsToKinds(value.DeletedKinds), - Properties: value.Properties, + ID: value.ID, + Kinds: StringsToKinds(value.Kinds), + Properties: value.Properties, } } diff --git a/graph/properties.go b/graph/properties.go index 3ac3f41..ef9bea2 100644 --- a/graph/properties.go +++ b/graph/properties.go @@ -229,9 +229,13 @@ func NewPropertyResult(key string, value any) PropertyValue { // Properties is a map type that satisfies the Properties interface. type Properties struct { - Map map[string]any `json:"map"` - Deleted map[string]struct{} `json:"deleted"` - Modified map[string]struct{} `json:"modified"` + Map map[string]any + Deleted map[string]struct{} + Modified map[string]struct{} +} + +func (s *Properties) MarshalJSON() ([]byte, error) { + return json.Marshal(s.Map) } func (s *Properties) Merge(other *Properties) { @@ -426,6 +430,10 @@ func (s *Properties) GetOrDefault(key string, defaultValue any) PropertyValue { } func PropertiesMustGetOrDefault[T any](properties *Properties, key string, defaultValue T) T { + if properties == nil { + return defaultValue + } + value := properties.GetWithFallback(key, defaultValue) if !value.IsNil() { diff --git a/graph/relationship.go b/graph/relationship.go index 9ffc74b..865a281 100644 --- a/graph/relationship.go +++ b/graph/relationship.go @@ -1,10 +1,20 @@ package graph import ( + "encoding/json" + "github.com/specterops/dawgs/cardinality" "github.com/specterops/dawgs/util/size" ) +type serializableRelationship struct { + ID ID `json:"id"` + StartID ID `json:"start_id"` + EndID ID `json:"end_id"` + Kind string `json:"kind"` + Properties *Properties `json:"properties"` +} + type Relationship struct { ID ID StartID ID @@ -13,6 +23,20 @@ type Relationship struct { Properties *Properties } +func (s *Relationship) MarshalJSON() ([]byte, error) { + var ( + jsonNode = serializableRelationship{ + ID: s.ID, + StartID: s.StartID, + EndID: s.EndID, + Kind: s.Kind.String(), + Properties: s.Properties, + } + ) + + return json.Marshal(jsonNode) +} + func (s *Relationship) Merge(other *Relationship) { s.Properties.Merge(other.Properties) } From eb6274e060154f51c7f73da1c56db8657e3fb15d Mon Sep 17 00:00:00 2001 From: John Hopper Date: Tue, 6 Jan 2026 16:18:27 -0800 Subject: [PATCH 5/5] wip --- algo/katz.go | 6 +- algo/sample.go | 16 ++- algo/scc.go | 13 +- container/adjacencymap.go | 8 +- container/digraph.go | 37 +++++- container/fetch.go | 49 +++++--- container/threadsafe.go | 45 +++++++ container/traversal.go | 152 +++++++++++++++++++++++ container/triplestore.go | 251 ++++++++++++++------------------------ 9 files changed, 370 insertions(+), 207 deletions(-) create mode 100644 container/threadsafe.go create mode 100644 container/traversal.go diff --git a/algo/katz.go b/algo/katz.go index 5668420..2b6b0b9 100644 --- a/algo/katz.go +++ b/algo/katz.go @@ -35,7 +35,7 @@ relationships. */ func CalculateKatzCentrality(digraph container.DirectedGraph, alpha, beta, epsilon Weight, iterations int, direction graph.Direction) (map[uint64]Weight, bool) { var ( - numNodes = digraph.Nodes().Cardinality() + numNodes = digraph.NumNodes() centrality = make(map[uint64]Weight, numNodes) prevCentrality = make(map[uint64]Weight, numNodes) ) @@ -43,7 +43,7 @@ func CalculateKatzCentrality(digraph container.DirectedGraph, alpha, beta, epsil defer util.SLogMeasure("CalculateKatzCentrality", slog.String("direction", direction.String()))() // Initialize centrality scores to baseline - digraph.Nodes().Each(func(value uint64) bool { + digraph.EachNode(func(value uint64) bool { centrality[value] = beta return true }) @@ -53,7 +53,7 @@ func CalculateKatzCentrality(digraph container.DirectedGraph, alpha, beta, epsil changed := false - digraph.Nodes().Each(func(sourceNode uint64) bool { + digraph.EachNode(func(sourceNode uint64) bool { sum := 0.0 digraph.EachAdjacentNode(sourceNode, direction, func(adjacentNode uint64) bool { diff --git a/algo/sample.go b/algo/sample.go index a2e1610..377f87f 100644 --- a/algo/sample.go +++ b/algo/sample.go @@ -16,18 +16,20 @@ func sampleHighestDegrees(digraph container.DirectedGraph, nSamples int, directi Degrees uint64 } - if numNodes := int(digraph.Nodes().Cardinality()); nSamples <= 0 || numNodes == 0 { + numNodes := int(digraph.NumNodes()) + + if nSamples <= 0 || numNodes == 0 { return nil } else if nSamples > numNodes { nSamples = numNodes } - entries := make([]entry, 0, digraph.Nodes().Cardinality()) + entries := make([]entry, 0, numNodes) digraph.EachNode(func(node uint64) bool { entries = append(entries, entry{ NodeID: node, - Degrees: digraph.Degrees(node, direction), + Degrees: container.Degrees(digraph, node, direction), }) return true @@ -59,7 +61,9 @@ func SampleHighestDegrees(direction graph.Direction) SampleFunc { } func SampleRandom(digraph container.DirectedGraph, nSamples int) []uint64 { - if numNodes := int(digraph.Nodes().Cardinality()); nSamples <= 0 || numNodes == 0 { + numNodes := int(digraph.NumNodes()) + + if nSamples <= 0 || numNodes == 0 { return nil } else if nSamples > numNodes { nSamples = numNodes @@ -67,12 +71,12 @@ func SampleRandom(digraph container.DirectedGraph, nSamples int) []uint64 { var ( samples = make([]uint64, 0, nSamples) - stride = digraph.Nodes().Cardinality() / uint64(nSamples) + stride = uint64(numNodes) / uint64(nSamples) counter = (rand.Uint64() % stride) + 1 remainder = stride - counter ) - digraph.Nodes().Each(func(value uint64) bool { + digraph.EachNode(func(value uint64) bool { if counter -= 1; counter == 0 { samples = append(samples, value) counter = (rand.Uint64() % stride) + 1 + remainder diff --git a/algo/scc.go b/algo/scc.go index 146fa27..f4e3afb 100644 --- a/algo/scc.go +++ b/algo/scc.go @@ -38,7 +38,7 @@ func StronglyConnectedComponents(ctx context.Context, digraph container.Directed if _, visited := visitedIndex[node]; !visited { dfsDescentStack = append(dfsDescentStack, &descentCursor{ id: node, - branches: digraph.AdjacentNodes(node, graph.DirectionOutbound), + branches: container.AdjacentNodes(digraph, node, graph.DirectionOutbound), branchIdx: 0, }) @@ -72,7 +72,7 @@ func StronglyConnectedComponents(ctx context.Context, digraph container.Directed dfsDescentStack = append(dfsDescentStack, &descentCursor{ id: nextBranchID, - branches: digraph.AdjacentNodes(nextBranchID, graph.DirectionOutbound), + branches: container.AdjacentNodes(digraph, nextBranchID, graph.DirectionOutbound), branchIdx: 0, }) } else if onStack.Contains(nextBranchID) { @@ -263,14 +263,13 @@ func NewComponentGraph(ctx context.Context, originGraph container.DirectedGraph) var ( componentMembers, memberComponentLookup = StronglyConnectedComponents(ctx, originGraph) componentDigraph = container.NewAdjacencyMapGraph() - nextEdgeID = uint64(1) ) defer util.SLogMeasure("NewComponentGraph")() // Ensure all components are present as vertices, even if they have no edges for componentID := range componentMembers { - componentDigraph.Nodes().Add(uint64(componentID)) + componentDigraph.AddNode(uint64(componentID)) } originGraph.EachNode(func(node uint64) bool { @@ -278,8 +277,7 @@ func NewComponentGraph(ctx context.Context, originGraph container.DirectedGraph) originGraph.EachAdjacentNode(node, graph.DirectionInbound, func(adjacent uint64) bool { if adjacentComponent := memberComponentLookup[adjacent]; nodeComponent != adjacentComponent { - componentDigraph.AddEdge(nextEdgeID, adjacentComponent, nodeComponent) - nextEdgeID += 1 + componentDigraph.AddEdge(adjacentComponent, nodeComponent) } return util.IsContextLive(ctx) @@ -287,8 +285,7 @@ func NewComponentGraph(ctx context.Context, originGraph container.DirectedGraph) originGraph.EachAdjacentNode(node, graph.DirectionOutbound, func(adjacent uint64) bool { if adjacentComponent := memberComponentLookup[adjacent]; nodeComponent != adjacentComponent { - componentDigraph.AddEdge(nextEdgeID, nodeComponent, adjacentComponent) - nextEdgeID += 1 + componentDigraph.AddEdge(nodeComponent, adjacentComponent) } return util.IsContextLive(ctx) diff --git a/container/adjacencymap.go b/container/adjacencymap.go index 1d83303..74253ab 100644 --- a/container/adjacencymap.go +++ b/container/adjacencymap.go @@ -11,7 +11,7 @@ type adjacencyMapDigraph struct { nodes cardinality.Duplex[uint64] } -func NewAdjacencyMapGraph() DirectedGraph { +func NewAdjacencyMapGraph() MutableDirectedGraph { return &adjacencyMapDigraph{ inbound: AdjacencyMap{}, outbound: AdjacencyMap{}, @@ -19,6 +19,10 @@ func NewAdjacencyMapGraph() DirectedGraph { } } +func (s *adjacencyMapDigraph) AddNode(node uint64) { + s.nodes.Add(node) +} + func (s *adjacencyMapDigraph) Normalize() ([]uint64, DirectedGraph) { var ( numNodes = s.NumNodes() @@ -145,7 +149,7 @@ func (s *adjacencyMapDigraph) EachAdjacentNode(node uint64, direction graph.Dire } } -func (s *adjacencyMapDigraph) AddEdge(edge, start, end uint64) { +func (s *adjacencyMapDigraph) AddEdge(start, end uint64) { if edgeBitmap, exists := s.outbound[start]; exists { edgeBitmap.Add(end) } else { diff --git a/container/digraph.go b/container/digraph.go index 81f9099..93bb703 100644 --- a/container/digraph.go +++ b/container/digraph.go @@ -60,28 +60,53 @@ type PathTerminal struct { Distance int } +func Degrees(digraph DirectedGraph, node uint64, direction graph.Direction) uint64 { + degrees := uint64(0) + + digraph.EachAdjacentNode(node, direction, func(adjacent uint64) bool { + degrees += 1 + return true + }) + + return degrees +} + +func AdjacentNodes(digraph DirectedGraph, node uint64, direction graph.Direction) []uint64 { + var nodes []uint64 + + digraph.EachAdjacentNode(node, direction, func(adjacent uint64) bool { + nodes = append(nodes, adjacent) + return true + }) + + return nodes +} + type DirectedGraph interface { - AddEdge(edge, start, end uint64) NumNodes() uint64 - Nodes() cardinality.Duplex[uint64] EachNode(delegate func(node uint64) bool) - Degrees(node uint64, direction graph.Direction) uint64 - AdjacentNodes(node uint64, direction graph.Direction) []uint64 EachAdjacentNode(node uint64, direction graph.Direction, delegate func(adjacent uint64) bool) } +type MutableDirectedGraph interface { + DirectedGraph + + AddNode(node uint64) + AddEdge(start, end uint64) +} + func Dimensions(digraph DirectedGraph, direction graph.Direction) (uint64, uint64) { var largestRow uint64 = 0 digraph.EachNode(func(node uint64) bool { - if degrees := digraph.Degrees(node, direction); degrees > largestRow { + if degrees := Degrees(digraph, node, direction); degrees > largestRow { largestRow = degrees } return true }) - return digraph.Nodes().Cardinality(), largestRow + return digraph.NumNodes(), largestRow } func BFSTree(digraph DirectedGraph, nodeID uint64, direction graph.Direction) []PathTerminal { diff --git a/container/fetch.go b/container/fetch.go index 7eab94e..4a1a71e 100644 --- a/container/fetch.go +++ b/container/fetch.go @@ -3,6 +3,7 @@ package container import ( "context" + "github.com/specterops/dawgs/cardinality" "github.com/specterops/dawgs/cypher/models/cypher" "github.com/specterops/dawgs/database" "github.com/specterops/dawgs/graph" @@ -14,13 +15,9 @@ const ( channelBufferLen = 4096 ) -type anonymousEdge struct { - EdgeID uint64 - StartID uint64 - EndID uint64 -} - func FetchAdjacencyGraph(ctx context.Context, graphDB database.Instance, relationshipFilter cypher.SyntaxNode) (DirectedGraph, error) { + defer util.SLogMeasure("FetchAdjacencyGraph")() + digraph := NewAdjacencyMapGraph() return digraph, graphDB.Session(ctx, func(ctx context.Context, driver database.Driver) error { @@ -31,7 +28,6 @@ func FetchAdjacencyGraph(ctx context.Context, graphDB database.Instance, relatio } builder.Return( - query.Relationship().ID(), query.Start().ID(), query.End().ID(), ) @@ -44,16 +40,15 @@ func FetchAdjacencyGraph(ctx context.Context, graphDB database.Instance, relatio for result.HasNext(ctx) { var ( - edgeID uint64 startID uint64 endID uint64 ) - if err := result.Scan(&edgeID, &startID, &endID); err != nil { + if err := result.Scan(&startID, &endID); err != nil { return err } - digraph.AddEdge(edgeID, startID, endID) + digraph.AddEdge(startID, endID) } return result.Error() @@ -100,22 +95,36 @@ func FetchKindDatabase(ctx context.Context, graphDB database.Instance) (KindData } type TSDB struct { - Triplestore Triplestore - EdgeKinds KindMap + Store Triplestore + EdgeKinds KindMap } -func NewTSDB() TSDB { +func (s TSDB) Projection(deletedNodes, deletedEdges cardinality.Duplex[uint64]) TSDB { return TSDB{ - Triplestore: NewTriplestore(), - EdgeKinds: KindMap{}, + Store: s.Store.Projection(deletedNodes, deletedEdges), + EdgeKinds: s.EdgeKinds, } } -func FetchTriplestore(ctx context.Context, graphDB database.Instance, filter cypher.SyntaxNode) (TSDB, error) { - tsdb := TSDB{ - Triplestore: NewTriplestore(), - EdgeKinds: KindMap{}, +func NewTSDB(ts Triplestore, edgeKinds KindMap) TSDB { + return TSDB{ + Store: ts, + EdgeKinds: edgeKinds, } +} + +func EmptyTSDB() TSDB { + return NewTSDB(NewTriplestore(), KindMap{}) +} + +func FetchTSDB(ctx context.Context, graphDB database.Instance, filter cypher.SyntaxNode) (TSDB, error) { + var ( + store = NewTriplestore() + tsdb = TSDB{ + Store: store, + EdgeKinds: KindMap{}, + } + ) defer util.SLogMeasure("FetchTriplestore")() @@ -149,7 +158,7 @@ func FetchTriplestore(ctx context.Context, graphDB database.Instance, filter cyp return err } - tsdb.Triplestore.AddEdge(relationshipID, startID, endID) + store.AddTriple(relationshipID, startID, endID) tsdb.EdgeKinds.Add(relationshipKind, relationshipID) } diff --git a/container/threadsafe.go b/container/threadsafe.go new file mode 100644 index 0000000..e430c9b --- /dev/null +++ b/container/threadsafe.go @@ -0,0 +1,45 @@ +package container + +import ( + "sync" + + "github.com/gammazero/deque" +) + +type ThreadSafeDeque[T any] struct { + lock *sync.RWMutex + container deque.Deque[T] +} + +func (s *ThreadSafeDeque[T]) PushFront(elem T) { + s.lock.Lock() + s.container.PushFront(elem) + s.lock.Unlock() +} + +func (s *ThreadSafeDeque[T]) PushBack(elem T) { + s.lock.Lock() + s.container.PushBack(elem) + s.lock.Unlock() +} + +func (s *ThreadSafeDeque[T]) PopFront() T { + s.lock.Lock() + value := s.container.PopFront() + s.lock.Unlock() + return value +} + +func (s *ThreadSafeDeque[T]) PopBack() T { + s.lock.Lock() + value := s.container.PopBack() + s.lock.Unlock() + return value +} + +func (s *ThreadSafeDeque[T]) Len() int { + s.lock.RLock() + numElements := s.container.Len() + s.lock.RUnlock() + return numElements +} diff --git a/container/traversal.go b/container/traversal.go new file mode 100644 index 0000000..4b00b31 --- /dev/null +++ b/container/traversal.go @@ -0,0 +1,152 @@ +package container + +import ( + "github.com/gammazero/deque" + "github.com/specterops/dawgs/graph" +) + +func TSDFS(ts Triplestore, nodeID uint64, direction graph.Direction, maxDepth int, descentFilter func(edge Edge) bool, handler func(segment *Segment) bool) int { + var ( + traversals deque.Deque[*Segment] + numImcompletePaths = 0 + ) + + traversals.PushBack(&Segment{ + Node: nodeID, + }) + + for remainingTraversals := traversals.Len(); remainingTraversals > 0; remainingTraversals = traversals.Len() { + var ( + nextSegment = traversals.PopBack() + segmentDepth = nextSegment.Depth() + depthExceded = maxDepth > 0 && maxDepth < segmentDepth + ) + + if !depthExceded { + ts.EachAdjacentEdge(nextSegment.Node, direction, func(nextEdge Edge) bool { + if descentFilter(nextEdge) { + traversals.PushBack(&Segment{ + Node: nextEdge.Pick(direction), + Edge: nextEdge.ID, + Previous: nextSegment, + }) + } + + return true + }) + } + + if segmentDepth > 1 && remainingTraversals-1 == traversals.Len() { + if depthExceded { + numImcompletePaths += 1 + } + + if !handler(nextSegment) { + break + } + } + } + + return numImcompletePaths +} + +func TSBFS(ts Triplestore, nodeID uint64, direction graph.Direction, maxDepth int, descentFilter func(edge Edge) bool, handler func(segment *Segment) bool) int { + var ( + traversals deque.Deque[*Segment] + numImcompletePaths = 0 + ) + + traversals.PushBack(&Segment{ + Node: nodeID, + }) + + for remainingTraversals := traversals.Len(); remainingTraversals > 0; remainingTraversals = traversals.Len() { + var ( + nextSegment = traversals.PopFront() + segmentDepth = nextSegment.Depth() + depthExceded = maxDepth > 0 && maxDepth < segmentDepth + ) + + if !depthExceded { + ts.EachAdjacentEdge(nextSegment.Node, direction, func(nextEdge Edge) bool { + if descentFilter(nextEdge) { + traversals.PushBack(&Segment{ + Node: nextEdge.Pick(direction), + Edge: nextEdge.ID, + Previous: nextSegment, + }) + } + + return true + }) + } + + if segmentDepth > 1 && remainingTraversals-1 == traversals.Len() { + if depthExceded { + numImcompletePaths += 1 + } + + if !handler(nextSegment) { + break + } + } + } + + return numImcompletePaths +} + +// SSPBFS is a parallel stateless shortest-path breadth-first search. +func TSStatelessBFS(ts Triplestore, rootNode uint64, direction graph.Direction, maxDepth int, descentFilter func(edge Edge) (Weight, bool), terminalHandler func(terminal PathTerminal) bool, numWorkers int) int { + var ( + traversals deque.Deque[PathTerminal] + numImcompletePaths = 0 + ) + + traversals.PushBack(PathTerminal{ + Node: rootNode, + Weight: 0, + Distance: 0, + }) + + for traversals.Len() > 0 { + var ( + nextSegment = traversals.PopFront() + hasExpansions = false + depthExceded = maxDepth > 0 && maxDepth < nextSegment.Distance + ) + + if !depthExceded { + nextDistance := nextSegment.Distance + 1 + + ts.EachAdjacentEdge(nextSegment.Node, direction, func(nextEdge Edge) bool { + if weight, shouldDescend := descentFilter(nextEdge); shouldDescend { + hasExpansions = true + + if nextSegment.Distance > 0 { + weight *= nextSegment.Weight + } + + traversals.PushBack(PathTerminal{ + Node: nextEdge.Pick(direction), + Distance: nextDistance, + Weight: weight, + }) + } + + return true + }) + } + + if nextSegment.Distance >= 1 && !hasExpansions { + if depthExceded { + numImcompletePaths += 1 + } + + if !terminalHandler(nextSegment) { + break + } + } + } + + return numImcompletePaths +} diff --git a/container/triplestore.go b/container/triplestore.go index 112e0dd..53e4c58 100644 --- a/container/triplestore.go +++ b/container/triplestore.go @@ -1,9 +1,6 @@ package container import ( - "sync" - - "github.com/gammazero/deque" "github.com/specterops/dawgs/cardinality" "github.com/specterops/dawgs/graph" ) @@ -26,8 +23,16 @@ type Triplestore interface { DirectedGraph NumEdges() uint64 - AdjacentEdges(node uint64, direction graph.Direction) []uint64 + EachEdge(delegate func(next Edge) bool) EachAdjacentEdge(node uint64, direction graph.Direction, delegate func(next Edge) bool) + + Projection(deletedNodes, deletedEdges cardinality.Duplex[uint64]) Triplestore +} + +type MutableTriplestore interface { + Triplestore + + AddTriple(edge, start, end uint64) } type triplestore struct { @@ -38,7 +43,7 @@ type triplestore struct { endIndex map[uint64]cardinality.Duplex[uint64] } -func NewTriplestore() Triplestore { +func NewTriplestore() MutableTriplestore { return &triplestore{ nodes: cardinality.NewBitmap64(), deletedEdges: cardinality.NewBitmap64(), @@ -51,19 +56,31 @@ func (s *triplestore) DeleteEdge(id uint64) { s.deletedEdges.Add(id) } +func (s *triplestore) Edges() []Edge { + return s.edges +} + func (s *triplestore) NumNodes() uint64 { return s.nodes.Cardinality() } -func (s *triplestore) Nodes() cardinality.Duplex[uint64] { - return s.nodes.Clone() +func (s *triplestore) AddNode(node uint64) { + s.nodes.Add(node) } func (s *triplestore) EachNode(delegate func(node uint64) bool) { s.nodes.Each(delegate) } -func (s *triplestore) AddEdge(edge, start, end uint64) { +func (s *triplestore) EachEdge(delegate func(edge Edge) bool) { + for _, nextEdge := range s.edges { + if !delegate(nextEdge) { + break + } + } +} + +func (s *triplestore) AddTriple(edge, start, end uint64) { s.edges = append(s.edges, Edge{ ID: edge, Start: start, @@ -184,186 +201,96 @@ func (s *triplestore) EachAdjacentEdge(node uint64, direction graph.Direction, d }) } -func TSDFS(ts Triplestore, nodeID uint64, direction graph.Direction, maxDepth int, descentFilter func(edge Edge) bool, handler func(segment *Segment) bool) int { - var ( - traversals deque.Deque[*Segment] - numImcompletePaths = 0 - ) - - traversals.PushBack(&Segment{ - Node: nodeID, - }) - - for remainingTraversals := traversals.Len(); remainingTraversals > 0; remainingTraversals = traversals.Len() { - var ( - nextSegment = traversals.PopBack() - segmentDepth = nextSegment.Depth() - depthExceded = maxDepth > 0 && maxDepth < segmentDepth - ) - - if !depthExceded { - ts.EachAdjacentEdge(nextSegment.Node, direction, func(nextEdge Edge) bool { - if descentFilter(nextEdge) { - traversals.PushBack(&Segment{ - Node: nextEdge.Pick(direction), - Edge: nextEdge.ID, - Previous: nextSegment, - }) - } - - return true - }) - } - - if segmentDepth > 1 && remainingTraversals-1 == traversals.Len() { - if depthExceded { - numImcompletePaths += 1 - } - - if !handler(nextSegment) { - break - } - } +func (s *triplestore) Projection(deletedNodes, deletedEdges cardinality.Duplex[uint64]) Triplestore { + return &triplestoreProjection{ + origin: s, + deletedNodes: deletedNodes, + deletedEdges: deletedEdges, } +} - return numImcompletePaths +type triplestoreProjection struct { + origin *triplestore + deletedNodes cardinality.Duplex[uint64] + deletedEdges cardinality.Duplex[uint64] } -func TSBFS(ts Triplestore, nodeID uint64, direction graph.Direction, maxDepth int, descentFilter func(edge Edge) bool, handler func(segment *Segment) bool) int { +func (s *triplestoreProjection) Projection(deletedNodes, deletedEdges cardinality.Duplex[uint64]) Triplestore { var ( - traversals deque.Deque[*Segment] - numImcompletePaths = 0 + allDeletedNodes = s.deletedNodes.Clone() + allDeletedEdges = s.deletedEdges.Clone() ) - traversals.PushBack(&Segment{ - Node: nodeID, - }) + allDeletedNodes.Or(deletedNodes) + allDeletedEdges.Or(deletedEdges) - for remainingTraversals := traversals.Len(); remainingTraversals > 0; remainingTraversals = traversals.Len() { - var ( - nextSegment = traversals.PopFront() - segmentDepth = nextSegment.Depth() - depthExceded = maxDepth > 0 && maxDepth < segmentDepth - ) - - if !depthExceded { - ts.EachAdjacentEdge(nextSegment.Node, direction, func(nextEdge Edge) bool { - if descentFilter(nextEdge) { - traversals.PushBack(&Segment{ - Node: nextEdge.Pick(direction), - Edge: nextEdge.ID, - Previous: nextSegment, - }) - } - - return true - }) - } + return &triplestoreProjection{ + origin: s.origin, + deletedNodes: allDeletedNodes, + deletedEdges: allDeletedEdges, + } +} - if segmentDepth > 1 && remainingTraversals-1 == traversals.Len() { - if depthExceded { - numImcompletePaths += 1 - } +func (s *triplestoreProjection) NumNodes() uint64 { + count := uint64(0) - if !handler(nextSegment) { - break - } + s.origin.EachNode(func(value uint64) bool { + if !s.deletedNodes.Contains(value) { + count += 1 } - } - - return numImcompletePaths -} -type ThreadSafeDeque[T any] struct { - lock *sync.RWMutex - container deque.Deque[T] -} + return true + }) -func (s *ThreadSafeDeque[T]) PushFront(elem T) { - s.lock.Lock() - s.container.PushFront(elem) - s.lock.Unlock() + return count } -func (s *ThreadSafeDeque[T]) PushBack(elem T) { - s.lock.Lock() - s.container.PushBack(elem) - s.lock.Unlock() -} +func (s *triplestoreProjection) NumEdges() uint64 { + count := uint64(0) -func (s *ThreadSafeDeque[T]) PopFront() T { - s.lock.Lock() - value := s.container.PopFront() - s.lock.Unlock() - return value -} + s.origin.EachEdge(func(next Edge) bool { + if !s.deletedEdges.Contains(next.ID) { + count += 1 + } -func (s *ThreadSafeDeque[T]) PopBack() T { - s.lock.Lock() - value := s.container.PopBack() - s.lock.Unlock() - return value -} + return true + }) -func (s *ThreadSafeDeque[T]) Len() int { - s.lock.RLock() - numElements := s.container.Len() - s.lock.RUnlock() - return numElements + return count } -// SSPBFS is a parallel stateless shortest-path breadth-first search. -func TSStatelessBFS(ts Triplestore, rootNode uint64, direction graph.Direction, maxDepth int, descentFilter func(edge Edge) (Weight, bool), terminalHandler func(terminal PathTerminal) bool, numWorkers int) int { - var ( - traversals deque.Deque[PathTerminal] - numImcompletePaths = 0 - ) +func (s *triplestoreProjection) EachNode(delegate func(node uint64) bool) { + s.origin.EachNode(func(node uint64) bool { + if !s.deletedNodes.Contains(node) { + return delegate(node) + } - traversals.PushBack(PathTerminal{ - Node: rootNode, - Weight: 0, - Distance: 0, + return true }) +} - for traversals.Len() > 0 { - var ( - nextSegment = traversals.PopFront() - hasExpansions = false - depthExceded = maxDepth > 0 && maxDepth < nextSegment.Distance - ) - - if !depthExceded { - nextDistance := nextSegment.Distance + 1 - - ts.EachAdjacentEdge(nextSegment.Node, direction, func(nextEdge Edge) bool { - if weight, shouldDescend := descentFilter(nextEdge); shouldDescend { - hasExpansions = true - - if nextSegment.Distance > 0 { - weight *= nextSegment.Weight - } - - traversals.PushBack(PathTerminal{ - Node: nextEdge.Pick(direction), - Distance: nextDistance, - Weight: weight, - }) - } - - return true - }) +func (s *triplestoreProjection) EachEdge(delegate func(next Edge) bool) { + s.origin.EachEdge(func(next Edge) bool { + if !s.deletedEdges.Contains(next.ID) && !s.deletedNodes.Contains(next.Start) && !s.deletedNodes.Contains(next.Start) { + return delegate(next) } - if nextSegment.Distance >= 1 && !hasExpansions { - if depthExceded { - numImcompletePaths += 1 - } + return true + }) +} - if !terminalHandler(nextSegment) { - break - } +func (s *triplestoreProjection) EachAdjacentEdge(node uint64, direction graph.Direction, delegate func(next Edge) bool) { + s.origin.EachAdjacentEdge(node, direction, func(next Edge) bool { + if !s.deletedEdges.Contains(next.ID) && !s.deletedNodes.Contains(next.Start) && !s.deletedNodes.Contains(next.Start) { + return delegate(next) } - } - return numImcompletePaths + return true + }) +} + +func (s *triplestoreProjection) EachAdjacentNode(node uint64, direction graph.Direction, delegate func(adjacent uint64) bool) { + s.EachAdjacentEdge(node, direction, func(next Edge) bool { + return delegate(next.Pick(direction)) + }) }