@@ -85,6 +85,50 @@ func get(gopath, repodir, repo, rev string) error {
8585 return err
8686}
8787
88+ // getModuleDir returns the path of the directory containing a module for the
89+ // given GOPATH and repository dir values.
90+ func getModuleDir (gopath , repodir , module string ) (string , error ) {
91+ cmd := exec .Command ("go" , "list" , "-f" , "{{.Dir}}" , module )
92+ cmd .Dir = repodir
93+ cmd .Stderr = os .Stderr
94+ cmd .Env = append ([]string {
95+ "GOPATH=" + gopath ,
96+ }, passthroughEnv ()... )
97+ out , err := cmd .Output ()
98+ if err != nil {
99+ return "" , fmt .Errorf ("go list: args: %v; error: %w" , cmd .Args , err )
100+ }
101+ return string (bytes .TrimSpace (out )), nil
102+ }
103+
104+ // getDirectDependencies returns a set of all the direct dependencies of a
105+ // module for the given GOPATH and repository dir values. It first finds the
106+ // directory that contains this module, then uses go list in this directory
107+ // to get its direct dependencies.
108+ func getDirectDependencies (gopath , repodir , module string ) (map [string ]bool , error ) {
109+ dir , err := getModuleDir (gopath , repodir , module )
110+ if err != nil {
111+ return nil , fmt .Errorf ("get module dir: %w" , err )
112+ }
113+ cmd := exec .Command ("go" , "list" , "-m" , "-f" , "{{if not .Indirect}}{{.Path}}{{end}}" , "all" )
114+ cmd .Dir = dir
115+ cmd .Stderr = os .Stderr
116+ cmd .Env = append ([]string {
117+ "GOPATH=" + gopath ,
118+ }, passthroughEnv ()... )
119+ out , err := cmd .Output ()
120+ if err != nil {
121+ return nil , fmt .Errorf ("go list: args: %v; error: %w" , cmd .Args , err )
122+ }
123+ out = bytes .TrimRight (out , "\n " )
124+ lines := strings .Split (string (out ), "\n " )
125+ deps := make (map [string ]bool , len (lines ))
126+ for _ , line := range lines {
127+ deps [line ] = true
128+ }
129+ return deps , nil
130+ }
131+
88132func removeVendor (gopath string ) (found bool , _ error ) {
89133 err := filepath .Walk (gopath , func (path string , info os.FileInfo , err error ) error {
90134 if err != nil {
@@ -203,6 +247,12 @@ func estimate(importpath, revision string) error {
203247 return fmt .Errorf ("go mod graph: args: %v; error: %w" , cmd .Args , err )
204248 }
205249
250+ // Get direct dependencies, to filter out indirect ones from go mod graph output
251+ directDeps , err := getDirectDependencies (gopath , repodir , importpath )
252+ if err != nil {
253+ return fmt .Errorf ("get direct dependencies: %w" , err )
254+ }
255+
206256 // Retrieve already-packaged ones
207257 golangBinaries , err := getGolangBinaries ()
208258 if err != nil {
@@ -227,14 +277,29 @@ func estimate(importpath, revision string) error {
227277 // imported it, separated by a single space. The module names
228278 // can have a version information delimited by the @ character
229279 src , dep , _ := strings .Cut (line , " " )
280+ // Get the module names without their version, as we do not use
281+ // this information.
230282 // The root module is the only one that does not have a version
231283 // indication with @ in the output of go mod graph. We use this
232284 // to filter out the depencencies of the "dummymod" module.
233- if mod , _ , found := strings .Cut (src , "@" ); ! found {
285+ dep , _ , _ = strings .Cut (dep , "@" )
286+ src , _ , found := strings .Cut (src , "@" )
287+ if ! found {
234288 continue
235- } else if mod == importpath || strings .HasPrefix (mod , importpath + "/" ) {
289+ }
290+ // Due to importing all packages of the estimated module in a
291+ // dummy one, some modules can depend on submodules of the
292+ // estimated one. We do as if they are dependencies of the
293+ // root one.
294+ if strings .HasPrefix (src , importpath + "/" ) {
236295 src = importpath
237296 }
297+ // go mod graph also lists indirect dependencies as dependencies
298+ // of the current module, so we filter them out. They will still
299+ // appear later.
300+ if src == importpath && ! directDeps [dep ] {
301+ continue
302+ }
238303 depNode , ok := nodes [dep ]
239304 if ! ok {
240305 depNode = & Node {name : dep }
@@ -255,10 +320,7 @@ func estimate(importpath, revision string) error {
255320 needed := make (map [string ]int )
256321 var visit func (n * Node , indent int )
257322 visit = func (n * Node , indent int ) {
258- // Get the module name without its version, as go mod graph
259- // can return multiple times the same module with different
260- // versions.
261- mod , _ , _ := strings .Cut (n .name , "@" )
323+ mod := n .name
262324 count , isNeeded := needed [mod ]
263325 if isNeeded {
264326 count ++
0 commit comments