Commit befb3537 authored by Samuel GAIST's avatar Samuel GAIST
Browse files

[databases][api] Code cleanup

parent dbbed637
......@@ -27,13 +27,12 @@
import os
import json
from django.http import HttpResponse
import logging
from rest_framework.response import Response
from rest_framework import permissions
from rest_framework import exceptions as drf_exceptions
from rest_framework import views
from rest_framework import status
from .models import Database
from .models import DatabaseSetTemplate
......@@ -44,10 +43,8 @@ from ..common import is_true
from ..common.mixins import IsAdminOrReadOnlyMixin
from ..common.api import ListCreateBaseView
from ..common.utils import ensure_html
from ..dataformats.serializers import ReferencedDataFormatSerializer
import logging
import traceback
from ..dataformats.serializers import ReferencedDataFormatSerializer
logger = logging.getLogger(__name__)
......@@ -55,7 +52,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------
def database_to_json(database, request_user, fields_to_return, last_version=None):
def database_to_json(database, request_user, fields_to_return):
# Prepare the response
result = {}
......@@ -67,7 +64,14 @@ def database_to_json(database, request_user, fields_to_return, last_version=None
result["version"] = database.version
if "last_version" in fields_to_return:
result["last_version"] = last_version
latest = (
Database.objects.for_user(request_user, True)
.filter(name=database.name)
.order_by("-version")[:1]
.first()
)
result["last_version"] = database.version == latest.version
if "short_description" in fields_to_return:
result["short_description"] = database.short_description
......@@ -204,35 +208,25 @@ class RetrieveDatabaseView(views.APIView):
Returns the given database details
"""
model = Database
permission_classes = [permissions.AllowAny]
def get(self, request, database_name, version=None):
# Retrieve the database
def get_object(self):
version = self.kwargs["version"]
database_name = self.kwargs["database_name"]
user = self.request.user
try:
if version is not None:
version = int(version)
databases = (
Database.objects.for_user(request.user, True)
.filter(name__iexact=database_name, version__gte=version)
.order_by("version")
obj = self.model.objects.for_user(user, True).get(
name__iexact=database_name, version=version
)
except self.model.DoesNotExist:
raise drf_exceptions.NotFound()
return obj
database = databases[0]
if database.version != version:
return HttpResponse(status=status.HTTP_404_NOT_FOUND)
last_version = len(databases) == 1
else:
database = (
Database.objects.for_user(request.user, True)
.filter(name__iexact=database_name)
.order_by("-version")[0]
)
last_version = True
except Exception:
return HttpResponse(status=status.HTTP_404_NOT_FOUND)
def get(self, request, database_name, version):
# Retrieve the database
database = self.get_object()
self.check_object_permissions(request, database)
# Process the query string
if "fields" in request.GET:
......@@ -256,29 +250,18 @@ class RetrieveDatabaseView(views.APIView):
"code",
]
try:
# Prepare the response
result = database_to_json(
database, request.user, fields_to_return, last_version=last_version
)
result = database_to_json(database, request.user, fields_to_return)
# Retrieve the code
if "declaration" in fields_to_return:
try:
declaration = database.declaration
except Exception:
logger.error(traceback.format_exc())
return HttpResponse(status=500)
cleaned_declaration = clean_paths(declaration)
result["declaration"] = json.dumps(cleaned_declaration)
# Retrieve the source code
if "code" in fields_to_return:
try:
result["code"] = database.source_code
except Exception:
logger.error(traceback.format_exc())
return HttpResponse(status=500)
# Retrieve the description in HTML format
if "html_description" in fields_to_return:
......@@ -294,9 +277,7 @@ class RetrieveDatabaseView(views.APIView):
referenced_dataformats = []
for dataformat in dataformats:
(has_access, accessibility) = dataformat.accessibility_for(
request.user
)
(has_access, accessibility) = dataformat.accessibility_for(request.user)
if has_access:
referenced_dataformats.append(dataformat)
serializer = ReferencedDataFormatSerializer(referenced_dataformats)
......@@ -308,9 +289,7 @@ class RetrieveDatabaseView(views.APIView):
needed_dataformats = []
for dataformat in dataformats:
(has_access, accessibility) = dataformat.accessibility_for(
request.user
)
(has_access, accessibility) = dataformat.accessibility_for(request.user)
if has_access:
needed_dataformats.append(dataformat)
serializer = ReferencedDataFormatSerializer(needed_dataformats)
......@@ -318,6 +297,3 @@ class RetrieveDatabaseView(views.APIView):
# Return the result
return Response(result)
except Exception:
logger.error(traceback.format_exc())
return HttpResponse(status=500)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment