diff --git a/beat/web/databases/api.py b/beat/web/databases/api.py index 209747d8d32c6b3592da8f8992efe5c836b31a2a..7b817d837a33c54ce8d57c8e430e54e7b786b85d 100755 --- a/beat/web/databases/api.py +++ b/beat/web/databases/api.py @@ -29,87 +29,88 @@ import os import json from django.http import HttpResponse -from django.core.urlresolvers import reverse from rest_framework.response import Response from rest_framework import permissions from rest_framework import views from rest_framework import status -from rest_framework import generics - from .models import Database from .models import DatabaseSetTemplate from .serializers import DatabaseSerializer, DatabaseCreationSerializer -from .exceptions import DatabaseCreationError from ..common import is_true from ..common.mixins import IsAdminOrReadOnlyMixin from ..common.api import ListCreateBaseView -from ..common.responses import BadRequestResponse from ..common.utils import ensure_html from ..dataformats.serializers import ReferencedDataFormatSerializer import logging import traceback + logger = logging.getLogger(__name__) -#---------------------------------------------------------- +# ---------------------------------------------------------- + -def database_to_json(database, request_user, fields_to_return, - last_version=None): +def database_to_json(database, request_user, fields_to_return, last_version=None): # Prepare the response result = {} - if 'name' in fields_to_return: - result['name'] = database.fullname() + if "name" in fields_to_return: + result["name"] = database.fullname() - if 'version' in fields_to_return: - result['version'] = database.version + if "version" in fields_to_return: + result["version"] = database.version - if 'last_version' in fields_to_return: - result['last_version'] = last_version + if "last_version" in fields_to_return: + result["last_version"] = last_version - if 'short_description' in fields_to_return: - result['short_description'] = database.short_description + if "short_description" in fields_to_return: + result["short_description"] = database.short_description - if 'description' in fields_to_return: - result['description'] = database.description + if "description" in fields_to_return: + result["description"] = database.description - if 'previous_version' in fields_to_return: - result['previous_version'] = \ - (database.previous_version.fullname() if \ - database.previous_version is not None else None) + if "previous_version" in fields_to_return: + result["previous_version"] = ( + database.previous_version.fullname() + if database.previous_version is not None + else None + ) - if 'creation_date' in fields_to_return: - result['creation_date'] = database.creation_date.isoformat(' ') + if "creation_date" in fields_to_return: + result["creation_date"] = database.creation_date.isoformat(" ") - if 'hash' in fields_to_return: - result['hash'] = database.hash + if "hash" in fields_to_return: + result["hash"] = database.hash - if 'accessibility' in fields_to_return: - result['accessibility'] = database.accessibility_for(request_user) + if "accessibility" in fields_to_return: + result["accessibility"] = database.accessibility_for(request_user) return result def clean_paths(declaration): - pseudo_path = '/path_to_db_folder' - root_folder = declaration['root_folder'] + pseudo_path = "/path_to_db_folder" + root_folder = declaration["root_folder"] cleaned_folder = os.path.basename(os.path.normpath(root_folder)) - declaration['root_folder'] = os.path.join(pseudo_path, cleaned_folder) - for protocol in declaration['protocols']: - for set_ in protocol['sets']: - if 'parameters' in set_ and 'annotations' in set_['parameters']: - annotations_folder = set_['parameters']['annotations'] - cleaned_folder = annotations_folder.split('/')[-2:] - set_['parameters']['annotations'] = os.path.join(pseudo_path, *cleaned_folder) + declaration["root_folder"] = os.path.join(pseudo_path, cleaned_folder) + for protocol in declaration["protocols"]: + for set_ in protocol["sets"]: + if "parameters" in set_ and "annotations" in set_["parameters"]: + annotations_folder = set_["parameters"]["annotations"] + cleaned_folder = annotations_folder.split("/")[-2:] + set_["parameters"]["annotations"] = os.path.join( + pseudo_path, *cleaned_folder + ) return declaration -#---------------------------------------------------------- + +# ---------------------------------------------------------- class ListCreateDatabasesView(IsAdminOrReadOnlyMixin, ListCreateBaseView): @@ -118,37 +119,42 @@ class ListCreateDatabasesView(IsAdminOrReadOnlyMixin, ListCreateBaseView): to a user and allows the creation of new databases only to platform administrator """ + model = Database serializer_class = DatabaseSerializer writing_serializer_class = DatabaseCreationSerializer - namespace = 'api_databases' + namespace = "api_databases" def get_queryset(self): user = self.request.user return self.model.objects.for_user(user, True) - def get(self, request, *args, **kwargs): fields_to_return = self.get_serializer_fields(request) - limit_to_latest_versions = is_true(request.query_params.get('latest_versions', False)) + limit_to_latest_versions = is_true( + request.query_params.get("latest_versions", False) + ) - all_databases = self.get_queryset().order_by('name') + all_databases = self.get_queryset().order_by("name") if limit_to_latest_versions: all_databases = self.model.filter_latest_versions(all_databases) all_databases.sort(key=lambda x: x.fullname()) - serializer = self.get_serializer(all_databases, many=True, fields=fields_to_return) + serializer = self.get_serializer( + all_databases, many=True, fields=fields_to_return + ) return Response(serializer.data) -#---------------------------------------------------------- +# ---------------------------------------------------------- class ListTemplatesView(views.APIView): """ List all templates available """ + permission_classes = [permissions.AllowAny] def get(self, request): @@ -158,39 +164,39 @@ class ListTemplatesView(views.APIView): databases = Database.objects.for_user(request.user, True) databases = Database.filter_latest_versions(databases) - for set_template in DatabaseSetTemplate.objects.filter(sets__protocol__database__in=databases).distinct().order_by('name'): - (db_template, dataset) = set_template.name.split('__') + for set_template in ( + DatabaseSetTemplate.objects.filter(sets__protocol__database__in=databases) + .distinct() + .order_by("name") + ): + (db_template, dataset) = set_template.name.split("__") if db_template not in result: - result[db_template] = { - 'templates': {}, - 'sets': [], - } + result[db_template] = {"templates": {}, "sets": []} - result[db_template]['templates'][dataset] = map(lambda x: x.name, set_template.outputs.order_by('name')) + result[db_template]["templates"][dataset] = map( + lambda x: x.name, set_template.outputs.order_by("name") + ) known_sets = [] for db_set in set_template.sets.iterator(): if db_set.name not in known_sets: - result[db_template]['sets'].append({ - 'name': db_set.name, - 'template': dataset, - 'id': db_set.id, - }) + result[db_template]["sets"].append( + {"name": db_set.name, "template": dataset, "id": db_set.id} + ) known_sets.append(db_set.name) for name, entry in result.items(): - entry['sets'].sort(key=lambda x: x['id']) - result[name]['sets'] = map(lambda x: { - 'name': x['name'], - 'template': x['template'], - }, entry['sets']) + entry["sets"].sort(key=lambda x: x["id"]) + result[name]["sets"] = map( + lambda x: {"name": x["name"], "template": x["template"]}, entry["sets"] + ) return Response(result) -#---------------------------------------------------------- +# ---------------------------------------------------------- class RetrieveDatabaseView(views.APIView): @@ -206,94 +212,112 @@ class RetrieveDatabaseView(views.APIView): if version is not None: version = int(version) - databases = Database.objects.for_user(request.user, True).filter(name__iexact=database_name, - version__gte=version).order_by('version') + databases = ( + Database.objects.for_user(request.user, True) + .filter(name__iexact=database_name, version__gte=version) + .order_by("version") + ) database = databases[0] if database.version != version: return HttpResponse(status=status.HTTP_404_NOT_FOUND) - last_version = (len(databases) == 1) + last_version = len(databases) == 1 else: - database = Database.objects.for_user(request.user, True).filter( - name__iexact=database_name).order_by('-version')[0] + database = ( + Database.objects.for_user(request.user, True) + .filter(name__iexact=database_name) + .order_by("-version")[0] + ) last_version = True - except: + except Exception: return HttpResponse(status=status.HTTP_404_NOT_FOUND) - # Process the query string - if 'fields' in request.GET: - fields_to_return = request.GET['fields'].split(',') + if "fields" in request.GET: + fields_to_return = request.GET["fields"].split(",") else: - fields_to_return = [ 'name', 'version', 'last_version', - 'short_description', 'description', 'fork_of', - 'previous_version', 'is_owner', 'accessibility', 'sharing', - 'opensource', 'hash', 'creation_date', - 'declaration', 'code' ] - + fields_to_return = [ + "name", + "version", + "last_version", + "short_description", + "description", + "fork_of", + "previous_version", + "is_owner", + "accessibility", + "sharing", + "opensource", + "hash", + "creation_date", + "declaration", + "code", + ] try: # Prepare the response - result = database_to_json(database, request.user, fields_to_return, - last_version=last_version) + result = database_to_json( + database, request.user, fields_to_return, last_version=last_version + ) # Retrieve the code - if 'declaration' in fields_to_return: + if "declaration" in fields_to_return: try: declaration = database.declaration - except: + except Exception: logger.error(traceback.format_exc()) return HttpResponse(status=500) cleaned_declaration = clean_paths(declaration) - result['declaration'] = json.dumps(cleaned_declaration) + result["declaration"] = json.dumps(cleaned_declaration) # Retrieve the source code - if 'code' in fields_to_return: + if "code" in fields_to_return: try: - result['code'] = database.source_code - except: + result["code"] = database.source_code + except Exception: logger.error(traceback.format_exc()) return HttpResponse(status=500) - # Retrieve the description in HTML format - if 'html_description' in fields_to_return: + if "html_description" in fields_to_return: description = database.description if len(description) > 0: - result['html_description'] = ensure_html(description) + result["html_description"] = ensure_html(description) else: - result['html_description'] = '' - + result["html_description"] = "" # Retrieve the referenced data formats - if 'referenced_dataformats' in fields_to_return: + if "referenced_dataformats" in fields_to_return: dataformats = database.all_referenced_dataformats() referenced_dataformats = [] for dataformat in dataformats: - (has_access, accessibility) = dataformat.accessibility_for(request.user) + (has_access, accessibility) = dataformat.accessibility_for( + request.user + ) if has_access: referenced_dataformats.append(dataformat) serializer = ReferencedDataFormatSerializer(referenced_dataformats) - result['referenced_dataformats'] = serializer.data - + result["referenced_dataformats"] = serializer.data # Retrieve the needed data formats - if 'needed_dataformats' in fields_to_return: + if "needed_dataformats" in fields_to_return: dataformats = database.all_needed_dataformats() needed_dataformats = [] for dataformat in dataformats: - (has_access, accessibility) = dataformat.accessibility_for(request.user) + (has_access, accessibility) = dataformat.accessibility_for( + request.user + ) if has_access: needed_dataformats.append(dataformat) serializer = ReferencedDataFormatSerializer(needed_dataformats) - result['needed_dataformats'] = serializer.data + result["needed_dataformats"] = serializer.data # Return the result return Response(result) - except: + except Exception: logger.error(traceback.format_exc()) return HttpResponse(status=500)