diff --git a/pkg/repositories/gormimpl/tag.go b/pkg/repositories/gormimpl/tag.go index adb9cbbf..e26f0593 100644 --- a/pkg/repositories/gormimpl/tag.go +++ b/pkg/repositories/gormimpl/tag.go @@ -4,10 +4,12 @@ import ( "context" "github.com/jinzhu/gorm" + "github.com/lyft/datacatalog/pkg/common" "github.com/lyft/datacatalog/pkg/repositories/errors" "github.com/lyft/datacatalog/pkg/repositories/interfaces" "github.com/lyft/datacatalog/pkg/repositories/models" idl_datacatalog "github.com/lyft/datacatalog/protos/gen" + "github.com/lyft/flytestdlib/logger" "github.com/lyft/flytestdlib/promutils" ) @@ -25,14 +27,79 @@ func NewTagRepo(db *gorm.DB, errorTransformer errors.ErrorTransformer, scope pro } } +// A tag is associated with a single artifact for each partition combination +// When creating a tag, we remove the tag from any artifacts of the same partition +// Then add the tag to the new artifact func (h *tagRepo) Create(ctx context.Context, tag models.Tag) error { timer := h.repoMetrics.CreateDuration.Start(ctx) defer timer.Stop() - db := h.db.Create(&tag) + tx := h.db.Begin() - if db.Error != nil { - return h.errorTransformer.ToDataCatalogError(db.Error) + var artifactToTag models.Artifact + tx = tx.Preload("Partitions").Find(&artifactToTag, models.Artifact{ + ArtifactKey: models.ArtifactKey{ArtifactID: tag.ArtifactID}, + }) + + // List artifacts with the same partitions and tag + filters := make([]models.ModelValueFilter, 0, len(artifactToTag.Partitions)*2+1) + for _, partition := range artifactToTag.Partitions { + filters = append(filters, NewGormValueFilter(common.Partition, common.Equal, "key", partition.Key)) + filters = append(filters, NewGormValueFilter(common.Partition, common.Equal, "value", partition.Value)) + } + + filters = append(filters, NewGormValueFilter(common.Artifact, common.Equal, "tag_name", tag.TagName)) + + listTaggedArtifacts := models.ListModelsInput{ + JoinEntityToConditionMap: map[common.Entity]models.ModelJoinCondition{ + common.Tag: NewGormJoinCondition(common.Artifact, common.Tag), + common.Partition: NewGormJoinCondition(common.Artifact, common.Partition), + }, + Filters: filters, + } + + tx, err := applyListModelsInput(tx, common.Artifact, listTaggedArtifacts) + if err != nil { + tx.Rollback() + return err + } + + var artifacts []models.Artifact + tx = tx.Find(&artifacts) + if tx.Error != nil { + logger.Errorf(ctx, "Unable to find previously tagged artifacts, rolling back, tag: [%v], err [%v]", tag, tx.Error) + tx.Rollback() + return h.errorTransformer.ToDataCatalogError(tx.Error) + } + + if len(artifacts) != 0 { + // Soft-delete the existing tags on the artifacts that are tagged by this tag in the partition + oldTags := make([]models.Tag, 0, len(artifacts)) + for _, artifact := range artifacts { + oldTags = append(oldTags, models.Tag{ + TagKey: models.TagKey{TagName: tag.TagName}, + ArtifactID: artifact.ArtifactID, + }) + } + tx = tx.Delete(&models.Tag{}, oldTags) + } + + // Check if the artifact was ever previously tagged with this tag, if so undelete the record + var previouslyTagged *models.Artifact + tx.Unscoped().Find(previouslyTagged, tag) + if previouslyTagged != nil { + previouslyTagged.DeletedAt = nil + tx = tx.Update(previouslyTagged) + } else { + // Tag the new artifact + tx = tx.Create(&tag) + } + + tx = tx.Commit() + if tx.Error != nil { + logger.Errorf(ctx, "Unable to create tag, rolling back, tag: [%v], err [%v]", tag, tx.Error) + tx.Rollback() + return h.errorTransformer.ToDataCatalogError(tx.Error) } return nil } diff --git a/pkg/repositories/gormimpl/tag_test.go b/pkg/repositories/gormimpl/tag_test.go index 0fc15074..de1b91a6 100644 --- a/pkg/repositories/gormimpl/tag_test.go +++ b/pkg/repositories/gormimpl/tag_test.go @@ -50,6 +50,12 @@ func TestCreateTag(t *testing.T) { GlobalMock.Logging = true // Only match on queries that append expected filters + GlobalMock.NewMock().WithQuery( + `SELECT * FROM "artifacts" WHERE "artifacts"."deleted_at" IS NULL AND (("artifacts"."artifact_id" = 123))`).WithReply(getDBArtifactResponse(getTestArtifact())) + + GlobalMock.NewMock().WithQuery( + `SELECT * FROM "partitions" WHERE "partitions"."deleted_at" IS NULL AND (("artifact_id" IN (123)))`).WithReply(getDBArtifactResponse(getTestArtifact())) + GlobalMock.NewMock().WithQuery( `INSERT INTO "tags" ("created_at","updated_at","deleted_at","dataset_project","dataset_name","dataset_domain","dataset_version","tag_name","artifact_id","dataset_uuid") VALUES (?,?,?,?,?,?,?,?,?,?)`).WithCallback( func(s string, values []driver.NamedValue) { diff --git a/pkg/repositories/models/tag.go b/pkg/repositories/models/tag.go index 057f78c3..a0ce3f2f 100644 --- a/pkg/repositories/models/tag.go +++ b/pkg/repositories/models/tag.go @@ -1,17 +1,17 @@ package models type TagKey struct { - DatasetProject string `gorm:"primary_key"` - DatasetName string `gorm:"primary_key"` - DatasetDomain string `gorm:"primary_key"` - DatasetVersion string `gorm:"primary_key"` + DatasetProject string + DatasetName string + DatasetDomain string + DatasetVersion string TagName string `gorm:"primary_key"` } type Tag struct { BaseModel TagKey - ArtifactID string + ArtifactID string `gorm:"primary_key"` DatasetUUID string `gorm:"type:uuid;index:tags_dataset_uuid_idx"` Artifact Artifact `gorm:"association_foreignkey:DatasetProject,DatasetName,DatasetDomain,DatasetVersion,ArtifactID;foreignkey:DatasetProject,DatasetName,DatasetDomain,DatasetVersion,ArtifactID"` }