11package ai .giskard .config ;
22
3+ import org .slf4j .Logger ;
4+ import org .slf4j .LoggerFactory ;
35import org .springframework .context .annotation .Configuration ;
46import org .springframework .messaging .simp .config .MessageBrokerRegistry ;
7+ import org .springframework .stereotype .Component ;
58import org .springframework .util .CollectionUtils ;
69import org .springframework .web .cors .CorsConfiguration ;
10+ import org .springframework .web .filter .OncePerRequestFilter ;
711import org .springframework .web .socket .config .annotation .EnableWebSocketMessageBroker ;
812import org .springframework .web .socket .config .annotation .StompEndpointRegistry ;
13+ import org .springframework .web .socket .config .annotation .StompWebSocketEndpointRegistration ;
914import org .springframework .web .socket .config .annotation .WebSocketMessageBrokerConfigurer ;
1015import tech .jhipster .config .JHipsterProperties ;
1116
17+ import javax .servlet .FilterChain ;
18+ import javax .servlet .ServletException ;
19+ import javax .servlet .http .HttpServletRequest ;
20+ import javax .servlet .http .HttpServletRequestWrapper ;
21+ import javax .servlet .http .HttpServletResponse ;
22+ import java .io .IOException ;
23+ import java .util .Collections ;
24+ import java .util .Enumeration ;
1225import java .util .List ;
26+ import java .util .Map ;
1327
1428@ Configuration
1529@ EnableWebSocketMessageBroker
1630public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
1731
32+ public static final String WEBSOCKET_ENDPOINT = "/websocket" ;
33+
34+ /**
35+ * For some reason when running on HuggingFace the connection header is set replacement "UPGRADE" instead of "Upgrade"
36+ * it causes websocket handshake replacement fail. This filter fixes the header.
37+ */
38+ @ Component
39+ static class FixUpgradeHeadersFilter extends OncePerRequestFilter {
40+ private static final Logger log = LoggerFactory .getLogger (FixUpgradeHeadersFilter .class );
41+
42+ private record HeaderReplacement (String original , String replacement ) {}
43+
44+ private static final Map <String , HeaderReplacement > HEADER_REPLACEMENTS = Map .of (
45+ "connection" , new HeaderReplacement ("UPGRADE" , "Upgrade" )
46+ );
47+
48+ @ Override
49+ protected void doFilterInternal (HttpServletRequest request , HttpServletResponse response , FilterChain filterChain )
50+ throws ServletException , IOException {
51+
52+ final HttpServletRequestWrapper reqWrapper = new HttpServletRequestWrapper (request ) {
53+ @ Override
54+ public Enumeration <String > getHeaders (String name ) {
55+ String nameLowerCase = name .toLowerCase ();
56+ if (WEBSOCKET_ENDPOINT .equals (request .getServletPath ()) &&
57+ HEADER_REPLACEMENTS .containsKey (nameLowerCase ) &&
58+ HEADER_REPLACEMENTS .get (nameLowerCase ).original .equals (super .getHeaders (nameLowerCase ).nextElement ())) {
59+
60+ String replacement = HEADER_REPLACEMENTS .get (nameLowerCase ).replacement ;
61+ log .warn ("Replacing header {} with {}" , name , replacement );
62+ return Collections .enumeration (Collections .singleton (replacement ));
63+ }
64+ return super .getHeaders (name );
65+ }
66+
67+ @ Override
68+ public String getHeader (String name ) {
69+ String nameLowerCase = name .toLowerCase ();
70+ if (WEBSOCKET_ENDPOINT .equals (request .getServletPath ()) &&
71+ HEADER_REPLACEMENTS .containsKey (nameLowerCase ) &&
72+ HEADER_REPLACEMENTS .get (nameLowerCase ).original .equals (super .getHeader (name ))) {
73+
74+ String replacement = HEADER_REPLACEMENTS .get (nameLowerCase ).replacement ;
75+ log .warn ("Replacing header {} with {}" , name , replacement );
76+ return replacement ;
77+ }
78+ return super .getHeader (name .toLowerCase ());
79+ }
80+ };
81+
82+ filterChain .doFilter (reqWrapper , response );
83+ }
84+ }
85+
86+
1887 private final JHipsterProperties jHipsterProperties ;
1988
2089 public WebSocketConfig (JHipsterProperties jHipsterProperties ) {
@@ -32,13 +101,9 @@ public void configureMessageBroker(MessageBrokerRegistry config) {
32101 public void registerStompEndpoints (StompEndpointRegistry registry ) {
33102 CorsConfiguration config = jHipsterProperties .getCors ();
34103 List <String > allowedOrigins = config .getAllowedOrigins ();
35-
104+ StompWebSocketEndpointRegistration endpoint = registry . addEndpoint ( WEBSOCKET_ENDPOINT );
36105 if (!CollectionUtils .isEmpty (allowedOrigins )) {
37- registry
38- .addEndpoint ("/websocket" )
39- .setAllowedOrigins (allowedOrigins .toArray (String []::new ));
40- } else {
41- registry .addEndpoint ("/websocket" );
106+ endpoint .setAllowedOrigins (allowedOrigins .toArray (String []::new ));
42107 }
43108 }
44109}
0 commit comments