diff --git a/src/main/java/rx/internal/operators/OnSubscribeRedo.java b/src/main/java/rx/internal/operators/OnSubscribeRedo.java index 1ba5d1f281..70bb5a8d21 100644 --- a/src/main/java/rx/internal/operators/OnSubscribeRedo.java +++ b/src/main/java/rx/internal/operators/OnSubscribeRedo.java @@ -35,7 +35,6 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; import rx.Notification; import rx.Observable; @@ -47,13 +46,14 @@ import rx.functions.Action0; import rx.functions.Func1; import rx.functions.Func2; +import rx.observers.Subscribers; import rx.schedulers.Schedulers; -import rx.subjects.PublishSubject; +import rx.subjects.BehaviorSubject; import rx.subscriptions.SerialSubscription; public final class OnSubscribeRedo implements OnSubscribe { - static final Func1>, Observable> REDO_INIFINITE = new Func1>, Observable>() { + static final Func1>, Observable> REDO_INFINITE = new Func1>, Observable>() { @Override public Observable call(Observable> ts) { return ts.map(new Func1, Notification>() { @@ -120,7 +120,7 @@ public Notification call(Notification n, Notification term) } public static Observable retry(Observable source) { - return retry(source, REDO_INIFINITE); + return retry(source, REDO_INFINITE); } public static Observable retry(Observable source, final long count) { @@ -144,7 +144,7 @@ public static Observable repeat(Observable source) { } public static Observable repeat(Observable source, Scheduler scheduler) { - return repeat(source, REDO_INIFINITE, scheduler); + return repeat(source, REDO_INFINITE, scheduler); } public static Observable repeat(Observable source, final long count) { @@ -172,10 +172,10 @@ public static Observable redo(Observable source, Func1(source, notificationHandler, false, false, scheduler)); } - private Observable source; + private final Observable source; private final Func1>, ? extends Observable> controlHandlerFunction; - private boolean stopOnComplete; - private boolean stopOnError; + private final boolean stopOnComplete; + private final boolean stopOnError; private final Scheduler scheduler; private OnSubscribeRedo(Observable source, Func1>, ? extends Observable> f, boolean stopOnComplete, boolean stopOnError, @@ -189,11 +189,12 @@ private OnSubscribeRedo(Observable source, Func1 child) { - final AtomicBoolean isLocked = new AtomicBoolean(true); + + // when true is a marker to say we are ready to resubscribe to source final AtomicBoolean resumeBoundary = new AtomicBoolean(true); + // incremented when requests are made, decremented when requests are fulfilled final AtomicLong consumerCapacity = new AtomicLong(0l); - final AtomicReference currentProducer = new AtomicReference(); final Scheduler.Worker worker = scheduler.createWorker(); child.add(worker); @@ -201,8 +202,18 @@ public void call(final Subscriber child) { final SerialSubscription sourceSubscriptions = new SerialSubscription(); child.add(sourceSubscriptions); - final PublishSubject> terminals = PublishSubject.create(); - + // use a subject to receive terminals (onCompleted and onError signals) from + // the source observable. We use a BehaviorSubject because subscribeToSource + // may emit a terminal before the restarts observable (transformed terminals) + // is subscribed + final BehaviorSubject> terminals = BehaviorSubject.create(); + final Subscriber> dummySubscriber = Subscribers.empty(); + // subscribe immediately so the last emission will be replayed to the next + // subscriber (which is the one we care about) + terminals.subscribe(dummySubscriber); + + final ProducerArbiter arbiter = new ProducerArbiter(); + final Action0 subscribeToSource = new Action0() { @Override public void call() { @@ -212,11 +223,11 @@ public void call() { Subscriber terminalDelegatingSubscriber = new Subscriber() { boolean done; + @Override public void onCompleted() { if (!done) { done = true; - currentProducer.set(null); unsubscribe(); terminals.onNext(Notification.createOnCompleted()); } @@ -226,7 +237,6 @@ public void onCompleted() { public void onError(Throwable e) { if (!done) { done = true; - currentProducer.set(null); unsubscribe(); terminals.onNext(Notification.createOnError(e)); } @@ -235,20 +245,30 @@ public void onError(Throwable e) { @Override public void onNext(T v) { if (!done) { - if (consumerCapacity.get() != Long.MAX_VALUE) { - consumerCapacity.decrementAndGet(); - } child.onNext(v); + decrementConsumerCapacity(); + arbiter.produced(1); + } + } + + private void decrementConsumerCapacity() { + // use a CAS loop because we don't want to decrement the + // value if it is Long.MAX_VALUE + while (true) { + long cc = consumerCapacity.get(); + if (cc != Long.MAX_VALUE) { + if (consumerCapacity.compareAndSet(cc, cc - 1)) { + break; + } + } else { + break; + } } } @Override public void setProducer(Producer producer) { - currentProducer.set(producer); - long c = consumerCapacity.get(); - if (c > 0) { - producer.request(c); - } + arbiter.setProducer(producer); } }; // new subscription each time so if it unsubscribes itself it does not prevent retries @@ -278,12 +298,11 @@ public void onError(Throwable e) { @Override public void onNext(Notification t) { - if (t.isOnCompleted() && stopOnComplete) - child.onCompleted(); - else if (t.isOnError() && stopOnError) - child.onError(t.getThrowable()); - else { - isLocked.set(false); + if (t.isOnCompleted() && stopOnComplete) { + filteredTerminals.onCompleted(); + } else if (t.isOnError() && stopOnError) { + filteredTerminals.onError(t.getThrowable()); + } else { filteredTerminals.onNext(t); } } @@ -313,10 +332,15 @@ public void onError(Throwable e) { @Override public void onNext(Object t) { - if (!isLocked.get() && !child.isUnsubscribed()) { + if (!child.isUnsubscribed()) { + // perform a best endeavours check on consumerCapacity + // with the intent of only resubscribing immediately + // if there is outstanding capacity if (consumerCapacity.get() > 0) { worker.schedule(subscribeToSource); } else { + // set this to true so that on next request + // subscribeToSource will be scheduled resumeBoundary.compareAndSet(false, true); } } @@ -334,16 +358,180 @@ public void setProducer(Producer producer) { @Override public void request(final long n) { - long c = BackpressureUtils.getAndAddRequest(consumerCapacity, n); - Producer producer = currentProducer.get(); - if (producer != null) { - producer.request(n); - } else - if (c == 0 && resumeBoundary.compareAndSet(true, false)) { - worker.schedule(subscribeToSource); + if (n > 0) { + BackpressureUtils.getAndAddRequest(consumerCapacity, n); + arbiter.request(n); + if (resumeBoundary.compareAndSet(true, false)) + worker.schedule(subscribeToSource); } } }); } + + /** + * Between when the source subscription finishes and the next subscription starts requests may arrive. + * ProducerArbiter keeps track of all requests made and all arriving emissions so that when setProducer + * is called for a new subscription the appropriate number of requests are made to the new producer. + */ + static final class ProducerArbiter implements Producer { + /** Guarded by this. */ + boolean emitting; + /** The current producer. Accessed while emitting. */ + Producer currentProducer; + /** The current requested count. */ + long requested; + + long missedRequested; + Producer missedProducer; + long missedProd; + + @Override + public void request(long n) { + if (n <= 0) { + return; + } + Producer mp; + long mprod; + synchronized (this) { + if (emitting) { + missedRequested += n; + return; + } + emitting = true; + mp = missedProducer; + mprod = missedProd; + + missedProducer = null; + missedProd = 0L; + } + + boolean skipFinal = false; + try { + emit(n, mp, mprod); + drainLoop(); + skipFinal = true; + } finally { + if (!skipFinal) { + synchronized (this) { + emitting = false; + } + } + } + } + void setProducer(Producer p) { + if (p == null) { + throw new NullPointerException(); + } + + long mreq; + long mprod; + synchronized (this) { + if (emitting) { + missedProducer = p; + return; + } + emitting = true; + mreq = missedRequested; + mprod = missedProd; + + missedRequested = 0L; + missedProd = 0L; + } + + boolean skipFinal = false; + try { + emit(mreq, p, mprod); + drainLoop(); + skipFinal = true; + } finally { + if (!skipFinal) { + synchronized (this) { + emitting = false; + } + } + } + } + void produced(long n) { + if (n <= 0) { + throw new IllegalArgumentException(n + " produced?!"); + } + + long mreq; + Producer mp; + synchronized (this) { + if (emitting) { + missedProd += n; + return; + } + emitting = true; + mreq = missedRequested; + mp = missedProducer; + + missedRequested = 0L; + missedProducer = null; + } + + boolean skipFinal = false; + try { + emit(mreq, mp, n); + drainLoop(); + skipFinal = true; + } finally { + if (!skipFinal) { + synchronized (this) { + emitting = false; + } + } + } + } + void drainLoop() { + for (;;) { + long mreq; + long mprod; + Producer mp; + synchronized (this) { + mreq = missedRequested; + mprod = missedProd; + mp = missedProducer; + if (mreq == 0L && mp == null && mprod == 0L) { + emitting = false; + return; + } + missedRequested = 0L; + missedProd = 0L; + missedProducer = null; + } + emit(mreq, mp, mprod); + } + } + void emit(long mreq, Producer mp, long mprod) { + boolean newMp = false; + if (mp != null) { + newMp = true; + currentProducer = mp; + } else { + mp = currentProducer; + } + + long u = requested + mreq; + if (u < 0) { + u = Long.MAX_VALUE; + } else + if (u != Long.MAX_VALUE) { + u -= mprod; + if (u < 0) { + throw new IllegalStateException("More produced than requested"); + } + } + requested = u; + + if (mreq > 0 && mp != null) { + mp.request(mreq); + } else + if (newMp && u > 0) { + mp.request(u); + } + } + } } diff --git a/src/main/java/rx/internal/operators/OperatorObserveOn.java b/src/main/java/rx/internal/operators/OperatorObserveOn.java index 13a78ca14c..e15c2f93cf 100644 --- a/src/main/java/rx/internal/operators/OperatorObserveOn.java +++ b/src/main/java/rx/internal/operators/OperatorObserveOn.java @@ -75,15 +75,19 @@ private static final class ObserveOnSubscriber extends Subscriber { final NotificationLite on = NotificationLite.instance(); final Queue queue; - volatile boolean completed = false; - volatile boolean failure = false; + + // the status of the current stream + volatile boolean finished = false; + @SuppressWarnings("unused") volatile long requested = 0; + @SuppressWarnings("rawtypes") static final AtomicLongFieldUpdater REQUESTED = AtomicLongFieldUpdater.newUpdater(ObserveOnSubscriber.class, "requested"); @SuppressWarnings("unused") volatile long counter; + @SuppressWarnings("rawtypes") static final AtomicLongFieldUpdater COUNTER_UPDATER = AtomicLongFieldUpdater.newUpdater(ObserveOnSubscriber.class, "counter"); @@ -127,7 +131,7 @@ public void onStart() { @Override public void onNext(final T t) { - if (isUnsubscribed() || completed) { + if (isUnsubscribed()) { return; } if (!queue.offer(on.next(t))) { @@ -139,30 +143,23 @@ public void onNext(final T t) { @Override public void onCompleted() { - if (isUnsubscribed() || completed) { + if (isUnsubscribed() || finished) { return; } - if (error != null) { - return; - } - completed = true; + finished = true; schedule(); } @Override public void onError(final Throwable e) { - if (isUnsubscribed() || completed) { - return; - } - if (error != null) { + if (isUnsubscribed() || finished) { return; } error = e; // unsubscribe eagerly since time will pass before the scheduled onError results in an unsubscribe event unsubscribe(); - // mark failure so the polling thread will skip onNext still in the queue - completed = true; - failure = true; + finished = true; + // polling thread should skip any onNext still in the queue schedule(); } @@ -185,52 +182,42 @@ protected void schedule() { void pollQueue() { int emitted = 0; do { - /* - * Set to 1 otherwise it could have grown very large while in the last poll loop - * and then we can end up looping all those times again here before exiting even once we've drained - */ counter = 1; - -// middle: - while (!scheduledUnsubscribe.isUnsubscribed()) { - if (failure) { - child.onError(error); - return; - } else { - if (requested == 0 && completed && queue.isEmpty()) { + long produced = 0; + long r = requested; + while (!child.isUnsubscribed()) { + Throwable error; + if (finished) { + if ((error = this.error) != null) { + // errors shortcut the queue so + // release the elements in the queue for gc + queue.clear(); + child.onError(error); + return; + } else + if (queue.isEmpty()) { child.onCompleted(); return; } - if (REQUESTED.getAndDecrement(this) != 0) { - Object o = queue.poll(); - if (o == null) { - if (completed) { - if (failure) { - child.onError(error); - } else { - child.onCompleted(); - } - return; - } - // nothing in queue - REQUESTED.incrementAndGet(this); - break; - } else { - if (!on.accept(child, o)) { - // non-terminal event so let's increment count - emitted++; - } - } + } + if (r > 0) { + Object o = queue.poll(); + if (o != null) { + child.onNext(on.getValue(o)); + r--; + emitted++; + produced++; } else { - // we hit the end ... so increment back to 0 again - REQUESTED.incrementAndGet(this); break; } + } else { + break; } } + if (produced > 0 && requested != Long.MAX_VALUE) { + REQUESTED.addAndGet(this, -produced); + } } while (COUNTER_UPDATER.decrementAndGet(this) > 0); - - // request the number of items that we emitted in this poll loop if (emitted > 0) { request(emitted); } diff --git a/src/test/java/rx/internal/operators/OperatorRetryTest.java b/src/test/java/rx/internal/operators/OperatorRetryTest.java index a5aa9f1c31..146ee3c254 100644 --- a/src/test/java/rx/internal/operators/OperatorRetryTest.java +++ b/src/test/java/rx/internal/operators/OperatorRetryTest.java @@ -395,9 +395,13 @@ public static class FuncWithErrors implements Observable.OnSubscribe { public void call(final Subscriber o) { o.setProducer(new Producer() { final AtomicLong req = new AtomicLong(); + // 0 = not set, 1 = fast path, 2 = backpressure + final AtomicInteger path = new AtomicInteger(0); + volatile boolean done = false; + @Override public void request(long n) { - if (n == Long.MAX_VALUE) { + if (n == Long.MAX_VALUE && path.compareAndSet(0, 1)) { o.onNext("beginningEveryTime"); int i = count.getAndIncrement(); if (i < numFailures) { @@ -408,11 +412,12 @@ public void request(long n) { } return; } - if (n > 0 && req.getAndAdd(n) == 0) { + if (n > 0 && req.getAndAdd(n) == 0 && (path.get() == 2 || path.compareAndSet(0, 2)) && !done) { int i = count.getAndIncrement(); if (i < numFailures) { o.onNext("beginningEveryTime"); o.onError(new RuntimeException("forced failure: " + (i + 1))); + done = true; } else { do { if (i == numFailures) { @@ -421,6 +426,7 @@ public void request(long n) { if (i > numFailures) { o.onNext("onSuccessOnly"); o.onCompleted(); + done = true; break; } i = count.getAndIncrement(); @@ -682,91 +688,88 @@ public void testTimeoutWithRetry() { assertEquals("Start 6 threads, retry 5 then fail on 6", 6, so.efforts.get()); } - @Test(timeout = 15000) + @Test//(timeout = 15000) public void testRetryWithBackpressure() throws InterruptedException { - final int NUM_RETRIES = RxRingBuffer.SIZE * 2; - for (int i = 0; i < 400; i++) { - @SuppressWarnings("unchecked") - Observer observer = mock(Observer.class); - Observable origin = Observable.create(new FuncWithErrors(NUM_RETRIES)); - TestSubscriber ts = new TestSubscriber(observer); - origin.retry().observeOn(Schedulers.computation()).unsafeSubscribe(ts); - ts.awaitTerminalEvent(5, TimeUnit.SECONDS); - - InOrder inOrder = inOrder(observer); - // should have no errors - verify(observer, never()).onError(any(Throwable.class)); - // should show NUM_RETRIES attempts - inOrder.verify(observer, times(NUM_RETRIES + 1)).onNext("beginningEveryTime"); - // should have a single success - inOrder.verify(observer, times(1)).onNext("onSuccessOnly"); - // should have a single successful onCompleted - inOrder.verify(observer, times(1)).onCompleted(); - inOrder.verifyNoMoreInteractions(); + final int NUM_LOOPS = 1; + for (int j=0;j observer = mock(Observer.class); + Observable origin = Observable.create(new FuncWithErrors(NUM_RETRIES)); + TestSubscriber ts = new TestSubscriber(observer); + origin.retry().observeOn(Schedulers.computation()).unsafeSubscribe(ts); + ts.awaitTerminalEvent(5, TimeUnit.SECONDS); + + InOrder inOrder = inOrder(observer); + // should have no errors + verify(observer, never()).onError(any(Throwable.class)); + // should show NUM_RETRIES attempts + inOrder.verify(observer, times(NUM_RETRIES + 1)).onNext("beginningEveryTime"); + // should have a single success + inOrder.verify(observer, times(1)).onNext("onSuccessOnly"); + // should have a single successful onCompleted + inOrder.verify(observer, times(1)).onCompleted(); + inOrder.verifyNoMoreInteractions(); + } } } - @Test(timeout = 15000) + @Test//(timeout = 15000) public void testRetryWithBackpressureParallel() throws InterruptedException { + final int NUM_LOOPS = 1; final int NUM_RETRIES = RxRingBuffer.SIZE * 2; int ncpu = Runtime.getRuntime().availableProcessors(); ExecutorService exec = Executors.newFixedThreadPool(Math.max(ncpu / 2, 2)); - final AtomicInteger timeouts = new AtomicInteger(); - final Map> data = new ConcurrentHashMap>(); - final Map> exceptions = new ConcurrentHashMap>(); - final Map completions = new ConcurrentHashMap(); - - int m = 5000; - final CountDownLatch cdl = new CountDownLatch(m); - for (int i = 0; i < m; i++) { - final int j = i; - exec.execute(new Runnable() { - @Override - public void run() { - final AtomicInteger nexts = new AtomicInteger(); - try { - Observable origin = Observable.create(new FuncWithErrors(NUM_RETRIES)); - TestSubscriber ts = new TestSubscriber(); - origin.retry() - .observeOn(Schedulers.computation()).unsafeSubscribe(ts); - ts.awaitTerminalEvent(2500, TimeUnit.MILLISECONDS); - if (ts.getOnCompletedEvents().size() != 1) { - completions.put(j, ts.getOnCompletedEvents().size()); - } - if (ts.getOnErrorEvents().size() != 0) { - exceptions.put(j, ts.getOnErrorEvents()); - } - if (ts.getOnNextEvents().size() != NUM_RETRIES + 2) { - data.put(j, ts.getOnNextEvents()); + try { + for (int r = 0; r < NUM_LOOPS; r++) { + if (r % 10 == 0) { + System.out.println("testRetryWithBackpressureParallelLoop -> " + r); + } + + final AtomicInteger timeouts = new AtomicInteger(); + final Map> data = new ConcurrentHashMap>(); + + int m = 5000; + final CountDownLatch cdl = new CountDownLatch(m); + for (int i = 0; i < m; i++) { + final int j = i; + exec.execute(new Runnable() { + @Override + public void run() { + final AtomicInteger nexts = new AtomicInteger(); + try { + Observable origin = Observable.create(new FuncWithErrors(NUM_RETRIES)); + TestSubscriber ts = new TestSubscriber(); + origin.retry() + .observeOn(Schedulers.computation()).unsafeSubscribe(ts); + ts.awaitTerminalEvent(2500, TimeUnit.MILLISECONDS); + List onNextEvents = new ArrayList(ts.getOnNextEvents()); + if (onNextEvents.size() != NUM_RETRIES + 2) { + for (Throwable t : ts.getOnErrorEvents()) { + onNextEvents.add(t.toString()); + } + for (Object o : ts.getOnCompletedEvents()) { + onNextEvents.add("onCompleted"); + } + data.put(j, onNextEvents); + } + } catch (Throwable t) { + timeouts.incrementAndGet(); + System.out.println(j + " | " + cdl.getCount() + " !!! " + nexts.get()); + } + cdl.countDown(); } - } catch (Throwable t) { - timeouts.incrementAndGet(); - System.out.println(j + " | " + cdl.getCount() + " !!! " + nexts.get()); - } - cdl.countDown(); + }); } - }); - } - exec.shutdown(); - cdl.await(); - assertEquals(0, timeouts.get()); - if (data.size() > 0) { - System.out.println(allSequenceFrequency(data)); - } - if (exceptions.size() > 0) { - System.out.println(exceptions); - } - if (completions.size() > 0) { - System.out.println(completions); - } - if (data.size() > 0) { - fail("Data content mismatch: " + allSequenceFrequency(data)); - } - if (exceptions.size() > 0) { - fail("Exceptions received: " + exceptions); - } - if (completions.size() > 0) { - fail("Multiple completions received: " + completions); + cdl.await(); + assertEquals(0, timeouts.get()); + if (data.size() > 0) { + fail("Data content mismatch: " + allSequenceFrequency(data)); + } + } + } finally { + exec.shutdown(); } } static StringBuilder allSequenceFrequency(Map> its) { @@ -783,10 +786,10 @@ static StringBuilder allSequenceFrequency(Map> its) { } static StringBuilder sequenceFrequency(Iterable it) { StringBuilder sb = new StringBuilder(); - + Object prev = null; int cnt = 0; - + for (Object curr : it) { if (sb.length() > 0) { if (!curr.equals(prev)) { @@ -805,10 +808,13 @@ static StringBuilder sequenceFrequency(Iterable it) { } prev = curr; } - + if (cnt > 1) { + sb.append(" x ").append(cnt); + } + return sb; } - @Test(timeout = 3000) + @Test//(timeout = 3000) public void testIssue1900() throws InterruptedException { @SuppressWarnings("unchecked") Observer observer = mock(Observer.class); @@ -849,7 +855,7 @@ public Observable call(GroupedObservable t1) { inOrder.verify(observer, times(1)).onCompleted(); inOrder.verifyNoMoreInteractions(); } - @Test(timeout = 3000) + @Test//(timeout = 3000) public void testIssue1900SourceNotSupportingBackpressure() { @SuppressWarnings("unchecked") Observer observer = mock(Observer.class);