diff --git a/src/main/java/dev/kruhlmann/imgfloat/service/ChannelDirectoryService.java b/src/main/java/dev/kruhlmann/imgfloat/service/ChannelDirectoryService.java index 40e1113..ef86cbf 100644 --- a/src/main/java/dev/kruhlmann/imgfloat/service/ChannelDirectoryService.java +++ b/src/main/java/dev/kruhlmann/imgfloat/service/ChannelDirectoryService.java @@ -39,8 +39,8 @@ import dev.kruhlmann.imgfloat.service.media.MediaDetectionService; import dev.kruhlmann.imgfloat.service.media.MediaOptimizationService; import dev.kruhlmann.imgfloat.service.media.OptimizedAsset; import dev.kruhlmann.imgfloat.service.media.MediaTypeRegistry; +import dev.kruhlmann.imgfloat.util.AllowedDomainNormalizer; import java.io.IOException; -import java.net.URI; import java.nio.file.Files; import java.nio.file.Path; import java.nio.charset.StandardCharsets; @@ -51,7 +51,6 @@ import java.util.concurrent.atomic.AtomicReference; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.dao.DataAccessException; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.messaging.simp.SimpMessagingTemplate; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; @@ -59,14 +58,11 @@ import org.springframework.web.multipart.MultipartFile; import org.springframework.web.server.ResponseStatusException; @Service -// TODO: Code smell God class; this service mixes admin management, asset CRUD, media processing, websocket publishing, and marketplace concerns. public class ChannelDirectoryService { private static final Logger logger = LoggerFactory.getLogger(ChannelDirectoryService.class); private static final Pattern SAFE_FILENAME = Pattern.compile("[^a-zA-Z0-9._ -]"); private static final String DEFAULT_CODE_MEDIA_TYPE = "application/javascript"; - private static final int MAX_ALLOWED_SCRIPT_DOMAINS = 32; - private static final Pattern ALLOWED_DOMAIN_PATTERN = Pattern.compile("^[a-z0-9.-]+(?::[0-9]{1,5})?$"); private static final EnumSet VISUAL_ASSET_TYPES = EnumSet.of( AssetType.IMAGE, AssetType.VIDEO, @@ -74,7 +70,6 @@ public class ChannelDirectoryService { AssetType.OTHER ); - // TODO: Code smell Constructor has too many dependencies, indicating high coupling and too many responsibilities. private final ChannelRepository channelRepository; private final AssetRepository assetRepository; private final VisualAssetRepository visualAssetRepository; @@ -92,7 +87,6 @@ public class ChannelDirectoryService { private final MarketplaceScriptSeedLoader marketplaceScriptSeedLoader; private final AuditLogService auditLogService; - @Autowired public ChannelDirectoryService( ChannelRepository channelRepository, AssetRepository assetRepository, @@ -187,13 +181,7 @@ public class ChannelDirectoryService { List assets = assetRepository.findByBroadcaster(normalized); List visualIds = assets .stream() - .filter( - (asset) -> - asset.getAssetType() == AssetType.IMAGE || - asset.getAssetType() == AssetType.VIDEO || - asset.getAssetType() == AssetType.MODEL || - asset.getAssetType() == AssetType.OTHER - ) + .filter((asset) -> VISUAL_ASSET_TYPES.contains(asset.getAssetType())) .map(Asset::getId) .toList(); Map assetById = assets.stream().collect(Collectors.toMap(Asset::getId, (asset) -> asset)); @@ -1851,58 +1839,11 @@ public class ChannelDirectoryService { } private List normalizeAllowedDomains(List requestedDomains) { - if (requestedDomains == null || requestedDomains.isEmpty()) { - return List.of(); - } - List normalized = new ArrayList<>(); - for (String raw : requestedDomains) { - if (raw == null) { - continue; - } - String candidate = raw.trim(); - if (candidate.isEmpty()) { - continue; - } - String withScheme = candidate.contains("://") ? candidate : "https://" + candidate; - URI uri; - try { - uri = URI.create(withScheme); - } catch (IllegalArgumentException ex) { - throw new ResponseStatusException(BAD_REQUEST, "Invalid allowed domain: " + candidate, ex); - } - String host = uri.getHost(); - if (host == null || host.isBlank()) { - throw new ResponseStatusException(BAD_REQUEST, "Invalid allowed domain: " + candidate); - } - String domain = host.toLowerCase(Locale.ROOT); - int port = uri.getPort(); - if (port > 0) { - domain = domain + ":" + port; - } - if (!ALLOWED_DOMAIN_PATTERN.matcher(domain).matches()) { - throw new ResponseStatusException(BAD_REQUEST, "Invalid allowed domain: " + candidate); - } - if (normalized.contains(domain)) { - continue; - } - if (normalized.size() >= MAX_ALLOWED_SCRIPT_DOMAINS) { - throw new ResponseStatusException( - BAD_REQUEST, - "A maximum of 32 allowed domains are supported per script asset" - ); - } - normalized.add(domain); - } - return new ArrayList<>(normalized); + return AllowedDomainNormalizer.normalize(requestedDomains); } private List normalizeAllowedDomainsLenient(List requestedDomains) { - try { - return normalizeAllowedDomains(requestedDomains); - } catch (ResponseStatusException ex) { - logger.warn("Ignoring invalid allowed domains: {}", ex.getReason()); - return List.of(); - } + return AllowedDomainNormalizer.normalizeLenient(requestedDomains); } private void removeScriptAssetFileIfOrphaned(String fileId) { diff --git a/src/main/java/dev/kruhlmann/imgfloat/service/MarketplaceScriptSeedLoader.java b/src/main/java/dev/kruhlmann/imgfloat/service/MarketplaceScriptSeedLoader.java index e113a3b..8d6e6ff 100644 --- a/src/main/java/dev/kruhlmann/imgfloat/service/MarketplaceScriptSeedLoader.java +++ b/src/main/java/dev/kruhlmann/imgfloat/service/MarketplaceScriptSeedLoader.java @@ -2,8 +2,8 @@ package dev.kruhlmann.imgfloat.service; import dev.kruhlmann.imgfloat.model.api.response.ScriptMarketplaceEntry; import dev.kruhlmann.imgfloat.service.media.AssetContent; +import dev.kruhlmann.imgfloat.util.AllowedDomainNormalizer; import java.io.IOException; -import java.net.URI; import java.nio.file.DirectoryStream; import java.nio.file.Files; import java.nio.file.Path; @@ -289,43 +289,7 @@ public class MarketplaceScriptSeedLoader { } private List normalizeAllowedDomains(List requestedDomains) { - if (requestedDomains == null || requestedDomains.isEmpty()) { - return List.of(); - } - List normalized = new ArrayList<>(); - for (String raw : requestedDomains) { - if (raw == null) { - continue; - } - String candidate = raw.trim(); - if (candidate.isEmpty()) { - continue; - } - String withScheme = candidate.contains("://") ? candidate : "https://" + candidate; - try { - URI uri = URI.create(withScheme); - String host = uri.getHost(); - if (host == null || host.isBlank()) { - logger.warn("Skipping invalid allowed domain {}", candidate); - continue; - } - String value = host.toLowerCase(Locale.ROOT); - if (uri.getPort() > 0) { - value = value + ":" + uri.getPort(); - } - if (normalized.contains(value)) { - continue; - } - if (normalized.size() >= 32) { - logger.warn("Trimming allowed domains for marketplace script {}, limit reached", candidate); - break; - } - normalized.add(value); - } catch (IllegalArgumentException ex) { - logger.warn("Skipping invalid allowed domain {}", candidate, ex); - } - } - return new ArrayList<>(normalized); + return AllowedDomainNormalizer.normalizeLenient(requestedDomains); } private static Optional readBytes(Path filePath) { diff --git a/src/main/java/dev/kruhlmann/imgfloat/util/AllowedDomainNormalizer.java b/src/main/java/dev/kruhlmann/imgfloat/util/AllowedDomainNormalizer.java new file mode 100644 index 0000000..d013d4f --- /dev/null +++ b/src/main/java/dev/kruhlmann/imgfloat/util/AllowedDomainNormalizer.java @@ -0,0 +1,124 @@ +package dev.kruhlmann.imgfloat.util; + +import static org.springframework.http.HttpStatus.BAD_REQUEST; + +import java.net.URI; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.regex.Pattern; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.web.server.ResponseStatusException; + +/** + * Shared utility for normalizing and validating the allowed-domain lists that + * gate {@code fetch()} calls inside broadcast script workers. + * + *

