Skip to content

Commit 9d4365d

Browse files
committed
nix: add cuda, use a symlinked toolkit instead for cmake
1 parent 51a7cf5 commit 9d4365d

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

flake.nix

+21
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,20 @@
3535
);
3636
pkgs = import nixpkgs { inherit system; };
3737
nativeBuildInputs = with pkgs; [ cmake ninja pkg-config ];
38+
cudatoolkit_joined = with pkgs; symlinkJoin {
39+
# HACK(Green-Sky): nix currently has issues with cmake findcudatoolkit
40+
# see https://github.com/NixOS/nixpkgs/issues/224291
41+
# copied from jaxlib
42+
name = "${cudaPackages.cudatoolkit.name}-merged";
43+
paths = [
44+
cudaPackages.cudatoolkit.lib
45+
cudaPackages.cudatoolkit.out
46+
] ++ lib.optionals (lib.versionOlder cudaPackages.cudatoolkit.version "11") [
47+
# for some reason some of the required libs are in the targets/x86_64-linux
48+
# directory; not sure why but this works around it
49+
"${cudaPackages.cudatoolkit}/targets/${system}"
50+
];
51+
};
3852
llama-python =
3953
pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]);
4054
postPatch = ''
@@ -70,6 +84,13 @@
7084
"-DLLAMA_CLBLAST=ON"
7185
];
7286
};
87+
packages.cuda = pkgs.stdenv.mkDerivation {
88+
inherit name src meta postPatch nativeBuildInputs postInstall;
89+
buildInputs = with pkgs; buildInputs ++ [ cudatoolkit_joined ];
90+
cmakeFlags = cmakeFlags ++ [
91+
"-DLLAMA_CUBLAS=ON"
92+
];
93+
};
7394
packages.rocm = pkgs.stdenv.mkDerivation {
7495
inherit name src meta postPatch nativeBuildInputs postInstall;
7596
buildInputs = with pkgs; buildInputs ++ [ hip hipblas rocblas ];

0 commit comments

Comments
 (0)