Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 159 additions & 0 deletions api/src/main/java/ai/djl/repository/RepositoryFactoryImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.util.ClassLoaderUtils;
import ai.djl.util.JsonUtils;
import ai.djl.util.Utils;

import com.google.gson.JsonParseException;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.Reader;
import java.net.MalformedURLException;
import java.net.URI;
Expand Down Expand Up @@ -113,6 +116,12 @@ private static Map<String, RepositoryFactory> init() {
registry.put("file", new LocalRepositoryFactory());
registry.put("jar", new JarRepositoryFactory());
registry.put("djl", new DjlRepositoryFactory());
if (S3RepositoryFactory.findS3Fuse() != null) {
registry.put("s3", new S3RepositoryFactory());
}
if (GcsRepositoryFactory.findGcsFuse() != null) {
registry.put("gs", new GcsRepositoryFactory());
}

ServiceLoader<RepositoryFactory> factories = ServiceLoader.load(RepositoryFactory.class);
for (RepositoryFactory factory : factories) {
Expand All @@ -138,6 +147,34 @@ static String parseFilePath(URI uri) {
return uriPath;
}

private static String exec(String... cmd) throws IOException, InterruptedException {
Process exec = new ProcessBuilder(cmd).redirectErrorStream(true).start();
String logOutput;
try (InputStream is = exec.getInputStream()) {
logOutput = Utils.toString(is);
}
int exitCode = exec.waitFor();
if (0 != exitCode) {
logger.error("exit: {}, {}", exitCode, logOutput);
throw new IOException("Failed to execute: [" + String.join(" ", cmd) + "]");
} else {
logger.debug("{}", logOutput);
}
return logOutput;
}

private static boolean isMounted(String path) throws IOException, InterruptedException {
String out = exec("df");
String[] lines = out.split("\\s");
for (String line : lines) {
if (line.trim().equals(path)) {
logger.debug("Mount point already mounted");
return true;
}
}
return false;
}

private static final class JarRepositoryFactory implements RepositoryFactory {

/** {@inheritDoc} */
Expand Down Expand Up @@ -274,4 +311,126 @@ public Set<String> getSupportedScheme() {
return Collections.singleton("djl");
}
}

static final class S3RepositoryFactory implements RepositoryFactory {

/** {@inheritDoc} */
@Override
public Repository newInstance(String name, URI uri) {
try {
Path path = mount(uri);
return new SimpleRepository(name, uri, path);
} catch (IOException | InterruptedException e) {
throw new IllegalArgumentException("Failed to mount s3 bucket", e);
}
}

/** {@inheritDoc} */
@Override
public Set<String> getSupportedScheme() {
return Collections.singleton("s3");
}

static String findS3Fuse() {
if (System.getProperty("os.name").startsWith("Win")) {
logger.debug("mount-s3 is not supported on Windows");
return null;
}
String gcsFuse = Utils.getEnvOrSystemProperty("MOUNT_S3", "/usr/bin/mount-s3");
if (Files.isRegularFile(Paths.get(gcsFuse))) {
return gcsFuse;
}
String path = System.getenv("PATH");
String[] directories = path.split(File.pathSeparator);
for (String dir : directories) {
Path file = Paths.get(dir, "mount-s3");
if (Files.isRegularFile(file)) {
return file.toAbsolutePath().toString();
}
}
return null;
}

private static Path mount(URI uri) throws IOException, InterruptedException {
String bucket = uri.getHost();
String prefix = uri.getPath();
if (!prefix.isEmpty()) {
prefix = prefix.substring(1);
}
Path dir = Utils.getCacheDir().toAbsolutePath().normalize();
dir = dir.resolve("s3").resolve(Utils.hash(uri.toString()));
String path = dir.toString();
if (Files.isDirectory(dir)) {
if (isMounted(path)) {
return dir.resolve(prefix);
}
} else {
Files.createDirectories(dir);
}

exec(findS3Fuse(), bucket, path);
return dir.resolve(prefix);
}
}

static final class GcsRepositoryFactory implements RepositoryFactory {

/** {@inheritDoc} */
@Override
public Repository newInstance(String name, URI uri) {
try {
Path path = mount(uri);
return new SimpleRepository(name, uri, path);
} catch (IOException | InterruptedException e) {
throw new IllegalArgumentException("Failed to mount gs bucket", e);
}
}

/** {@inheritDoc} */
@Override
public Set<String> getSupportedScheme() {
return Collections.singleton("gs");
}

static String findGcsFuse() {
if (System.getProperty("os.name").startsWith("Win")) {
logger.debug("gcsfuse is not supported on Windows");
return null;
}
String gcsFuse = Utils.getEnvOrSystemProperty("GCSFUSE", "/usr/bin/gcsfuse");
if (Files.isRegularFile(Paths.get(gcsFuse))) {
return gcsFuse;
}
String path = System.getenv("PATH");
String[] directories = path.split(File.pathSeparator);
for (String dir : directories) {
Path file = Paths.get(dir, "gcsfuse");
if (Files.isRegularFile(file)) {
return file.toAbsolutePath().toString();
}
}
return null;
}

private static Path mount(URI uri) throws IOException, InterruptedException {
String bucket = uri.getHost();
String prefix = uri.getPath();
if (!prefix.isEmpty()) {
prefix = prefix.substring(1);
}
Path dir = Utils.getCacheDir().toAbsolutePath().normalize();
dir = dir.resolve("gs").resolve(Utils.hash(uri.toString()));
String path = dir.toString();
if (Files.isDirectory(dir)) {
if (isMounted(path)) {
return dir.resolve(prefix);
}
} else {
Files.createDirectories(dir);
}

exec(findGcsFuse(), "--implicit-dirs", bucket, path);
return dir.resolve(prefix);
}
}
}
80 changes: 80 additions & 0 deletions api/src/test/java/ai/djl/repository/FuseRepositoryTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.repository;

import org.testng.Assert;
import org.testng.SkipException;
import org.testng.annotations.Test;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.attribute.PosixFilePermission;
import java.nio.file.attribute.PosixFilePermissions;
import java.util.Set;

public class FuseRepositoryTest {

@Test
public void testGcsRepository() throws IOException {
if (System.getProperty("os.name").startsWith("Win")) {
throw new SkipException("GcsRepository is not supported on Windows");
}

Path gcsfuse = Paths.get("build/gcsfuse");
Set<PosixFilePermission> permissions = PosixFilePermissions.fromString("rwxr-xr-x");
Files.write(gcsfuse, new byte[0]);
Files.setAttribute(gcsfuse, "posix:permissions", permissions);

System.setProperty("GCSFUSE", "build/gcsfuse");
System.setProperty("DJL_CACHE_DIR", "build/cache");
Repository.registerRepositoryFactory(new RepositoryFactoryImpl.GcsRepositoryFactory());
try {
Repository repo = Repository.newInstance("gs", "gs://djl/resnet");
Assert.assertEquals(repo.getResources().size(), 0);

// test folder already exist
Repository.newInstance("gs", "gs://djl/resnet");
} finally {
System.clearProperty("GCSFUSE");
System.clearProperty("DJL_CACHE_DIR");
}
}

@Test
public void testS3Repository() throws IOException {
if (System.getProperty("os.name").startsWith("Win")) {
throw new SkipException("S3Repository is not supported on Windows");
}

Path gcsfuse = Paths.get("build/mount-s3");
Set<PosixFilePermission> permissions = PosixFilePermissions.fromString("rwxr-xr-x");
Files.write(gcsfuse, new byte[0]);
Files.setAttribute(gcsfuse, "posix:permissions", permissions);

System.setProperty("MOUNT_S3", "build/mount-s3");
System.setProperty("DJL_CACHE_DIR", "build/cache");
Repository.registerRepositoryFactory(new RepositoryFactoryImpl.S3RepositoryFactory());
try {
Repository repo = Repository.newInstance("s3", "s3://djl/resnet");
Assert.assertEquals(repo.getResources().size(), 0);

// test folder already exist
Repository.newInstance("s3", "s3://djl/resnet");
} finally {
System.clearProperty("MOUNT_S3");
System.clearProperty("DJL_CACHE_DIR");
}
}
}
Loading