Skip to content

Commit d314f0d

Browse files
authored
Merge pull request #1230 from Giskard-AI/feature/gsk-1349-make-it-possible-to-upload-artifacts-to-a-huggingface-spaces
Feature/gsk 1349 make it possible to upload artifacts to a huggingface spaces
2 parents 83334be + fe8e603 commit d314f0d

19 files changed

Lines changed: 314 additions & 113 deletions

File tree

Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ ENV PATH="$VENV_PATH/bin:/usr/lib/postgresql/${POSTGRES_VERSION}/bin:$PATH" \
6767

6868
WORKDIR $GSK_DIST_PATH
6969

70+
COPY --from=build /app/python-client/dist $GSK_DIST_PATH/python-client
7071
COPY --from=build /app/python-client/.venv-prod $VENV_PATH
7172
COPY --from=build /app/backend/build/libs/backend*.jar $GSK_DIST_PATH/backend/giskard.jar
7273
COPY --from=build /app/frontend/dist $GSK_DIST_PATH/frontend/dist

backend/build.gradle.kts

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,8 @@ dependencies {
190190
implementation("org.springframework.security:spring-security-messaging")
191191
testImplementation("org.testcontainers:postgresql")
192192
implementation("org.springframework.boot:spring-boot-starter-security")
193-
implementation("org.springframework.boot:spring-boot-starter-web") {
194-
exclude(module = "spring-boot-starter-tomcat")
195-
}
196-
implementation("org.springframework.boot:spring-boot-starter-undertow")
193+
implementation("org.springframework.boot:spring-boot-starter-web")
194+
implementation("org.springframework.boot:spring-boot-starter-tomcat")
197195
implementation("org.springframework.boot:spring-boot-starter-thymeleaf")
198196
implementation("org.zalando:problem-spring-web")
199197
implementation("org.springframework.cloud:spring-cloud-starter-bootstrap")

backend/src/main/java/ai/giskard/config/WebSocketConfig.java

Lines changed: 71 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,89 @@
11
package ai.giskard.config;
22

3+
import org.slf4j.Logger;
4+
import org.slf4j.LoggerFactory;
35
import org.springframework.context.annotation.Configuration;
46
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
7+
import org.springframework.stereotype.Component;
58
import org.springframework.util.CollectionUtils;
69
import org.springframework.web.cors.CorsConfiguration;
10+
import org.springframework.web.filter.OncePerRequestFilter;
711
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
812
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
13+
import org.springframework.web.socket.config.annotation.StompWebSocketEndpointRegistration;
914
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
1015
import tech.jhipster.config.JHipsterProperties;
1116

17+
import javax.servlet.FilterChain;
18+
import javax.servlet.ServletException;
19+
import javax.servlet.http.HttpServletRequest;
20+
import javax.servlet.http.HttpServletRequestWrapper;
21+
import javax.servlet.http.HttpServletResponse;
22+
import java.io.IOException;
23+
import java.util.Collections;
24+
import java.util.Enumeration;
1225
import java.util.List;
26+
import java.util.Map;
1327

1428
@Configuration
1529
@EnableWebSocketMessageBroker
1630
public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
1731

32+
public static final String WEBSOCKET_ENDPOINT = "/websocket";
33+
34+
/**
35+
* For some reason when running on HuggingFace the connection header is set replacement "UPGRADE" instead of "Upgrade"
36+
* it causes websocket handshake replacement fail. This filter fixes the header.
37+
*/
38+
@Component
39+
static class FixUpgradeHeadersFilter extends OncePerRequestFilter {
40+
private static final Logger log = LoggerFactory.getLogger(FixUpgradeHeadersFilter.class);
41+
42+
private record HeaderReplacement(String original, String replacement) {}
43+
44+
private static final Map<String, HeaderReplacement> HEADER_REPLACEMENTS = Map.of(
45+
"connection", new HeaderReplacement("UPGRADE", "Upgrade")
46+
);
47+
48+
@Override
49+
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
50+
throws ServletException, IOException {
51+
52+
final HttpServletRequestWrapper reqWrapper = new HttpServletRequestWrapper(request) {
53+
@Override
54+
public Enumeration<String> getHeaders(String name) {
55+
String nameLowerCase = name.toLowerCase();
56+
if (WEBSOCKET_ENDPOINT.equals(request.getServletPath()) &&
57+
HEADER_REPLACEMENTS.containsKey(nameLowerCase) &&
58+
HEADER_REPLACEMENTS.get(nameLowerCase).original.equals(super.getHeaders(nameLowerCase).nextElement())) {
59+
60+
String replacement = HEADER_REPLACEMENTS.get(nameLowerCase).replacement;
61+
log.warn("Replacing header {} with {}", name, replacement);
62+
return Collections.enumeration(Collections.singleton(replacement));
63+
}
64+
return super.getHeaders(name);
65+
}
66+
67+
@Override
68+
public String getHeader(String name) {
69+
String nameLowerCase = name.toLowerCase();
70+
if (WEBSOCKET_ENDPOINT.equals(request.getServletPath()) &&
71+
HEADER_REPLACEMENTS.containsKey(nameLowerCase) &&
72+
HEADER_REPLACEMENTS.get(nameLowerCase).original.equals(super.getHeader(name))) {
73+
74+
String replacement = HEADER_REPLACEMENTS.get(nameLowerCase).replacement;
75+
log.warn("Replacing header {} with {}", name, replacement);
76+
return replacement;
77+
}
78+
return super.getHeader(name.toLowerCase());
79+
}
80+
};
81+
82+
filterChain.doFilter(reqWrapper, response);
83+
}
84+
}
85+
86+
1887
private final JHipsterProperties jHipsterProperties;
1988

2089
public WebSocketConfig(JHipsterProperties jHipsterProperties) {
@@ -32,13 +101,9 @@ public void configureMessageBroker(MessageBrokerRegistry config) {
32101
public void registerStompEndpoints(StompEndpointRegistry registry) {
33102
CorsConfiguration config = jHipsterProperties.getCors();
34103
List<String> allowedOrigins = config.getAllowedOrigins();
35-
104+
StompWebSocketEndpointRegistration endpoint = registry.addEndpoint(WEBSOCKET_ENDPOINT);
36105
if (!CollectionUtils.isEmpty(allowedOrigins)) {
37-
registry
38-
.addEndpoint("/websocket")
39-
.setAllowedOrigins(allowedOrigins.toArray(String[]::new));
40-
} else {
41-
registry.addEndpoint("/websocket");
106+
endpoint.setAllowedOrigins(allowedOrigins.toArray(String[]::new));
42107
}
43108
}
44109
}

backend/src/main/java/ai/giskard/domain/Project.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import ai.giskard.domain.ml.Dataset;
44
import ai.giskard.domain.ml.ProjectModel;
55
import ai.giskard.domain.ml.TestSuite;
6+
import ai.giskard.service.GeneralSettingsService;
67
import com.fasterxml.jackson.annotation.JsonIdentityInfo;
78
import com.fasterxml.jackson.annotation.JsonIgnore;
89
import com.fasterxml.jackson.annotation.JsonProperty;
@@ -97,7 +98,9 @@ public class Project extends AbstractAuditingEntity {
9798
private MLWorkerType mlWorkerType = MLWorkerType.EXTERNAL;
9899

99100
public boolean isUsingInternalWorker() {
100-
return mlWorkerType == MLWorkerType.INTERNAL;
101+
return mlWorkerType == MLWorkerType.INTERNAL ||
102+
// In HF Spaces, we always use the internal worker
103+
GeneralSettingsService.isRunningInHFSpaces;
101104
}
102105

103106
public void addGuest(User user) {

backend/src/main/java/ai/giskard/service/GeneralSettingsService.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.springframework.stereotype.Service;
1313

1414
import java.util.Optional;
15+
import java.util.stream.Stream;
1516

1617
@Service
1718
@RequiredArgsConstructor
@@ -20,6 +21,10 @@ public class GeneralSettingsService {
2021

2122
private final GeneralSettingsRepository settingsRepository;
2223

24+
public static final boolean isRunningInHFSpaces = Stream.of("SPACE_REPO_NAME", "SPACE_ID", "SPACE_HOST").allMatch(System.getenv()::containsKey);
25+
26+
public static final String hfSpaceId = System.getenv().get("SPACE_ID");
27+
2328
public GeneralSettings getSettings() {
2429
return deserializeSettings(settingsRepository.getMandatoryById(SerializedGiskardGeneralSettings.SINGLE_ID).getSettings());
2530
}

backend/src/main/java/ai/giskard/web/dto/config/AppConfigDTO.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import ai.giskard.web.dto.user.AdminUserDTO;
55
import ai.giskard.web.dto.user.RoleDTO;
66
import com.dataiku.j2ts.annotations.UIModel;
7+
import com.fasterxml.jackson.annotation.JsonProperty;
78
import lombok.*;
89

910
import java.time.Instant;
@@ -43,5 +44,8 @@ public static class AppInfoDTO {
4344
private GeneralSettings generalSettings;
4445
private int externalMlWorkerEntrypointPort;
4546
private String externalMlWorkerEntrypointHost;
47+
private String hfSpaceId;
48+
@JsonProperty(value = "isRunningOnHfSpaces")
49+
private boolean isRunningOnHfSpaces;
4650
}
4751
}

backend/src/main/java/ai/giskard/web/rest/controllers/SettingsController.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ public class SettingsController {
5252
private String gitBuildCommitId;
5353
@Value("${git.commit.time:-}")
5454
private String gitCommitTime;
55+
5556
private final GeneralSettingsService settingsService;
5657
private final ApplicationProperties applicationProperties;
5758

@@ -99,6 +100,8 @@ public AppConfigDTO getApplicationSettings(@AuthenticationPrincipal final UserDe
99100
.planName(currentLicense.getPlanName())
100101
.externalMlWorkerEntrypointPort(applicationProperties.getExternalMlWorkerEntrypointPort())
101102
.externalMlWorkerEntrypointHost(applicationProperties.getExternalMlWorkerEntrypointHost())
103+
.hfSpaceId(GeneralSettingsService.hfSpaceId)
104+
.isRunningOnHfSpaces(GeneralSettingsService.isRunningInHFSpaces)
102105
.roles(roles)
103106
.build())
104107
.user(userDTO)

backend/src/main/java/ai/giskard/web/socket/WorkerStatusSocketService.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
package ai.giskard.web.socket;
22

33
import ai.giskard.event.UpdateWorkerStatusEvent;
4+
import ai.giskard.service.GeneralSettingsService;
45
import ai.giskard.service.ml.MLWorkerService;
56
import lombok.RequiredArgsConstructor;
67
import org.springframework.context.event.EventListener;
78
import org.springframework.messaging.simp.SimpMessagingTemplate;
89
import org.springframework.stereotype.Service;
910
import org.springframework.web.socket.messaging.SessionSubscribeEvent;
11+
1012
import java.util.HashMap;
1113
import java.util.Map;
1214

@@ -32,7 +34,10 @@ public void handleWorkerStatusChangeEvent(UpdateWorkerStatusEvent event) {
3234

3335
public void sendCurrentStatus() {
3436
Map<String, Boolean> data = new HashMap<>();
35-
data.put("connected", mlWorkerService.isExternalWorkerConnected());
37+
data.put("connected",
38+
// HF Space uses internal worker that always connected
39+
mlWorkerService.isExternalWorkerConnected() || GeneralSettingsService.isRunningInHFSpaces
40+
);
3641
sendData(data);
3742
}
3843

frontend/src/api.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,9 @@ function downloadURL(urlString) {
190190
}
191191

192192
export const api = {
193+
async getHuggingFaceToken(spaceId: string) {
194+
return await axios.get<unknown, any>(`https://huggingface.co/api/spaces/${spaceId}/jwt`);
195+
},
193196
async logInGetToken(username: string, password: string) {
194197
return apiV2.post<unknown, JWTToken>(`/authenticate`, { username, password });
195198
},
@@ -268,7 +271,6 @@ export const api = {
268271
async getApiAccessToken() {
269272
return apiV2.get<unknown, JWTToken>(`/api-access-token`);
270273
},
271-
272274
// Projects
273275
async getProjects() {
274276
return apiV2.get<unknown, ProjectDTO[]>(`projects`);

frontend/src/generated-sources/ai/giskard/web/dto/config/app-config-dto.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ export namespace AppConfigDTO {
2323
externalMlWorkerEntrypointHost: string;
2424
externalMlWorkerEntrypointPort: number;
2525
generalSettings: GeneralSettings;
26+
hfSpaceId: string;
27+
isRunningOnHfSpaces: boolean;
2628
planCode: string;
2729
planName: string;
2830
roles: RoleDTO[];

0 commit comments

Comments
 (0)