001package ca.uhn.fhir.interceptor.executor; 002 003/*- 004 * #%L 005 * HAPI FHIR - Core Library 006 * %% 007 * Copyright (C) 2014 - 2019 University Health Network 008 * %% 009 * Licensed under the Apache License, Version 2.0 (the "License"); 010 * you may not use this file except in compliance with the License. 011 * You may obtain a copy of the License at 012 * 013 * http://www.apache.org/licenses/LICENSE-2.0 014 * 015 * Unless required by applicable law or agreed to in writing, software 016 * distributed under the License is distributed on an "AS IS" BASIS, 017 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 018 * See the License for the specific language governing permissions and 019 * limitations under the License. 020 * #L% 021 */ 022 023import ca.uhn.fhir.interceptor.api.*; 024import ca.uhn.fhir.rest.server.exceptions.InternalErrorException; 025import com.google.common.annotations.VisibleForTesting; 026import com.google.common.collect.ArrayListMultimap; 027import com.google.common.collect.ListMultimap; 028import com.google.common.collect.Multimaps; 029import org.apache.commons.lang3.Validate; 030import org.apache.commons.lang3.builder.ToStringBuilder; 031import org.apache.commons.lang3.builder.ToStringStyle; 032import org.apache.commons.lang3.reflect.MethodUtils; 033import org.slf4j.Logger; 034import org.slf4j.LoggerFactory; 035 036import javax.annotation.Nonnull; 037import javax.annotation.Nullable; 038import java.lang.annotation.Annotation; 039import java.lang.reflect.AnnotatedElement; 040import java.lang.reflect.InvocationTargetException; 041import java.lang.reflect.Method; 042import java.util.*; 043import java.util.concurrent.atomic.AtomicInteger; 044import java.util.stream.Collectors; 045 046public class InterceptorService implements IInterceptorService, IInterceptorBroadcaster { 047 private static final Logger ourLog = LoggerFactory.getLogger(InterceptorService.class); 048 private final List<Object> myInterceptors = new ArrayList<>(); 049 private final ListMultimap<Pointcut, BaseInvoker> myGlobalInvokers = ArrayListMultimap.create(); 050 private final ListMultimap<Pointcut, BaseInvoker> myAnonymousInvokers = ArrayListMultimap.create(); 051 private final Object myRegistryMutex = new Object(); 052 private final ThreadLocal<ListMultimap<Pointcut, BaseInvoker>> myThreadlocalInvokers = new ThreadLocal<>(); 053 private String myName; 054 private boolean myThreadlocalInvokersEnabled = true; 055 056 /** 057 * Constructor which uses a default name of "default" 058 */ 059 public InterceptorService() { 060 this("default"); 061 } 062 063 /** 064 * Constructor 065 * 066 * @param theName The name for this registry (useful for troubleshooting) 067 */ 068 public InterceptorService(String theName) { 069 super(); 070 myName = theName; 071 } 072 073 /** 074 * Are threadlocal interceptors enabled on this registry (defaults to true) 075 */ 076 public boolean isThreadlocalInvokersEnabled() { 077 return myThreadlocalInvokersEnabled; 078 } 079 080 /** 081 * Are threadlocal interceptors enabled on this registry (defaults to true) 082 */ 083 public void setThreadlocalInvokersEnabled(boolean theThreadlocalInvokersEnabled) { 084 myThreadlocalInvokersEnabled = theThreadlocalInvokersEnabled; 085 } 086 087 @VisibleForTesting 088 List<Object> getGlobalInterceptorsForUnitTest() { 089 return myInterceptors; 090 } 091 092 @Override 093 @VisibleForTesting 094 public void registerAnonymousInterceptor(Pointcut thePointcut, IAnonymousInterceptor theInterceptor) { 095 registerAnonymousInterceptor(thePointcut, Interceptor.DEFAULT_ORDER, theInterceptor); 096 } 097 098 public void setName(String theName) { 099 myName = theName; 100 } 101 102 @Override 103 public void registerAnonymousInterceptor(Pointcut thePointcut, int theOrder, IAnonymousInterceptor theInterceptor) { 104 Validate.notNull(thePointcut); 105 Validate.notNull(theInterceptor); 106 synchronized (myRegistryMutex) { 107 108 myAnonymousInvokers.put(thePointcut, new AnonymousLambdaInvoker(thePointcut, theInterceptor, theOrder)); 109 if (!isInterceptorAlreadyRegistered(theInterceptor)) { 110 myInterceptors.add(theInterceptor); 111 } 112 } 113 } 114 115 @Override 116 public List<Object> getAllRegisteredInterceptors() { 117 synchronized (myRegistryMutex) { 118 List<Object> retVal = new ArrayList<>(); 119 retVal.addAll(myInterceptors); 120 return Collections.unmodifiableList(retVal); 121 } 122 } 123 124 @Override 125 @VisibleForTesting 126 public void unregisterAllInterceptors() { 127 synchronized (myRegistryMutex) { 128 myAnonymousInvokers.clear(); 129 myGlobalInvokers.clear(); 130 myInterceptors.clear(); 131 } 132 } 133 134 @Override 135 public void unregisterInterceptors(@Nullable Collection<?> theInterceptors) { 136 if (theInterceptors != null) { 137 theInterceptors.forEach(t -> unregisterInterceptor(t)); 138 } 139 } 140 141 @Override 142 public void registerInterceptors(@Nullable Collection<?> theInterceptors) { 143 if (theInterceptors != null) { 144 theInterceptors.forEach(t -> registerInterceptor(t)); 145 } 146 } 147 148 @Override 149 public boolean registerThreadLocalInterceptor(Object theInterceptor) { 150 if (!myThreadlocalInvokersEnabled) { 151 return false; 152 } 153 ListMultimap<Pointcut, BaseInvoker> invokers = getThreadLocalInvokerMultimap(); 154 scanInterceptorAndAddToInvokerMultimap(theInterceptor, invokers); 155 return !invokers.isEmpty(); 156 157 } 158 159 @Override 160 public void unregisterThreadLocalInterceptor(Object theInterceptor) { 161 if (myThreadlocalInvokersEnabled) { 162 ListMultimap<Pointcut, BaseInvoker> invokers = getThreadLocalInvokerMultimap(); 163 invokers.entries().removeIf(t -> t.getValue().getInterceptor() == theInterceptor); 164 if (invokers.isEmpty()) { 165 myThreadlocalInvokers.remove(); 166 } 167 } 168 } 169 170 private ListMultimap<Pointcut, BaseInvoker> getThreadLocalInvokerMultimap() { 171 ListMultimap<Pointcut, BaseInvoker> invokers = myThreadlocalInvokers.get(); 172 if (invokers == null) { 173 invokers = Multimaps.synchronizedListMultimap(ArrayListMultimap.create()); 174 myThreadlocalInvokers.set(invokers); 175 } 176 return invokers; 177 } 178 179 @Override 180 public boolean registerInterceptor(Object theInterceptor) { 181 synchronized (myRegistryMutex) { 182 183 if (isInterceptorAlreadyRegistered(theInterceptor)) { 184 return false; 185 } 186 187 List<HookInvoker> addedInvokers = scanInterceptorAndAddToInvokerMultimap(theInterceptor, myGlobalInvokers); 188 if (addedInvokers.isEmpty()) { 189 ourLog.warn("Interceptor registered with no valid hooks - Type was: {}", theInterceptor.getClass().getName()); 190 return false; 191 } 192 193 // Add to the global list 194 myInterceptors.add(theInterceptor); 195 sortByOrderAnnotation(myInterceptors); 196 197 return true; 198 } 199 } 200 201 private boolean isInterceptorAlreadyRegistered(Object theInterceptor) { 202 for (Object next : myInterceptors) { 203 if (next == theInterceptor) { 204 return true; 205 } 206 } 207 return false; 208 } 209 210 @Override 211 public boolean unregisterInterceptor(Object theInterceptor) { 212 synchronized (myRegistryMutex) { 213 boolean removed = myInterceptors.removeIf(t -> t == theInterceptor); 214 removed |= myGlobalInvokers.entries().removeIf(t -> t.getValue().getInterceptor() == theInterceptor); 215 removed |= myAnonymousInvokers.entries().removeIf(t -> t.getValue().getInterceptor() == theInterceptor); 216 return removed; 217 } 218 } 219 220 private void sortByOrderAnnotation(List<Object> theObjects) { 221 IdentityHashMap<Object, Integer> interceptorToOrder = new IdentityHashMap<>(); 222 for (Object next : theObjects) { 223 Interceptor orderAnnotation = next.getClass().getAnnotation(Interceptor.class); 224 int order = orderAnnotation != null ? orderAnnotation.order() : 0; 225 interceptorToOrder.put(next, order); 226 } 227 228 theObjects.sort((a, b) -> { 229 Integer orderA = interceptorToOrder.get(a); 230 Integer orderB = interceptorToOrder.get(b); 231 return orderA - orderB; 232 }); 233 } 234 235 @Override 236 public Object callHooksAndReturnObject(Pointcut thePointcut, HookParams theParams) { 237 assert haveAppropriateParams(thePointcut, theParams); 238 assert thePointcut.getReturnType() != void.class; 239 240 return doCallHooks(thePointcut, theParams, null); 241 } 242 243 @Override 244 public boolean hasHooks(Pointcut thePointcut) { 245 return myGlobalInvokers.containsKey(thePointcut) 246 || myAnonymousInvokers.containsKey(thePointcut) 247 || hasThreadLocalHooks(thePointcut); 248 } 249 250 private boolean hasThreadLocalHooks(Pointcut thePointcut) { 251 ListMultimap<Pointcut, BaseInvoker> hooks = myThreadlocalInvokersEnabled ? myThreadlocalInvokers.get() : null; 252 return hooks != null && hooks.containsKey(thePointcut); 253 } 254 255 @Override 256 public boolean callHooks(Pointcut thePointcut, HookParams theParams) { 257 assert haveAppropriateParams(thePointcut, theParams); 258 assert thePointcut.getReturnType() == void.class || thePointcut.getReturnType() == boolean.class; 259 260 Object retValObj = doCallHooks(thePointcut, theParams, true); 261 return (Boolean) retValObj; 262 } 263 264 private Object doCallHooks(Pointcut thePointcut, HookParams theParams, Object theRetVal) { 265 List<BaseInvoker> invokers = getInvokersForPointcut(thePointcut); 266 267 /* 268 * Call each hook in order 269 */ 270 for (BaseInvoker nextInvoker : invokers) { 271 Object nextOutcome = nextInvoker.invoke(theParams); 272 Class<?> pointcutReturnType = thePointcut.getReturnType(); 273 if (pointcutReturnType.equals(boolean.class)) { 274 Boolean nextOutcomeAsBoolean = (Boolean) nextOutcome; 275 if (Boolean.FALSE.equals(nextOutcomeAsBoolean)) { 276 ourLog.trace("callHooks({}) for invoker({}) returned false", thePointcut, nextInvoker); 277 theRetVal = false; 278 break; 279 } 280 } else if (pointcutReturnType.equals(void.class) == false) { 281 if (nextOutcome != null) { 282 theRetVal = nextOutcome; 283 break; 284 } 285 } 286 } 287 288 return theRetVal; 289 } 290 291 @VisibleForTesting 292 List<Object> getInterceptorsWithInvokersForPointcut(Pointcut thePointcut) { 293 return getInvokersForPointcut(thePointcut) 294 .stream() 295 .map(BaseInvoker::getInterceptor) 296 .collect(Collectors.toList()); 297 } 298 299 /** 300 * Returns an ordered list of invokers for the given pointcut. Note that 301 * a new and stable list is returned to.. do whatever you want with it. 302 */ 303 private List<BaseInvoker> getInvokersForPointcut(Pointcut thePointcut) { 304 List<BaseInvoker> invokers; 305 306 synchronized (myRegistryMutex) { 307 List<BaseInvoker> globalInvokers = myGlobalInvokers.get(thePointcut); 308 List<BaseInvoker> anonymousInvokers = myAnonymousInvokers.get(thePointcut); 309 List<BaseInvoker> threadLocalInvokers = null; 310 if (myThreadlocalInvokersEnabled) { 311 ListMultimap<Pointcut, BaseInvoker> pointcutToInvokers = myThreadlocalInvokers.get(); 312 if (pointcutToInvokers != null) { 313 threadLocalInvokers = pointcutToInvokers.get(thePointcut); 314 } 315 } 316 invokers = union(globalInvokers, anonymousInvokers, threadLocalInvokers); 317 } 318 319 return invokers; 320 } 321 322 /** 323 * First argument must be the global invoker list!! 324 */ 325 @SafeVarargs 326 private final List<BaseInvoker> union(List<BaseInvoker>... theInvokersLists) { 327 List<BaseInvoker> haveOne = null; 328 boolean haveMultiple = false; 329 for (List<BaseInvoker> nextInvokerList : theInvokersLists) { 330 if (nextInvokerList == null || nextInvokerList.isEmpty()) { 331 continue; 332 } 333 334 if (haveOne == null) { 335 haveOne = nextInvokerList; 336 } else { 337 haveMultiple = true; 338 } 339 } 340 341 if (haveOne == null) { 342 return Collections.emptyList(); 343 } 344 345 List<BaseInvoker> retVal; 346 347 if (haveMultiple == false) { 348 349 // The global list doesn't need to be sorted every time since it's sorted on 350 // insertion each time. Doing so is a waste of cycles.. 351 if (haveOne == theInvokersLists[0]) { 352 retVal = haveOne; 353 } else { 354 retVal = new ArrayList<>(haveOne); 355 retVal.sort(Comparator.naturalOrder()); 356 } 357 358 } else { 359 360 retVal = Arrays 361 .stream(theInvokersLists) 362 .filter(t -> t != null) 363 .flatMap(t -> t.stream()) 364 .sorted() 365 .collect(Collectors.toList()); 366 367 } 368 369 return retVal; 370 } 371 372 /** 373 * Only call this when assertions are enabled, it's expensive 374 */ 375 boolean haveAppropriateParams(Pointcut thePointcut, HookParams theParams) { 376 Validate.isTrue(theParams.getParamsForType().values().size() == thePointcut.getParameterTypes().size(), "Wrong number of params for pointcut %s - Wanted %s but found %s", thePointcut.name(), toErrorString(thePointcut.getParameterTypes()), theParams.getParamsForType().values().stream().map(t -> t != null ? t.getClass().getSimpleName() : "null").sorted().collect(Collectors.toList())); 377 378 List<String> wantedTypes = new ArrayList<>(thePointcut.getParameterTypes()); 379 380 ListMultimap<Class<?>, Object> givenTypes = theParams.getParamsForType(); 381 for (Class<?> nextTypeClass : givenTypes.keySet()) { 382 String nextTypeName = nextTypeClass.getName(); 383 for (Object nextParamValue : givenTypes.get(nextTypeClass)) { 384 Validate.isTrue(nextParamValue == null || nextTypeClass.isAssignableFrom(nextParamValue.getClass()), "Invalid params for pointcut %s - %s is not of type %s", thePointcut.name(), nextParamValue != null ? nextParamValue.getClass() : "null", nextTypeClass); 385 Validate.isTrue(wantedTypes.remove(nextTypeName), "Invalid params for pointcut %s - Wanted %s but found %s", thePointcut.name(), toErrorString(thePointcut.getParameterTypes()), nextTypeName); 386 } 387 } 388 389 return true; 390 } 391 392 private class AnonymousLambdaInvoker extends BaseInvoker { 393 private final IAnonymousInterceptor myHook; 394 private final Pointcut myPointcut; 395 396 public AnonymousLambdaInvoker(Pointcut thePointcut, IAnonymousInterceptor theHook, int theOrder) { 397 super(theHook, theOrder); 398 myHook = theHook; 399 myPointcut = thePointcut; 400 } 401 402 @Override 403 Object invoke(HookParams theParams) { 404 myHook.invoke(myPointcut, theParams); 405 return true; 406 } 407 } 408 409 private abstract static class BaseInvoker implements Comparable<BaseInvoker> { 410 411 private final int myOrder; 412 private final Object myInterceptor; 413 414 BaseInvoker(Object theInterceptor, int theOrder) { 415 myInterceptor = theInterceptor; 416 myOrder = theOrder; 417 } 418 419 public Object getInterceptor() { 420 return myInterceptor; 421 } 422 423 abstract Object invoke(HookParams theParams); 424 425 @Override 426 public int compareTo(BaseInvoker theInvoker) { 427 return myOrder - theInvoker.myOrder; 428 } 429 } 430 431 private static class HookInvoker extends BaseInvoker { 432 433 private final Method myMethod; 434 private final Class<?>[] myParameterTypes; 435 private final int[] myParameterIndexes; 436 private final Pointcut myPointcut; 437 438 /** 439 * Constructor 440 */ 441 private HookInvoker(Hook theHook, @Nonnull Object theInterceptor, @Nonnull Method theHookMethod, int theOrder) { 442 super(theInterceptor, theOrder); 443 myPointcut = theHook.value(); 444 myParameterTypes = theHookMethod.getParameterTypes(); 445 myMethod = theHookMethod; 446 447 Class<?> returnType = theHookMethod.getReturnType(); 448 if (myPointcut.getReturnType().equals(boolean.class)) { 449 Validate.isTrue(boolean.class.equals(returnType) || void.class.equals(returnType), "Method does not return boolean or void: %s", theHookMethod); 450 } else if (myPointcut.getReturnType().equals(void.class)) { 451 Validate.isTrue(void.class.equals(returnType), "Method does not return void: %s", theHookMethod); 452 } else { 453 Validate.isTrue(myPointcut.getReturnType().isAssignableFrom(returnType) || void.class.equals(returnType), "Method does not return %s or void: %s", myPointcut.getReturnType(), theHookMethod); 454 } 455 456 myParameterIndexes = new int[myParameterTypes.length]; 457 Map<Class<?>, AtomicInteger> typeToCount = new HashMap<>(); 458 for (int i = 0; i < myParameterTypes.length; i++) { 459 AtomicInteger counter = typeToCount.computeIfAbsent(myParameterTypes[i], t -> new AtomicInteger(0)); 460 myParameterIndexes[i] = counter.getAndIncrement(); 461 } 462 463 myMethod.setAccessible(true); 464 } 465 466 @Override 467 public String toString() { 468 return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) 469 .append("method", myMethod) 470 .toString(); 471 } 472 473 public Pointcut getPointcut() { 474 return myPointcut; 475 } 476 477 /** 478 * @return Returns true/false if the hook method returns a boolean, returns true otherwise 479 */ 480 @Override 481 Object invoke(HookParams theParams) { 482 483 Object[] args = new Object[myParameterTypes.length]; 484 for (int i = 0; i < myParameterTypes.length; i++) { 485 Class<?> nextParamType = myParameterTypes[i]; 486 if (nextParamType.equals(Pointcut.class)) { 487 args[i] = myPointcut; 488 } else { 489 int nextParamIndex = myParameterIndexes[i]; 490 Object nextParamValue = theParams.get(nextParamType, nextParamIndex); 491 args[i] = nextParamValue; 492 } 493 } 494 495 // Invoke the method 496 try { 497 return myMethod.invoke(getInterceptor(), args); 498 } catch (InvocationTargetException e) { 499 Throwable targetException = e.getTargetException(); 500 if (myPointcut.isShouldLogAndSwallowException(targetException)) { 501 ourLog.error("Exception thrown by interceptor: " + targetException.toString(), targetException); 502 return null; 503 } 504 505 if (targetException instanceof RuntimeException) { 506 throw ((RuntimeException) targetException); 507 } else { 508 throw new InternalErrorException("Failure invoking interceptor for pointcut(s) " + getPointcut(), targetException); 509 } 510 } catch (Exception e) { 511 throw new InternalErrorException(e); 512 } 513 514 } 515 516 } 517 518 private static List<HookInvoker> scanInterceptorAndAddToInvokerMultimap(Object theInterceptor, ListMultimap<Pointcut, BaseInvoker> theInvokers) { 519 Class<?> interceptorClass = theInterceptor.getClass(); 520 int typeOrder = determineOrder(interceptorClass); 521 522 List<HookInvoker> addedInvokers = scanInterceptorForHookMethods(theInterceptor, typeOrder); 523 524 // Invoke the REGISTERED pointcut for any added hooks 525 addedInvokers.stream() 526 .filter(t -> Pointcut.INTERCEPTOR_REGISTERED.equals(t.getPointcut())) 527 .forEach(t -> t.invoke(new HookParams())); 528 529 // Register the interceptor and its various hooks 530 for (HookInvoker nextAddedHook : addedInvokers) { 531 Pointcut nextPointcut = nextAddedHook.getPointcut(); 532 if (nextPointcut.equals(Pointcut.INTERCEPTOR_REGISTERED)) { 533 continue; 534 } 535 theInvokers.put(nextPointcut, nextAddedHook); 536 } 537 538 // Make sure we're always sorted according to the order declared in 539 // @Order 540 for (Pointcut nextPointcut : theInvokers.keys()) { 541 List<BaseInvoker> nextInvokerList = theInvokers.get(nextPointcut); 542 nextInvokerList.sort(Comparator.naturalOrder()); 543 } 544 545 return addedInvokers; 546 } 547 548 /** 549 * @return Returns a list of any added invokers 550 */ 551 private static List<HookInvoker> scanInterceptorForHookMethods(Object theInterceptor, int theTypeOrder) { 552 ArrayList<HookInvoker> retVal = new ArrayList<>(); 553 for (Method nextMethod : theInterceptor.getClass().getMethods()) { 554 Optional<Hook> hook = findAnnotation(nextMethod, Hook.class); 555 556 if (hook.isPresent()) { 557 int methodOrder = theTypeOrder; 558 int methodOrderAnnotation = hook.get().order(); 559 if (methodOrderAnnotation != Interceptor.DEFAULT_ORDER) { 560 methodOrder = methodOrderAnnotation; 561 } 562 563 retVal.add(new HookInvoker(hook.get(), theInterceptor, nextMethod, methodOrder)); 564 } 565 } 566 567 return retVal; 568 } 569 570 private static <T extends Annotation> Optional<T> findAnnotation(AnnotatedElement theObject, Class<T> theHookClass) { 571 T annotation; 572 if (theObject instanceof Method) { 573 annotation = MethodUtils.getAnnotation((Method) theObject, theHookClass, true, true); 574 } else { 575 annotation = theObject.getAnnotation(theHookClass); 576 } 577 return Optional.ofNullable(annotation); 578 } 579 580 private static int determineOrder(Class<?> theInterceptorClass) { 581 int typeOrder = Interceptor.DEFAULT_ORDER; 582 Optional<Interceptor> typeOrderAnnotation = findAnnotation(theInterceptorClass, Interceptor.class); 583 if (typeOrderAnnotation.isPresent()) { 584 typeOrder = typeOrderAnnotation.get().order(); 585 } 586 return typeOrder; 587 } 588 589 private static String toErrorString(List<String> theParameterTypes) { 590 return theParameterTypes 591 .stream() 592 .sorted() 593 .collect(Collectors.joining(",")); 594 } 595 596}