From 1fd65a09ef4e098c8e96d5ecfbae6cdd2ffdd7ed Mon Sep 17 00:00:00 2001 From: c-bata Date: Mon, 20 Apr 2020 04:14:40 +0900 Subject: [PATCH] Fix a bug when set attrs multiple times --- rdb/storage.go | 32 ++++++++++++++++++++++++-------- rdb/storage_test.go | 12 ++++++++++++ 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/rdb/storage.go b/rdb/storage.go index a8ba4d58..4539c515 100644 --- a/rdb/storage.go +++ b/rdb/storage.go @@ -62,20 +62,28 @@ func (s *Storage) SetStudyDirection(studyID int, direction goptuna.StudyDirectio // SetStudyUserAttr to store the value for the user. func (s *Storage) SetStudyUserAttr(studyID int, key string, value string) error { - return s.db.Create(&studyUserAttributeModel{ + var result studyUserAttributeModel + return s.db.Where(&studyUserAttributeModel{ + UserAttributeReferStudy: studyID, + Key: key, + }).Assign(&studyUserAttributeModel{ UserAttributeReferStudy: studyID, Key: key, ValueJSON: encodeAttrValue(value), - }).Error + }).FirstOrCreate(&result).Error } // SetStudySystemAttr to store the value for the system. func (s *Storage) SetStudySystemAttr(studyID int, key string, value string) error { - return s.db.Create(&studySystemAttributeModel{ + var result studySystemAttributeModel + return s.db.Where(&studySystemAttributeModel{ + SystemAttributeReferStudy: studyID, + Key: key, + }).Assign(&studySystemAttributeModel{ SystemAttributeReferStudy: studyID, Key: key, ValueJSON: encodeAttrValue(value), - }).Error + }).FirstOrCreate(&result).Error } // GetStudyIDFromName return the study id from study name. @@ -550,20 +558,28 @@ func (s *Storage) SetTrialState(trialID int, state goptuna.TrialState) error { // SetTrialUserAttr to store the value for the user. func (s *Storage) SetTrialUserAttr(trialID int, key string, value string) error { - return s.db.Create(&trialUserAttributeModel{ + var result trialUserAttributeModel + return s.db.Where(&trialUserAttributeModel{ + UserAttributeReferTrial: trialID, + Key: key, + }).Assign(&trialUserAttributeModel{ UserAttributeReferTrial: trialID, Key: key, ValueJSON: encodeAttrValue(value), - }).Error + }).FirstOrCreate(&result).Error } // SetTrialSystemAttr to store the value for the system. func (s *Storage) SetTrialSystemAttr(trialID int, key string, value string) error { - return s.db.Create(&trialSystemAttributeModel{ + var result trialSystemAttributeModel + return s.db.Where(&trialSystemAttributeModel{ + SystemAttributeReferTrial: trialID, + Key: key, + }).Assign(&trialSystemAttributeModel{ SystemAttributeReferTrial: trialID, Key: key, ValueJSON: encodeAttrValue(value), - }).Error + }).FirstOrCreate(&result).Error } // GetTrialNumberFromID returns the trial's number. diff --git a/rdb/storage_test.go b/rdb/storage_test.go index efb5382d..903b446a 100644 --- a/rdb/storage_test.go +++ b/rdb/storage_test.go @@ -262,6 +262,12 @@ func TestStorage_TrialUserAttrs(t *testing.T) { if !reflect.DeepEqual(got, want) { t.Errorf("want %#v, but got %#v", want, got) } + + err = storage.SetTrialUserAttr(trialID, "key", "value") + if err != nil { + t.Errorf("error: %v != nil", err) + return + } } func TestStorage_TrialSystemAttrs(t *testing.T) { @@ -298,6 +304,12 @@ func TestStorage_TrialSystemAttrs(t *testing.T) { if v, ok := got["key"]; !ok || v != "value" { t.Errorf("want %#v, but got %v %v", "value", ok, got) } + + err = storage.SetTrialSystemAttr(trialID, "key", "value") + if err != nil { + t.Errorf("error: %v != nil", err) + return + } } func TestStorage_GetAllTrials(t *testing.T) {