diff --git a/study.go b/study.go index 87bb18b4..4b199850 100644 --- a/study.go +++ b/study.go @@ -180,52 +180,6 @@ func (s *Study) WithContext(ctx context.Context) { s.ctx = ctx } -func (s *Study) callRelativeSampler(trialID int) ( - map[string]interface{}, - map[string]float64, - error, -) { - if s.RelativeSampler == nil { - return nil, nil, nil - } - - frozen, err := s.Storage.GetTrial(trialID) - if err != nil { - return nil, nil, err - } - - var searchSpace map[string]interface{} - if s.definedSearchSpace != nil { - searchSpace = s.definedSearchSpace - } else { - searchSpace, err = IntersectionSearchSpace(s) - if err != nil { - return nil, nil, err - } - } - if searchSpace == nil { - return nil, nil, nil - } - - for paramName := range searchSpace { - distribution := searchSpace[paramName] - if yes, _ := DistributionIsSingle(distribution); yes { - delete(searchSpace, paramName) - } - } - - relativeParams, err := s.RelativeSampler.SampleRelative(s, frozen, searchSpace) - if err == ErrUnsupportedSearchSpace { - s.logger.Warn("Your objective function contains unsupported search space for RelativeSampler.", - fmt.Sprintf("trialID=%d", trialID), - fmt.Sprintf("searchSpace=%#v", searchSpace)) - return nil, nil, nil - } else if err != nil { - return nil, nil, err - } - return searchSpace, relativeParams, nil -} - func (s *Study) runTrial(objective FuncObjective) (int, error) { trialID, err := s.popWaitingTrialID() if err != nil { @@ -241,19 +195,18 @@ func (s *Study) runTrial(objective FuncObjective) (int, error) { return -1, errCreateNewTrial } } - searchSpace, relativeParams, err := s.callRelativeSampler(trialID) + + trial := Trial{ + Study: s, + ID: trialID, + } + err = trial.CallRelativeSampler() if err != nil { s.logger.Error("failed to call relative sampler", fmt.Sprintf("err=%s", err)) return -1, err } - trial := Trial{ - Study: s, - ID: trialID, - relativeParams: relativeParams, - relativeSearchSpace: searchSpace, - } evaluation, objerr := objective(trial) var state TrialState if objerr == ErrTrialPruned { diff --git a/trial.go b/trial.go index 9bab3253..a65ec932 100644 --- a/trial.go +++ b/trial.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "reflect" ) @@ -97,6 +98,57 @@ func (t *Trial) isFixedParam(name string, distribution interface{}) (float64, bo return internalParam, true, nil } +// CallRelativeSampler should be called before evaluate an objective function only 1 time. +// Please note that this method is public for third party libraries like "Kubeflow/Katib". +// Goptuna users SHOULD NOT call this method. +func (t *Trial) CallRelativeSampler() error { + if t.Study.RelativeSampler == nil { + return nil + } + + var err error + var searchSpace map[string]interface{} + if t.Study.definedSearchSpace != nil { + searchSpace = t.Study.definedSearchSpace + } else { + searchSpace, err = IntersectionSearchSpace(t.Study) + if err != nil { + return err + } + } + if searchSpace == nil { + return nil + } + + relativeSearchSpace := make(map[string]interface{}, len(searchSpace)) + for paramName := range searchSpace { + distribution := searchSpace[paramName] + if yes, _ := DistributionIsSingle(distribution); yes { + continue + } + relativeSearchSpace[paramName] = distribution + } + + frozen, err := t.Study.Storage.GetTrial(t.ID) + if err != nil { + return err + } + + relativeParams, err := t.Study.RelativeSampler.SampleRelative(t.Study, frozen, searchSpace) + if err == ErrUnsupportedSearchSpace { + t.Study.logger.Warn("Your objective function contains unsupported search space for RelativeSampler.", + fmt.Sprintf("trialID=%d", t.ID), + fmt.Sprintf("searchSpace=%#v", searchSpace)) + return nil + } else if err != nil { + return err + } + + t.relativeSearchSpace = searchSpace + t.relativeParams = relativeParams + return nil +} + func (t *Trial) isRelativeParam(name string, distribution interface{}) bool { expected, ok := t.relativeSearchSpace[name] if !ok {