diff --git a/beat/web/databases/api.py b/beat/web/databases/api.py index 7b817d837a33c54ce8d57c8e430e54e7b786b85d..ec93a01ee79c03828a45314d0c447d629afca129 100755 --- a/beat/web/databases/api.py +++ b/beat/web/databases/api.py @@ -27,13 +27,12 @@ import os import json - -from django.http import HttpResponse +import logging from rest_framework.response import Response from rest_framework import permissions +from rest_framework import exceptions as drf_exceptions from rest_framework import views -from rest_framework import status from .models import Database from .models import DatabaseSetTemplate @@ -44,10 +43,8 @@ from ..common import is_true from ..common.mixins import IsAdminOrReadOnlyMixin from ..common.api import ListCreateBaseView from ..common.utils import ensure_html -from ..dataformats.serializers import ReferencedDataFormatSerializer -import logging -import traceback +from ..dataformats.serializers import ReferencedDataFormatSerializer logger = logging.getLogger(__name__) @@ -55,7 +52,7 @@ logger = logging.getLogger(__name__) # ---------------------------------------------------------- -def database_to_json(database, request_user, fields_to_return, last_version=None): +def database_to_json(database, request_user, fields_to_return): # Prepare the response result = {} @@ -67,7 +64,14 @@ def database_to_json(database, request_user, fields_to_return, last_version=None result["version"] = database.version if "last_version" in fields_to_return: - result["last_version"] = last_version + latest = ( + Database.objects.for_user(request_user, True) + .filter(name=database.name) + .order_by("-version")[:1] + .first() + ) + + result["last_version"] = database.version == latest.version if "short_description" in fields_to_return: result["short_description"] = database.short_description @@ -204,35 +208,25 @@ class RetrieveDatabaseView(views.APIView): Returns the given database details """ + model = Database permission_classes = [permissions.AllowAny] - def get(self, request, database_name, version=None): - # Retrieve the database + def get_object(self): + version = self.kwargs["version"] + database_name = self.kwargs["database_name"] + user = self.request.user try: - if version is not None: - version = int(version) - - databases = ( - Database.objects.for_user(request.user, True) - .filter(name__iexact=database_name, version__gte=version) - .order_by("version") - ) - - database = databases[0] - - if database.version != version: - return HttpResponse(status=status.HTTP_404_NOT_FOUND) + obj = self.model.objects.for_user(user, True).get( + name__iexact=database_name, version=version + ) + except self.model.DoesNotExist: + raise drf_exceptions.NotFound() + return obj - last_version = len(databases) == 1 - else: - database = ( - Database.objects.for_user(request.user, True) - .filter(name__iexact=database_name) - .order_by("-version")[0] - ) - last_version = True - except Exception: - return HttpResponse(status=status.HTTP_404_NOT_FOUND) + def get(self, request, database_name, version): + # Retrieve the database + database = self.get_object() + self.check_object_permissions(request, database) # Process the query string if "fields" in request.GET: @@ -256,68 +250,50 @@ class RetrieveDatabaseView(views.APIView): "code", ] - try: - # Prepare the response - result = database_to_json( - database, request.user, fields_to_return, last_version=last_version - ) + # Prepare the response + result = database_to_json(database, request.user, fields_to_return) - # Retrieve the code - if "declaration" in fields_to_return: - try: - declaration = database.declaration - except Exception: - logger.error(traceback.format_exc()) - return HttpResponse(status=500) - cleaned_declaration = clean_paths(declaration) - result["declaration"] = json.dumps(cleaned_declaration) - - # Retrieve the source code - if "code" in fields_to_return: - try: - result["code"] = database.source_code - except Exception: - logger.error(traceback.format_exc()) - return HttpResponse(status=500) - - # Retrieve the description in HTML format - if "html_description" in fields_to_return: - description = database.description - if len(description) > 0: - result["html_description"] = ensure_html(description) - else: - result["html_description"] = "" - - # Retrieve the referenced data formats - if "referenced_dataformats" in fields_to_return: - dataformats = database.all_referenced_dataformats() - - referenced_dataformats = [] - for dataformat in dataformats: - (has_access, accessibility) = dataformat.accessibility_for( - request.user - ) - if has_access: - referenced_dataformats.append(dataformat) - serializer = ReferencedDataFormatSerializer(referenced_dataformats) - result["referenced_dataformats"] = serializer.data - - # Retrieve the needed data formats - if "needed_dataformats" in fields_to_return: - dataformats = database.all_needed_dataformats() - - needed_dataformats = [] - for dataformat in dataformats: - (has_access, accessibility) = dataformat.accessibility_for( - request.user - ) - if has_access: - needed_dataformats.append(dataformat) - serializer = ReferencedDataFormatSerializer(needed_dataformats) - result["needed_dataformats"] = serializer.data - - # Return the result - return Response(result) - except Exception: - logger.error(traceback.format_exc()) - return HttpResponse(status=500) + # Retrieve the code + if "declaration" in fields_to_return: + declaration = database.declaration + cleaned_declaration = clean_paths(declaration) + result["declaration"] = json.dumps(cleaned_declaration) + + # Retrieve the source code + if "code" in fields_to_return: + result["code"] = database.source_code + + # Retrieve the description in HTML format + if "html_description" in fields_to_return: + description = database.description + if len(description) > 0: + result["html_description"] = ensure_html(description) + else: + result["html_description"] = "" + + # Retrieve the referenced data formats + if "referenced_dataformats" in fields_to_return: + dataformats = database.all_referenced_dataformats() + + referenced_dataformats = [] + for dataformat in dataformats: + (has_access, accessibility) = dataformat.accessibility_for(request.user) + if has_access: + referenced_dataformats.append(dataformat) + serializer = ReferencedDataFormatSerializer(referenced_dataformats) + result["referenced_dataformats"] = serializer.data + + # Retrieve the needed data formats + if "needed_dataformats" in fields_to_return: + dataformats = database.all_needed_dataformats() + + needed_dataformats = [] + for dataformat in dataformats: + (has_access, accessibility) = dataformat.accessibility_for(request.user) + if has_access: + needed_dataformats.append(dataformat) + serializer = ReferencedDataFormatSerializer(needed_dataformats) + result["needed_dataformats"] = serializer.data + + # Return the result + return Response(result)