Skip to content

Commit c0bad78

Browse files
author
Connor Baker
committed
cudaPackages: tidy, add cudnn back (causes infinite recursion due to use of flags attribute?)
1 parent 7d0f702 commit c0bad78

15 files changed

Lines changed: 279 additions & 263 deletions

File tree

pkgs/development/cuda-modules/backendStdenv.nix

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,24 @@
1-
{ lib
2-
, nixpkgsCompatibleLibstdcxx
3-
, nvccCompatibleCC
4-
, overrideCC
5-
, stdenv
6-
, wrapCCWith
7-
}:
1+
{
2+
lib,
3+
nvccCompatibilities,
4+
cudaVersion,
5+
buildPackages,
6+
overrideCC,
7+
stdenv,
8+
wrapCCWith,
9+
}: let
10+
gccMajorVersion = nvccCompatibilities.${cudaVersion}.gccMaxMajorVersion;
11+
# We use buildPackages (= pkgsBuildHost) because we look for a gcc that
12+
# runs on our build platform, and that produces executables for the host
13+
# platform (= platform on which we deploy and run the downstream packages).
14+
# The target platform of buildPackages.gcc is our host platform, so its
15+
# .lib output should be the libstdc++ we want to be writing in the runpaths
16+
# Cf. https://github.com/NixOS/nixpkgs/pull/225661#discussion_r1164564576
17+
nixpkgsCompatibleLibstdcxx = buildPackages.gcc.cc.lib;
18+
nvccCompatibleCC = buildPackages."gcc${gccMajorVersion}".cc;
819

9-
let
10-
cc = wrapCCWith
20+
cc =
21+
wrapCCWith
1122
{
1223
cc = nvccCompatibleCC;
1324

@@ -26,8 +37,7 @@ let
2637
};
2738
assertCondition = true;
2839
in
29-
lib.extendDerivation
40+
lib.extendDerivation
3041
assertCondition
3142
passthruExtra
3243
cudaStdenv
33-

