aboutsummaryrefslogtreecommitdiff
path: root/infra/libkookie/nixpkgs/pkgs/development/libraries/science/math/libtorch/bin.nix
diff options
context:
space:
mode:
Diffstat (limited to 'infra/libkookie/nixpkgs/pkgs/development/libraries/science/math/libtorch/bin.nix')
-rw-r--r--infra/libkookie/nixpkgs/pkgs/development/libraries/science/math/libtorch/bin.nix28
1 files changed, 16 insertions, 12 deletions
diff --git a/infra/libkookie/nixpkgs/pkgs/development/libraries/science/math/libtorch/bin.nix b/infra/libkookie/nixpkgs/pkgs/development/libraries/science/math/libtorch/bin.nix
index 9631f3931cab..241eb5a37211 100644
--- a/infra/libkookie/nixpkgs/pkgs/development/libraries/science/math/libtorch/bin.nix
+++ b/infra/libkookie/nixpkgs/pkgs/development/libraries/science/math/libtorch/bin.nix
@@ -8,11 +8,17 @@
, fixDarwinDylibNames
, cudaSupport
-, nvidia_x11
+, cudatoolkit_10_2
+, cudnn_cudatoolkit_10_2
}:
let
- version = "1.7.1";
+ # The binary libtorch distribution statically links the CUDA
+ # toolkit. This means that we do not need to provide CUDA to
+ # this derivation. However, we should ensure on version bumps
+ # that the CUDA toolkit for `passthru.tests` is still
+ # up-to-date.
+ version = "1.8.0";
device = if cudaSupport then "cuda" else "cpu";
srcs = import ./binary-hashes.nix version;
unavailable = throw "libtorch is not available for this platform";
@@ -24,12 +30,7 @@ in stdenv.mkDerivation {
nativeBuildInputs =
if stdenv.isDarwin then [ fixDarwinDylibNames ]
- else [ addOpenGLRunpath patchelf ]
- ++ lib.optionals cudaSupport [ addOpenGLRunpath ];
-
- buildInputs = [
- stdenv.cc.cc
- ] ++ lib.optionals cudaSupport [ nvidia_x11 ];
+ else [ patchelf ] ++ lib.optionals cudaSupport [ addOpenGLRunpath ];
dontBuild = true;
dontConfigure = true;
@@ -56,9 +57,7 @@ in stdenv.mkDerivation {
'';
postFixup = let
- libPaths = [ stdenv.cc.cc.lib ]
- ++ lib.optionals cudaSupport [ nvidia_x11 ];
- rpath = lib.makeLibraryPath libPaths;
+ rpath = lib.makeLibraryPath [ stdenv.cc.cc.lib ];
in lib.optionalString stdenv.isLinux ''
find $out/lib -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
echo "setting rpath for $lib..."
@@ -108,12 +107,17 @@ in stdenv.mkDerivation {
outputs = [ "out" "dev" ];
- passthru.tests.cmake = callPackage ./test { };
+ passthru.tests.cmake = callPackage ./test {
+ inherit cudaSupport;
+ cudatoolkit = cudatoolkit_10_2;
+ cudnn = cudnn_cudatoolkit_10_2;
+ };
meta = with lib; {
description = "C++ API of the PyTorch machine learning framework";
homepage = "https://pytorch.org/";
license = licenses.unfree; # Includes CUDA and Intel MKL.
+ maintainers = with maintainers; [ danieldk ];
platforms = with platforms; linux ++ darwin;
};
}