@@ -24,6 +24,7 @@ import (
2424 "go/build"
2525 "go/parser"
2626 "go/token"
27+ "io/ioutil"
2728 "log"
2829 "path"
2930 "path/filepath"
@@ -48,19 +49,10 @@ func sourceMode(source string) (*model.Package, error) {
4849 return nil , fmt .Errorf ("failed getting source directory: %v" , err )
4950 }
5051
51- cfg := & packages.Config {Mode : packages .LoadFiles , Tests : true , Dir : srcDir }
52- pkgs , err := packages .Load (cfg , "file=" + source )
52+ packageImport , err := parsePackageImport (source , srcDir )
5353 if err != nil {
5454 return nil , err
5555 }
56- if packages .PrintErrors (pkgs ) > 0 || len (pkgs ) == 0 {
57- return nil , errors .New ("loading package failed" )
58- }
59-
60- packageImport := pkgs [0 ].PkgPath
61-
62- // It is illegal to import a _test package.
63- packageImport = strings .TrimSuffix (packageImport , "_test" )
6456
6557 fs := token .NewFileSet ()
6658 file , err := parser .ParseFile (fs , source , nil , 0 )
@@ -519,3 +511,46 @@ func isVariadic(f *ast.FuncType) bool {
519511 _ , ok := f .Params .List [nargs - 1 ].Type .(* ast.Ellipsis )
520512 return ok
521513}
514+
515+ // packageNameOfDir get package import path via dir
516+ func packageNameOfDir (srcDir string ) (string , error ) {
517+ files , err := ioutil .ReadDir (srcDir )
518+ if err != nil {
519+ log .Fatal (err )
520+ }
521+
522+ var goFilePath string
523+ for _ , file := range files {
524+ if ! file .IsDir () && strings .HasSuffix (file .Name (), ".go" ) {
525+ goFilePath = file .Name ()
526+ break
527+ }
528+ }
529+ if goFilePath == "" {
530+ return "" , fmt .Errorf ("go source file not found %s" , srcDir )
531+ }
532+
533+ packageImport , err := parsePackageImport (goFilePath , srcDir )
534+ if err != nil {
535+ return "" , err
536+ }
537+ return packageImport , nil
538+ }
539+
540+ // parseImportPackage get package import path via source file
541+ func parsePackageImport (source , srcDir string ) (string , error ) {
542+ cfg := & packages.Config {Mode : packages .LoadFiles , Tests : true , Dir : srcDir }
543+ pkgs , err := packages .Load (cfg , "file=" + source )
544+ if err != nil {
545+ return "" , err
546+ }
547+ if packages .PrintErrors (pkgs ) > 0 || len (pkgs ) == 0 {
548+ return "" , errors .New ("loading package failed" )
549+ }
550+
551+ packageImport := pkgs [0 ].PkgPath
552+
553+ // It is illegal to import a _test package.
554+ packageImport = strings .TrimSuffix (packageImport , "_test" )
555+ return packageImport , nil
556+ }
0 commit comments