Skip to content
Snippets Groups Projects

Refactor update creation api

Merged Samuel GAIST requested to merge refactor_update_creation_api into master
All threads resolved!
1 file
+ 73
97
Compare changes
  • Side-by-side
  • Inline
+ 166
165
@@ -27,129 +27,140 @@
import os
import json
from django.http import HttpResponse
from django.core.urlresolvers import reverse
import logging
from rest_framework.response import Response
from rest_framework import permissions
from rest_framework import permissions as drf_permissions
from rest_framework import exceptions as drf_exceptions
from rest_framework import views
from rest_framework import status
from rest_framework import generics
from .models import Database
from .models import DatabaseSetTemplate
from .serializers import DatabaseSerializer, DatabaseCreationSerializer
from .exceptions import DatabaseCreationError
from ..common import is_true
from ..common.mixins import IsAdminOrReadOnlyMixin
from ..common import permissions as beat_permissions
from ..common.api import ListCreateBaseView
from ..common.responses import BadRequestResponse
from ..common.utils import ensure_html
from ..dataformats.serializers import ReferencedDataFormatSerializer
import logging
import traceback
logger = logging.getLogger(__name__)
#----------------------------------------------------------
# ----------------------------------------------------------
def database_to_json(database, request_user, fields_to_return,
last_version=None):
def database_to_json(database, request_user, fields_to_return):
# Prepare the response
result = {}
if 'name' in fields_to_return:
result['name'] = database.fullname()
if "name" in fields_to_return:
result["name"] = database.fullname()
if "version" in fields_to_return:
result["version"] = database.version
if 'version' in fields_to_return:
result['version'] = database.version
if "last_version" in fields_to_return:
latest = (
Database.objects.for_user(request_user, True)
.filter(name=database.name)
.order_by("-version")[:1]
.first()
)
if 'last_version' in fields_to_return:
result['last_version'] = last_version
result["last_version"] = database.version == latest.version
if 'short_description' in fields_to_return:
result['short_description'] = database.short_description
if "short_description" in fields_to_return:
result["short_description"] = database.short_description
if 'description' in fields_to_return:
result['description'] = database.description
if "description" in fields_to_return:
result["description"] = database.description
if 'previous_version' in fields_to_return:
result['previous_version'] = \
(database.previous_version.fullname() if \
database.previous_version is not None else None)
if "previous_version" in fields_to_return:
result["previous_version"] = (
database.previous_version.fullname()
if database.previous_version is not None
else None
)
if 'creation_date' in fields_to_return:
result['creation_date'] = database.creation_date.isoformat(' ')
if "creation_date" in fields_to_return:
result["creation_date"] = database.creation_date.isoformat(" ")
if 'hash' in fields_to_return:
result['hash'] = database.hash
if "hash" in fields_to_return:
result["hash"] = database.hash
if 'accessibility' in fields_to_return:
result['accessibility'] = database.accessibility_for(request_user)
if "accessibility" in fields_to_return:
result["accessibility"] = database.accessibility_for(request_user)
return result
def clean_paths(declaration):
pseudo_path = '/path_to_db_folder'
root_folder = declaration['root_folder']
pseudo_path = "/path_to_db_folder"
root_folder = declaration["root_folder"]
cleaned_folder = os.path.basename(os.path.normpath(root_folder))
declaration['root_folder'] = os.path.join(pseudo_path, cleaned_folder)
for protocol in declaration['protocols']:
for set_ in protocol['sets']:
if 'parameters' in set_ and 'annotations' in set_['parameters']:
annotations_folder = set_['parameters']['annotations']
cleaned_folder = annotations_folder.split('/')[-2:]
set_['parameters']['annotations'] = os.path.join(pseudo_path, *cleaned_folder)
declaration["root_folder"] = os.path.join(pseudo_path, cleaned_folder)
for protocol in declaration["protocols"]:
for set_ in protocol["sets"]:
if "parameters" in set_ and "annotations" in set_["parameters"]:
annotations_folder = set_["parameters"]["annotations"]
cleaned_folder = annotations_folder.split("/")[-2:]
set_["parameters"]["annotations"] = os.path.join(
pseudo_path, *cleaned_folder
)
return declaration
#----------------------------------------------------------
# ----------------------------------------------------------
class ListCreateDatabasesView(IsAdminOrReadOnlyMixin, ListCreateBaseView):
class ListCreateDatabasesView(ListCreateBaseView):
"""
Read/Write end point that list the database available
to a user and allows the creation of new databases only to
platform administrator
"""
model = Database
permission_classes = [beat_permissions.IsAdminOrReadOnly]
serializer_class = DatabaseSerializer
writing_serializer_class = DatabaseCreationSerializer
namespace = 'api_databases'
namespace = "api_databases"
def get_queryset(self):
user = self.request.user
return self.model.objects.for_user(user, True)
def get(self, request, *args, **kwargs):
fields_to_return = self.get_serializer_fields(request)
limit_to_latest_versions = is_true(request.query_params.get('latest_versions', False))
limit_to_latest_versions = is_true(
request.query_params.get("latest_versions", False)
)
all_databases = self.get_queryset().order_by('name')
all_databases = self.get_queryset().order_by("name")
if limit_to_latest_versions:
all_databases = self.model.filter_latest_versions(all_databases)
all_databases.sort(key=lambda x: x.fullname())
serializer = self.get_serializer(all_databases, many=True, fields=fields_to_return)
serializer = self.get_serializer(
all_databases, many=True, fields=fields_to_return
)
return Response(serializer.data)
#----------------------------------------------------------
# ----------------------------------------------------------
class ListTemplatesView(views.APIView):
"""
List all templates available
"""
permission_classes = [permissions.AllowAny]
permission_classes = [drf_permissions.AllowAny]
def get(self, request):
result = {}
@@ -158,39 +169,39 @@ class ListTemplatesView(views.APIView):
databases = Database.objects.for_user(request.user, True)
databases = Database.filter_latest_versions(databases)
for set_template in DatabaseSetTemplate.objects.filter(sets__protocol__database__in=databases).distinct().order_by('name'):
(db_template, dataset) = set_template.name.split('__')
for set_template in (
DatabaseSetTemplate.objects.filter(sets__protocol__database__in=databases)
.distinct()
.order_by("name")
):
(db_template, dataset) = set_template.name.split("__")
if db_template not in result:
result[db_template] = {
'templates': {},
'sets': [],
}
result[db_template] = {"templates": {}, "sets": []}
result[db_template]['templates'][dataset] = map(lambda x: x.name, set_template.outputs.order_by('name'))
result[db_template]["templates"][dataset] = map(
lambda x: x.name, set_template.outputs.order_by("name")
)
known_sets = []
for db_set in set_template.sets.iterator():
if db_set.name not in known_sets:
result[db_template]['sets'].append({
'name': db_set.name,
'template': dataset,
'id': db_set.id,
})
result[db_template]["sets"].append(
{"name": db_set.name, "template": dataset, "id": db_set.id}
)
known_sets.append(db_set.name)
for name, entry in result.items():
entry['sets'].sort(key=lambda x: x['id'])
result[name]['sets'] = map(lambda x: {
'name': x['name'],
'template': x['template'],
}, entry['sets'])
entry["sets"].sort(key=lambda x: x["id"])
result[name]["sets"] = map(
lambda x: {"name": x["name"], "template": x["template"]}, entry["sets"]
)
return Response(result)
#----------------------------------------------------------
# ----------------------------------------------------------
class RetrieveDatabaseView(views.APIView):
@@ -198,102 +209,92 @@ class RetrieveDatabaseView(views.APIView):
Returns the given database details
"""
permission_classes = [permissions.AllowAny]
model = Database
permission_classes = [drf_permissions.AllowAny]
def get(self, request, database_name, version=None):
# Retrieve the database
def get_object(self):
version = self.kwargs["version"]
database_name = self.kwargs["database_name"]
user = self.request.user
try:
if version is not None:
version = int(version)
databases = Database.objects.for_user(request.user, True).filter(name__iexact=database_name,
version__gte=version).order_by('version')
database = databases[0]
if database.version != version:
return HttpResponse(status=status.HTTP_404_NOT_FOUND)
last_version = (len(databases) == 1)
else:
database = Database.objects.for_user(request.user, True).filter(
name__iexact=database_name).order_by('-version')[0]
last_version = True
except:
return HttpResponse(status=status.HTTP_404_NOT_FOUND)
obj = self.model.objects.for_user(user, True).get(
name__iexact=database_name, version=version
)
except self.model.DoesNotExist:
raise drf_exceptions.NotFound()
return obj
def get(self, request, database_name, version):
# Retrieve the database
database = self.get_object()
self.check_object_permissions(request, database)
# Process the query string
if 'fields' in request.GET:
fields_to_return = request.GET['fields'].split(',')
if "fields" in request.GET:
fields_to_return = request.GET["fields"].split(",")
else:
fields_to_return = [ 'name', 'version', 'last_version',
'short_description', 'description', 'fork_of',
'previous_version', 'is_owner', 'accessibility', 'sharing',
'opensource', 'hash', 'creation_date',
'declaration', 'code' ]
try:
# Prepare the response
result = database_to_json(database, request.user, fields_to_return,
last_version=last_version)
# Retrieve the code
if 'declaration' in fields_to_return:
try:
declaration = database.declaration
except:
logger.error(traceback.format_exc())
return HttpResponse(status=500)
cleaned_declaration = clean_paths(declaration)
result['declaration'] = json.dumps(cleaned_declaration)
# Retrieve the source code
if 'code' in fields_to_return:
try:
result['code'] = database.source_code
except:
logger.error(traceback.format_exc())
return HttpResponse(status=500)
# Retrieve the description in HTML format
if 'html_description' in fields_to_return:
description = database.description
if len(description) > 0:
result['html_description'] = ensure_html(description)
else:
result['html_description'] = ''
# Retrieve the referenced data formats
if 'referenced_dataformats' in fields_to_return:
dataformats = database.all_referenced_dataformats()
referenced_dataformats = []
for dataformat in dataformats:
(has_access, accessibility) = dataformat.accessibility_for(request.user)
if has_access:
referenced_dataformats.append(dataformat)
serializer = ReferencedDataFormatSerializer(referenced_dataformats)
result['referenced_dataformats'] = serializer.data
# Retrieve the needed data formats
if 'needed_dataformats' in fields_to_return:
dataformats = database.all_needed_dataformats()
needed_dataformats = []
for dataformat in dataformats:
(has_access, accessibility) = dataformat.accessibility_for(request.user)
if has_access:
needed_dataformats.append(dataformat)
serializer = ReferencedDataFormatSerializer(needed_dataformats)
result['needed_dataformats'] = serializer.data
# Return the result
return Response(result)
except:
logger.error(traceback.format_exc())
return HttpResponse(status=500)
fields_to_return = [
"name",
"version",
"last_version",
"short_description",
"description",
"fork_of",
"previous_version",
"is_owner",
"accessibility",
"sharing",
"opensource",
"hash",
"creation_date",
"declaration",
"code",
]
# Prepare the response
result = database_to_json(database, request.user, fields_to_return)
# Retrieve the code
if "declaration" in fields_to_return:
declaration = database.declaration
cleaned_declaration = clean_paths(declaration)
result["declaration"] = json.dumps(cleaned_declaration)
# Retrieve the source code
if "code" in fields_to_return:
result["code"] = database.source_code
# Retrieve the description in HTML format
if "html_description" in fields_to_return:
description = database.description
if len(description) > 0:
result["html_description"] = ensure_html(description)
else:
result["html_description"] = ""
# Retrieve the referenced data formats
if "referenced_dataformats" in fields_to_return:
dataformats = database.all_referenced_dataformats()
referenced_dataformats = []
for dataformat in dataformats:
(has_access, accessibility) = dataformat.accessibility_for(request.user)
if has_access:
referenced_dataformats.append(dataformat)
serializer = ReferencedDataFormatSerializer(referenced_dataformats)
result["referenced_dataformats"] = serializer.data
# Retrieve the needed data formats
if "needed_dataformats" in fields_to_return:
dataformats = database.all_needed_dataformats()
needed_dataformats = []
for dataformat in dataformats:
(has_access, accessibility) = dataformat.accessibility_for(request.user)
if has_access:
needed_dataformats.append(dataformat)
serializer = ReferencedDataFormatSerializer(needed_dataformats)
result["needed_dataformats"] = serializer.data
# Return the result
return Response(result)
Loading