Two modes are provided: + *

    + *
  • {@link #normalize} – strict; throws {@link ResponseStatusException} (400) on bad input. + * Use this for user-submitted data (API requests).
  • + *
  • {@link #normalizeLenient} – lenient; silently skips invalid entries. + * Use this when reading seed/marketplace data from disk.
  • + *
+ */ +public final class AllowedDomainNormalizer { + + private static final Logger LOG = LoggerFactory.getLogger(AllowedDomainNormalizer.class); + private static final int MAX_DOMAINS = 32; + private static final Pattern VALID_DOMAIN = Pattern.compile("^[a-z0-9.-]+(?::[0-9]{1,5})?$"); + + private AllowedDomainNormalizer() {} + + /** + * Strict normalization: invalid entries cause a {@link ResponseStatusException} (400). + */ + public static List normalize(List requestedDomains) { + if (requestedDomains == null || requestedDomains.isEmpty()) { + return List.of(); + } + List result = new ArrayList<>(); + for (String raw : requestedDomains) { + if (raw == null) { + continue; + } + String candidate = raw.trim(); + if (candidate.isEmpty()) { + continue; + } + String normalized = parseAndNormalize(candidate); + if (normalized == null) { + throw new ResponseStatusException(BAD_REQUEST, "Invalid allowed domain: " + candidate); + } + if (!VALID_DOMAIN.matcher(normalized).matches()) { + throw new ResponseStatusException(BAD_REQUEST, "Invalid allowed domain: " + candidate); + } + if (result.contains(normalized)) { + continue; + } + if (result.size() >= MAX_DOMAINS) { + throw new ResponseStatusException( + BAD_REQUEST, + "A maximum of " + MAX_DOMAINS + " allowed domains are supported per script asset" + ); + } + result.add(normalized); + } + return List.copyOf(result); + } + + /** + * Lenient normalization: invalid entries are skipped with a warning. + */ + public static List normalizeLenient(List requestedDomains) { + if (requestedDomains == null || requestedDomains.isEmpty()) { + return List.of(); + } + List result = new ArrayList<>(); + for (String raw : requestedDomains) { + if (raw == null) { + continue; + } + String candidate = raw.trim(); + if (candidate.isEmpty()) { + continue; + } + String normalized = parseAndNormalize(candidate); + if (normalized == null || !VALID_DOMAIN.matcher(normalized).matches()) { + LOG.warn("Skipping invalid allowed domain {}", candidate); + continue; + } + if (result.contains(normalized)) { + continue; + } + if (result.size() >= MAX_DOMAINS) { + LOG.warn("Trimming allowed domains at limit of {}", MAX_DOMAINS); + break; + } + result.add(normalized); + } + return List.copyOf(result); + } + + private static String parseAndNormalize(String candidate) { + String withScheme = candidate.contains("://") ? candidate : "https://" + candidate; + URI uri; + try { + uri = URI.create(withScheme); + } catch (IllegalArgumentException ex) { + LOG.warn("Unable to parse allowed domain {}", candidate, ex); + return null; + } + String host = uri.getHost(); + if (host == null || host.isBlank()) { + return null; + } + String domain = host.toLowerCase(Locale.ROOT); + int port = uri.getPort(); + if (port > 0) { + domain = domain + ":" + port; + } + return domain; + } +} diff --git a/src/test/java/dev/kruhlmann/imgfloat/util/AllowedDomainNormalizerTest.java b/src/test/java/dev/kruhlmann/imgfloat/util/AllowedDomainNormalizerTest.java new file mode 100644 index 0000000..b15e302 --- /dev/null +++ b/src/test/java/dev/kruhlmann/imgfloat/util/AllowedDomainNormalizerTest.java @@ -0,0 +1,106 @@ +package dev.kruhlmann.imgfloat.util; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.springframework.web.server.ResponseStatusException; + +class AllowedDomainNormalizerTest { + + // --- normalize (strict) --- + + @Test + void returnsEmptyListWhenNullInput() { + assertThat(AllowedDomainNormalizer.normalize(null)).isEmpty(); + } + + @Test + void returnsEmptyListWhenEmptyInput() { + assertThat(AllowedDomainNormalizer.normalize(List.of())).isEmpty(); + } + + @Test + void normalizesHostToLowercase() { + assertThat(AllowedDomainNormalizer.normalize(List.of("EXAMPLE.COM"))) + .containsExactly("example.com"); + } + + @Test + void preservesPort() { + assertThat(AllowedDomainNormalizer.normalize(List.of("api.example.com:8080"))) + .containsExactly("api.example.com:8080"); + } + + @Test + void stripsSchemeWhenProvided() { + assertThat(AllowedDomainNormalizer.normalize(List.of("https://example.com"))) + .containsExactly("example.com"); + } + + @Test + void deduplicatesEntries() { + List result = AllowedDomainNormalizer.normalize(List.of("example.com", "EXAMPLE.COM", "example.com")); + assertThat(result).containsExactly("example.com"); + } + + @Test + void throwsOn400WhenDomainInvalid() { + assertThatThrownBy(() -> AllowedDomainNormalizer.normalize(List.of("not a domain!!!"))) + .isInstanceOf(ResponseStatusException.class); + } + + @Test + void throwsWhenExceedsMaxDomains() { + List many = new ArrayList<>(); + for (int i = 0; i < 33; i++) { + many.add("host" + i + ".example.com"); + } + assertThatThrownBy(() -> AllowedDomainNormalizer.normalize(many)) + .isInstanceOf(ResponseStatusException.class) + .hasMessageContaining("32"); + } + + @Test + void skipsNullAndBlankEntries() { + List input = new ArrayList<>(); + input.add(null); + input.add(" "); + input.add("example.com"); + assertThat(AllowedDomainNormalizer.normalize(input)).containsExactly("example.com"); + } + + @Test + void resultIsImmutable() { + List result = AllowedDomainNormalizer.normalize(List.of("example.com")); + assertThatThrownBy(() -> result.add("other.com")) + .isInstanceOf(UnsupportedOperationException.class); + } + + // --- normalizeLenient --- + + @Test + void lenientSkipsInvalidDomainsWithoutThrowing() { + List result = AllowedDomainNormalizer.normalizeLenient(List.of("valid.com", "not a domain!!!")); + assertThat(result).containsExactly("valid.com"); + } + + @Test + void lenientStopsAtMaxDomains() { + List many = new ArrayList<>(); + for (int i = 0; i < 40; i++) { + many.add("host" + i + ".example.com"); + } + List result = AllowedDomainNormalizer.normalizeLenient(many); + assertThat(result).hasSize(32); + } + + @Test + void lenientReturnsEmptyForNullInput() { + assertThat(AllowedDomainNormalizer.normalizeLenient(null)).isEmpty(); + } +}