diff --git a/contentcuration/automation/settings.py b/contentcuration/automation/settings.py new file mode 100644 index 0000000000..e5053b3020 --- /dev/null +++ b/contentcuration/automation/settings.py @@ -0,0 +1,27 @@ +from enum import Enum +from torch.cuda import is_available as is_gpu_available + +DEVICE = "cuda:0" if is_gpu_available() else "cpu" + + +# [TRANSCRIPTION GENERATION] +WHISPER_MODELS = dict( + TINY="openai/whisper-tiny", + BASE="openai/whisper-base", + SMALL="openai/whisper-small", + MEDIUM="openai/whisper-medium", + LARGE="openai/whisper-large", + LARGEV2="openai/whisper-large-v2", +) + + +DEV_TRANSCRIPTION_MODEL = WHISPER_MODELS['TINY'] +TRANSCRIPTION_MODEL = WHISPER_MODELS['TINY'] + +class WhisperTask(Enum): + TRANSLATE = "translate" + TRANSCRIBE = "transcribe" + +# https://huggingface.co/docs/transformers/v4.29.1/en/generation_strategies#customize-text-generation +MAX_TOKEN_LENGTH = 448 +CHUNK_LENGTH = 10 diff --git a/contentcuration/automation/urls.py b/contentcuration/automation/urls.py new file mode 100644 index 0000000000..fdc6929b3a --- /dev/null +++ b/contentcuration/automation/urls.py @@ -0,0 +1,10 @@ +from automation.views import TranscriptionsViewSet +from django.urls import include, path +from rest_framework import routers + +automation_router = routers.DefaultRouter() +automation_router.register(r'transcribe', TranscriptionsViewSet, basename="transcribe") + +urlpatterns = [ + path("api/automation/", include(automation_router.urls), name='automation'), +] diff --git a/contentcuration/automation/utils/__init__.py b/contentcuration/automation/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/contentcuration/automation/views.py b/contentcuration/automation/views.py index fd0e044955..948a5a7c19 100644 --- a/contentcuration/automation/views.py +++ b/contentcuration/automation/views.py @@ -1,3 +1,10 @@ -# from django.shortcuts import render +from rest_framework.viewsets import ViewSet +from rest_framework.response import Response +from rest_framework.permissions import AllowAny -# Create your views here. +class TranscriptionsViewSet(ViewSet): + def create(self, request): + permission_classes = [AllowAny] + return Response({ + "OK":"OK" + }) \ No newline at end of file diff --git a/contentcuration/contentcuration/constants/transcription_languages.py b/contentcuration/contentcuration/constants/transcription_languages.py new file mode 100644 index 0000000000..a20c422cfe --- /dev/null +++ b/contentcuration/contentcuration/constants/transcription_languages.py @@ -0,0 +1,51 @@ +# The list of supported AI languages is dynamically loaded from a JSON file. +# You can update the supported languages for transcription by modifying the 'ai_supported_languages.json' file. +# The script then determines the intersection of languages supported by the Kolibri project and the Whisper speech-to-text model. +# The resulting list of language codes is stored in CAPTION_LANGUAGES for creating captions. +# Note: To update supported languages, modify the 'ai_supported_languages.json' file. + +import os +import json +from typing import List, Dict + +import le_utils.resources as resources + + +def _ai_supported_languages() -> Dict: + """Loads JSON of supported AI languages""" + file = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "../static/ai_supported_languages.json", + ) + with open(file) as f: + data = json.load(f) + return data + + +WHISPER_LANGUAGES = _ai_supported_languages() + + +def _load_kolibri_languages() -> List[str]: + """Loads the language codes from languagelookup.json and returns them as a list.""" + filepath = resources.__path__[0] + kolibri_languages = [] + with open(f"{filepath}/languagelookup.json") as f: + kolibri_languages = list(json.load(f).keys()) + return kolibri_languages + + +def _load_model_languages(languages: Dict[str, str]) -> List[str]: + """Load languages supported by the speech-to-text model. + :param: languages: dict mapping language codes to language names""" + return list(languages.keys()) + + +def create_captions_languages() -> List[str]: + """Returns the intersection of Kolibri languages and model languages""" + kolibri_set = set(_load_kolibri_languages()) + model_set = set(_load_model_languages(languages=WHISPER_LANGUAGES)) + return list(kolibri_set.intersection(model_set)) + + +# list of language id's ['en', 'hi', ...] +CAPTION_LANGUAGES = create_captions_languages() diff --git a/contentcuration/contentcuration/dev_urls.py b/contentcuration/contentcuration/dev_urls.py index afbb7a83f8..8e54367747 100644 --- a/contentcuration/contentcuration/dev_urls.py +++ b/contentcuration/contentcuration/dev_urls.py @@ -1,5 +1,6 @@ import urllib.parse +from automation.urls import urlpatterns as automation_urlpatterns from django.conf import settings from django.contrib import admin from django.core.files.storage import default_storage @@ -76,6 +77,8 @@ def file_server(request, storage_path=None): re_path(r"^content/(?P.+)$", file_server), ] +urlpatterns += automation_urlpatterns + if getattr(settings, "DEBUG_PANEL_ACTIVE", False): import debug_toolbar diff --git a/contentcuration/contentcuration/frontend/channelEdit/components/CaptionsTab/CaptionsTab.vue b/contentcuration/contentcuration/frontend/channelEdit/components/CaptionsTab/CaptionsTab.vue new file mode 100644 index 0000000000..74125fe6e7 --- /dev/null +++ b/contentcuration/contentcuration/frontend/channelEdit/components/CaptionsTab/CaptionsTab.vue @@ -0,0 +1,162 @@ + + + + + diff --git a/contentcuration/contentcuration/frontend/channelEdit/components/edit/EditModal.vue b/contentcuration/contentcuration/frontend/channelEdit/components/edit/EditModal.vue index 995072d190..7a1e7823b2 100644 --- a/contentcuration/contentcuration/frontend/channelEdit/components/edit/EditModal.vue +++ b/contentcuration/contentcuration/frontend/channelEdit/components/edit/EditModal.vue @@ -340,6 +340,7 @@ // (especially marking nodes as (in)complete) vm.loadFiles({ contentnode__in: childrenNodesIds }), vm.loadAssessmentItems({ contentnode__in: childrenNodesIds }), + vm.loadCaptions({ contentnode__in: childrenNodesIds }), ]; } else { // no need to load assessment items or files as topics have none @@ -394,6 +395,7 @@ ]), ...mapActions('file', ['loadFiles', 'updateFile']), ...mapActions('assessmentItem', ['loadAssessmentItems', 'updateAssessmentItems']), + ...mapActions('caption', ['loadCaptions']), ...mapMutations('contentNode', { enableValidation: 'ENABLE_VALIDATION_ON_NODES' }), closeModal() { this.promptUploading = false; diff --git a/contentcuration/contentcuration/frontend/channelEdit/components/edit/EditView.vue b/contentcuration/contentcuration/frontend/channelEdit/components/edit/EditView.vue index 4a1e0cebb5..e153e49eeb 100644 --- a/contentcuration/contentcuration/frontend/channelEdit/components/edit/EditView.vue +++ b/contentcuration/contentcuration/frontend/channelEdit/components/edit/EditView.vue @@ -62,6 +62,16 @@ {{ relatedResourcesCount }} + + + + {{ $tr(tabs.CAPTIONS) }} + @@ -82,6 +92,7 @@ + + + + + @@ -104,6 +120,7 @@ import { TabNames } from '../../constants'; import AssessmentTab from '../../components/AssessmentTab/AssessmentTab'; + import CaptionsTab from '../../components/CaptionsTab/CaptionsTab' import RelatedResourcesTab from '../../components/RelatedResourcesTab/RelatedResourcesTab'; import DetailsTabView from './DetailsTabView'; import { ContentKindsNames } from 'shared/leUtils/ContentKinds'; @@ -113,11 +130,12 @@ export default { name: 'EditView', components: { - DetailsTabView, - AssessmentTab, - RelatedResourcesTab, - Tabs, - ToolBar, + AssessmentTab, + CaptionsTab, + DetailsTabView, + RelatedResourcesTab, + Tabs, + ToolBar, }, props: { nodeIds: { @@ -143,6 +161,7 @@ 'getImmediateRelatedResourcesCount', ]), ...mapGetters('assessmentItem', ['getAssessmentItemsAreValid', 'getAssessmentItemsCount']), + ...mapGetters(['isAIFeatureEnabled']), firstNode() { return this.nodes.length ? this.nodes[0] : null; }, @@ -167,6 +186,14 @@ showRelatedResourcesTab() { return this.oneSelected && this.firstNode && this.firstNode.kind !== 'topic'; }, + showCaptions() { + return ( + this.oneSelected && + this.firstNode && + (this.firstNode.kind === 'video' || this.firstNode.kind === 'audio') && + this.isAIFeatureEnabled + ) + }, countText() { const totals = reduce( this.nodes, @@ -260,6 +287,8 @@ questions: 'Questions', /** @see TabNames.RELATED */ related: 'Related', + /** @see TabNames.CAPTIONS */ + captions: 'Captions', /* eslint-enable kolibri/vue-no-unused-translations */ noItemsToEditText: 'Please select resources or folders to edit', invalidFieldsToolTip: 'Some required information is missing', diff --git a/contentcuration/contentcuration/frontend/channelEdit/constants.js b/contentcuration/contentcuration/frontend/channelEdit/constants.js index 6512e9e9b4..cca0274c62 100644 --- a/contentcuration/contentcuration/frontend/channelEdit/constants.js +++ b/contentcuration/contentcuration/frontend/channelEdit/constants.js @@ -55,6 +55,7 @@ export const TabNames = { PREVIEW: 'preview', QUESTIONS: 'questions', RELATED: 'related', + CAPTIONS: 'captions', }; export const modes = { diff --git a/contentcuration/contentcuration/frontend/channelEdit/store.js b/contentcuration/contentcuration/frontend/channelEdit/store.js index 068563d111..497a99ddef 100644 --- a/contentcuration/contentcuration/frontend/channelEdit/store.js +++ b/contentcuration/contentcuration/frontend/channelEdit/store.js @@ -1,5 +1,6 @@ import template from './vuex/template'; import assessmentItem from './vuex/assessmentItem'; +import caption from './vuex/caption'; import clipboard from './vuex/clipboard'; import contentNode from './vuex/contentNode'; import currentChannel from './vuex/currentChannel'; @@ -45,6 +46,7 @@ export const STORE_CONFIG = { task, template, assessmentItem, + caption, clipboard, contentNode, currentChannel, diff --git a/contentcuration/contentcuration/frontend/channelEdit/vuex/caption/actions.js b/contentcuration/contentcuration/frontend/channelEdit/vuex/caption/actions.js new file mode 100644 index 0000000000..a8d63a2a58 --- /dev/null +++ b/contentcuration/contentcuration/frontend/channelEdit/vuex/caption/actions.js @@ -0,0 +1,73 @@ +import { CaptionFile, CaptionCues } from 'shared/data/resources'; +import { GENERATING } from 'shared/data/constants'; + +async function loadCaptionFiles(commit, params) { + const captionFiles = await CaptionFile.where(params); + commit('ADD_CAPTIONFILES', { captionFiles, nodeIds: params.contentnode__in }); + return captionFiles; +} + +async function loadCaptionCues(commit, { caption_file_id }) { + const cues = await CaptionCues.where({ caption_file_id }); + commit('ADD_CAPTIONCUES', { cues }); + return cues; +} + +export async function loadCaptions({ commit, rootGetters }, params) { + const isAIFeatureEnabled = rootGetters['isAIFeatureEnabled']; + if (!isAIFeatureEnabled) return; + + // If a new file is uploaded, the contentnode_id will be string + if (typeof params.contentnode__in === 'string') { + params.contentnode__in = [params.contentnode__in]; + } + const nodeIdsToLoad = []; + for (const nodeId of params.contentnode__in) { + const node = rootGetters['contentNode/getContentNode'](nodeId); + if (node && (node.kind === 'video' || node.kind === 'audio')) { + nodeIdsToLoad.push(nodeId); // already in vuex + } else if (!node) { + nodeIdsToLoad.push(nodeId); // Assume that its audio/video + } + } + + const captionFiles = await loadCaptionFiles(commit, { + contentnode__in: nodeIdsToLoad, + }); + + // If there is no Caption File for this contentnode don't request for the cues + if (captionFiles.length === 0) return; + + captionFiles.forEach((file) => { + // Load all the cues associated with the file_id + const caption_file_id = file.id; + loadCaptionCues(commit, { caption_file_id }); + }); +} + +export async function addCaptionFile({ state, commit }, { id, file_id, language, nodeId }) { + const captionFile = { + id: id, + file_id: file_id, + language: language, + }; + // The file_id and language should be unique together in the vuex state. + // This check avoids duplicating existing caption data already loaded into vuex. + const existingCaptionFile = state.captionFilesMap[nodeId] + ? Object.values(state.captionFilesMap[nodeId]).find( + (file) => file.language === captionFile.language && file.file_id === captionFile.file_id + ) + : null; + + if (!existingCaptionFile) { + // new created file will enqueue generate caption celery task + captionFile[GENERATING] = true; + return CaptionFile.add(captionFile).then((id) => { + commit('ADD_CAPTIONFILE', { + id, + captionFile, + nodeId, + }); + }); + } +} diff --git a/contentcuration/contentcuration/frontend/channelEdit/vuex/caption/getters.js b/contentcuration/contentcuration/frontend/channelEdit/vuex/caption/getters.js new file mode 100644 index 0000000000..cf3f282ec8 --- /dev/null +++ b/contentcuration/contentcuration/frontend/channelEdit/vuex/caption/getters.js @@ -0,0 +1,8 @@ +import { GENERATING } from 'shared/data/constants'; + +export function isGeneratingGetter(state) { + return contentnode_id => { + const captionFiles = Object.values(state.captionFilesMap[contentnode_id] || {}); + return captionFiles.some(file => file[GENERATING] === true); + }; +} diff --git a/contentcuration/contentcuration/frontend/channelEdit/vuex/caption/index.js b/contentcuration/contentcuration/frontend/channelEdit/vuex/caption/index.js new file mode 100644 index 0000000000..f8e12ff254 --- /dev/null +++ b/contentcuration/contentcuration/frontend/channelEdit/vuex/caption/index.js @@ -0,0 +1,49 @@ +import * as getters from './getters'; +import * as mutations from './mutations'; +import * as actions from './actions'; +import { TABLE_NAMES, CHANGE_TYPES } from 'shared/data'; + +export default { + namespaced: true, + state: () => ({ + /* List of caption files for a contentnode + * [ + * contentnode_id: { + * pk: { + * id: pk + * file_id: file_id + * language: language + * __generating_captions: boolean + * } + * }, + * ] + */ + captionFilesMap: [], + /* Caption Cues for a contentnode + * [ + * caption_file_id: { + * id: id + * starttime: starttime + * endtime: endtime + * text: text + * } + * ] + */ + captionCuesMap: [], + }), + getters, + mutations, + actions, + listeners: { + [TABLE_NAMES.CAPTION_FILE]: { + [CHANGE_TYPES.CREATED]: 'ADD_CAPTIONFILE', + [CHANGE_TYPES.UPDATED]: 'UPDATE_CAPTIONFILE_FROM_INDEXEDDB', + [CHANGE_TYPES.DELETED]: 'DELETE_CAPTIONFILE', + }, + [TABLE_NAMES.CAPTION_CUES]: { + [CHANGE_TYPES.CREATED]: 'ADD_CAPTIONCUES', + [CHANGE_TYPES.UPDATED]: 'UPDATE_CAPTIONCUE', + [CHANGE_TYPES.DELETED]: 'DELETE_CAPTIONCUE', + }, + }, +}; diff --git a/contentcuration/contentcuration/frontend/channelEdit/vuex/caption/mutations.js b/contentcuration/contentcuration/frontend/channelEdit/vuex/caption/mutations.js new file mode 100644 index 0000000000..2432a083a3 --- /dev/null +++ b/contentcuration/contentcuration/frontend/channelEdit/vuex/caption/mutations.js @@ -0,0 +1,77 @@ +import Vue from 'vue'; +import { GENERATING } from 'shared/data/constants'; + +/* Mutations for Caption File */ +export function ADD_CAPTIONFILE(state, { captionFile, nodeId }) { + if (!nodeId || !captionFile) return; + // Check if there is Map for the current nodeId + if (!state.captionFilesMap[nodeId]) { + Vue.set(state.captionFilesMap, nodeId, {}); + } + + // Spread the new data into the state + Vue.set(state.captionFilesMap[nodeId], captionFile.id, { + ...state.captionFilesMap[nodeId][captionFile.id], + id: captionFile.id, + file_id: captionFile.file_id, + language: captionFile.language, + [GENERATING]: captionFile[GENERATING] || false, + }); +} + +export function ADD_CAPTIONFILES(state, { captionFiles, nodeIds }) { + // TODO: this causes to not update Vuex state correctly on the initial loading of the component + let currentIndex = 0; // pointer to nodeIds + if (Array.isArray(captionFiles)) { + captionFiles.forEach((captionFile, index) => { + if (index > 0) { + const prevCaptionFile = captionFiles[index - 1]; + if (captionFile.file_id !== prevCaptionFile.file_id) currentIndex++; + } + + const nodeId = currentIndex < nodeIds.length ? nodeIds[currentIndex] : null; + if (nodeId !== null) { + ADD_CAPTIONFILE(state, { captionFile, nodeId }); + } + }); + } +} + +/* Mutations for Caption Cues */ +export function ADD_CUE(state, { cue }) { + if (!cue) return; + + if (!state.captionCuesMap[cue.caption_file_id]) { + Vue.set(state.captionCuesMap, cue.caption_file_id, {}); + } + + const fileMap = state.captionCuesMap[cue.caption_file_id]; + + Vue.set(state.captionCuesMap, cue.caption_file_id, { + ...fileMap, + [cue.id]: { + id: cue.id, + text: cue.text, + starttime: cue.starttime, + endtime: cue.endtime, + }, + }); +} + +export function ADD_CAPTIONCUES(state, { cues }) { + if (Array.isArray(cues)) { + cues.forEach((cue) => { + ADD_CUE(state, { cue }); + }); + } +} + +export function UPDATE_CAPTIONFILE_FROM_INDEXEDDB(state, { id, ...mods }) { + if (!id) return; + for (const nodeId in state.captionFilesMap) { + if (state.captionFilesMap[nodeId][id]) { + Vue.set(state.captionFilesMap[nodeId][id], GENERATING, mods[GENERATING]); + break; + } + } +} diff --git a/contentcuration/contentcuration/frontend/shared/data/constants.js b/contentcuration/contentcuration/frontend/shared/data/constants.js index 709b251d1f..24c93a98ef 100644 --- a/contentcuration/contentcuration/frontend/shared/data/constants.js +++ b/contentcuration/contentcuration/frontend/shared/data/constants.js @@ -46,6 +46,8 @@ export const TABLE_NAMES = { TASK: 'task', CHANGES_TABLE, BOOKMARK: 'bookmark', + CAPTION_FILE: 'caption_file', + CAPTION_CUES: 'caption_cues', }; /** @@ -68,6 +70,7 @@ export const RELATIVE_TREE_POSITIONS_LOOKUP = invert(RELATIVE_TREE_POSITIONS); export const COPYING_FLAG = '__COPYING'; export const TASK_ID = '__TASK_ID'; export const LAST_FETCHED = '__last_fetch'; +export const GENERATING = '__generating_captions'; // This constant is used for saving/retrieving a current // user object from the session table diff --git a/contentcuration/contentcuration/frontend/shared/data/resources.js b/contentcuration/contentcuration/frontend/shared/data/resources.js index 4a7101f307..32f5d79e53 100644 --- a/contentcuration/contentcuration/frontend/shared/data/resources.js +++ b/contentcuration/contentcuration/frontend/shared/data/resources.js @@ -19,6 +19,7 @@ import { RELATIVE_TREE_POSITIONS, TABLE_NAMES, COPYING_FLAG, + GENERATING, TASK_ID, CURRENT_USER, MAX_REV_KEY, @@ -1017,6 +1018,85 @@ export const Bookmark = new Resource({ getUserId: getUserIdFromStore, }); +export const CaptionFile = new Resource({ + tableName: TABLE_NAMES.CAPTION_FILE, + urlName: 'captions', + idField: 'id', + indexFields: ['file_id', 'language'], + syncable: true, + getChannelId: getChannelFromChannelScope, + + waitForCaptionCueGeneration(id) { + const observable = Dexie.liveQuery(() => { + return this.table + .where('id') + .equals(id) + .filter(f => !f[GENERATING]) + .toArray(); + }); + return new Promise((resolve, reject) => { + const subscription = observable.subscribe({ + next(result) { + if (result.length > 0 && result[0][GENERATING] === false) { + subscription.unsubscribe(); + resolve(false); + } + }, + error() { + subscription.unsubscribe(); + reject(); + }, + }); + }); + }, +}); + +export const CaptionCues = new Resource({ + tableName: TABLE_NAMES.CAPTION_CUES, + urlName: 'captioncues', + idField: 'id', + indexFields: ['starttime', 'endtime', 'caption_file_id'], + syncable: true, + getChannelId: getChannelFromChannelScope, + filterCuesByFileId(caption_file_id) { + return this.table + .where('id') + .equals(caption_file_id) + .toArray(); + }, + collectionUrl(caption_file_id) { + return this.getUrlFunction('list')(caption_file_id); + }, + fetchCollection({ caption_file_id }) { + const now = Date.now(); + const generatedUrl = this.collectionUrl(caption_file_id); + const cachedRequest = this._requests[generatedUrl]; + if ( + cachedRequest && + cachedRequest[LAST_FETCHED] && + cachedRequest[LAST_FETCHED] + REFRESH_INTERVAL * 1000 > now && + cachedRequest.promise + ) { + return cachedRequest.promise; + } + const promise = client.get(generatedUrl).then(response => { + let itemData; + if (Array.isArray(response.data)) { + itemData = response.data; + } else { + console.error(`Unexpected response from ${this.urlName}`, response); + itemData = []; + } + return this.setData(itemData); + }); + this._requests[generatedUrl] = { + [LAST_FETCHED]: now, + promise, + }; + return promise; + }, +}); + export const Channel = new Resource({ tableName: TABLE_NAMES.CHANNEL, urlName: 'channel', diff --git a/contentcuration/contentcuration/frontend/shared/leUtils/TranscriptionLanguages.js b/contentcuration/contentcuration/frontend/shared/leUtils/TranscriptionLanguages.js new file mode 100644 index 0000000000..f926adba49 --- /dev/null +++ b/contentcuration/contentcuration/frontend/shared/leUtils/TranscriptionLanguages.js @@ -0,0 +1,21 @@ +/** + * This file generates the list of supported caption languages by + * filtering the full list of languages against the whisperLanguages object. + * To switch to a new model for supported languages, you can update the + * ai_supported_languages.json. + */ + +import { LanguagesList } from 'shared/leUtils/Languages'; +import aiSupportedLanguages from 'static/ai_supported_languages.json'; + +const whisperLanguages = aiSupportedLanguages + +export const supportedCaptionLanguages = LanguagesList.filter( + language => language.lang_code in whisperLanguages +); + +export const notSupportedCaptionLanguages = LanguagesList.filter( + language => !(language.lang_code in whisperLanguages) +); + +export default supportedCaptionLanguages; diff --git a/contentcuration/contentcuration/migrations/0147_captioncue_captionfile.py b/contentcuration/contentcuration/migrations/0147_captioncue_captionfile.py new file mode 100644 index 0000000000..a694501a74 --- /dev/null +++ b/contentcuration/contentcuration/migrations/0147_captioncue_captionfile.py @@ -0,0 +1,41 @@ +# Generated by Django 3.2.19 on 2023-11-01 19:54 +import uuid + +import django.db.models.deletion +from django.db import migrations +from django.db import models + +import contentcuration.models + + +class Migration(migrations.Migration): + + dependencies = [ + ('contentcuration', '0146_drop_taskresult_fields'), + ] + + operations = [ + migrations.CreateModel( + name='CaptionFile', + fields=[ + ('id', contentcuration.models.UUIDField(default=uuid.uuid4, max_length=32, primary_key=True, serialize=False)), + ('file_id', contentcuration.models.UUIDField(default=uuid.uuid4, max_length=32)), + ('modified', models.DateTimeField(auto_now=True, verbose_name='modified')), + ('language', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='caption_file', to='contentcuration.language')), + ('output_file', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='contentcuration.file')), + ], + options={ + 'unique_together': {('file_id', 'language')}, + }, + ), + migrations.CreateModel( + name='CaptionCue', + fields=[ + ('id', contentcuration.models.UUIDField(default=uuid.uuid4, max_length=32, primary_key=True, serialize=False)), + ('text', models.TextField()), + ('starttime', models.FloatField()), + ('endtime', models.FloatField()), + ('caption_file', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='caption_cue', to='contentcuration.captionfile')), + ], + ), + ] diff --git a/contentcuration/contentcuration/models.py b/contentcuration/contentcuration/models.py index 0d73096bfa..b0a550ca79 100644 --- a/contentcuration/contentcuration/models.py +++ b/contentcuration/contentcuration/models.py @@ -67,6 +67,7 @@ from contentcuration.constants import completion_criteria from contentcuration.constants import user_history from contentcuration.constants.contentnode import kind_activity_map +from contentcuration.constants.transcription_languages import CAPTION_LANGUAGES from contentcuration.db.models.expressions import Array from contentcuration.db.models.functions import ArrayRemove from contentcuration.db.models.functions import Unnest @@ -2058,6 +2059,53 @@ def __str__(self): return self.ietf_name() +class CaptionFile(models.Model): + """ + Represents a caption file record. + + - file_id: The identifier of related Video/Audio File object. + - language: The language of the caption file. + - output_file: The FK to the associated generated VTT File object. + """ + id = UUIDField(primary_key=True, default=uuid.uuid4) + file_id = UUIDField(default=uuid.uuid4, max_length=36) + language = models.ForeignKey(Language, related_name="caption_file", on_delete=models.CASCADE) + modified = models.DateTimeField(auto_now=True, verbose_name="modified") + output_file = models.ForeignKey('File', null=True, blank=True, + on_delete=models.SET_NULL) + + class Meta: + unique_together = ['file_id', 'language'] + + def __str__(self): + return "file_id: {file_id}, language: {language}".format(file_id=self.file_id, language=self.language) + + def save(self, *args, **kwargs): + # Check if the language is supported by speech-to-text AI model. + if self.language and self.language.lang_code not in CAPTION_LANGUAGES: + raise ValueError("The language is currently not supported by speech-to-text model.") + super(CaptionFile, self).save(*args, **kwargs) + + +class CaptionCue(models.Model): + """ + Represents a caption cue in a VTT file. + + - text: The caption text. + - starttime: The start time of the cue in seconds. + - endtime: The end time of the cue in seconds. + - caption_file (Foreign Key): The related caption file. + """ + id = UUIDField(primary_key=True, default=uuid.uuid4) + text = models.TextField(null=False) + starttime = models.FloatField(null=False) + endtime = models.FloatField(null=False) + caption_file = models.ForeignKey(CaptionFile, related_name="caption_cue", on_delete=models.CASCADE) + + def __str__(self): + return "text: {text}, start_time: {starttime}, end_time: {endtime}".format(text=self.text, starttime=self.starttime, endtime=self.endtime) + + ASSESSMENT_ID_INDEX_NAME = "assessment_id_idx" diff --git a/contentcuration/contentcuration/not_production_settings.py b/contentcuration/contentcuration/not_production_settings.py index e98410433d..96cfaf1987 100644 --- a/contentcuration/contentcuration/not_production_settings.py +++ b/contentcuration/contentcuration/not_production_settings.py @@ -4,6 +4,7 @@ ACCOUNT_ACTIVATION_DAYS = 7 EMAIL_BACKEND = 'postmark.django_backend.EmailBackend' +WHISPER_BACKEND = 'contentcuration.utils.transcription.LocalWhisper' POSTMARK_API_KEY = 'POSTMARK_API_TEST' POSTMARK_TEST_MODE = True diff --git a/contentcuration/contentcuration/static/ai_supported_languages.json b/contentcuration/contentcuration/static/ai_supported_languages.json new file mode 100644 index 0000000000..e03c95c674 --- /dev/null +++ b/contentcuration/contentcuration/static/ai_supported_languages.json @@ -0,0 +1,101 @@ +{ + "en": "english", + "zh": "chinese", + "de": "german", + "es": "spanish", + "ru": "russian", + "ko": "korean", + "fr": "french", + "ja": "japanese", + "pt": "portuguese", + "tr": "turkish", + "pl": "polish", + "ca": "catalan", + "nl": "dutch", + "ar": "arabic", + "sv": "swedish", + "it": "italian", + "id": "indonesian", + "hi": "hindi", + "fi": "finnish", + "vi": "vietnamese", + "he": "hebrew", + "uk": "ukrainian", + "el": "greek", + "ms": "malay", + "cs": "czech", + "ro": "romanian", + "da": "danish", + "hu": "hungarian", + "ta": "tamil", + "no": "norwegian", + "th": "thai", + "ur": "urdu", + "hr": "croatian", + "bg": "bulgarian", + "lt": "lithuanian", + "la": "latin", + "mi": "maori", + "ml": "malayalam", + "cy": "welsh", + "sk": "slovak", + "te": "telugu", + "fa": "persian", + "lv": "latvian", + "bn": "bengali", + "sr": "serbian", + "az": "azerbaijani", + "sl": "slovenian", + "kn": "kannada", + "et": "estonian", + "mk": "macedonian", + "br": "breton", + "eu": "basque", + "is": "icelandic", + "hy": "armenian", + "ne": "nepali", + "mn": "mongolian", + "bs": "bosnian", + "kk": "kazakh", + "sq": "albanian", + "sw": "swahili", + "gl": "galician", + "mr": "marathi", + "pa": "punjabi", + "si": "sinhala", + "km": "khmer", + "sn": "shona", + "yo": "yoruba", + "so": "somali", + "af": "afrikaans", + "oc": "occitan", + "ka": "georgian", + "be": "belarusian", + "tg": "tajik", + "sd": "sindhi", + "gu": "gujarati", + "am": "amharic", + "yi": "yiddish", + "lo": "lao", + "uz": "uzbek", + "fo": "faroese", + "ht": "haitian creole", + "ps": "pashto", + "tk": "turkmen", + "nn": "nynorsk", + "mt": "maltese", + "sa": "sanskrit", + "lb": "luxembourgish", + "my": "myanmar", + "bo": "tibetan", + "tl": "tagalog", + "mg": "malagasy", + "as": "assamese", + "tt": "tatar", + "haw": "hawaiian", + "ln": "lingala", + "ha": "hausa", + "ba": "bashkir", + "jw": "javanese", + "su": "sundanese" +} \ No newline at end of file diff --git a/contentcuration/contentcuration/tasks.py b/contentcuration/contentcuration/tasks.py index 39f89805ce..07f04da337 100644 --- a/contentcuration/contentcuration/tasks.py +++ b/contentcuration/contentcuration/tasks.py @@ -137,3 +137,47 @@ def sendcustomemails_task(subject, message, query): text = message.format(current_date=time.strftime("%A, %B %d"), current_time=time.strftime("%H:%M %Z"), **recipient.__dict__) text = render_to_string('registration/custom_email.txt', {'message': text}) recipient.email_user(subject, text, settings.DEFAULT_FROM_EMAIL, ) + +@app.task(name="generatecaptioncues_task") +def generatecaptioncues_task(caption_file_id: str, channel_id, user_id) -> None: + """Start generating the Captions Cues for requested the Caption File""" + + from contentcuration.viewsets.caption import CaptionCueSerializer + from contentcuration.viewsets.sync.constants import CAPTION_FILE + from contentcuration.viewsets.sync.constants import CAPTION_CUES + from contentcuration.viewsets.sync.utils import generate_update_event + from contentcuration.viewsets.sync.utils import generate_create_event + from contentcuration.utils.transcription import WhisperAdapter + from contentcuration.utils.transcription import WhisperBackendFactory + + + backend = WhisperBackendFactory().create_backend() + adapter = WhisperAdapter(backend=backend) + + cues = adapter.transcribe(caption_file_id=caption_file_id).get_cues(caption_file_id) + + for cue in cues: + serializer = CaptionCueSerializer(data=cue) + if serializer.is_valid(): + serializer.save() + Change.create_change(generate_create_event( + cue["id"], + CAPTION_CUES, + { + "id": cue["id"], + "text": cue["text"], + "starttime": cue["starttime"], + "endtime": cue["endtime"], + "caption_file_id": cue["caption_file_id"], + }, + channel_id=channel_id, + ), applied=True, created_by_id=user_id) + else: + raise ValueError(f"Error in cue serialization: {serializer.errors}") + + Change.create_change(generate_update_event( + caption_file_id, + CAPTION_FILE, + {"__generating_captions": False}, + channel_id=channel_id, + ), applied=True, created_by_id=user_id) diff --git a/contentcuration/contentcuration/tests/test_exportchannel.py b/contentcuration/contentcuration/tests/test_exportchannel.py index 36e331c713..4fb1698ae4 100644 --- a/contentcuration/contentcuration/tests/test_exportchannel.py +++ b/contentcuration/contentcuration/tests/test_exportchannel.py @@ -429,6 +429,47 @@ def test_publish_no_modify_legacy_exercise_extra_fields(self): 'n': 2 }) + def test_vtt_on_publish(self): + from contentcuration.utils.publish import process_webvtt_file_publishing + # Set up a video node with captions + new_video = create_node({'kind_id': 'video', 'title': 'caption creation test'}) + new_video.complete = True + new_video.parent = self.content_channel.main_tree + new_video.save() + + # create a CaptionFile associated with contentnode + video_files = new_video.files.all() + caption_file_data = { + "file_id": video_files[0].id, + "language": cc.Language.objects.get(pk="en"), + } + caption_file = cc.CaptionFile(**caption_file_data) + caption_file.save() + + # create a CaptionCue associated with CaptionFile + cues = cc.CaptionCue(text='a test string', starttime=0, endtime=3, caption_file=caption_file) + cues.save() + + assert caption_file.output_file is None + process_webvtt_file_publishing('create', new_video, caption_file) + assert caption_file.output_file is not None + + expected_webvtt = 'WEBVTT\n\n0:00:00.000 --> 0:00:03.000\na test string\n\n'.encode('utf-8') + webvtt = caption_file.output_file.file_on_disk.read() # output_file is the VTT file + assert webvtt == expected_webvtt + + # Update caption text + caption_cue = caption_file.caption_cue.first() + caption_cue.text = "Updated text" + caption_cue.save() + + # Publish again to update VTT file + process_webvtt_file_publishing('update', new_video, caption_file) + updated_vtt = caption_file.output_file.file_on_disk.read() + + # Assert VTT files are different + assert webvtt != updated_vtt + assert updated_vtt == 'WEBVTT\n\n0:00:00.000 --> 0:00:03.000\nUpdated text\n\n'.encode('utf-8') class EmptyChannelTestCase(StudioTestCase): diff --git a/contentcuration/contentcuration/tests/utils/test_transcription.py b/contentcuration/contentcuration/tests/utils/test_transcription.py new file mode 100644 index 0000000000..1105cf8a9d --- /dev/null +++ b/contentcuration/contentcuration/tests/utils/test_transcription.py @@ -0,0 +1,15 @@ +from django.test import TestCase + +from contentcuration.utils.transcription import LocalWhisper +from contentcuration.utils.transcription import WhisperBackendFactory + + +class TranscriptionTestCase(TestCase): + def test_backend_initialization(self): + backend = LocalWhisper() + self.assertIsNotNone(backend) + self.assertIsInstance(backend.get_instance(), LocalWhisper) + + def test_backend_factory(self): + backend = WhisperBackendFactory().create_backend() + assert isinstance(backend, LocalWhisper) diff --git a/contentcuration/contentcuration/tests/viewsets/test_caption.py b/contentcuration/contentcuration/tests/viewsets/test_caption.py new file mode 100644 index 0000000000..603cc01f76 --- /dev/null +++ b/contentcuration/contentcuration/tests/viewsets/test_caption.py @@ -0,0 +1,299 @@ +from __future__ import absolute_import + +import json +import uuid + +from contentcuration.models import CaptionCue, CaptionFile, Language +from contentcuration.tests import testdata +from contentcuration.tests.base import StudioAPITestCase +from contentcuration.tests.viewsets.base import ( + SyncTestMixin, + generate_create_event, + generate_delete_event, + generate_update_event, +) +from contentcuration.viewsets.caption import CaptionCueSerializer, CaptionFileSerializer +from contentcuration.viewsets.sync.constants import CAPTION_CUES, CAPTION_FILE + + +class SyncTestCase(SyncTestMixin, StudioAPITestCase): + @property + def caption_file_metadata(self): + return { + "file_id": uuid.uuid4().hex, + "language": Language.objects.get(pk="en").pk, + } + + @property + def same_file_different_language_metadata(self): + id = uuid.uuid4().hex + return [ + { + "file_id": id, + "language": Language.objects.get(pk="en"), + }, + { + "file_id": id, + "language": Language.objects.get(pk="ru"), + }, + ] + + @property + def caption_cue_metadata(self): + return { + "file": { + "file_id": uuid.uuid4().hex, + "language": Language.objects.get(pk="en").pk, + }, + "cue": { + "text": "This is the beginning!", + "starttime": 0.0, + "endtime": 12.0, + }, + } + + def setUp(self): + super(SyncTestCase, self).setUp() + self.channel = testdata.channel() + self.user = testdata.user() + self.channel.editors.add(self.user) + + # Test for CaptionFile model + def test_create_caption(self): + self.client.force_authenticate(user=self.user) + caption_file = self.caption_file_metadata + + response = self.sync_changes( + [ + generate_create_event( + uuid.uuid4().hex, + CAPTION_FILE, + caption_file, + channel_id=self.channel.id, + ) + ], + ) + self.assertEqual(response.status_code, 200, response.content) + + try: + caption_file_db = CaptionFile.objects.get( + file_id=caption_file["file_id"], + language_id=caption_file["language"], + ) + except CaptionFile.DoesNotExist: + self.fail("caption file was not created") + + # Check the values of the object in the PostgreSQL + self.assertEqual(caption_file_db.file_id, caption_file["file_id"]) + self.assertEqual(caption_file_db.language_id, caption_file["language"]) + + def test_delete_caption_file(self): + self.client.force_authenticate(user=self.user) + caption_file = self.caption_file_metadata + # Explicitly set language to model object to follow Django ORM conventions + caption_file["language"] = Language.objects.get(pk="en") + caption_file_1 = CaptionFile(**caption_file) + pk = caption_file_1.pk + + # Delete the caption file + response = self.sync_changes( + [generate_delete_event(pk, CAPTION_FILE, channel_id=self.channel.id)] + ) + self.assertEqual(response.status_code, 200, response.content) + + with self.assertRaises(CaptionFile.DoesNotExist): + caption_file_db = CaptionFile.objects.get( + file_id=caption_file["file_id"], language_id=caption_file["language"] + ) + + def test_delete_file_with_same_file_id_different_language(self): + self.client.force_authenticate(user=self.user) + obj = self.same_file_different_language_metadata + + caption_file_1 = CaptionFile.objects.create(**obj[0]) + caption_file_2 = CaptionFile.objects.create(**obj[1]) + + response = self.sync_changes( + [ + generate_delete_event( + caption_file_2.pk, + CAPTION_FILE, + channel_id=self.channel.id, + ) + ] + ) + + self.assertEqual(response.status_code, 200, response.content) + + with self.assertRaises(CaptionFile.DoesNotExist): + CaptionFile.objects.get( + file_id=caption_file_2.file_id, language_id=caption_file_2.language + ) + + def test_caption_file_serialization(self): + metadata = self.caption_file_metadata + metadata["language"] = Language.objects.get(pk="en") + caption_file = CaptionFile.objects.create(**metadata) + serializer = CaptionFileSerializer(instance=caption_file) + try: + json.dumps(serializer.data) # Try to serialize the data to JSON + except Exception as e: + self.fail(f"CaptionFile serialization failed. Error: {str(e)}") + + def test_caption_cue_serialization(self): + metadata = self.caption_cue_metadata + metadata["file"]["language"] = Language.objects.get(pk="en") + caption_file = CaptionFile.objects.create(**metadata["file"]) + caption_cue = metadata["cue"] + caption_cue.update( + { + "caption_file": caption_file, + } + ) + CaptionCue.objects.create(**caption_cue) + CaptionCue.objects.create( + text="How are you?", starttime=2.0, endtime=3.0, caption_file=caption_file + ) + serializer = CaptionFileSerializer(instance=caption_file) + try: + json.dumps(serializer.data) # Try to serialize the data to JSON + except Exception as e: + self.fail(f"CaptionFile serialization failed. Error: {str(e)}") + + def test_create_caption_cue(self): + self.client.force_authenticate(user=self.user) + metadata = self.caption_cue_metadata + + metadata["file"]["language"] = Language.objects.get(pk="en") + + caption_file_1 = CaptionFile.objects.create(**metadata["file"]) + caption_cue = metadata["cue"] + caption_cue["caption_file_id"] = caption_file_1.pk + + response = self.sync_changes( + [ + generate_create_event( + uuid.uuid4(), + CAPTION_CUES, + caption_cue, + channel_id=self.channel.id, + ) + ], + ) + self.assertEqual(response.status_code, 200, response.content) + + try: + CaptionCue.objects.get( + text=caption_cue["text"], + starttime=caption_cue["starttime"], + endtime=caption_cue["endtime"], + ) + except CaptionCue.DoesNotExist: + self.fail("Caption cue not found!") + + def test_delete_caption_cue(self): + self.client.force_authenticate(user=self.user) + metadata = self.caption_cue_metadata + metadata["file"]["language"] = Language.objects.get(pk="en") + caption_file_1 = CaptionFile.objects.create(**metadata["file"]) + caption_cue = metadata["cue"] + caption_cue.update({"caption_file": caption_file_1}) + CaptionCue.objects.create(**caption_cue) + try: + caption_cue_db = CaptionCue.objects.get( + text=caption_cue["text"], + starttime=caption_cue["starttime"], + endtime=caption_cue["endtime"], + ) + except CaptionCue.DoesNotExist: + self.fail("Caption cue not found!") + + # Delete the caption Cue that we just created + response = self.sync_changes( + [ + generate_delete_event( + caption_cue_db.pk, CAPTION_CUES, channel_id=self.channel.id + ) + ] + ) + self.assertEqual(response.status_code, 200, response.content) + + caption_cue_db_exists = CaptionCue.objects.filter( + text=caption_cue["text"], + starttime=caption_cue["starttime"], + endtime=caption_cue["endtime"], + ).exists() + if caption_cue_db_exists: + self.fail("Caption Cue still exists!") + + def test_update_caption_cue(self): + self.client.force_authenticate(user=self.user) + metadata = self.caption_cue_metadata + metadata["file"]["language"] = Language.objects.get(pk="en") + caption_file_1 = CaptionFile.objects.create(**metadata["file"]) + + caption_cue = metadata["cue"] + caption_cue.update({"caption_file": caption_file_1}) + + caption_cue_1 = CaptionCue.objects.create(**caption_cue) + try: + CaptionCue.objects.get( + text=caption_cue["text"], + starttime=caption_cue["starttime"], + endtime=caption_cue["endtime"], + ) + except CaptionCue.DoesNotExist: + self.fail("Caption cue not found!") + + # Update the cue + pk = caption_cue_1.pk + new_text = "Yo" + new_starttime = 10 + new_endtime = 20 + + response = self.sync_changes( + [ + generate_update_event( + pk, + CAPTION_CUES, + { + "text": new_text, + "starttime": new_starttime, + "endtime": new_endtime, + "caption_file_id": caption_file_1.pk, + }, + channel_id=self.channel.id, + ) + ] + ) + self.assertEqual(response.status_code, 200, response.content) + self.assertEqual( + CaptionCue.objects.get(id=pk).text, + new_text, + ) + self.assertEqual( + CaptionCue.objects.get(id=pk).starttime, + new_starttime, + ) + self.assertEqual( + CaptionCue.objects.get(id=pk).endtime, + new_endtime, + ) + + def test_invalid_caption_cue_data_serialization(self): + metadata = self.caption_cue_metadata + metadata["file"]["language"] = Language.objects.get(pk="en") + caption_file = CaptionFile.objects.create(**metadata["file"]) + caption_cue = metadata["cue"] + caption_cue.update( + { + "starttime": float(20), + "endtime": float(10), + "caption_file": caption_file, + } + ) + serializer = CaptionCueSerializer(data=caption_cue) + assert not serializer.is_valid() + errors = serializer.errors + assert "non_field_errors" in errors + assert str(errors["non_field_errors"][0]) == "The cue must finish after start." diff --git a/contentcuration/contentcuration/urls.py b/contentcuration/contentcuration/urls.py index bb03f3876e..692c8cc938 100644 --- a/contentcuration/contentcuration/urls.py +++ b/contentcuration/contentcuration/urls.py @@ -33,6 +33,7 @@ from contentcuration.views import pwa from contentcuration.viewsets.assessmentitem import AssessmentItemViewSet from contentcuration.viewsets.bookmark import BookmarkViewSet +from contentcuration.viewsets.caption import CaptionViewSet, CaptionCueViewSet from contentcuration.viewsets.channel import AdminChannelViewSet from contentcuration.viewsets.channel import CatalogViewSet from contentcuration.viewsets.channel import ChannelViewSet @@ -55,6 +56,8 @@ def get_redirect_url(self, *args, **kwargs): router = routers.DefaultRouter(trailing_slash=False) router.register(r'bookmark', BookmarkViewSet, basename="bookmark") +router.register(r'captions', CaptionViewSet, basename="captions") +router.register(r'captions/(?P[^/]*)/cues', CaptionCueViewSet, basename="captioncues") router.register(r'channel', ChannelViewSet) router.register(r'channelset', ChannelSetViewSet) router.register(r'catalog', CatalogViewSet, basename='catalog') diff --git a/contentcuration/contentcuration/utils/publish.py b/contentcuration/contentcuration/utils/publish.py index 5bfddf4b7b..8c95394a1f 100644 --- a/contentcuration/contentcuration/utils/publish.py +++ b/contentcuration/contentcuration/utils/publish.py @@ -12,7 +12,9 @@ import zipfile from builtins import str from copy import deepcopy +from datetime import timedelta from itertools import chain +from typing import Literal from django.conf import settings from django.contrib.sites.models import Site @@ -26,6 +28,7 @@ from django.db.models import Q from django.db.models import Subquery from django.db.models import Sum +from django.db.models.query import QuerySet from django.db.utils import IntegrityError from django.template.loader import render_to_string from django.utils import timezone @@ -167,6 +170,79 @@ def create_kolibri_license_object(ccnode): ) +def process_webvtt_file_publishing( + action: Literal["create", "update"], + ccnode: ccmodels.ContentNode, + caption_file: ccmodels.CaptionFile, + user_id: int = None + ) -> None: + """Create or Update a WebVTT file and associate it with a CaptionFile. + + :param action: 'create' to create a new WebVTT file and 'update' to update an existing WebVTT file + :param ccnode: The ContentNode associated with the WebVTT file. + :param caption_file: A CaptionFile to associate with the WebVTT file. + :param user_id: The ID of the user creating the WebVTT file (optional). + """ + logging.debug(f"{action[:-1]}ing WebVTT for Node {ccnode.title}") + vtt_content = generate_webvtt_file(caption_cues=caption_file.caption_cue.all()) + filename = "{name}_{lang}.{ext}".format(name=ccnode.title, lang=caption_file.language, ext=file_formats.VTT) + temppath = None + try: + with tempfile.NamedTemporaryFile(suffix="vtt", delete=False) as tempf: + temppath = tempf.name + tempf.write(vtt_content) + file_size = tempf.tell() + tempf.flush() + + new_vtt_file = ccmodels.File.objects.create( + file_on_disk=File(open(temppath, 'rb'), name=filename), + contentnode=ccnode, + file_format_id=file_formats.VTT, + preset_id=format_presets.VIDEO_SUBTITLE, + original_filename=filename, + file_size=file_size, + uploaded_by_id=user_id, + language=caption_file.language, + ) + logging.debug("Created VTT for {0} with checksum {1}".format(ccnode.title, new_vtt_file.checksum)) + + if action == 'update' and caption_file.output_file: + caption_file.output_file.contentnode = None + caption_file.output_file.save(update_fields=['contentnode']) + + caption_file.output_file = new_vtt_file + # specifying output_field to be updated because by default the addition of FK updates + # the modified of CaptionFile obj results in always vtt_file.modified > caption_file.modified + caption_file.save(update_fields=['output_file']) + except Exception as e: + logging.error(f"Error creating VTT file for {ccnode.title}: {str(e)}") + finally: + temppath and os.unlink(temppath) + + +def generate_webvtt_file(caption_cues: QuerySet[ccmodels.CaptionCue]) -> str: + """ Generate the content of a WebVTT file based on CaptionCue's. + + :param: caption_cues: QuerySet of CaptionCues to include in the WebVTT. + :returns: The WebVTT content as a UTF-8 encoded string. + """ + webvtt_content = "WEBVTT\n\n" + for cue in caption_cues.order_by('starttime'): + st = float_to_timedelta(seconds=cue.starttime) + et = float_to_timedelta(seconds=cue.endtime) + webvtt_content += f"{st} --> {et}\n" + webvtt_content += f"{cue.text}\n\n" + return webvtt_content.encode('utf-8') + + +def float_to_timedelta(seconds: float) -> str: + s = int(seconds) + ms = int((seconds-s)*1000) + if ms == 0: + return f"{timedelta(seconds=s)}.000" + return f"{timedelta(seconds=s)}.{ms}" + + def increment_channel_version(channel): channel.version += 1 channel.save() @@ -264,6 +340,16 @@ def recurse_nodes(self, node, inherited_fields): # noqa C901 create_perseus_exercise(node, kolibrinode, exercise_data, user_id=self.user_id) elif node.kind_id == content_kinds.SLIDESHOW: create_slideshow_manifest(node, user_id=self.user_id) + elif node.kind_id in [content_kinds.AUDIO, content_kinds.VIDEO]: + if node.changed: + file_ids = node.files.all().values_list('id') + caption_files = ccmodels.CaptionFile.objects.filter(file_id__in=file_ids) + for cf in caption_files: + vtt_file = cf.output_file + if vtt_file and vtt_file.modified < cf.modified: + process_webvtt_file_publishing('update', node, cf, self.user_id) + elif vtt_file is None: + process_webvtt_file_publishing('create', node, cf, self.user_id) elif node.kind_id == content_kinds.TOPIC: for child in node.children.all(): self.recurse_nodes(child, metadata) @@ -473,7 +559,6 @@ def create_associated_file_objects(kolibrinode, ccnode): local_file=kolibrilocalfilemodel, ) - def create_perseus_exercise(ccnode, kolibrinode, exercise_data, user_id=None): logging.debug("Creating Perseus Exercise for Node {}".format(ccnode.title)) filename = "{0}.{ext}".format(ccnode.title, ext=file_formats.PERSEUS) diff --git a/contentcuration/contentcuration/utils/transcription.py b/contentcuration/contentcuration/utils/transcription.py index 105b1b0608..9555a708f0 100644 --- a/contentcuration/contentcuration/utils/transcription.py +++ b/contentcuration/contentcuration/utils/transcription.py @@ -1,44 +1,110 @@ +import uuid +from typing import Optional +from importlib import import_module + +import requests +from automation.settings import CHUNK_LENGTH +from automation.settings import DEVICE +from automation.settings import DEV_TRANSCRIPTION_MODEL +from automation.settings import MAX_TOKEN_LENGTH from automation.utils.appnexus.base import Adapter from automation.utils.appnexus.base import Backend from automation.utils.appnexus.base import BackendFactory from automation.utils.appnexus.base import BackendRequest from automation.utils.appnexus.base import BackendResponse +from contentcuration.constants.transcription_languages import WHISPER_LANGUAGES as LANGS +from contentcuration.models import CaptionFile +from contentcuration.models import File +from contentcuration.not_production_settings import WHISPER_BACKEND class WhisperRequest(BackendRequest): - def __init__(self) -> None: - super().__init__() + """ Create a WhisperRequest object to make request to 'WhisperBackend' + :param: url (str): The URL of the resource to retrieve. + :param: binary (optional): Provide the binary data directly, default is 'None'. + """ + def __init__(self, url: str, language: str, binary: Optional[bytes] = None) -> None: + self.url, self.language = url, language + self.binary = binary if binary else self._get_binary() + + def _get_binary(self) -> bytes: + # if not url.startswith('http'): raise TypeError(f'url:{url} must be start with http.') + res = requests.get(self.url) + return res.content if res.status_code == 200 else None + + def get_binary_data(self) -> bytes: return self.binary + def get_file_url(self) -> str: return self.url + def get_language(self) -> str: return self.language + class WhisperResponse(BackendResponse): - def __init__(self) -> None: - super().__init__() + def __init__(self, response) -> None: + self.result = response + def get_cues(self, caption_file_id: str) -> list: + cues = [] + for transcription in self.result['chunks']: + start_time, end_time = transcription["timestamp"] + text = transcription["text"] + cue = { + "id": uuid.uuid4().hex, + "text": text, + "starttime": start_time, + "endtime": end_time, + "caption_file_id": caption_file_id, + } + cues.append(cue) + return cues -class Whisper(Backend): - def connect(self) -> None: - raise NotImplementedError("The 'connect' method is not implemented for the 'Whisper' backend.") +class Whisper(Backend): def make_request(self, request: WhisperRequest) -> WhisperResponse: # Implement production backend here. pass - @classmethod - def _create_instance(cls) -> 'Whisper': - return cls() + def connect(self) -> None: + raise NotImplementedError("The 'connect' method is not implemented for the 'Whisper' backend.") + class LocalWhisper(Backend): + def __init__(self) -> None: + self.pipe = None + def make_request(self, request: WhisperRequest) -> WhisperResponse: - # Implement your local backend here. - pass + self.connect() + media_url = request.get_file_url() + result = self.pipe(media_url, max_new_tokens=MAX_TOKEN_LENGTH) + return WhisperResponse(response=result) + + def connect(self) -> None: + if self.pipe is None: + from transformers import pipeline + self.pipe = pipeline( + model=DEV_TRANSCRIPTION_MODEL, + chunk_length_s=CHUNK_LENGTH, + device=DEVICE, + return_timestamps=True, + ) class WhisperBackendFactory(BackendFactory): def create_backend(self) -> Backend: - # Return backend based on some setting. - return super().create_backend() + mod, backend = WHISPER_BACKEND.rsplit('.', 1) # module, backend + try: + mod = import_module(mod) + backend = getattr(mod, backend) # get `backend` from `module` + return backend() + except ModuleNotFoundError: + raise ImportError(f'Failed to import `{mod}`') + except AttributeError: + raise ImportError(f"Failed to find attribute `{backend}` in module `{backend}`") class WhisperAdapter(Adapter): def transcribe(self, caption_file_id: str) -> WhisperResponse: - request = WhisperRequest() + f = CaptionFile.objects.get(pk=caption_file_id) + file_id, language = f.file_id, LANGS[f.language.lang_code] + media_file = File.objects.get(pk=file_id).file_on_disk.url + + request = WhisperRequest(url=media_file, language=language) return self.backend.make_request(request) diff --git a/contentcuration/contentcuration/viewsets/caption.py b/contentcuration/contentcuration/viewsets/caption.py new file mode 100644 index 0000000000..ef895906ad --- /dev/null +++ b/contentcuration/contentcuration/viewsets/caption.py @@ -0,0 +1,162 @@ +import logging + +from le_utils.constants.format_presets import AUDIO, VIDEO_HIGH_RES, VIDEO_LOW_RES +from rest_framework import serializers +from rest_framework.permissions import IsAuthenticated +from rest_framework.response import Response + +from contentcuration.models import CaptionCue, CaptionFile, Change, File, ContentNode +from contentcuration.tasks import generatecaptioncues_task +from contentcuration.viewsets.base import BulkModelSerializer, ValuesViewset +from contentcuration.viewsets.sync.constants import CAPTION_FILE, CONTENTNODE +from contentcuration.viewsets.sync.utils import generate_update_event + + +class CaptionCueSerializer(BulkModelSerializer): + class Meta: + model = CaptionCue + fields = ["text", "starttime", "endtime", "caption_file_id"] + + def validate(self, attrs): + """Check that the cue's starttime is before the endtime.""" + attrs = super().validate(attrs) + if attrs["starttime"] > attrs["endtime"]: + raise serializers.ValidationError("The cue must finish after start.") + return attrs + + def to_internal_value(self, data): + """ + Copies the caption_file_id from the request data + to the internal representation before validation. + Without this, the caption_file_id would be lost + if validation fails, leading to errors. + """ + caption_file_id = data.get("caption_file_id") + value = super().to_internal_value(data) + + if "caption_file_id" not in value: + value["caption_file_id"] = caption_file_id + return value + + + +class CaptionFileSerializer(BulkModelSerializer): + caption_cue = CaptionCueSerializer(many=True, required=False) + + class Meta: + model = CaptionFile + fields = ["id", "file_id", "language", "caption_cue"] + + +class CaptionViewSet(ValuesViewset): + # Handles operations for the CaptionFile model. + queryset = CaptionFile.objects.prefetch_related("caption_cue") + permission_classes = [IsAuthenticated] + serializer_class = CaptionFileSerializer + values = ("id", "file_id", "language") + + field_map = { + "file_id": "file_id", + "language": "language", + } + + def get_queryset(self): + queryset = super().get_queryset() + + contentnode_ids = self.request.GET.get("contentnode__in") + file_id = self.request.GET.get("file_id") + language = self.request.GET.get("language") + + if contentnode_ids: + allowed_contentnodes = set( + ContentNode.filter_edit_queryset( + ContentNode.objects.all(), self.request.user + ) + .filter(id__in=contentnode_ids.split(",")) + .values_list("id", flat=True) + ) + file_ids = File.objects.filter( + preset_id__in=[AUDIO, VIDEO_HIGH_RES, VIDEO_LOW_RES], + contentnode_id__in=allowed_contentnodes, + ).values_list("pk", flat=True) + queryset = queryset.filter(file_id__in=file_ids) + + if file_id: + allowed_file_id = set( + File.filter_edit_queryset( + File.objects.get(pk=file_id), self.request.user + ) + .values_list("id", flat=True) + ) + queryset = queryset.filter(file_id=allowed_file_id) + if language: + queryset = queryset.filter(language=language) + + return queryset + + def perform_create(self, serializer, change=None): + instance = serializer.save() + Change.create_change( + generate_update_event( + instance.pk, + CAPTION_FILE, + { + "__generating_captions": True, + }, + channel_id=change["channel_id"], + ), + applied=True, + created_by_id=self.request.user.id, + ) + + # Set the contentnode's changed to True + try: + file_id = instance.file_id + cnn_id = File.objects.get(pk=file_id).contentnode_id + if cnn_id: + cnn = ContentNode.objects.get(pk=cnn_id) + cnn.changed = True + cnn.save() + + Change.create_change(generate_update_event( + cnn_id, + CONTENTNODE, + {'changed': True}, + channel_id=change["channel_id"], + ), applied=True, created_by_id=self.request.user.id) + except Exception as e: + print(e) + + # enqueue task of generating captions for the saved CaptionFile instance + try: + # Also sets the generating flag to false <<< Generating Completed + generatecaptioncues_task.enqueue( + self.request.user, + caption_file_id=instance.pk, + channel_id=change["channel_id"], + user_id=self.request.user.id, + ) + + except Exception as e: + logging.error(f"Failed to queue celery task.\nWith the error: {e}") + + +class CaptionCueViewSet(ValuesViewset): + # Handles operations for the CaptionCue model. + queryset = CaptionCue.objects.all() + permission_classes = [IsAuthenticated] + serializer_class = CaptionCueSerializer + values = ("id", "text", "starttime", "endtime", "caption_file_id") + + field_map = { + "id": "id", + "text": "text", + "starttime": "starttime", + "endtime": "endtime", + "caption_file": "caption_file_id", + } + + def list(self, request, *args, **kwargs): + caption_file_id = kwargs["caption_file_id"] + queryset = CaptionCue.objects.filter(caption_file_id=caption_file_id) + return Response(self.serialize(queryset)) diff --git a/contentcuration/contentcuration/viewsets/sync/base.py b/contentcuration/contentcuration/viewsets/sync/base.py index f11a8f4729..7606853bcc 100644 --- a/contentcuration/contentcuration/viewsets/sync/base.py +++ b/contentcuration/contentcuration/viewsets/sync/base.py @@ -5,6 +5,7 @@ from contentcuration.decorators import delay_user_storage_calculation from contentcuration.viewsets.assessmentitem import AssessmentItemViewSet from contentcuration.viewsets.bookmark import BookmarkViewSet +from contentcuration.viewsets.caption import CaptionViewSet, CaptionCueViewSet from contentcuration.viewsets.channel import ChannelViewSet from contentcuration.viewsets.channelset import ChannelSetViewSet from contentcuration.viewsets.clipboard import ClipboardViewSet @@ -14,6 +15,8 @@ from contentcuration.viewsets.invitation import InvitationViewSet from contentcuration.viewsets.sync.constants import ASSESSMENTITEM from contentcuration.viewsets.sync.constants import BOOKMARK +from contentcuration.viewsets.sync.constants import CAPTION_CUES +from contentcuration.viewsets.sync.constants import CAPTION_FILE from contentcuration.viewsets.sync.constants import CHANNEL from contentcuration.viewsets.sync.constants import CHANNELSET from contentcuration.viewsets.sync.constants import CLIPBOARD @@ -73,6 +76,8 @@ def __init__(self, change_type, viewset_class): (EDITOR_M2M, ChannelUserViewSet), (VIEWER_M2M, ChannelUserViewSet), (SAVEDSEARCH, SavedSearchViewSet), + (CAPTION_FILE, CaptionViewSet), + (CAPTION_CUES, CaptionCueViewSet), ] ) diff --git a/contentcuration/contentcuration/viewsets/sync/constants.py b/contentcuration/contentcuration/viewsets/sync/constants.py index 84c2b5aad7..6ad7305c6c 100644 --- a/contentcuration/contentcuration/viewsets/sync/constants.py +++ b/contentcuration/contentcuration/viewsets/sync/constants.py @@ -22,6 +22,8 @@ # Client-side table constants BOOKMARK = "bookmark" +CAPTION_FILE = "caption_file" +CAPTION_CUES = "caption_cues" CHANNEL = "channel" CONTENTNODE = "contentnode" CONTENTNODE_PREREQUISITE = "contentnode_prerequisite" @@ -39,6 +41,8 @@ ALL_TABLES = set( [ BOOKMARK, + CAPTION_FILE, + CAPTION_CUES, CHANNEL, CLIPBOARD, CONTENTNODE, diff --git a/requirements-dev.in b/requirements-dev.in index 59157d10d7..149e562dcd 100644 --- a/requirements-dev.in +++ b/requirements-dev.in @@ -26,3 +26,5 @@ git+https://github.com/someshchaturvedi/customizable-django-profiler.git#customi tabulate==0.8.2 fonttools minio==7.1.1 +torch==2.0.1 +transformers==4.29.2 \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index c2e52887be..9d28f20408 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -35,6 +35,10 @@ click==8.1.3 # -c requirements.txt # flask # pip-tools +cmake==3.27.7 + # via + # -c requirements.txt + # triton colorama==0.4.4 # via pytest-watch configargparse==1.5.3 @@ -77,8 +81,14 @@ drf-yasg==1.20.0 # via -r requirements-dev.in faker==0.9.1 # via mixer -filelock==3.4.1 - # via virtualenv +filelock==3.12.4 + # via + # -c requirements.txt + # huggingface-hub + # torch + # transformers + # triton + # virtualenv flake8==3.4.1 # via -r requirements-dev.in flask==2.0.3 @@ -92,6 +102,10 @@ flask-cors==3.0.10 # via locust fonttools==4.40.0 # via -r requirements-dev.in +fsspec==2023.9.2 + # via + # -c requirements.txt + # huggingface-hub gevent==23.9.1 # via # geventhttpclient @@ -100,6 +114,10 @@ geventhttpclient==2.0.9 # via locust greenlet==2.0.2 # via gevent +huggingface-hub==0.17.3 + # via + # -c requirements.txt + # transformers identify==2.4.4 # via pre-commit idna==2.10 @@ -118,14 +136,21 @@ itsdangerous==2.0.1 # via flask itypes==1.2.0 # via coreapi -jinja2==3.0.3 +jinja2==3.1.2 # via + # -c requirements.txt # coreschema # flask + # torch +lit==17.0.3 + # via + # -c requirements.txt + # triton locust==2.15.1 # via -r requirements-dev.in -markupsafe==2.1.2 +markupsafe==2.1.3 # via + # -c requirements.txt # jinja2 # werkzeug mccabe==0.6.1 @@ -138,18 +163,78 @@ mock==4.0.3 # via # -r requirements-dev.in # django-concurrent-test-helper +mpmath==1.3.0 + # via + # -c requirements.txt + # sympy msgpack==1.0.4 # via locust +networkx==3.1 + # via + # -c requirements.txt + # torch nodeenv==1.6.0 # via # -r requirements-dev.in # pre-commit +numpy==1.26.0 + # via + # -c requirements.txt + # transformers +nvidia-cublas-cu11==11.10.3.66 + # via + # -c requirements.txt + # nvidia-cudnn-cu11 + # nvidia-cusolver-cu11 + # torch +nvidia-cuda-cupti-cu11==11.7.101 + # via + # -c requirements.txt + # torch +nvidia-cuda-nvrtc-cu11==11.7.99 + # via + # -c requirements.txt + # torch +nvidia-cuda-runtime-cu11==11.7.99 + # via + # -c requirements.txt + # torch +nvidia-cudnn-cu11==8.5.0.96 + # via + # -c requirements.txt + # torch +nvidia-cufft-cu11==10.9.0.58 + # via + # -c requirements.txt + # torch +nvidia-curand-cu11==10.2.10.91 + # via + # -c requirements.txt + # torch +nvidia-cusolver-cu11==11.4.0.1 + # via + # -c requirements.txt + # torch +nvidia-cusparse-cu11==11.7.4.91 + # via + # -c requirements.txt + # torch +nvidia-nccl-cu11==2.14.3 + # via + # -c requirements.txt + # torch +nvidia-nvtx-cu11==11.7.91 + # via + # -c requirements.txt + # torch packaging==20.9 # via # -c requirements.txt # build # drf-yasg + # huggingface-hub # pytest + # transformers pep517==0.12.0 # via build pip-tools==6.8.0 @@ -209,17 +294,26 @@ pytz==2022.1 # via # -c requirements.txt # django -pyyaml==6.0 +pyyaml==6.0.1 # via + # -c requirements.txt # aspy-yaml + # huggingface-hub # pre-commit + # transformers pyzmq==23.1.0 # via locust +regex==2023.10.3 + # via + # -c requirements.txt + # transformers requests==2.25.1 # via # -c requirements.txt # coreapi + # huggingface-hub # locust + # transformers roundrobin==0.0.2 # via locust ruamel-yaml==0.17.21 @@ -241,12 +335,20 @@ sqlparse==0.4.1 # -c requirements.txt # django # django-debug-toolbar +sympy==1.12 + # via + # -c requirements.txt + # torch tabulate==0.8.2 # via -r requirements-dev.in tblib==1.7.0 # via django-concurrent-test-helper text-unidecode==1.2 # via faker +tokenizers==0.13.3 + # via + # -c requirements.txt + # transformers toml==0.10.2 # via # pre-commit @@ -256,8 +358,30 @@ tomli==1.2.3 # build # coverage # pep517 -typing-extensions==4.5.0 - # via locust +torch==2.0.1 + # via + # -c requirements.txt + # -r requirements-dev.in + # triton +tqdm==4.66.1 + # via + # -c requirements.txt + # huggingface-hub + # transformers +transformers==4.29.2 + # via + # -c requirements.txt + # -r requirements-dev.in +triton==2.0.0 + # via + # -c requirements.txt + # torch +typing-extensions==4.8.0 + # via + # -c requirements.txt + # huggingface-hub + # locust + # torch uritemplate==3.0.1 # via # coreapi @@ -275,8 +399,16 @@ werkzeug==2.2.3 # via # flask # locust -wheel==0.38.1 - # via pip-tools +wheel==0.41.2 + # via + # -c requirements.txt + # nvidia-cublas-cu11 + # nvidia-cuda-cupti-cu11 + # nvidia-cuda-runtime-cu11 + # nvidia-curand-cu11 + # nvidia-cusparse-cu11 + # nvidia-nvtx-cu11 + # pip-tools whitenoise==5.2.0 # via -r requirements-dev.in zipp==3.4.1 diff --git a/requirements.txt b/requirements.txt index 65081fa4a3..14669ec1cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # -# This file is autogenerated by pip-compile with python 3.9 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.9 +# by the following command: # # pip-compile requirements.in #