diff --git a/LauncherCore/src/main/java/pro/gravit/launcher/AsyncDownloader.java b/LauncherCore/src/main/java/pro/gravit/launcher/AsyncDownloader.java index 49124e95..dd5d73e3 100644 --- a/LauncherCore/src/main/java/pro/gravit/launcher/AsyncDownloader.java +++ b/LauncherCore/src/main/java/pro/gravit/launcher/AsyncDownloader.java @@ -30,6 +30,7 @@ public class AsyncDownloader { @LauncherInject("launcher.certificatePinning") private static boolean isCertificatePinning; private static volatile SSLSocketFactory sslSocketFactory; + private static volatile SSLContext sslContext; public final Callback callback; public AsyncDownloader(Callback callback) { @@ -70,14 +71,20 @@ public void downloadFile(URL url, Path target) throws IOException { } } - public SSLSocketFactory makeSSLSocketFactory() throws NoSuchAlgorithmException, CertificateException, KeyStoreException, IOException, KeyManagementException { + public static SSLSocketFactory makeSSLSocketFactory() throws NoSuchAlgorithmException, CertificateException, KeyStoreException, IOException, KeyManagementException { if (sslSocketFactory != null) return sslSocketFactory; - SSLContext sslContext = SSLContext.getInstance("TLS"); - sslContext.init(null, CertificatePinningTrustManager.getTrustManager().getTrustManagers(), new SecureRandom()); + SSLContext sslContext = makeSSLContext(); sslSocketFactory = sslContext.getSocketFactory(); return sslSocketFactory; } + public static SSLContext makeSSLContext() throws NoSuchAlgorithmException, CertificateException, KeyStoreException, IOException, KeyManagementException { + if (sslContext != null) return sslContext; + SSLContext sslContext = SSLContext.getInstance("TLS"); + sslContext.init(null, CertificatePinningTrustManager.getTrustManager().getTrustManagers(), new SecureRandom()); + return sslContext; + } + public void downloadListInOneThread(List files, String baseURL, Path targetDir) throws URISyntaxException, IOException { URI baseUri = new URI(baseURL); String scheme = baseUri.getScheme(); diff --git a/LauncherCore/src/main/java/pro/gravit/utils/Downloader.java b/LauncherCore/src/main/java/pro/gravit/utils/Downloader.java new file mode 100644 index 00000000..cb76e722 --- /dev/null +++ b/LauncherCore/src/main/java/pro/gravit/utils/Downloader.java @@ -0,0 +1,39 @@ +package pro.gravit.utils; + +import pro.gravit.launcher.AsyncDownloader; + +import java.nio.file.Path; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +public class Downloader { + public interface DownloadCallback { + void apply(long fullDiff); + } + + public CompletableFuture downloadList(List files, String baseURL, Path targetDir, DownloadCallback callback, ExecutorService executor, int threads) throws Exception { + final boolean closeExecutor; + if (executor == null) { + executor = Executors.newWorkStealingPool(4); + closeExecutor = true; + } else { + closeExecutor = false; + } + AsyncDownloader asyncDownloader = new AsyncDownloader((diff) -> { + if (callback != null) { + callback.apply(diff); + } + }); + List> list = asyncDownloader.sortFiles(files, threads); + CompletableFuture future = CompletableFuture.allOf(asyncDownloader.runDownloadList(list, baseURL, targetDir, executor)); + + ExecutorService finalExecutor = executor; + return future.thenAccept(e -> { + if (closeExecutor) { + finalExecutor.shutdownNow(); + } + }); + } +} diff --git a/LauncherCore/src/main/java11/pro/gravit/utils/Downloader.java b/LauncherCore/src/main/java11/pro/gravit/utils/Downloader.java new file mode 100644 index 00000000..9ea80c9a --- /dev/null +++ b/LauncherCore/src/main/java11/pro/gravit/utils/Downloader.java @@ -0,0 +1,168 @@ +package pro.gravit.utils; + +import pro.gravit.launcher.AsyncDownloader; +import pro.gravit.launcher.LauncherInject; + +import java.net.URI; +import java.net.URISyntaxException; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.ByteBuffer; +import java.nio.file.Path; +import java.util.Collections; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; + +public class Downloader { + @LauncherInject("launcher.certificatePinning") + private static boolean isCertificatePinning; + @LauncherInject("launcher.noHttp2") + private static boolean isNoHttp2; + + public interface DownloadCallback { + void apply(long fullDiff); + + void onComplete(Path path); + } + + public static CompletableFuture downloadList(List files, String baseURL, Path targetDir, DownloadCallback callback, ExecutorService executor, int threads) throws Exception { + boolean closeExecutor = false; + if (executor == null) { + executor = Executors.newWorkStealingPool(Math.min(3, threads)); + closeExecutor = true; + } + HttpClient.Builder builder = HttpClient.newBuilder() + .version(isNoHttp2 ? HttpClient.Version.HTTP_1_1 : HttpClient.Version.HTTP_2) + .followRedirects(HttpClient.Redirect.NORMAL) + .executor(executor); + if (isCertificatePinning) { + try { + builder.sslContext(AsyncDownloader.makeSSLContext()); + } catch (Exception e) { + throw new SecurityException(e); + } + } + CompletableFuture future = downloadList(builder.build(), files, baseURL, targetDir, callback, executor, threads); + if (closeExecutor) { + ExecutorService finalExecutor = executor; + future = future.thenAccept(e -> { + finalExecutor.shutdownNow(); + }); + } + return future; + } + + private static class ConsumerObject { + Consumer> next = null; + } + + public static CompletableFuture downloadList(HttpClient client, List files, String baseURL, Path targetDir, DownloadCallback callback, ExecutorService executor, int threads) throws Exception { + // URI scheme + URI baseUri = new URI(baseURL); + Collections.shuffle(files); + Queue queue = new ConcurrentLinkedDeque<>(files); + CompletableFuture future = new CompletableFuture<>(); + AtomicInteger currentThreads = new AtomicInteger(threads); + ConsumerObject consumerObject = new ConsumerObject(); + Consumer> next = e -> { + if (callback != null && e != null) { + callback.onComplete(e.body()); + } + AsyncDownloader.SizedFile file = queue.poll(); + if (file == null) { + if (currentThreads.decrementAndGet() == 0) + future.complete(null); + return; + } + try { + sendAsync(client, file, baseUri, targetDir, callback).thenAccept(consumerObject.next); + } catch (Exception exception) { + future.completeExceptionally(exception); + } + }; + consumerObject.next = next; + for (int i = 0; i < threads; ++i) { + next.accept(null); + } + return future; + } + + private static CompletableFuture> sendAsync(HttpClient client, AsyncDownloader.SizedFile file, URI baseUri, Path targetDir, DownloadCallback callback) throws Exception { + return client.sendAsync(makeHttpRequest(baseUri, file.urlPath), makeBodyHandler(targetDir.resolve(file.filePath), callback)); + } + + private static HttpRequest makeHttpRequest(URI baseUri, String filePath) throws URISyntaxException { + String scheme = baseUri.getScheme(); + String host = baseUri.getHost(); + int port = baseUri.getPort(); + if (port != -1) + host = host + ":" + port; + String path = baseUri.getPath(); + return HttpRequest.newBuilder() + .GET() + .uri(new URI(scheme, host, path + filePath, "", "")) + .header("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/45.0.2454.85 Safari/537.36") + .build(); + } + + private static HttpResponse.BodyHandler makeBodyHandler(Path file, DownloadCallback callback) { + return new ProgressTrackingBodyHandler<>(HttpResponse.BodyHandlers.ofFile(file), callback); + } + + public static class ProgressTrackingBodyHandler implements HttpResponse.BodyHandler { + private final HttpResponse.BodyHandler delegate; + private final DownloadCallback callback; + + public ProgressTrackingBodyHandler(HttpResponse.BodyHandler delegate, DownloadCallback callback) { + this.delegate = delegate; + this.callback = callback; + } + + @Override + public HttpResponse.BodySubscriber apply(HttpResponse.ResponseInfo responseInfo) { + return delegate.apply(responseInfo); + } + + private class ProgressTrackingBodySubscriber implements HttpResponse.BodySubscriber { + private final HttpResponse.BodySubscriber delegate; + + public ProgressTrackingBodySubscriber(HttpResponse.BodySubscriber delegate) { + this.delegate = delegate; + } + + @Override + public CompletionStage getBody() { + return delegate.getBody(); + } + + @Override + public void onSubscribe(Flow.Subscription subscription) { + delegate.onSubscribe(subscription); + } + + @Override + public void onNext(List byteBuffers) { + delegate.onNext(byteBuffers); + long diff = 0; + for (ByteBuffer buffer : byteBuffers) { + diff += buffer.remaining(); + } + if (callback != null) callback.apply(diff); + } + + @Override + public void onError(Throwable throwable) { + delegate.onError(throwable); + } + + @Override + public void onComplete() { + delegate.onComplete(); + } + } + } +}