001package ca.uhn.fhir.interceptor.executor;
002
003/*-
004 * #%L
005 * HAPI FHIR - Core Library
006 * %%
007 * Copyright (C) 2014 - 2019 University Health Network
008 * %%
009 * Licensed under the Apache License, Version 2.0 (the "License");
010 * you may not use this file except in compliance with the License.
011 * You may obtain a copy of the License at
012 *
013 *      http://www.apache.org/licenses/LICENSE-2.0
014 *
015 * Unless required by applicable law or agreed to in writing, software
016 * distributed under the License is distributed on an "AS IS" BASIS,
017 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
018 * See the License for the specific language governing permissions and
019 * limitations under the License.
020 * #L%
021 */
022
023import ca.uhn.fhir.interceptor.api.*;
024import ca.uhn.fhir.rest.server.exceptions.InternalErrorException;
025import com.google.common.annotations.VisibleForTesting;
026import com.google.common.collect.ArrayListMultimap;
027import com.google.common.collect.ListMultimap;
028import com.google.common.collect.Multimaps;
029import org.apache.commons.lang3.Validate;
030import org.apache.commons.lang3.builder.ToStringBuilder;
031import org.apache.commons.lang3.builder.ToStringStyle;
032import org.apache.commons.lang3.reflect.MethodUtils;
033import org.slf4j.Logger;
034import org.slf4j.LoggerFactory;
035
036import javax.annotation.Nonnull;
037import javax.annotation.Nullable;
038import java.lang.annotation.Annotation;
039import java.lang.reflect.AnnotatedElement;
040import java.lang.reflect.InvocationTargetException;
041import java.lang.reflect.Method;
042import java.util.*;
043import java.util.concurrent.atomic.AtomicInteger;
044import java.util.stream.Collectors;
045
046public class InterceptorService implements IInterceptorService, IInterceptorBroadcaster {
047        private static final Logger ourLog = LoggerFactory.getLogger(InterceptorService.class);
048        private final List<Object> myInterceptors = new ArrayList<>();
049        private final ListMultimap<Pointcut, BaseInvoker> myGlobalInvokers = ArrayListMultimap.create();
050        private final ListMultimap<Pointcut, BaseInvoker> myAnonymousInvokers = ArrayListMultimap.create();
051        private final Object myRegistryMutex = new Object();
052        private final ThreadLocal<ListMultimap<Pointcut, BaseInvoker>> myThreadlocalInvokers = new ThreadLocal<>();
053        private String myName;
054        private boolean myThreadlocalInvokersEnabled = true;
055
056        /**
057         * Constructor which uses a default name of "default"
058         */
059        public InterceptorService() {
060                this("default");
061        }
062
063        /**
064         * Constructor
065         *
066         * @param theName The name for this registry (useful for troubleshooting)
067         */
068        public InterceptorService(String theName) {
069                super();
070                myName = theName;
071        }
072
073        /**
074         * Are threadlocal interceptors enabled on this registry (defaults to true)
075         */
076        public boolean isThreadlocalInvokersEnabled() {
077                return myThreadlocalInvokersEnabled;
078        }
079
080        /**
081         * Are threadlocal interceptors enabled on this registry (defaults to true)
082         */
083        public void setThreadlocalInvokersEnabled(boolean theThreadlocalInvokersEnabled) {
084                myThreadlocalInvokersEnabled = theThreadlocalInvokersEnabled;
085        }
086
087        @VisibleForTesting
088        List<Object> getGlobalInterceptorsForUnitTest() {
089                return myInterceptors;
090        }
091
092        @Override
093        @VisibleForTesting
094        public void registerAnonymousInterceptor(Pointcut thePointcut, IAnonymousInterceptor theInterceptor) {
095                registerAnonymousInterceptor(thePointcut, Interceptor.DEFAULT_ORDER, theInterceptor);
096        }
097
098        public void setName(String theName) {
099                myName = theName;
100        }
101
102        @Override
103        public void registerAnonymousInterceptor(Pointcut thePointcut, int theOrder, IAnonymousInterceptor theInterceptor) {
104                Validate.notNull(thePointcut);
105                Validate.notNull(theInterceptor);
106                synchronized (myRegistryMutex) {
107
108                        myAnonymousInvokers.put(thePointcut, new AnonymousLambdaInvoker(thePointcut, theInterceptor, theOrder));
109                        if (!isInterceptorAlreadyRegistered(theInterceptor)) {
110                                myInterceptors.add(theInterceptor);
111                        }
112                }
113        }
114
115        @Override
116        public List<Object> getAllRegisteredInterceptors() {
117                synchronized (myRegistryMutex) {
118                        List<Object> retVal = new ArrayList<>();
119                        retVal.addAll(myInterceptors);
120                        return Collections.unmodifiableList(retVal);
121                }
122        }
123
124        @Override
125        @VisibleForTesting
126        public void unregisterAllInterceptors() {
127                synchronized (myRegistryMutex) {
128                        myAnonymousInvokers.clear();
129                        myGlobalInvokers.clear();
130                        myInterceptors.clear();
131                }
132        }
133
134        @Override
135        public void unregisterInterceptors(@Nullable Collection<?> theInterceptors) {
136                if (theInterceptors != null) {
137                        theInterceptors.forEach(t -> unregisterInterceptor(t));
138                }
139        }
140
141        @Override
142        public void registerInterceptors(@Nullable Collection<?> theInterceptors) {
143                if (theInterceptors != null) {
144                        theInterceptors.forEach(t -> registerInterceptor(t));
145                }
146        }
147
148        @Override
149        public boolean registerThreadLocalInterceptor(Object theInterceptor) {
150                if (!myThreadlocalInvokersEnabled) {
151                        return false;
152                }
153                ListMultimap<Pointcut, BaseInvoker> invokers = getThreadLocalInvokerMultimap();
154                scanInterceptorAndAddToInvokerMultimap(theInterceptor, invokers);
155                return !invokers.isEmpty();
156
157        }
158
159        @Override
160        public void unregisterThreadLocalInterceptor(Object theInterceptor) {
161                if (myThreadlocalInvokersEnabled) {
162                        ListMultimap<Pointcut, BaseInvoker> invokers = getThreadLocalInvokerMultimap();
163                        invokers.entries().removeIf(t -> t.getValue().getInterceptor() == theInterceptor);
164                        if (invokers.isEmpty()) {
165                                myThreadlocalInvokers.remove();
166                        }
167                }
168        }
169
170        private ListMultimap<Pointcut, BaseInvoker> getThreadLocalInvokerMultimap() {
171                ListMultimap<Pointcut, BaseInvoker> invokers = myThreadlocalInvokers.get();
172                if (invokers == null) {
173                        invokers = Multimaps.synchronizedListMultimap(ArrayListMultimap.create());
174                        myThreadlocalInvokers.set(invokers);
175                }
176                return invokers;
177        }
178
179        @Override
180        public boolean registerInterceptor(Object theInterceptor) {
181                synchronized (myRegistryMutex) {
182
183                        if (isInterceptorAlreadyRegistered(theInterceptor)) {
184                                return false;
185                        }
186
187                        List<HookInvoker> addedInvokers = scanInterceptorAndAddToInvokerMultimap(theInterceptor, myGlobalInvokers);
188                        if (addedInvokers.isEmpty()) {
189                                ourLog.warn("Interceptor registered with no valid hooks - Type was: {}", theInterceptor.getClass().getName());
190                                return false;
191                        }
192
193                        // Add to the global list
194                        myInterceptors.add(theInterceptor);
195                        sortByOrderAnnotation(myInterceptors);
196
197                        return true;
198                }
199        }
200
201        private boolean isInterceptorAlreadyRegistered(Object theInterceptor) {
202                for (Object next : myInterceptors) {
203                        if (next == theInterceptor) {
204                                return true;
205                        }
206                }
207                return false;
208        }
209
210        @Override
211        public boolean unregisterInterceptor(Object theInterceptor) {
212                synchronized (myRegistryMutex) {
213                        boolean removed = myInterceptors.removeIf(t -> t == theInterceptor);
214                        removed |= myGlobalInvokers.entries().removeIf(t -> t.getValue().getInterceptor() == theInterceptor);
215                        removed |= myAnonymousInvokers.entries().removeIf(t -> t.getValue().getInterceptor() == theInterceptor);
216                        return removed;
217                }
218        }
219
220        private void sortByOrderAnnotation(List<Object> theObjects) {
221                IdentityHashMap<Object, Integer> interceptorToOrder = new IdentityHashMap<>();
222                for (Object next : theObjects) {
223                        Interceptor orderAnnotation = next.getClass().getAnnotation(Interceptor.class);
224                        int order = orderAnnotation != null ? orderAnnotation.order() : 0;
225                        interceptorToOrder.put(next, order);
226                }
227
228                theObjects.sort((a, b) -> {
229                        Integer orderA = interceptorToOrder.get(a);
230                        Integer orderB = interceptorToOrder.get(b);
231                        return orderA - orderB;
232                });
233        }
234
235        @Override
236        public Object callHooksAndReturnObject(Pointcut thePointcut, HookParams theParams) {
237                assert haveAppropriateParams(thePointcut, theParams);
238                assert thePointcut.getReturnType() != void.class;
239
240                return doCallHooks(thePointcut, theParams, null);
241        }
242
243        @Override
244        public boolean hasHooks(Pointcut thePointcut) {
245                return myGlobalInvokers.containsKey(thePointcut)
246                        || myAnonymousInvokers.containsKey(thePointcut)
247                        || hasThreadLocalHooks(thePointcut);
248        }
249
250        private boolean hasThreadLocalHooks(Pointcut thePointcut) {
251                ListMultimap<Pointcut, BaseInvoker> hooks = myThreadlocalInvokersEnabled ? myThreadlocalInvokers.get() : null;
252                return hooks != null && hooks.containsKey(thePointcut);
253        }
254
255        @Override
256        public boolean callHooks(Pointcut thePointcut, HookParams theParams) {
257                assert haveAppropriateParams(thePointcut, theParams);
258                assert thePointcut.getReturnType() == void.class || thePointcut.getReturnType() == boolean.class;
259
260                Object retValObj = doCallHooks(thePointcut, theParams, true);
261                return (Boolean) retValObj;
262        }
263
264        private Object doCallHooks(Pointcut thePointcut, HookParams theParams, Object theRetVal) {
265                List<BaseInvoker> invokers = getInvokersForPointcut(thePointcut);
266
267                /*
268                 * Call each hook in order
269                 */
270                for (BaseInvoker nextInvoker : invokers) {
271                        Object nextOutcome = nextInvoker.invoke(theParams);
272                        Class<?> pointcutReturnType = thePointcut.getReturnType();
273                        if (pointcutReturnType.equals(boolean.class)) {
274                                Boolean nextOutcomeAsBoolean = (Boolean) nextOutcome;
275                                if (Boolean.FALSE.equals(nextOutcomeAsBoolean)) {
276                                        ourLog.trace("callHooks({}) for invoker({}) returned false", thePointcut, nextInvoker);
277                                        theRetVal = false;
278                                        break;
279                                }
280                        } else if (pointcutReturnType.equals(void.class) == false) {
281                                if (nextOutcome != null) {
282                                        theRetVal = nextOutcome;
283                                        break;
284                                }
285                        }
286                }
287
288                return theRetVal;
289        }
290
291        @VisibleForTesting
292        List<Object> getInterceptorsWithInvokersForPointcut(Pointcut thePointcut) {
293                return getInvokersForPointcut(thePointcut)
294                        .stream()
295                        .map(BaseInvoker::getInterceptor)
296                        .collect(Collectors.toList());
297        }
298
299        /**
300         * Returns an ordered list of invokers for the given pointcut. Note that
301         * a new and stable list is returned to.. do whatever you want with it.
302         */
303        private List<BaseInvoker> getInvokersForPointcut(Pointcut thePointcut) {
304                List<BaseInvoker> invokers;
305
306                synchronized (myRegistryMutex) {
307                        List<BaseInvoker> globalInvokers = myGlobalInvokers.get(thePointcut);
308                        List<BaseInvoker> anonymousInvokers = myAnonymousInvokers.get(thePointcut);
309                        List<BaseInvoker> threadLocalInvokers = null;
310                        if (myThreadlocalInvokersEnabled) {
311                                ListMultimap<Pointcut, BaseInvoker> pointcutToInvokers = myThreadlocalInvokers.get();
312                                if (pointcutToInvokers != null) {
313                                        threadLocalInvokers = pointcutToInvokers.get(thePointcut);
314                                }
315                        }
316                        invokers = union(globalInvokers, anonymousInvokers, threadLocalInvokers);
317                }
318
319                return invokers;
320        }
321
322        /**
323         * First argument must be the global invoker list!!
324         */
325        @SafeVarargs
326        private final List<BaseInvoker> union(List<BaseInvoker>... theInvokersLists) {
327                List<BaseInvoker> haveOne = null;
328                boolean haveMultiple = false;
329                for (List<BaseInvoker> nextInvokerList : theInvokersLists) {
330                        if (nextInvokerList == null || nextInvokerList.isEmpty()) {
331                                continue;
332                        }
333
334                        if (haveOne == null) {
335                                haveOne = nextInvokerList;
336                        } else {
337                                haveMultiple = true;
338                        }
339                }
340
341                if (haveOne == null) {
342                        return Collections.emptyList();
343                }
344
345                List<BaseInvoker> retVal;
346
347                if (haveMultiple == false) {
348
349                        // The global list doesn't need to be sorted every time since it's sorted on
350                        // insertion each time. Doing so is a waste of cycles..
351                        if (haveOne == theInvokersLists[0]) {
352                                retVal = haveOne;
353                        } else {
354                                retVal = new ArrayList<>(haveOne);
355                                retVal.sort(Comparator.naturalOrder());
356                        }
357
358                } else {
359
360                        retVal = Arrays
361                                .stream(theInvokersLists)
362                                .filter(t -> t != null)
363                                .flatMap(t -> t.stream())
364                                .sorted()
365                                .collect(Collectors.toList());
366
367                }
368
369                return retVal;
370        }
371
372        /**
373         * Only call this when assertions are enabled, it's expensive
374         */
375        boolean haveAppropriateParams(Pointcut thePointcut, HookParams theParams) {
376                Validate.isTrue(theParams.getParamsForType().values().size() == thePointcut.getParameterTypes().size(), "Wrong number of params for pointcut %s - Wanted %s but found %s", thePointcut.name(), toErrorString(thePointcut.getParameterTypes()), theParams.getParamsForType().values().stream().map(t -> t != null ? t.getClass().getSimpleName() : "null").sorted().collect(Collectors.toList()));
377
378                List<String> wantedTypes = new ArrayList<>(thePointcut.getParameterTypes());
379
380                ListMultimap<Class<?>, Object> givenTypes = theParams.getParamsForType();
381                for (Class<?> nextTypeClass : givenTypes.keySet()) {
382                        String nextTypeName = nextTypeClass.getName();
383                        for (Object nextParamValue : givenTypes.get(nextTypeClass)) {
384                                Validate.isTrue(nextParamValue == null || nextTypeClass.isAssignableFrom(nextParamValue.getClass()), "Invalid params for pointcut %s - %s is not of type %s", thePointcut.name(), nextParamValue != null ? nextParamValue.getClass() : "null", nextTypeClass);
385                                Validate.isTrue(wantedTypes.remove(nextTypeName), "Invalid params for pointcut %s - Wanted %s but found %s", thePointcut.name(), toErrorString(thePointcut.getParameterTypes()), nextTypeName);
386                        }
387                }
388
389                return true;
390        }
391
392        private class AnonymousLambdaInvoker extends BaseInvoker {
393                private final IAnonymousInterceptor myHook;
394                private final Pointcut myPointcut;
395
396                public AnonymousLambdaInvoker(Pointcut thePointcut, IAnonymousInterceptor theHook, int theOrder) {
397                        super(theHook, theOrder);
398                        myHook = theHook;
399                        myPointcut = thePointcut;
400                }
401
402                @Override
403                Object invoke(HookParams theParams) {
404                        myHook.invoke(myPointcut, theParams);
405                        return true;
406                }
407        }
408
409        private abstract static class BaseInvoker implements Comparable<BaseInvoker> {
410
411                private final int myOrder;
412                private final Object myInterceptor;
413
414                BaseInvoker(Object theInterceptor, int theOrder) {
415                        myInterceptor = theInterceptor;
416                        myOrder = theOrder;
417                }
418
419                public Object getInterceptor() {
420                        return myInterceptor;
421                }
422
423                abstract Object invoke(HookParams theParams);
424
425                @Override
426                public int compareTo(BaseInvoker theInvoker) {
427                        return myOrder - theInvoker.myOrder;
428                }
429        }
430
431        private static class HookInvoker extends BaseInvoker {
432
433                private final Method myMethod;
434                private final Class<?>[] myParameterTypes;
435                private final int[] myParameterIndexes;
436                private final Pointcut myPointcut;
437
438                /**
439                 * Constructor
440                 */
441                private HookInvoker(Hook theHook, @Nonnull Object theInterceptor, @Nonnull Method theHookMethod, int theOrder) {
442                        super(theInterceptor, theOrder);
443                        myPointcut = theHook.value();
444                        myParameterTypes = theHookMethod.getParameterTypes();
445                        myMethod = theHookMethod;
446
447                        Class<?> returnType = theHookMethod.getReturnType();
448                        if (myPointcut.getReturnType().equals(boolean.class)) {
449                                Validate.isTrue(boolean.class.equals(returnType) || void.class.equals(returnType), "Method does not return boolean or void: %s", theHookMethod);
450                        } else if (myPointcut.getReturnType().equals(void.class)) {
451                                Validate.isTrue(void.class.equals(returnType), "Method does not return void: %s", theHookMethod);
452                        } else {
453                                Validate.isTrue(myPointcut.getReturnType().isAssignableFrom(returnType) || void.class.equals(returnType), "Method does not return %s or void: %s", myPointcut.getReturnType(), theHookMethod);
454                        }
455
456                        myParameterIndexes = new int[myParameterTypes.length];
457                        Map<Class<?>, AtomicInteger> typeToCount = new HashMap<>();
458                        for (int i = 0; i < myParameterTypes.length; i++) {
459                                AtomicInteger counter = typeToCount.computeIfAbsent(myParameterTypes[i], t -> new AtomicInteger(0));
460                                myParameterIndexes[i] = counter.getAndIncrement();
461                        }
462
463                        myMethod.setAccessible(true);
464                }
465
466                @Override
467                public String toString() {
468                        return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
469                                .append("method", myMethod)
470                                .toString();
471                }
472
473                public Pointcut getPointcut() {
474                        return myPointcut;
475                }
476
477                /**
478                 * @return Returns true/false if the hook method returns a boolean, returns true otherwise
479                 */
480                @Override
481                Object invoke(HookParams theParams) {
482
483                        Object[] args = new Object[myParameterTypes.length];
484                        for (int i = 0; i < myParameterTypes.length; i++) {
485                                Class<?> nextParamType = myParameterTypes[i];
486                                if (nextParamType.equals(Pointcut.class)) {
487                                        args[i] = myPointcut;
488                                } else {
489                                        int nextParamIndex = myParameterIndexes[i];
490                                        Object nextParamValue = theParams.get(nextParamType, nextParamIndex);
491                                        args[i] = nextParamValue;
492                                }
493                        }
494
495                        // Invoke the method
496                        try {
497                                return myMethod.invoke(getInterceptor(), args);
498                        } catch (InvocationTargetException e) {
499                                Throwable targetException = e.getTargetException();
500                                if (myPointcut.isShouldLogAndSwallowException(targetException)) {
501                                        ourLog.error("Exception thrown by interceptor: " + targetException.toString(), targetException);
502                                        return null;
503                                }
504
505                                if (targetException instanceof RuntimeException) {
506                                        throw ((RuntimeException) targetException);
507                                } else {
508                                        throw new InternalErrorException("Failure invoking interceptor for pointcut(s) " + getPointcut(), targetException);
509                                }
510                        } catch (Exception e) {
511                                throw new InternalErrorException(e);
512                        }
513
514                }
515
516        }
517
518        private static List<HookInvoker> scanInterceptorAndAddToInvokerMultimap(Object theInterceptor, ListMultimap<Pointcut, BaseInvoker> theInvokers) {
519                Class<?> interceptorClass = theInterceptor.getClass();
520                int typeOrder = determineOrder(interceptorClass);
521
522                List<HookInvoker> addedInvokers = scanInterceptorForHookMethods(theInterceptor, typeOrder);
523
524                // Invoke the REGISTERED pointcut for any added hooks
525                addedInvokers.stream()
526                        .filter(t -> Pointcut.INTERCEPTOR_REGISTERED.equals(t.getPointcut()))
527                        .forEach(t -> t.invoke(new HookParams()));
528
529                // Register the interceptor and its various hooks
530                for (HookInvoker nextAddedHook : addedInvokers) {
531                        Pointcut nextPointcut = nextAddedHook.getPointcut();
532                        if (nextPointcut.equals(Pointcut.INTERCEPTOR_REGISTERED)) {
533                                continue;
534                        }
535                        theInvokers.put(nextPointcut, nextAddedHook);
536                }
537
538                // Make sure we're always sorted according to the order declared in
539                // @Order
540                for (Pointcut nextPointcut : theInvokers.keys()) {
541                        List<BaseInvoker> nextInvokerList = theInvokers.get(nextPointcut);
542                        nextInvokerList.sort(Comparator.naturalOrder());
543                }
544
545                return addedInvokers;
546        }
547
548        /**
549         * @return Returns a list of any added invokers
550         */
551        private static List<HookInvoker> scanInterceptorForHookMethods(Object theInterceptor, int theTypeOrder) {
552                ArrayList<HookInvoker> retVal = new ArrayList<>();
553                for (Method nextMethod : theInterceptor.getClass().getMethods()) {
554                        Optional<Hook> hook = findAnnotation(nextMethod, Hook.class);
555
556                        if (hook.isPresent()) {
557                                int methodOrder = theTypeOrder;
558                                int methodOrderAnnotation = hook.get().order();
559                                if (methodOrderAnnotation != Interceptor.DEFAULT_ORDER) {
560                                        methodOrder = methodOrderAnnotation;
561                                }
562
563                                retVal.add(new HookInvoker(hook.get(), theInterceptor, nextMethod, methodOrder));
564                        }
565                }
566
567                return retVal;
568        }
569
570        private static <T extends Annotation> Optional<T> findAnnotation(AnnotatedElement theObject, Class<T> theHookClass) {
571                T annotation;
572                if (theObject instanceof Method) {
573                        annotation = MethodUtils.getAnnotation((Method) theObject, theHookClass, true, true);
574                } else {
575                        annotation = theObject.getAnnotation(theHookClass);
576                }
577                return Optional.ofNullable(annotation);
578        }
579
580        private static int determineOrder(Class<?> theInterceptorClass) {
581                int typeOrder = Interceptor.DEFAULT_ORDER;
582                Optional<Interceptor> typeOrderAnnotation = findAnnotation(theInterceptorClass, Interceptor.class);
583                if (typeOrderAnnotation.isPresent()) {
584                        typeOrder = typeOrderAnnotation.get().order();
585                }
586                return typeOrder;
587        }
588
589        private static String toErrorString(List<String> theParameterTypes) {
590                return theParameterTypes
591                        .stream()
592                        .sorted()
593                        .collect(Collectors.joining(","));
594        }
595
596}