pkgs/development/cuda-modules/cuda/overrides.nix

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,13 @@
11
final: prev: let
2-
inherit (prev) lib pkgs;
2+
inherit (prev) pkgs;
3+
inherit (prev.lib) attrsets lists strings;
34
# cudaVersionOlder : Version -> Boolean
4-
cudaVersionOlder = lib.versionOlder final.cudaVersion;
5+
cudaVersionOlder = strings.versionOlder final.cudaVersion;
56
# cudaVersionAtLeast : Version -> Boolean
6-
cudaVersionAtLeast = lib.versionAtLeast final.cudaVersion;
7-
# cudaVersionAtMost : Version -> Boolean
8-
cudaVersionAtMost = flip versionAtLeast cudaVersion;
9-
# cudaVersionBounded : Version -> Version -> Boolean
10-
# NOTE: This is inclusive on both ends.
11-
cudaVersionBounded = min: max: cudaVersionAtLeast min && cudaVersionAtMost max;
12-
7+
cudaVersionAtLeast = strings.versionAtLeast final.cudaVersion;
138
inherit (builtins) hasAttr;
14-
inherit (final) cudaVersion addBuildInputs addAutoPatchelfIgnoreMissingDeps;
15-
inherit (prev.lib.attrsets) filterAttrs optionalAttrs;
16-
inherit (prev.lib.lists) optionals;
17-
inherit (prev.lib.strings) versionAtLeast;
18-
inherit (prev.lib.trivial) flip pipe;
19-
inherit (prev.pkgs.stdenv.hostPlatform) isx86_64 isAarch64 isPower64;
20-
inherit
21-
(prev.pkgs)
22-
pkgsBuildHost # for nativeBuildInputs
23-
pkgsHostTarget # good ol' pkgs, for buildInputs
24-
;
259
in
26-
filterAttrs (attr: _: (hasAttr attr prev)) {
10+
attrsets.filterAttrs (attr: _: (hasAttr attr prev)) {
2711
### Overrides to fix the components of cudatoolkit-redist
2812

2913
# Attributes that don't exist in the previous set are removed.
@@ -41,7 +25,7 @@ in
4125
autoPatchelfIgnoreMissingDeps =
4226
["libcuda.so.1"]
4327
# Before 12.0 libcufile depends on itself for some reason.
44-
++ lib.optionals (cudaVersionOlder "12.0") [
28+
++ lists.optionals (cudaVersionOlder "12.0") [
4529
"libcufile.so.0"
4630
];
4731
});
@@ -50,24 +34,24 @@ in
5034
# Always depends on this
5135
[final.libcublas.lib]
5236
# Dependency from 12.0 and on
53-
++ lib.optionals (cudaVersionAtLeast "12.0") [
37+
++ lists.optionals (cudaVersionAtLeast "12.0") [
5438
final.libnvjitlink.lib
5539
]
5640
# Dependency from 12.1 and on
57-
++ lib.optionals (cudaVersionAtLeast "12.1") [
41+
++ lists.optionals (cudaVersionAtLeast "12.1") [
5842
final.libcusparse.lib
5943
]
6044
);
6145

6246
libcusparse = final.addBuildInputs prev.libcusparse (
63-
lib.optionals (cudaVersionAtLeast "12.0") [
47+
lists.optionals (cudaVersionAtLeast "12.0") [
6448
final.libnvjitlink.lib
6549
]
6650
);
6751

6852
cuda_gdb = final.addBuildInputs prev.cuda_gdb (
6953
# x86_64 only needs gmp from 12.0 and on
70-
lib.optionals (cudaVersionAtLeast "12.0") [
54+
lists.optionals (cudaVersionAtLeast "12.0") [
7155
pkgs.gmp
7256
]
7357
);

pkgs/development/cuda-modules/cudnn/extension.nix

Lines changed: 45 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,9 @@
22
# https://docs.nvidia.com/deeplearning/cudnn/archives/cudnn-880/support-matrix/index.html
33
final: prev: let
44
inherit (final) callPackage;
5-
inherit (prev) cudaVersion;
6-
inherit (prev.lib) attrsets lists modules versions;
7-
inherit (prev.lib.strings) replaceStrings versionAtLeast versionOlder;
8-
9-
# Compute versioned attribute name to be used in this package set
10-
# Patch version changes should not break the build, so we only use major and minor
11-
# computeName :: String -> String
12-
computeName = version: "cudnn_${replaceStrings ["."] ["_"] (versions.majorMinor version)}";
13-
14-
# Check whether a CUDNN release supports our CUDA version
15-
# Thankfully we're able to do lexicographic comparison on the version strings
16-
# isSupported :: Release -> Bool
17-
isSupported = release:
18-
versionAtLeast cudaVersion release.minCudaVersion
19-
&& versionAtLeast release.maxCudaVersion cudaVersion;
5+
inherit (prev) cudaVersion flags;
6+
inherit (prev.pkgs) hostPlatform;
7+
inherit (prev.lib) attrsets lists modules versions strings;
208

219
evaluatedModules = modules.evalModules {
2210
modules = [
@@ -25,43 +13,55 @@ final: prev: let
2513
];
2614
};
2715

28-
cudnn_releases = evaluatedModules.config.cudnn.releases;
16+
# NOTE: Important types:
17+
# - Releases: ../modules/cudnn/releases/releases.nix
18+
# - Package: ../modules/cudnn/releases/package.nix
2919

30-
# useCudatoolkitRunfile :: Bool
31-
useCudatoolkitRunfile = versionOlder cudaVersion "11.3.999";
20+
# All CUDNN releases across all platforms
21+
# See ../modules/cudnn/releases/releases.nix
22+
allCudnnReleases = evaluatedModules.config.cudnn.releases;
3223

33-
# buildCuDnnPackage :: Release -> Derivation
34-
buildCuDnnPackage = callPackage ./generic.nix {inherit useCudatoolkitRunfile;};
24+
# Compute versioned attribute name to be used in this package set
25+
# Patch version changes should not break the build, so we only use major and minor
26+
# computeName :: Package -> String
27+
computeName = package: "cudnn_${strings.replaceStrings ["."] ["_"] (versions.majorMinor package.version)}";
3528

36-
# Reverse the list to have the latest release first
37-
# cudnnReleases :: List Release
38-
cudnnReleases = lists.reverseList cudnnReleaseAttrs.${cudaRedistPlatform};
29+
# Check whether a CUDNN package supports our CUDA version
30+
# isSupported :: Package -> Bool
31+
isSupported = package:
32+
strings.versionAtLeast cudaVersion package.minCudaVersion
33+
&& strings.versionAtLeast package.maxCudaVersion cudaVersion;
3934

40-
# Check whether a CUDNN release supports our CUDA version
41-
# supportedReleases :: List Release
42-
supportedReleases = builtins.filter isSupported cudnnReleases;
35+
# Get all of the packages for our given platform.
36+
redistArch = flags.getRedistArch hostPlatform.system;
4337

44-
# Function to transform our releases into build attributes
45-
# toBuildAttrs :: Release -> { name: String, value: Derivation }
46-
toBuildAttrs = release: {
47-
name = computeName release.version;
48-
value = buildCuDnnPackage release;
49-
};
38+
# All the packages for our platform.
39+
# cudnnPackages :: List (AttrSet Packages)
40+
cudnnPackages = builtins.filter isSupported (allCudnnReleases.${redistArch} or []);
5041

51-
# Add all supported builds as attributes
52-
# allBuilds :: AttrSet String Derivation
53-
allBuilds = builtins.listToAttrs (builtins.map toBuildAttrs supportedReleases);
42+
# newestToOldestCudnnPackages :: List (AttrSet Packages)
43+
newestToOldestCudnnPackages = lists.reverseList cudnnPackages;
5444

55-
defaultBuild = attrsets.optionalAttrs (supportedReleases != []) {
56-
cudnn = let
57-
# The latest release is the first element of the list and will be our default choice
58-
# latestReleaseName :: String
59-
latestReleaseName = computeName (builtins.head supportedReleases).version;
60-
in
61-
allBuilds.${latestReleaseName};
45+
# buildCudnnPackage :: Package -> Derivation
46+
buildCudnnPackage = package: {
47+
name = computeName package;
48+
value = callPackage ./generic.nix {
49+
inherit package;
50+
platforms = lists.map (flags.getNixSystem) (builtins.attrNames allCudnnReleases);
51+
useCudatoolkitRunfile = strings.versionOlder cudaVersion "11.3.999";
52+
};
6253
};
6354

64-
# builds :: AttrSet String Derivation
65-
builds = allBuilds // defaultBuild;
55+
# allCudnnDerivations :: AttrSet Derivation
56+
versionedCudnnDerivations = builtins.listToAttrs (lists.map buildCudnnPackage newestToOldestCudnnPackages);
57+
58+
# allCudnnDerivations :: AttrSet Derivation
59+
allCudnnDerivations = let
60+
nameOfNewest = computeName (builtins.head newestToOldestCudnnPackages);
61+
containsDefault = attrsets.optionalAttrs (versionedCudnnDerivations != {}) {
62+
cudnn = versionedCudnnDerivations.${nameOfNewest};
63+
};
64+
in
65+
versionedCudnnDerivations // containsDefault;
6666
in
67-
builds
67+
allCudnnDerivations

pkgs/development/cuda-modules/cudnn/generic.nix

Lines changed: 49 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,26 @@
1-
{ stdenv,
1+
{
2+
# General arguments supplied by callPackage
3+
stdenv,
24
backendStdenv,
35
lib,
46
lndir,
57
zlib,
6-
useCudatoolkitRunfile ? false,
7-
cudaVersion,
8-
cudaMajorVersion,
9-
cudatoolkit, # For cuda < 11
108
libcublas ? null, # cuda <11 doesn't ship redist packages
119
autoPatchelfHook,
1210
autoAddOpenGLRunpathHook,
1311
fetchurl,
14-
}: {
15-
version,
16-
url,
17-
hash,
18-
minCudaVersion,
19-
maxCudaVersion,
12+
cudatoolkit, # For cuda < 11
13+
cudaVersion,
14+
# Arguments supplied by the caller
15+
useCudatoolkitRunfile ? false,
16+
# See ../modules/cudnn/releases/package.nix for type of package
17+
package,
18+
# Platforms supported by the package
19+
# platforms :: List String
20+
platforms,
2021
}:
21-
assert useCudatoolkitRunfile || (libcublas != null); let
22-
inherit (lib) lists strings trivial versions;
22+
assert libcublas == null -> useCudatoolkitRunfile; let
23+
inherit (lib) lists strings trivial versions maintainers licenses meta sourceTypes;
2324

2425
# majorMinorPatch :: String -> String
2526
majorMinorPatch = (trivial.flip trivial.pipe) [
@@ -30,16 +31,16 @@ assert useCudatoolkitRunfile || (libcublas != null); let
3031

3132
# versionTriple :: String
3233
# Version with three components: major.minor.patch
33-
versionTriple = majorMinorPatch version;
34+
versionTriple = majorMinorPatch package.version;
3435
in
3536
backendStdenv.mkDerivation {
36-
pname = "cudatoolkit-${cudaMajorVersion}-cudnn";
37-
version = versionTriple;
37+
pname = "cudnn";
38+
inherit (package) version;
3839
strictDeps = true;
3940
outputs = ["out" "lib" "static" "dev"];
4041

4142
src = fetchurl {
42-
inherit url hash;
43+
inherit (package) url hash;
4344
};
4445

4546
# We do need some other phases, like configurePhase, so the multiple-output setup hook works.
@@ -53,17 +54,20 @@ in
5354
];
5455

5556
# Used by autoPatchelfHook
56-
buildInputs = [
57-
# Note this libstdc++ isn't from the (possibly older) nvcc-compatible
58-
# stdenv, but from the (newer) stdenv that the rest of nixpkgs uses
59-
stdenv.cc.cc.lib
60-
61-
zlib
62-
] ++ lists.optionals useCudatoolkitRunfile [
63-
cudatoolkit
64-
] ++ lists.optionals (!useCudatoolkitRunfile) [
65-
libcublas.lib
66-
];
57+
buildInputs =
58+
[
59+
# Note this libstdc++ isn't from the (possibly older) nvcc-compatible
60+
# stdenv, but from the (newer) stdenv that the rest of nixpkgs uses
61+
stdenv.cc.cc.lib
62+
63+
zlib
64+
]
65+
++ lists.optionals useCudatoolkitRunfile [
66+
cudatoolkit
67+
]
68+
++ lists.optionals (!useCudatoolkitRunfile) [
69+
libcublas.lib
70+
];
6771

6872
# We used to patch Runpath here, but now we use autoPatchelfHook
6973
#
@@ -74,18 +78,17 @@ in
7478
# output.
7579
# Note that moveToOutput operates on all outputs:
7680
# https://github.com/NixOS/nixpkgs/blob/2920b6fc16a9ed5d51429e94238b28306ceda79e/pkgs/build-support/setup-hooks/multiple-outputs.sh#L105-L107
77-
installPhase =
78-
''
79-
runHook preInstall
81+
installPhase = ''
82+
runHook preInstall
8083
81-
mkdir -p "$out"
82-
mv * "$out"
83-
moveToOutput "lib64" "$lib"
84-
moveToOutput "lib" "$lib"
85-
moveToOutput "**/*.a" "$static"
84+
mkdir -p "$out"
85+
mv * "$out"
86+
moveToOutput "lib64" "$lib"
87+
moveToOutput "lib" "$lib"
88+
moveToOutput "**/*.a" "$static"
8689
87-
runHook postInstall
88-
'';
90+
runHook postInstall
91+
'';
8992

9093
# Without --add-needed autoPatchelf forgets $ORIGIN on cuda>=8.0.5.
9194
postFixup = strings.optionalString (strings.versionAtLeast versionTriple "8.0.5") ''
@@ -109,9 +112,9 @@ in
109112
# found: <customPhaseName>".
110113
postPatchelf = ''
111114
mkdir -p "$out"
112-
${lib.meta.getExe lndir} "$lib" "$out"
113-
${lib.meta.getExe lndir} "$static" "$out"
114-
${lib.meta.getExe lndir} "$dev" "$out"
115+
${meta.getExe lndir} "$lib" "$out"
116+
${meta.getExe lndir} "$static" "$out"
117+
${meta.getExe lndir} "$dev" "$out"
115118
'';
116119

117120
passthru = {
@@ -132,22 +135,22 @@ in
132135
# unqualified (that is, without an explicit output).
133136
outputSpecified = true;
134137

135-
meta = with lib; {
138+
meta = {
136139
# Check that the cudatoolkit version satisfies our min/max constraints (both
137140
# inclusive). We mark the package as broken if it fails to satisfies the
138141
# official version constraints (as recorded in default.nix). In some cases
139142
# you _may_ be able to smudge version constraints, just know that you're
140143
# embarking into unknown and unsupported territory when doing so.
141144
broken =
142-
strings.versionOlder cudaVersion minCudaVersion
143-
|| strings.versionOlder maxCudaVersion cudaVersion;
145+
strings.versionOlder cudaVersion package.minCudaVersion
146+
|| strings.versionOlder package.maxCudaVersion cudaVersion;
144147
description = "NVIDIA CUDA Deep Neural Network library (cuDNN)";
145148
homepage = "https://developer.nvidia.com/cudnn";
146149
sourceProvenance = with sourceTypes; [binaryNativeCode];
147150
# TODO: consider marking unfreRedistributable when not using runfile
148151
license = licenses.unfree;
149-
platforms = ["x86_64-linux"];
150-
maintainers = with maintainers; [mdaiter samuela];
152+
inherit platforms;
153+
maintainers = with maintainers; [mdaiter samuela connorbaker];
151154
# Force the use of the default, fat output by default (even though `dev` exists, which
152155
# causes Nix to prefer that output over the others if outputSpecified isn't set).
153156
outputsToInstall = ["out"];

0 commit comments

Comments
 (0)