1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.logging.log4j.audit.rest;
18
19 import javax.servlet.http.HttpServletRequest;
20 import javax.servlet.http.HttpServletResponse;
21 import java.util.Enumeration;
22
23 import org.apache.logging.log4j.LogManager;
24 import org.apache.logging.log4j.Logger;
25 import org.apache.logging.log4j.ThreadContext;
26 import org.apache.logging.log4j.audit.request.ChainedMapping;
27 import org.apache.logging.log4j.audit.request.RequestContextMapping;
28 import org.apache.logging.log4j.audit.request.RequestContextMappings;
29 import org.springframework.web.servlet.HandlerInterceptor;
30 import org.springframework.web.servlet.ModelAndView;
31
32
33
34
35 public class RequestContextHandlerInterceptor implements HandlerInterceptor {
36
37 private static final Logger logger = LogManager.getLogger(RequestContextHandlerInterceptor.class);
38 private RequestContextMappings mappings;
39 private ThreadLocal<Long> startTime = new ThreadLocal<>();
40
41 public RequestContextHandlerInterceptor(Class<?> clazz) {
42 mappings = new RequestContextMappings(clazz);
43 }
44
45 @Override
46 public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object o) throws Exception {
47 logger.trace("Starting request {}", request.getRequestURI());
48 Enumeration<String> headers = request.getHeaderNames();
49 while (headers.hasMoreElements()) {
50 String name = headers.nextElement();
51 RequestContextMapping mapping = mappings.getMappingByHeader(name);
52 logger.debug("Got Mapping:{} for Header:{}", mapping, name);
53 if (mapping != null) {
54 if (mapping.isChained()) {
55 ThreadContext.put(mapping.getChainKey(), request.getHeader(name));
56 logger.debug("Setting Context Key:{} with value:{}", mapping.getChainKey(), request.getHeader(name));
57 String value = ((ChainedMapping) mapping).getSupplier().get();
58 ThreadContext.put(mapping.getFieldName(), value);
59 logger.debug("Setting Context Key:{} with value:{}", mapping.getFieldName(), value);
60 } else {
61 ThreadContext.put(mapping.getFieldName(), request.getHeader(name));
62 logger.debug("Setting Context Key:{} with value:{}", mapping.getFieldName(), request.getHeader(name));
63 }
64 }
65 }
66 if (logger.isTraceEnabled()) {
67 startTime.set(System.nanoTime());
68 }
69 return true;
70 }
71
72 @Override
73 public void postHandle(HttpServletRequest request, HttpServletResponse response, Object o, ModelAndView modelAndView) throws Exception {
74 if (logger.isTraceEnabled()) {
75 long elapsed = System.nanoTime() - startTime.get();
76 StringBuilder sb = new StringBuilder("Request ").append(request.getRequestURI()).append(" completed in ");
77 ElapsedUtil.addElapsed(elapsed, sb);
78 logger.trace(sb.toString());
79 startTime.remove();
80 }
81 }
82
83 @Override
84 public void afterCompletion(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, Object o, Exception e) throws Exception {
85 ThreadContext.clearMap();
86 }
87 }