diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bb2dfa1a6897347679b3da2cb1bb15548e155a8b..f1497effdbc1099618fb840b7a842bbea8c6985b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,4 +21,4 @@ repos: rev: 'master' # Update me! hooks: - id: bandit - exclude: beat/editor/test + exclude: (beat/web/.*/tests|eggs/) diff --git a/beat/web/accounts/api_urls.py b/beat/web/accounts/api_urls.py index 2d037b3dfbd30a22d2173fd2e90a06aa05f61299..b57b22e5c52ddcf7713a4bf4c81bba145e9b4bb5 100644 --- a/beat/web/accounts/api_urls.py +++ b/beat/web/accounts/api_urls.py @@ -25,62 +25,50 @@ # # ############################################################################### -from django.conf.urls import * +from django.conf.urls import url + from . import api urlpatterns = [ + url(r"^$", api.SupervisorListView.as_view(), name="list_supervisee"), url( - r'^$', - api.SupervisorListView.as_view(), - name='list_supervisee' - ), - - url( - r'^(?P<supervisee_name>[\w\W]+)/validate/$', + r"^(?P<supervisee_name>[\w\W]+)/validate/$", api.SupervisorAddSuperviseeView.as_view(), - name='validate_supervisee' + name="validate_supervisee", ), - url( - r'^(?P<supervisee_name>[\w\W]+)/remove/$', + r"^(?P<supervisee_name>[\w\W]+)/remove/$", api.SupervisorRemoveSuperviseeView.as_view(), - name='remove_supervisee' + name="remove_supervisee", ), - url( - r'^(?P<supervisor_name>[\w\W]+)/add/$', + r"^(?P<supervisor_name>[\w\W]+)/add/$", api.SuperviseeAddSupervisorView.as_view(), - name='add_supervisor' + name="add_supervisor", ), - url( - r'^revalidate/$', + r"^revalidate/$", api.SuperviseeReValidationView.as_view(), - name='revalidate_account' + name="revalidate_account", ), - url( - r'^set_supervisor_mode/$', + r"^set_supervisor_mode/$", api.SetSupervisorModeView.as_view(), - name='set_supervisor_mode' + name="set_supervisor_mode", ), - url( - r'^remove_supervisor_mode/$', + r"^remove_supervisor_mode/$", api.RemoveSupervisorModeView.as_view(), - name='remove_supervisor_mode' + name="remove_supervisor_mode", ), - url( - r'^list_supervisor_candidates/$', + r"^list_supervisor_candidates/$", api.ListSupervisorCandidatesView.as_view(), - name='list_supervisor_candidates' + name="list_supervisor_candidates", ), - url( - r'^grant_supervisor_access/$', + r"^grant_supervisor_access/$", api.UpdateSupervisorCandidateView.as_view(), - name='update_supervisor_candidate' + name="update_supervisor_candidate", ), - ] diff --git a/beat/web/accounts/serializers.py b/beat/web/accounts/serializers.py index 5d3cd2820b0c732a7ad40aeb775129c4d9b6858f..59679a3754d2e917cc8aeee30d320da82e7d26b4 100644 --- a/beat/web/accounts/serializers.py +++ b/beat/web/accounts/serializers.py @@ -25,20 +25,13 @@ # # ############################################################################### -from django.contrib.auth.models import User, AnonymousUser - +from django.contrib.auth.models import User from rest_framework import serializers -from .models import Profile, SupervisionTrack -from ..common.models import Contribution -from ..common.fields import JSONSerializerField -from ..ui.templatetags.markup import restructuredtext -from ..common.utils import validate_restructuredtext - -import simplejson as json +from .models import SupervisionTrack -#---------------------------------------------------------- +# ---------------------------------------------------------- class UserSerializer(serializers.ModelSerializer): @@ -47,7 +40,7 @@ class UserSerializer(serializers.ModelSerializer): class Meta: model = User - fields = ['username', 'email'] + fields = ["username", "email"] def get_username(self, obj): return obj.username @@ -56,7 +49,7 @@ class UserSerializer(serializers.ModelSerializer): return obj.email -#---------------------------------------------------------- +# ---------------------------------------------------------- class BasicSupervisionTrackSerializer(serializers.ModelSerializer): @@ -70,12 +63,12 @@ class BasicSupervisionTrackSerializer(serializers.ModelSerializer): class Meta: model = SupervisionTrack - fields = ['is_valid'] + fields = ["is_valid"] - #def get_supervisee(self, obj): + # def get_supervisee(self, obj): # return obj.supervisee - #def get_supervisor(self, obj): + # def get_supervisor(self, obj): # return obj.supervisor def get_is_valid(self, obj): @@ -94,20 +87,27 @@ class BasicSupervisionTrackSerializer(serializers.ModelSerializer): return obj.supervision_key -#---------------------------------------------------------- +# ---------------------------------------------------------- class FullSupervisionTrackSerializer(BasicSupervisionTrackSerializer): - class Meta(BasicSupervisionTrackSerializer.Meta): - fields = ['supervisee', 'supervisor', 'is_valid', 'start_date', 'expiration_date','last_validation_date', 'supervision_key'] + fields = [ + "supervisee", + "supervisor", + "is_valid", + "start_date", + "expiration_date", + "last_validation_date", + "supervision_key", + ] -#---------------------------------------------------------- +# ---------------------------------------------------------- class SupervisionTrackUpdateSerializer(BasicSupervisionTrackSerializer): pass -#---------------------------------------------------------- +# ---------------------------------------------------------- diff --git a/beat/web/algorithms/api.py b/beat/web/algorithms/api.py index d85922e0fdee7fc3c9789e893e9baf02106eac24..71b1f08c08ac23bf4abbba1b4d81d601bbc800cc 100755 --- a/beat/web/algorithms/api.py +++ b/beat/web/algorithms/api.py @@ -29,26 +29,28 @@ from django.http import Http404 from django.http import HttpResponse from django.http import HttpResponseForbidden from django.http import HttpResponseBadRequest +from django.http import HttpResponseNotAllowed from django.shortcuts import get_object_or_404 -from django.conf import settings - -import os from .models import Algorithm from .serializers import AlgorithmSerializer from .serializers import FullAlgorithmSerializer from .serializers import AlgorithmCreationSerializer +from .serializers import AlgorithmModSerializer from ..code.api import ShareCodeView, RetrieveUpdateDestroyCodeView from ..code.serializers import CodeDiffSerializer -from ..common.api import (CheckContributionNameView, ListContributionView, - ListCreateContributionView) +from ..common.api import ( + CheckContributionNameView, + ListContributionView, + ListCreateContributionView, +) from ..code.api import DiffView -#---------------------------------------------------------- +# ---------------------------------------------------------- class CheckAlgorithmNameView(CheckContributionNameView): @@ -56,10 +58,11 @@ class CheckAlgorithmNameView(CheckContributionNameView): This view sanitizes an algorithm name and checks whether it is already used. """ + model = Algorithm -#---------------------------------------------------------- +# ---------------------------------------------------------- class ShareAlgorithmView(ShareCodeView): @@ -67,21 +70,23 @@ class ShareAlgorithmView(ShareCodeView): This view allows to share an algorithm with other users and/or teams """ + model = Algorithm -#---------------------------------------------------------- +# ---------------------------------------------------------- class ListAlgorithmsView(ListContributionView): """ List all available algorithms """ + model = Algorithm serializer_class = AlgorithmSerializer -#---------------------------------------------------------- +# ---------------------------------------------------------- class ListCreateAlgorithmsView(ListCreateContributionView): @@ -89,13 +94,14 @@ class ListCreateAlgorithmsView(ListCreateContributionView): Read/Write end point that list the algorithms available from a given author and allows the creation of new algorithms """ + model = Algorithm serializer_class = AlgorithmSerializer writing_serializer_class = AlgorithmCreationSerializer - namespace = 'api_algorithms' + namespace = "api_algorithms" -#---------------------------------------------------------- +# ---------------------------------------------------------- class RetrieveUpdateDestroyAlgorithmsView(RetrieveUpdateDestroyCodeView): @@ -105,39 +111,30 @@ class RetrieveUpdateDestroyAlgorithmsView(RetrieveUpdateDestroyCodeView): model = Algorithm serializer_class = FullAlgorithmSerializer + writing_serializer_class = AlgorithmModSerializer - def do_update(self, request, author_name, object_name, version=None): - modified, algorithm = super(RetrieveUpdateDestroyAlgorithmsView, self).do_update(request, author_name, object_name, version) - - if modified: - # Delete existing experiments using the algorithm (code changed) - experiments = list(set(map(lambda x: x.experiment, - algorithm.blocks.iterator()))) - for experiment in experiments: experiment.delete() - return modified, algorithm - - -#---------------------------------------------------------- +# ---------------------------------------------------------- class DiffAlgorithmView(DiffView): """ This view shows the differences between two algorithms """ + model = Algorithm serializer_class = CodeDiffSerializer -#---------------------------------------------------------- +# ---------------------------------------------------------- def binary(request, author_name, object_name, version=None): """Returns the shared library of a binary algorithm """ - if request.method not in ['GET', 'POST']: - return HttpResponseNotAllowed(['GET', 'POST']) + if request.method not in ["GET", "POST"]: + return HttpResponseNotAllowed(["GET", "POST"]) # Retrieves the algorithm if version: @@ -148,8 +145,9 @@ def binary(request, author_name, object_name, version=None): version=int(version), ) else: - algorithm = Algorithm.objects.filter(author__username__iexact=author_name, - name__iexact=object_name).order_by('-version') + algorithm = Algorithm.objects.filter( + author__username__iexact=author_name, name__iexact=object_name + ).order_by("-version") if not algorithm: raise Http404() else: @@ -158,18 +156,22 @@ def binary(request, author_name, object_name, version=None): if not algorithm.is_binary(): raise Http404() - if request.method == 'GET': - (has_access, _, accessibility) = algorithm.accessibility_for(request.user, without_usable=True) + if request.method == "GET": + (has_access, _, accessibility) = algorithm.accessibility_for( + request.user, without_usable=True + ) if not has_access: raise Http404() binary_data = algorithm.source_code - response = HttpResponse(binary_data, content_type='application/octet-stream') + response = HttpResponse(binary_data, content_type="application/octet-stream") - response['Content-Length'] = len(binary_data) - response['Content-Disposition'] = 'attachment; filename=%d.so' % algorithm.version + response["Content-Length"] = len(binary_data) + response["Content-Disposition"] = ( + "attachment; filename=%d.so" % algorithm.version + ) return response @@ -177,12 +179,12 @@ def binary(request, author_name, object_name, version=None): if request.user.is_anonymous() or (request.user.username != author_name): return HttpResponseForbidden() - if 'binary' not in request.FILES: + if "binary" not in request.FILES: return HttpResponseBadRequest() - file = request.FILES['binary'] + file = request.FILES["binary"] - binary_data = b'' + binary_data = b"" for chunk in file.chunks(): binary_data += chunk diff --git a/beat/web/algorithms/api_urls.py b/beat/web/algorithms/api_urls.py index 56148227124bcecba42e5d2a0d22ef866f9c74f4..cbb3a26021f73b241bc5fb50ff090c330a88cffb 100755 --- a/beat/web/algorithms/api_urls.py +++ b/beat/web/algorithms/api_urls.py @@ -29,51 +29,33 @@ from django.conf.urls import url from . import api -urlpatterns = [ - - url(r'^$', - api.ListAlgorithmsView.as_view(), - name='all', - ), - url(r'^check_name/$', - api.CheckAlgorithmNameView.as_view(), - name='check_name', - ), - - url(r'^diff/(?P<author1>\w+)/(?P<name1>[-\w]+)/(?P<version1>\d+)/(?P<author2>\w+)/(?P<name2>[-\w]+)/(?P<version2>\d+)/$', +urlpatterns = [ + url(r"^$", api.ListAlgorithmsView.as_view(), name="all"), + url(r"^check_name/$", api.CheckAlgorithmNameView.as_view(), name="check_name"), + url( + r"^diff/(?P<author1>\w+)/(?P<name1>[-\w]+)/(?P<version1>\d+)/(?P<author2>\w+)/(?P<name2>[-\w]+)/(?P<version2>\d+)/$", api.DiffAlgorithmView.as_view(), - name='diff', - ), - - url(r'^(?P<author_name>\w+)/(?P<object_name>[-\w]+)/(?P<version>\d+)/share/$', + name="diff", + ), + url( + r"^(?P<author_name>\w+)/(?P<object_name>[-\w]+)/(?P<version>\d+)/share/$", api.ShareAlgorithmView.as_view(), - name='share', - ), - - url(r'^(?P<author_name>\w+)/$', + name="share", + ), + url( + r"^(?P<author_name>\w+)/$", api.ListCreateAlgorithmsView.as_view(), - name='list_create', - ), - - url(r'^(?P<author_name>\w+)/(?P<object_name>[-\w]+)/(?P<version>\d+)/$', + name="list_create", + ), + url( + r"^(?P<author_name>\w+)/(?P<object_name>[-\w]+)/(?P<version>\d+)/$", api.RetrieveUpdateDestroyAlgorithmsView.as_view(), - name='object', - ), - - url(r'^(?P<author_name>\w+)/(?P<object_name>[-\w]+)/$', - api.RetrieveUpdateDestroyAlgorithmsView.as_view(), - name='object', - ), - - url(r'^(?P<author_name>\w+)/(?P<object_name>[-\w]+)/(?P<version>\d+)/binary/$', + name="object", + ), + url( + r"^(?P<author_name>\w+)/(?P<object_name>[-\w]+)/(?P<version>\d+)/binary/$", api.binary, - name='binary', - ), - - url(r'^(?P<author_name>\w+)/(?P<object_name>[-\w]+)/binary/$', - api.binary, - name='binary', - ), - + name="binary", + ), ] diff --git a/beat/web/algorithms/serializers.py b/beat/web/algorithms/serializers.py index 3f40b5be298074e110c7e9d823fc58c9a13fb290..d44166586e4515fa8a850bd9b30865d80ae78b08 100755 --- a/beat/web/algorithms/serializers.py +++ b/beat/web/algorithms/serializers.py @@ -32,7 +32,7 @@ import beat.core.algorithm from rest_framework import serializers from operator import itemgetter -from ..code.serializers import CodeSerializer, CodeCreationSerializer +from ..code.serializers import CodeSerializer, CodeCreationSerializer, CodeModSerializer from ..libraries.serializers import LibraryReferenceSerializer from ..dataformats.serializers import ReferencedDataFormatSerializer from ..attestations.serializers import AttestationSerializer @@ -49,7 +49,13 @@ from .models import Algorithm class AlgorithmCreationSerializer(CodeCreationSerializer): class Meta(CodeCreationSerializer.Meta): model = Algorithm - beat_core_class = beat.core.algorithm + beat_core_class = beat.core.algorithm.Algorithm + + +class AlgorithmModSerializer(CodeModSerializer): + class Meta(CodeModSerializer.Meta): + model = Algorithm + beat_core_class = beat.core.algorithm.Algorithm # ---------------------------------------------------------- diff --git a/beat/web/algorithms/templates/algorithms/edition.html b/beat/web/algorithms/templates/algorithms/edition.html index 73667155ed846d514853715fddb0a533d598d75d..0c9cd5ed05807059105962898672a4a64d1160c7 100644 --- a/beat/web/algorithms/templates/algorithms/edition.html +++ b/beat/web/algorithms/templates/algorithms/edition.html @@ -329,13 +329,14 @@ function setupEditor(algorithm, dataformats, libraries) var data = { {% if not edition %} - {% if algorithm_version > 1 and not fork_of %} - name: '{{ algorithm_name }}', - version: '{{ algorithm_version }}', - previous_version:'{{ algorithm_author }}/{{ algorithm_name }}/{{ algorithm_version|add:-1 }}', - {% else %} - name: $('#algorithm_name')[0].value.trim(), - {% endif %} + {% if algorithm_version > 1 and not fork_of %} + name: '{{ algorithm_name }}', + version: {{ algorithm_version }}, + previous_version:'{{ algorithm_author }}/{{ algorithm_name }}/{{ algorithm_version|add:-1 }}', + {% else %} + name: $('#algorithm_name')[0].value.trim(), + version: 1, + {% endif %} description: (algorithm && algorithm.description) || '', {% endif %} short_description: (algorithm && algorithm.short_description) || '', @@ -710,10 +711,7 @@ jQuery(document).ready(function() { }); {% if original_author %} - var url = '{% url 'api_algorithms:object' original_author algorithm_name %}' - {% if algorithm_version %} - url += '{{ algorithm_version }}/'; - {% endif %} + var url = '{% url 'api_algorithms:object' original_author algorithm_name algorithm_version %}' var query = '?fields=html_description,description,short_description'; diff --git a/beat/web/algorithms/tests/tests_api.py b/beat/web/algorithms/tests/tests_api.py index 3c68aa8f43bb226cd06b5f0385f417d77ef4d926..d8c1c84926c4a7f92504e1cd7c5cd1bb81fd1083 100755 --- a/beat/web/algorithms/tests/tests_api.py +++ b/beat/web/algorithms/tests/tests_api.py @@ -1229,20 +1229,26 @@ class AlgorithmUpdate(AlgorithmsAPIBase): def test_no_update_without_content(self): self.login_jackdoe() response = self.client.put(self.url) - self.checkResponse(response, 400) + self.checkResponse(response, 200, content_type="application/json") - def test_successfull_update(self): + def test_successful_update(self): self.login_jackdoe() + code = b"""import numpy as np""" + response = self.client.put( self.url, json.dumps( - {"description": "blah", "declaration": AlgorithmsAPIBase.UPDATE} + { + "description": "blah", + "declaration": AlgorithmsAPIBase.UPDATE, + "code": code, + } ), content_type="application/json", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") algorithm = Algorithm.objects.get( author__username="jackdoe", name="personal", version=1 @@ -1252,25 +1258,74 @@ class AlgorithmUpdate(AlgorithmsAPIBase): storage = beat.core.algorithm.Storage(settings.PREFIX, algorithm.fullname()) storage.language = "python" self.assertTrue(storage.exists()) - self.assertEqual(storage.json.load(), AlgorithmsAPIBase.UPDATE) + self.assertEqual( + json.loads(storage.json.load()), json.loads(AlgorithmsAPIBase.UPDATE) + ) + self.assertEqual(storage.code.load(), code) - def test_successfull_update_description_only(self): + def test_successful_update_with_specific_return_field(self): self.login_jackdoe() + code = b"""import numpy as np""" + response = self.client.put( self.url, - json.dumps({"description": "blah"}), + json.dumps( + { + "description": "blah", + "declaration": AlgorithmsAPIBase.UPDATE, + "code": code, + } + ), content_type="application/json", + QUERY_STRING="fields=code", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") + answer = response.json() + self.assertEqual(len(answer), 1) + self.assertTrue("code" in answer) + def test_successful_update_with_specific_return_several_fields(self): + self.login_jackdoe() + + code = b"""import numpy as np""" + + response = self.client.put( + self.url, + json.dumps( + { + "description": "blah", + "declaration": AlgorithmsAPIBase.UPDATE, + "code": code, + } + ), + content_type="application/json", + QUERY_STRING="fields=code,description", + ) + + self.checkResponse(response, 200, content_type="application/json") + answer = response.json() + self.assertEqual(len(answer), 2) + self.assertTrue("code" in answer) + self.assertTrue("description" in answer) + + def test_successful_update_description_only(self): + self.login_jackdoe() + + response = self.client.put( + self.url, + json.dumps({"description": "blah"}), + content_type="application/json", + ) + + self.checkResponse(response, 200, content_type="application/json") algorithm = Algorithm.objects.get( author__username="jackdoe", name="personal", version=1 ) self.assertEqual(algorithm.description, b"blah") - def test_successfull_update_code_only(self): + def test_successful_declaration_only(self): self.login_jackdoe() response = self.client.put( @@ -1279,7 +1334,7 @@ class AlgorithmUpdate(AlgorithmsAPIBase): content_type="application/json", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") algorithm = Algorithm.objects.get( author__username="jackdoe", name="personal", version=1 @@ -1288,9 +1343,31 @@ class AlgorithmUpdate(AlgorithmsAPIBase): storage = beat.core.algorithm.Storage(settings.PREFIX, algorithm.fullname()) storage.language = "python" self.assertTrue(storage.exists()) - self.assertEqual(storage.json.load(), AlgorithmsAPIBase.UPDATE) + self.assertEqual( + json.loads(storage.json.load()), json.loads(AlgorithmsAPIBase.UPDATE) + ) - def test_successfull_update_change_input_name(self): + def test_successful_update_code_only(self): + self.login_jackdoe() + + code = b"""import pandas""" + + response = self.client.put( + self.url, json.dumps({"code": code}), content_type="application/json" + ) + + self.checkResponse(response, 200, content_type="application/json") + + algorithm = Algorithm.objects.get( + author__username="jackdoe", name="personal", version=1 + ) + + storage = beat.core.algorithm.Storage(settings.PREFIX, algorithm.fullname()) + storage.language = "python" + self.assertTrue(storage.exists()) + self.assertEqual(storage.code.load(), code) + + def test_successful_update_change_input_name(self): declaration = """{ "language": "python", "splittable": false, @@ -1317,7 +1394,7 @@ class AlgorithmUpdate(AlgorithmsAPIBase): content_type="application/json", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") algorithm = Algorithm.objects.get( author__username="jackdoe", name="personal", version=1 @@ -1337,7 +1414,7 @@ class AlgorithmUpdate(AlgorithmsAPIBase): self.assertFalse(endpoint.input) self.assertEqual(endpoint.dataformat.fullname(), "johndoe/single_integer/1") - def test_successfull_update_change_input_format(self): + def test_successful_update_change_input_format(self): declaration = """{ "language": "python", "splittable": false, @@ -1369,7 +1446,7 @@ class AlgorithmUpdate(AlgorithmsAPIBase): content_type="application/json", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") algorithm = Algorithm.objects.get( author__username="jackdoe", name="personal", version=1 @@ -1389,7 +1466,7 @@ class AlgorithmUpdate(AlgorithmsAPIBase): self.assertFalse(endpoint.input) self.assertEqual(endpoint.dataformat.fullname(), "johndoe/single_integer/1") - def test_successfull_update_change_output_name(self): + def test_successful_update_change_output_name(self): declaration = """{ "language": "python", "splittable": false, @@ -1407,7 +1484,7 @@ class AlgorithmUpdate(AlgorithmsAPIBase): "parameters": { } }""" - code = """class Algorithm: + code = b"""class Algorithm: def process(self, inputs, outputs): return True @@ -1422,7 +1499,7 @@ class AlgorithmUpdate(AlgorithmsAPIBase): content_type="application/json", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") algorithm = Algorithm.objects.get( author__username="jackdoe", name="personal", version=1 @@ -1442,7 +1519,7 @@ class AlgorithmUpdate(AlgorithmsAPIBase): self.assertFalse(endpoint.input) self.assertEqual(endpoint.dataformat.fullname(), "johndoe/single_integer/1") - def test_successfull_update_change_output_format(self): + def test_successful_update_change_output_format(self): declaration = """{ "language": "python", "splittable": false, @@ -1474,7 +1551,7 @@ class AlgorithmUpdate(AlgorithmsAPIBase): content_type="application/json", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") algorithm = Algorithm.objects.get( author__username="jackdoe", name="personal", version=1 @@ -1494,7 +1571,7 @@ class AlgorithmUpdate(AlgorithmsAPIBase): self.assertFalse(endpoint.input) self.assertEqual(endpoint.dataformat.fullname(), "jackdoe/single_float/1") - def test_successfull_update_add_input(self): + def test_successful_update_add_input(self): declaration = """{ "language": "python", "splittable": false, @@ -1522,7 +1599,7 @@ class AlgorithmUpdate(AlgorithmsAPIBase): content_type="application/json", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") algorithm = Algorithm.objects.get( author__username="jackdoe", name="personal", version=1 @@ -1546,7 +1623,7 @@ class AlgorithmUpdate(AlgorithmsAPIBase): self.assertFalse(endpoint.input) self.assertEqual(endpoint.dataformat.fullname(), "johndoe/single_integer/1") - def test_successfull_update_remove_input(self): + def test_successful_update_remove_input(self): declaration = """{ "language": "python", "splittable": false, @@ -1572,7 +1649,7 @@ class AlgorithmUpdate(AlgorithmsAPIBase): content_type="application/json", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") algorithm = Algorithm.objects.get( author__username="jackdoe", name="personal", version=1 @@ -1588,7 +1665,7 @@ class AlgorithmUpdate(AlgorithmsAPIBase): self.assertFalse(endpoint.input) self.assertEqual(endpoint.dataformat.fullname(), "johndoe/single_integer/1") - def test_successfull_update_add_output(self): + def test_successful_update_add_output(self): declaration = """{ "language": "python", "splittable": false, @@ -1616,7 +1693,7 @@ class AlgorithmUpdate(AlgorithmsAPIBase): content_type="application/json", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") algorithm = Algorithm.objects.get( author__username="jackdoe", name="personal", version=1 @@ -1640,7 +1717,7 @@ class AlgorithmUpdate(AlgorithmsAPIBase): self.assertFalse(endpoint.input) self.assertEqual(endpoint.dataformat.fullname(), "johndoe/single_integer/1") - def test_successfull_update_remove_output(self): + def test_successful_update_remove_output(self): declaration1 = """{ "language": "python", "splittable": false, @@ -1695,7 +1772,7 @@ class AlgorithmUpdate(AlgorithmsAPIBase): content_type="application/json", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") response = self.client.put( self.url, @@ -1703,7 +1780,7 @@ class AlgorithmUpdate(AlgorithmsAPIBase): content_type="application/json", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") algorithm = Algorithm.objects.get( author__username="jackdoe", name="personal", version=1 @@ -1723,7 +1800,7 @@ class AlgorithmUpdate(AlgorithmsAPIBase): self.assertFalse(endpoint.input) self.assertEqual(endpoint.dataformat.fullname(), "johndoe/single_integer/1") - def test_successfull_update_change_input_to_output(self): + def test_successful_update_change_input_to_output(self): declaration = """{ "language": "python", "splittable": false, @@ -1751,7 +1828,7 @@ class AlgorithmUpdate(AlgorithmsAPIBase): content_type="application/json", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") algorithm = Algorithm.objects.get( author__username="jackdoe", name="personal", version=1 @@ -1771,7 +1848,7 @@ class AlgorithmUpdate(AlgorithmsAPIBase): self.assertFalse(endpoint.input) self.assertEqual(endpoint.dataformat.fullname(), "johndoe/single_integer/1") - def test_successfull_update_change_output_to_input(self): + def test_successful_update_change_output_to_input(self): declaration1 = """{ "language": "python", "splittable": false, @@ -1827,7 +1904,7 @@ class AlgorithmUpdate(AlgorithmsAPIBase): content_type="application/json", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") response = self.client.put( self.url, @@ -1835,7 +1912,7 @@ class AlgorithmUpdate(AlgorithmsAPIBase): content_type="application/json", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") algorithm = Algorithm.objects.get( author__username="jackdoe", name="personal", version=1 @@ -1909,7 +1986,7 @@ class AlgorithmBinaryUpdate(AlgorithmsAPIBase): response = self.client.post(self.url, {"binary": None}) self.checkResponse(response, 400) - def test_successfull_update(self): + def test_successful_update(self): self.login_jackdoe() response = self.client.post(self.url, {"binary": self.updated_binary_file}) @@ -1925,33 +2002,33 @@ class AlgorithmBinaryUpdate(AlgorithmsAPIBase): class AlgorithmRetrieval(AlgorithmsAPIBase): def test_no_retrieval_of_confidential_algorithm_for_anonymous_user(self): - url = reverse("api_algorithms:object", args=["johndoe", "forked_algo"]) + url = reverse("api_algorithms:object", args=["johndoe", "forked_algo", 1]) response = self.client.get(url) self.checkResponse(response, 404) def test_fail_to_retrieve_with_invalid_username(self): self.login_johndoe() - url = reverse("api_algorithms:object", args=["unknown", "forked_algo"]) + url = reverse("api_algorithms:object", args=["unknown", "forked_algo", 1]) response = self.client.get(url) self.checkResponse(response, 404) def test_fail_to_retrieve_with_invalid_algorithm_name(self): self.login_johndoe() - url = reverse("api_algorithms:object", args=["johndoe", "unknown"]) + url = reverse("api_algorithms:object", args=["johndoe", "unknown", 1]) response = self.client.get(url) self.checkResponse(response, 404) def test_no_retrieval_of_confidential_algorithm(self): self.login_jackdoe() - url = reverse("api_algorithms:object", args=["johndoe", "forked_algo"]) + url = reverse("api_algorithms:object", args=["johndoe", "forked_algo", 1]) response = self.client.get(url) self.checkResponse(response, 404) def test_successful_retrieval_of_public_algorithm_for_anonymous_user(self): - url = reverse("api_algorithms:object", args=["jackdoe", "public_for_all"]) + url = reverse("api_algorithms:object", args=["jackdoe", "public_for_all", 1]) response = self.client.get(url) data = self.checkResponse(response, 200, content_type="application/json") @@ -1965,7 +2042,7 @@ class AlgorithmRetrieval(AlgorithmsAPIBase): self.assertEqual(data["code"].encode("utf-8"), AlgorithmsAPIBase.CODE) def test_successful_retrieval_of_usable_algorithm_for_anonymous_user(self): - url = reverse("api_algorithms:object", args=["jackdoe", "usable_by_all"]) + url = reverse("api_algorithms:object", args=["jackdoe", "usable_by_all", 1]) response = self.client.get(url) data = self.checkResponse(response, 200, content_type="application/json") @@ -1979,7 +2056,7 @@ class AlgorithmRetrieval(AlgorithmsAPIBase): def test_successful_retrieval_of_public_algorithm(self): self.login_johndoe() - url = reverse("api_algorithms:object", args=["jackdoe", "public_for_all"]) + url = reverse("api_algorithms:object", args=["jackdoe", "public_for_all", 1]) response = self.client.get(url) data = self.checkResponse(response, 200, content_type="application/json") @@ -1995,7 +2072,9 @@ class AlgorithmRetrieval(AlgorithmsAPIBase): def test_successful_retrieval_of_usable_algorithm(self): self.login_johndoe() - url = reverse("api_algorithms:object", args=["jackdoe", "usable_by_one_user"]) + url = reverse( + "api_algorithms:object", args=["jackdoe", "usable_by_one_user", 1] + ) response = self.client.get(url) data = self.checkResponse(response, 200, content_type="application/json") @@ -2009,7 +2088,7 @@ class AlgorithmRetrieval(AlgorithmsAPIBase): def test_successful_retrieval_of_publicly_usable_algorithm(self): self.login_johndoe() - url = reverse("api_algorithms:object", args=["jackdoe", "usable_by_all"]) + url = reverse("api_algorithms:object", args=["jackdoe", "usable_by_all", 1]) response = self.client.get(url) data = self.checkResponse(response, 200, content_type="application/json") @@ -2023,7 +2102,9 @@ class AlgorithmRetrieval(AlgorithmsAPIBase): def test_successful_retrieval_of_confidential_algorithm(self): self.login_johndoe() - url = reverse("api_algorithms:object", args=["jackdoe", "public_for_one_user"]) + url = reverse( + "api_algorithms:object", args=["jackdoe", "public_for_one_user", 1] + ) response = self.client.get(url) data = self.checkResponse(response, 200, content_type="application/json") @@ -2039,7 +2120,7 @@ class AlgorithmRetrieval(AlgorithmsAPIBase): def test_successful_retrieval_of_own_public_algorithm(self): self.login_jackdoe() - url = reverse("api_algorithms:object", args=["jackdoe", "public_for_all"]) + url = reverse("api_algorithms:object", args=["jackdoe", "public_for_all", 1]) response = self.client.get(url) data = self.checkResponse(response, 200, content_type="application/json") @@ -2060,7 +2141,9 @@ class AlgorithmRetrieval(AlgorithmsAPIBase): def test_successful_retrieval_of_own_confidential_algorithm(self): self.login_jackdoe() - url = reverse("api_algorithms:object", args=["jackdoe", "usable_by_one_user"]) + url = reverse( + "api_algorithms:object", args=["jackdoe", "usable_by_one_user", 1] + ) response = self.client.get(url) data = self.checkResponse(response, 200, content_type="application/json") @@ -2081,7 +2164,7 @@ class AlgorithmRetrieval(AlgorithmsAPIBase): def test_successful_retrieval_of_own_usable_algorithm(self): self.login_jackdoe() - url = reverse("api_algorithms:object", args=["jackdoe", "usable_by_all"]) + url = reverse("api_algorithms:object", args=["jackdoe", "usable_by_all", 1]) response = self.client.get(url) data = self.checkResponse(response, 200, content_type="application/json") @@ -2102,7 +2185,9 @@ class AlgorithmRetrieval(AlgorithmsAPIBase): def test_successful_retrieval_of_own_shared_algorithm(self): self.login_jackdoe() - url = reverse("api_algorithms:object", args=["jackdoe", "public_for_one_user"]) + url = reverse( + "api_algorithms:object", args=["jackdoe", "public_for_one_user", 1] + ) response = self.client.get(url) data = self.checkResponse(response, 200, content_type="application/json") @@ -2123,7 +2208,7 @@ class AlgorithmRetrieval(AlgorithmsAPIBase): def test_successful_retrieval_of_binary_algorithm(self): self.login_jackdoe() - url = reverse("api_algorithms:object", args=["jackdoe", "binary_personal"]) + url = reverse("api_algorithms:object", args=["jackdoe", "binary_personal", 1]) response = self.client.get(url) data = self.checkResponse(response, 200, content_type="application/json") @@ -2139,34 +2224,34 @@ class AlgorithmRetrieval(AlgorithmsAPIBase): class AlgorithmBinaryRetrieval(AlgorithmsAPIBase): def test_no_retrieval_of_confidential_algorithm_for_anonymous_user(self): - url = reverse("api_algorithms:binary", args=["jackdoe", "binary_personal"]) + url = reverse("api_algorithms:binary", args=["jackdoe", "binary_personal", 1]) response = self.client.get(url) self.checkResponse(response, 404) def test_fail_to_retrieve_with_invalid_username(self): self.login_johndoe() - url = reverse("api_algorithms:binary", args=["unknown", "binary_personal"]) + url = reverse("api_algorithms:binary", args=["unknown", "binary_personal", 1]) response = self.client.get(url) self.checkResponse(response, 404) def test_fail_to_retrieve_with_invalid_algorithm_name(self): self.login_johndoe() - url = reverse("api_algorithms:binary", args=["johndoe", "unknown"]) + url = reverse("api_algorithms:binary", args=["johndoe", "unknown", 1]) response = self.client.get(url) self.checkResponse(response, 404) def test_no_retrieval_of_confidential_algorithm(self): self.login_jackdoe() - url = reverse("api_algorithms:binary", args=["johndoe", "binary_personal"]) + url = reverse("api_algorithms:binary", args=["johndoe", "binary_personal", 1]) response = self.client.get(url) self.checkResponse(response, 404) def test_successful_retrieval_of_public_algorithm_for_anonymous_user(self): url = reverse( - "api_algorithms:binary", args=["jackdoe", "binary_public_for_all"] + "api_algorithms:binary", args=["jackdoe", "binary_public_for_all", 1] ) response = self.client.get(url) data = self.checkResponse( @@ -2176,7 +2261,7 @@ class AlgorithmBinaryRetrieval(AlgorithmsAPIBase): self.assertEqual(data, AlgorithmsAPIBase.BINARY) def test_no_retrieval_of_usable_algorithm_for_anonymous_user(self): - url = reverse("api_algorithms:binary", args=["jackdoe", "usable_by_all"]) + url = reverse("api_algorithms:binary", args=["jackdoe", "usable_by_all", 1]) response = self.client.get(url) self.checkResponse(response, 404) @@ -2184,7 +2269,7 @@ class AlgorithmBinaryRetrieval(AlgorithmsAPIBase): self.login_johndoe() url = reverse( - "api_algorithms:binary", args=["jackdoe", "binary_public_for_all"] + "api_algorithms:binary", args=["jackdoe", "binary_public_for_all", 1] ) response = self.client.get(url) data = self.checkResponse( @@ -2197,7 +2282,7 @@ class AlgorithmBinaryRetrieval(AlgorithmsAPIBase): self.login_johndoe() url = reverse( - "api_algorithms:binary", args=["jackdoe", "binary_usable_by_one_user"] + "api_algorithms:binary", args=["jackdoe", "binary_usable_by_one_user", 1] ) response = self.client.get(url) self.checkResponse(response, 404) @@ -2205,7 +2290,9 @@ class AlgorithmBinaryRetrieval(AlgorithmsAPIBase): def test_no_retrieval_of_publicly_usable_algorithm(self): self.login_johndoe() - url = reverse("api_algorithms:binary", args=["jackdoe", "binary_usable_by_all"]) + url = reverse( + "api_algorithms:binary", args=["jackdoe", "binary_usable_by_all", 1] + ) response = self.client.get(url) self.checkResponse(response, 404) @@ -2213,7 +2300,7 @@ class AlgorithmBinaryRetrieval(AlgorithmsAPIBase): self.login_johndoe() url = reverse( - "api_algorithms:binary", args=["jackdoe", "binary_public_for_one_user"] + "api_algorithms:binary", args=["jackdoe", "binary_public_for_one_user", 1] ) response = self.client.get(url) data = self.checkResponse( @@ -2226,7 +2313,7 @@ class AlgorithmBinaryRetrieval(AlgorithmsAPIBase): self.login_jackdoe() url = reverse( - "api_algorithms:binary", args=["jackdoe", "binary_public_for_all"] + "api_algorithms:binary", args=["jackdoe", "binary_public_for_all", 1] ) response = self.client.get(url) data = self.checkResponse( @@ -2238,7 +2325,7 @@ class AlgorithmBinaryRetrieval(AlgorithmsAPIBase): def test_successful_retrieval_of_own_confidential_algorithm(self): self.login_jackdoe() - url = reverse("api_algorithms:binary", args=["jackdoe", "binary_personal"]) + url = reverse("api_algorithms:binary", args=["jackdoe", "binary_personal", 1]) response = self.client.get(url) data = self.checkResponse( response, 200, content_type="application/octet-stream" @@ -2249,7 +2336,9 @@ class AlgorithmBinaryRetrieval(AlgorithmsAPIBase): def test_successful_retrieval_of_own_usable_algorithm(self): self.login_jackdoe() - url = reverse("api_algorithms:binary", args=["jackdoe", "binary_usable_by_all"]) + url = reverse( + "api_algorithms:binary", args=["jackdoe", "binary_usable_by_all", 1] + ) response = self.client.get(url) data = self.checkResponse( response, 200, content_type="application/octet-stream" @@ -2261,7 +2350,7 @@ class AlgorithmBinaryRetrieval(AlgorithmsAPIBase): self.login_jackdoe() url = reverse( - "api_algorithms:binary", args=["jackdoe", "binary_public_for_one_user"] + "api_algorithms:binary", args=["jackdoe", "binary_public_for_one_user", 1] ) response = self.client.get(url) data = self.checkResponse( diff --git a/beat/web/attestations/api_urls.py b/beat/web/attestations/api_urls.py index 7fe694b0a688f64ae93be149cf0177b41fc2e244..cdb369bd9145160afad71e03c816b39703797de9 100644 --- a/beat/web/attestations/api_urls.py +++ b/beat/web/attestations/api_urls.py @@ -26,30 +26,15 @@ ############################################################################### from django.conf.urls import url -from . import api - -urlpatterns = [ - url( - r'^$', - api.CreateAttestationView.as_view(), - name="create" - ), - url( - r'^unlock/(?P<number>\d+)/$', - api.UnlockAttestationView.as_view(), - name="unlock" - ), +from . import api - url( - r'^(?P<number>\d+)/$', - api.DeleteAttestationView.as_view(), - name="delete" - ), +urlpatterns = [ + url(r"^$", api.CreateAttestationView.as_view(), name="create"), url( - r'^(?P<username>\w+)/$', - api.ListUserAttestationView.as_view(), - name="all" + r"^unlock/(?P<number>\d+)/$", api.UnlockAttestationView.as_view(), name="unlock" ), + url(r"^(?P<number>\d+)/$", api.DeleteAttestationView.as_view(), name="delete"), + url(r"^(?P<username>\w+)/$", api.ListUserAttestationView.as_view(), name="all"), ] diff --git a/beat/web/backend/api_urls.py b/beat/web/backend/api_urls.py index cbf2f3724b7003200960eb8343534e3af9eb0f14..b96631e848e73d1412a6b13f3e283d19ab20a1bd 100755 --- a/beat/web/backend/api_urls.py +++ b/beat/web/backend/api_urls.py @@ -26,26 +26,24 @@ ############################################################################### from django.conf.urls import url + from . import api -urlpatterns = [ +urlpatterns = [ url( - r'^environments/$', + r"^environments/$", api.accessible_environments_list, - name='backend-api-environments', + name="backend-api-environments", ), - url( - r'^local_scheduler/start/$', + r"^local_scheduler/start/$", api.start_local_scheduler, - name='local_scheduler-start', + name="local_scheduler-start", ), - url( - r'^local_scheduler/stop/$', + r"^local_scheduler/stop/$", api.stop_local_scheduler, - name='local_scheduler-stop', + name="local_scheduler-stop", ), - ] diff --git a/beat/web/code/api.py b/beat/web/code/api.py index 696cd3aee67e3d01dbbbfa1b39f7acacafe083da..13ec6332210d435a3c555db39b618efa32d7daaf 100755 --- a/beat/web/code/api.py +++ b/beat/web/code/api.py @@ -25,33 +25,26 @@ # # ############################################################################### -from django.utils import six from django.shortcuts import get_object_or_404 -from django.core.exceptions import ValidationError from rest_framework import generics -from rest_framework import permissions from rest_framework.response import Response -from rest_framework.exceptions import PermissionDenied, ParseError -from rest_framework import serializers +from rest_framework import exceptions as drf_exceptions -from ..common.responses import ForbiddenResponse from ..common.api import ShareView, RetrieveUpdateDestroyContributionView -from ..common.utils import validate_restructuredtext, ensure_html from ..common.serializers import DiffSerializer from ..code.models import Code from .serializers import CodeSharingSerializer, CodeSerializer -import simplejson as json class ShareCodeView(ShareView): serializer_class = CodeSharingSerializer def do_share(self, obj, data): - users = data.get('users', None) - teams = data.get('teams', None) - public = data.get('status') == 'public' + users = data.get("users", None) + teams = data.get("teams", None) + public = data.get("status") == "public" obj.share(public=public, users=users, teams=teams) @@ -62,35 +55,36 @@ class DiffView(generics.RetrieveAPIView): def get(self, request, author1, name1, version1, author2, name2, version2): # Retrieve the objects - try: - object1 = self.model.objects.get(author__username__iexact=author1, - name__iexact=name1, - version=int(version1)) - except: - return Response('%s/%s/%s' % (author1, name1, version1), status=404) - try: - object2 = self.model.objects.get(author__username__iexact=author2, - name__iexact=name2, - version=int(version2)) - except: - return Response('%s/%s/%s' % (author2, name2, version2), status=404) + object1 = get_object_or_404( + self.model, + author__username__iexact=author1, + name__iexact=name1, + version=int(version1), + ) + object2 = get_object_or_404( + self.model, + author__username__iexact=author2, + name__iexact=name2, + version=int(version2), + ) # Check that the user can access them has_access, open_source, _ = object1.accessibility_for(request.user) - if not ((request.user == object1.author) or \ - (has_access and open_source)): - return ForbiddenResponse("You cannot access the source-code of \"%s\"" % object1.fullname()) + if not ((request.user == object1.author) or (has_access and open_source)): + raise drf_exceptions.PermissionDenied( + 'You cannot access the source-code of "%s"' % object1.fullname() + ) has_access, open_source, _ = object2.accessibility_for(request.user) - if not ((request.user == object2.author) or \ - (has_access and open_source)): - return ForbiddenResponse("You cannot access the source-code of \"%s\"" % object2.fullname()) + if not ((request.user == object2.author) or (has_access and open_source)): + raise drf_exceptions.PermissionDenied( + 'You cannot access the source-code of "%s"' % object2.fullname() + ) # Compute the diff - serializer = self.get_serializer({'object1': object1, - 'object2': object2}) + serializer = self.get_serializer({"object1": object1, "object2": object2}) return Response(serializer.data) @@ -98,125 +92,13 @@ class RetrieveUpdateDestroyCodeView(RetrieveUpdateDestroyContributionView): model = Code serializer_class = CodeSerializer - def do_update(self, request, author_name, object_name, version=None): - if version is None: - raise ValidationError({'version': 'A version number must be provided'}) - - try: - data = request.data - except ParseError as e: - raise serializers.ValidationError({'data': str(e)}) - else: - if not data: - raise serializers.ValidationError({'data': 'Empty'}) - - if 'short_description' in data: - if not(isinstance(data['short_description'], six.string_types)): - raise ValidationError({'short_description': 'Invalid short_description data'}) - short_description = data['short_description'] - else: - short_description = None - - if 'description' in data: - if not(isinstance(data['description'], six.string_types)): - raise serializers.ValidationError({'description': 'Invalid description data'}) - description = data['description'] - try: - validate_restructuredtext(description) - except ValidationError as errors: - raise serializers.ValidationError({'description': [error for error in errors]}) - else: - description = None - - if 'declaration' in data: - if isinstance(data['declaration'], dict): - json_declaration = data['declaration'] - declaration = json.dumps(json_declaration, indent=4) - elif isinstance(data['declaration'], six.string_types): - declaration = data['declaration'] - try: - json_declaration = json.loads(declaration) - except: - raise serializers.ValidationError({'declaration': 'Invalid declaration data'}) - else: - raise serializers.ValidationError({'declaration': 'Invalid declaration data'}) - - if 'description' in json_declaration: - if short_description is not None: - raise serializers.ValidationError({'short_description': 'A short description is already provided in the declaration'}) - - short_description = json_declaration['description'] - elif short_description is not None: - json_declaration['description'] = short_description - declaration = json.dumps(json_declaration, indent=4) - else: - declaration = None - json_declaration = None - - if (short_description is not None) and (len(short_description) > self.model._meta.get_field('short_description').max_length): - raise ValidationError({'short_description': 'Short description too long'}) - - if 'code' in data: - if not(isinstance(data['code'], six.string_types)): - raise ValidationError({'code': 'Invalid code data'}) - code = data['code'] - else: - code = None - - # Retrieve the object - db_object = get_object_or_404(self.model, - author__username__iexact=author_name, - name__iexact=object_name, - version=version) - - # Check that the object can still be modified (if applicable, the - # documentation can always be modified) - if ((declaration is not None) or (code is not None)) and \ - not(db_object.modifiable()): - raise PermissionDenied("The {} isn't modifiable anymore (either shared with someone else, or needed by an attestation)".format(db_object.model_name())) - - - # Modification of the documentation - if (short_description is not None) and (declaration is None): - tmp_declaration = db_object.declaration - tmp_declaration['description'] = short_description - db_object.declaration = tmp_declaration - - if description is not None: - db_object.description = description - - - # Modification of the declaration - modified = False - if declaration is not None: - db_object.declaration = declaration - modified = True - - - # Modification of the source code - if code is not None: - db_object.source_code = code - modified = True - - db_object.save() - - return modified, db_object - - def get(self, request, *args, **kwargs): - db_objects = self.get_queryset() - - if db_objects.count() == 0: - return Response(status=404) - - db_object = db_objects[0] - version = int(self.kwargs.get('version', -1)) - - if version != -1 and db_object.version != version: - return Response(status=404) + db_object = self.get_object() # Check that the user can access it - (has_access, open_source, accessibility) = db_object.accessibility_for(request.user) + (has_access, open_source, accessibility) = db_object.accessibility_for( + request.user + ) # Process the query string # Other available fields (not returned by default): @@ -226,34 +108,16 @@ class RetrieveUpdateDestroyCodeView(RetrieveUpdateDestroyContributionView): # - needed_dataformats # - attestations fields_to_remove = [] - if ((request.user != db_object.author) and not(open_source)) or db_object.is_binary(): - fields_to_remove = ['code'] + if ( + (request.user != db_object.author) and not (open_source) + ) or db_object.is_binary(): + fields_to_remove = ["code"] - fields_to_return = self.get_serializer_fields(request, allow_sharing=(request.user == db_object.author), - exclude_fields=fields_to_remove) + fields_to_return = self.get_serializer_fields( + request, + allow_sharing=(request.user == db_object.author), + exclude_fields=fields_to_remove, + ) serializer = self.get_serializer(db_object, fields=fields_to_return) return Response(serializer.data) - - - def put(self, request, author_name, object_name, version=None): - (modified, db_object) = self.do_update(request, author_name, object_name, version) - - # Available fields (not returned by default): - # - html_description - if 'fields' in request.GET: - fields_to_return = request.GET['fields'].split(',') - else: - return Response(status=204) - - result = {} - - # Retrieve the description in HTML format - if 'html_description' in fields_to_return: - description = db_object.description - if len(description) > 0: - result['html_description'] = ensure_html(description) - else: - result['html_description'] = '' - - return Response(result) diff --git a/beat/web/code/serializers.py b/beat/web/code/serializers.py index 4f25aef869692005089f669cda10869e52bcde21..1b585461f67bb48bc3f9584bbf6c06afe7ed114d 100755 --- a/beat/web/code/serializers.py +++ b/beat/web/code/serializers.py @@ -25,42 +25,78 @@ # # ############################################################################### + +import difflib + from rest_framework import serializers from ..common.serializers import ContributionCreationSerializer +from ..common.serializers import ContributionModSerializer from ..common.serializers import SharingSerializer from ..common.serializers import ContributionSerializer from ..common.serializers import DiffSerializer from .models import Code -import simplejson as json -import difflib - -#---------------------------------------------------------- +# ---------------------------------------------------------- class CodeCreationSerializer(ContributionCreationSerializer): - code = serializers.CharField(required=False, allow_blank=True, trim_whitespace=False) + code = serializers.CharField( + required=False, allow_blank=True, trim_whitespace=False + ) class Meta(ContributionCreationSerializer.Meta): - fields = ContributionCreationSerializer.Meta.fields + ['code', 'language'] + fields = ContributionCreationSerializer.Meta.fields + ["code", "language"] + + +class CodeModSerializer(ContributionModSerializer): + code = serializers.CharField( + required=False, allow_blank=True, trim_whitespace=False + ) + + class Meta(ContributionModSerializer.Meta): + fields = ContributionModSerializer.Meta.fields + ["code", "language"] + + def save(self, **kwargs): + code = self.validated_data.pop("code", None) + if code is not None: + self.validated_data["source_code"] = code + + return super().save(**kwargs) + + def filter_representation(self, representation): + def add_code(representation): + # This is inspired from the source of Django REST Framework + field = self._declared_fields["code"] + representation["code"] = field.to_representation(self.instance.source_code) + + request = self.context["request"] + fields = request.query_params.get("fields", None) + if fields is not None: + fields = fields.split(",") + if "code" in fields: + add_code(representation) + else: + add_code(representation) + + return super().filter_representation(representation) -#---------------------------------------------------------- +# ---------------------------------------------------------- class CodeSharingSerializer(SharingSerializer): status = serializers.CharField() def validate_status(self, value): - if value not in ['public', 'usable']: - raise serializers.ValidationError('Invalid status value') + if value not in ["public", "usable"]: + raise serializers.ValidationError("Invalid status value") return value -#---------------------------------------------------------- +# ---------------------------------------------------------- class CodeSerializer(ContributionSerializer): @@ -71,19 +107,23 @@ class CodeSerializer(ContributionSerializer): class Meta(ContributionSerializer.Meta): model = Code - default_fields = ContributionSerializer.Meta.default_fields + ['opensource', 'language', 'valid'] - extra_fields = ContributionSerializer.Meta.extra_fields + ['code'] - exclude = ContributionSerializer.Meta.exclude + ['source_code_file'] + default_fields = ContributionSerializer.Meta.default_fields + [ + "opensource", + "language", + "valid", + ] + extra_fields = ContributionSerializer.Meta.extra_fields + ["code"] + exclude = ContributionSerializer.Meta.exclude + ["source_code_file"] def __init__(self, *args, **kwargs): # Don't pass the 'opensource' arg up to the superclass - self.opensource = kwargs.pop('opensource', False) + self.opensource = kwargs.pop("opensource", False) # Instantiate the superclass normally super(ContributionSerializer, self).__init__(*args, **kwargs) def get_opensource(self, obj): - user = self.context.get('user') + user = self.context.get("user") (has_access, open_source, accessibility) = obj.accessibility_for(user) return open_source @@ -92,46 +132,48 @@ class CodeSerializer(ContributionSerializer): def get_accessibility(self, obj): if obj.sharing == Code.PUBLIC: - return 'public' + return "public" elif obj.sharing == Code.SHARED or obj.sharing == Code.USABLE: - return 'confidential' + return "confidential" else: - return 'private' + return "private" def get_sharing(self, obj): - user = self.context.get('user') + user = self.context.get("user") sharing = super(CodeSerializer, self).get_sharing(obj) if user == obj.author: if obj.usable_by.count() > 0: - sharing['usable_by'] = map(lambda x: x.username, - obj.usable_by.iterator()) + sharing["usable_by"] = map( + lambda x: x.username, obj.usable_by.iterator() + ) if obj.usable_by_team.count() > 0: - sharing['usable_by_team'] = map(lambda x: x.name, - obj.usable_by_team.iterator()) + sharing["usable_by_team"] = map( + lambda x: x.name, obj.usable_by_team.iterator() + ) return sharing def get_code(self, obj): - user = self.context.get('user') + user = self.context.get("user") if obj.author != user: (has_access, open_source, accessibility) = obj.accessibility_for(user) - if not(has_access) or not(open_source): + if not (has_access) or not (open_source): return None return obj.source_code -#---------------------------------------------------------- +# ---------------------------------------------------------- class CodeDiffSerializer(DiffSerializer): source_code_diff = serializers.SerializerMethodField() def get_source_code_diff(self, obj): - source1 = obj['object1'].source_code - source2 = obj['object2'].source_code + source1 = obj["object1"].source_code + source2 = obj["object2"].source_code diff = difflib.ndiff(source1.splitlines(), source2.splitlines()) - return '\n'.join(filter(lambda x: x[0] != '?', list(diff))) + return "\n".join(filter(lambda x: x[0] != "?", list(diff))) diff --git a/beat/web/common/api.py b/beat/web/common/api.py index 40e151bcb22869f6ff1d8a902b0d466248f628db..a31f8415f3da8d4074ea5a50c13eb203c153b94f 100644 --- a/beat/web/common/api.py +++ b/beat/web/common/api.py @@ -29,27 +29,33 @@ from django.shortcuts import get_object_or_404 from rest_framework import status from rest_framework import generics -from rest_framework import permissions +from rest_framework import permissions as drf_permissions +from rest_framework import exceptions as drf_exceptions from rest_framework.response import Response from rest_framework.reverse import reverse -from .responses import BadRequestResponse, ForbiddenResponse from .models import Contribution, Versionable -from .permissions import IsAuthor from .exceptions import ShareError, BaseCreationError -from .serializers import SharingSerializer, ContributionSerializer, CheckNameSerializer, DiffSerializer -from .mixins import CommonContextMixin, SerializerFieldsMixin, IsAuthorOrReadOnlyMixin - +from .serializers import ( + SharingSerializer, + ContributionSerializer, + CheckNameSerializer, + DiffSerializer, +) +from .mixins import CommonContextMixin, SerializerFieldsMixin +from .utils import py3_cmp + +from . import permissions as beat_permissions from . import is_true class CheckContributionNameView(CommonContextMixin, generics.CreateAPIView): serializer_class = CheckNameSerializer - permission_classes = [permissions.IsAuthenticated] + permission_classes = [drf_permissions.IsAuthenticated] def get_serializer_context(self): context = super(CheckContributionNameView, self).get_serializer_context() - context['model'] = self.model + context["model"] = self.model return context def post(self, request): @@ -59,28 +65,28 @@ class CheckContributionNameView(CommonContextMixin, generics.CreateAPIView): class ShareView(CommonContextMixin, generics.CreateAPIView): - permission_classes = [permissions.IsAuthenticated, IsAuthor] + permission_classes = [beat_permissions.IsAuthor] serializer_class = SharingSerializer def get_queryset(self): - author_name = self.kwargs.get('author_name') - object_name = self.kwargs.get('object_name') - version = self.kwargs.get('version') - return get_object_or_404(self.model, - author__username__iexact=author_name, - name__iexact=object_name, - version=version) + author_name = self.kwargs.get("author_name") + object_name = self.kwargs.get("object_name") + version = self.kwargs.get("version") + return get_object_or_404( + self.model, + author__username__iexact=author_name, + name__iexact=object_name, + version=version, + ) def do_share(self, obj, data): - users = data.get('users', None) - teams = data.get('teams', None) + users = data.get("users", None) + teams = data.get("teams", None) obj.share(users=users, teams=teams) def post(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) - if not serializer.is_valid(): - return BadRequestResponse(serializer.errors) - + serializer.is_valid(raise_exception=True) data = serializer.data object_db = self.get_queryset() @@ -88,48 +94,61 @@ class ShareView(CommonContextMixin, generics.CreateAPIView): try: self.do_share(object_db, data) except ShareError as e: - return BadRequestResponse(e.errors) + # ShareError happens when someone does not have access to + # Â all dependencies that should be shared. + raise drf_exceptions.PermissionDenied(e.errors) return Response(object_db.sharing_preferences()) -class ListContributionView(CommonContextMixin, SerializerFieldsMixin, generics.ListAPIView): +class ListContributionView( + CommonContextMixin, SerializerFieldsMixin, generics.ListAPIView +): model = Contribution serializer_class = ContributionSerializer - permission_classes = [permissions.AllowAny] + permission_classes = [drf_permissions.AllowAny] def get_queryset(self): return self.model.objects.for_user(self.request.user, True) def get(self, request, *args, **kwargs): fields_to_return = self.get_serializer_fields(request) - limit_to_latest_versions = is_true(request.query_params.get('latest_versions', False)) + limit_to_latest_versions = is_true( + request.query_params.get("latest_versions", False) + ) all_contributions = self.get_queryset().select_related() - if hasattr(self.model, 'author'): - all_contributions = all_contributions.order_by('author__username', 'name', '-version') + if hasattr(self.model, "author"): + all_contributions = all_contributions.order_by( + "author__username", "name", "-version" + ) else: - all_contributions = all_contributions.order_by('name', '-version') + all_contributions = all_contributions.order_by("name", "-version") if limit_to_latest_versions: all_contributions = self.model.filter_latest_versions(all_contributions) # Sort the data formats and sends the response - all_contributions.sort(lambda x, y: cmp(x.fullname(), y.fullname())) + all_contributions.sort(lambda x, y: py3_cmp(x.fullname(), y.fullname())) - serializer = self.get_serializer(all_contributions, many=True, fields=fields_to_return) + serializer = self.get_serializer( + all_contributions, many=True, fields=fields_to_return + ) return Response(serializer.data) -class ListCreateBaseView(CommonContextMixin, SerializerFieldsMixin, generics.ListCreateAPIView): - +class ListCreateBaseView( + CommonContextMixin, SerializerFieldsMixin, generics.ListCreateAPIView +): def get_serializer(self, *args, **kwargs): - if self.request.method == 'POST': + if self.request.method == "POST": self.serializer_class = self.writing_serializer_class return super(ListCreateBaseView, self).get_serializer(*args, **kwargs) def get(self, request, *args, **kwargs): fields_to_return = self.get_serializer_fields(request) - limit_to_latest_versions = is_true(request.query_params.get('latest_versions', False)) + limit_to_latest_versions = is_true( + request.query_params.get("latest_versions", False) + ) objects = self.get_queryset().select_related() @@ -141,38 +160,38 @@ class ListCreateBaseView(CommonContextMixin, SerializerFieldsMixin, generics.Lis def post(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) - if serializer.is_valid(): - try: - if hasattr(self.model, 'author'): - db_object = serializer.save(author=request.user) - else: - db_object = serializer.save() - except BaseCreationError as e: - return BadRequestResponse(e.errors) - else: - return BadRequestResponse(serializer.errors) - - html_view_args = [db_object.name, db_object.version] - if hasattr(db_object, "author"): - html_view_args.insert(0, db_object.author.username) + serializer.is_valid(raise_exception=True) + try: + if hasattr(self.model, "author"): + db_object = serializer.save(author=request.user) + else: + db_object = serializer.save() + except BaseCreationError as e: + raise drf_exceptions.APIException(e.errors) result = { - 'name': db_object.name, - 'full_name': db_object.fullname(), - 'url': reverse('{}:all'.format(self.namespace)) + db_object.fullname() + '/', - 'object_view': reverse('{}:view'.format(self.namespace.split('_')[1]), args=html_view_args), + "name": db_object.name, + "full_name": db_object.fullname(), + "url": reverse("{}:all".format(self.namespace)) + + db_object.fullname() + + "/", + "object_view": reverse( + "{}:view".format(self.namespace.split("_")[1]), + args=db_object.fullname().split("/"), + ), } response = Response(result, status=201) - response['Location'] = result['url'] + response["Location"] = result["url"] return response -class ListCreateContributionView(IsAuthorOrReadOnlyMixin, ListCreateBaseView): +class ListCreateContributionView(ListCreateBaseView): + permission_classes = [beat_permissions.IsAuthorOrReadOnly] def get_queryset(self): user = self.request.user - author_name = self.kwargs.get('author_name') + author_name = self.kwargs.get("author_name") return self.model.objects.from_author_and_public(user, author_name) @@ -182,96 +201,83 @@ class DiffView(generics.RetrieveAPIView): def get(self, request, author1, name1, version1, author2, name2, version2): # Retrieve the objects - try: - object1 = self.model.objects.get(author__username__iexact=author1, - name__iexact=name1, - version=int(version1)) - except: - return Response('%s/%s/%s' % (author1, name1, version1), status=404) - - try: - object2 = self.model.objects.get(author__username__iexact=author2, - name__iexact=name2, - version=int(version2)) - except: - return Response('%s/%s/%s' % (author2, name2, version2), status=404) - + object1 = get_object_or_404( + self.model, + author__username__iexact=author1, + name__iexact=name1, + version=int(version1), + ) + + object2 = get_object_or_404( + self.model, + author__username__iexact=author2, + name__iexact=name2, + version=int(version2), + ) # Check that the user can access them accessibility = object1.accessibility_for(request.user) if not accessibility[0]: - return ForbiddenResponse(object1.fullname()) + raise drf_exceptions.PermissionDenied(object1.fullname()) accessibility = object2.accessibility_for(request.user) if not accessibility[0]: - return ForbiddenResponse(object2.fullname()) + raise drf_exceptions.PermissionDenied(object2.fullname()) # Compute the diff - serializer = self.get_serializer({'object1': object1, - 'object2': object2}) + serializer = self.get_serializer({"object1": object1, "object2": object2}) return Response(serializer.data) -class RetrieveUpdateDestroyContributionView(CommonContextMixin, SerializerFieldsMixin, IsAuthorOrReadOnlyMixin, generics.RetrieveUpdateDestroyAPIView): +class RetrieveUpdateDestroyContributionView( + CommonContextMixin, SerializerFieldsMixin, generics.RetrieveUpdateDestroyAPIView +): model = Contribution + permission_classes = [ + beat_permissions.IsAuthorOrReadOnly, + beat_permissions.IsModifiableOrRead, + ] - def get_queryset(self): - version = self.kwargs.get('version', None) - author_name = self.kwargs.get('author_name') - object_name = self.kwargs.get('object_name') - user = self.request.user - if version is not None: - queryset = self.model.objects.for_user(user, True).filter(author__username__iexact=author_name, - name__iexact=object_name, - version__gte=version)\ - .order_by('version') - else: - queryset = self.model.objects.for_user(user, True).filter(author__username__iexact=author_name, - name__iexact=object_name).order_by('-version') + def get_serializer(self, *args, **kwargs): + if self.request.method == "PUT": + self.serializer_class = self.writing_serializer_class + return super().get_serializer(*args, **kwargs) - return queryset + def get_object(self): + version = self.kwargs["version"] + author_name = self.kwargs["author_name"] + object_name = self.kwargs["object_name"] + user = self.request.user + try: + obj = self.model.objects.for_user(user, True).get( + author__username__iexact=author_name, + name__iexact=object_name, + version=version, + ) + except self.model.DoesNotExist: + raise drf_exceptions.NotFound() + return obj def get(self, request, *args, **kwargs): - db_objects = self.get_queryset() - - if db_objects.count() == 0: - return Response(status=404) - - - db_object = db_objects[0] - version = int(self.kwargs.get('version', -1)) - - if version != -1 and db_object.version != version: - return Response(status=404) + db_object = self.get_object() + self.check_object_permissions(request, db_object) # Process the query string - allow_sharing = hasattr(db_object, 'author') and (request.user == db_object.author) + allow_sharing = request.user == db_object.author - fields_to_return = self.get_serializer_fields(request, allow_sharing=allow_sharing) + fields_to_return = self.get_serializer_fields( + request, allow_sharing=allow_sharing + ) serializer = self.get_serializer(db_object, fields=fields_to_return) return Response(serializer.data) - - def delete(self, request, *args, **kwargs): - author_name = self.kwargs.get('author_name') - object_name = self.kwargs.get('object_name') - version = self.kwargs.get('version', None) - - # Retrieve the object - if version is None: - return BadRequestResponse('A version number must be provided') - - db_object = get_object_or_404(self.model, - author__username__iexact=author_name, - name__iexact=object_name, - version=version) - + def perform_destroy(self, instance): # Check that the object can be deleted - if not(db_object.deletable()): - return ForbiddenResponse("The {} can't be deleted anymore (needed by an attestation, an algorithm or another data format)".format(db_object.model_name())) - - - # Deletion of the object - db_object.delete() - return Response(status=204) + if not (instance.deletable()): + raise drf_exceptions.MethodNotAllowed( + "The {} can't be deleted anymore (needed by an attestation, an algorithm or another data format)".format( + instance.model_name() + ) + ) + return super().perform_destroy(instance) diff --git a/beat/web/common/fields.py b/beat/web/common/fields.py index fc932fbbaebdbff1322394d4a89134f72b04c2fc..ffa5e5e7e39956d6e21999ce9afa46198e879bb2 100644 --- a/beat/web/common/fields.py +++ b/beat/web/common/fields.py @@ -25,22 +25,22 @@ # # ############################################################################### -from rest_framework import serializers -from django.utils import six +import simplejson as json -import simplejson +from rest_framework import serializers +from rest_framework.fields import JSONField as drf_JSONField class StringListField(serializers.ListField): child = serializers.CharField() -class JSONSerializerField(serializers.Field): - """ Serializer for JSONField -- required to make field writable""" +class JSONField(drf_JSONField): def to_internal_value(self, data): - if isinstance(data, six.string_types): - return simplejson.loads(data) - return data + if isinstance(data, str): + try: + data = json.loads(data) + except json.errors.JSONDecodeError: + self.fail("invalid") - def to_representation(self, value): - return value + return super().to_internal_value(data) diff --git a/beat/web/common/mixins.py b/beat/web/common/mixins.py index b02307999649c7e42c11613175add7a01ed94a50..24bfd2b6b398ed9c4a035e95ee4e9d41a8896ba4 100644 --- a/beat/web/common/mixins.py +++ b/beat/web/common/mixins.py @@ -25,8 +25,6 @@ # # ############################################################################### -from rest_framework import permissions -from .permissions import IsAuthor class CommonContextMixin(object): """ @@ -34,16 +32,17 @@ class CommonContextMixin(object): request user in the serializer context optionnaly the request object format """ + def get_serializer_context(self): context = super(CommonContextMixin, self).get_serializer_context() - context['user'] = self.request.user + context["user"] = self.request.user - object_format = self.request.GET.get('object_format', None) + object_format = self.request.GET.get("object_format", None) if object_format is not None: - if object_format not in ['json', 'string']: - object_format = 'json' + if object_format not in ["json", "string"]: + object_format = "json" - context['object_format'] = object_format + context["object_format"] = object_format return context @@ -53,52 +52,27 @@ class SerializerFieldsMixin(object): Apply this mixin to any view or viewset to get the list of fields to return """ + def get_serializer_fields(self, request, allow_sharing=False, exclude_fields=[]): # Process the query string fields = None query_params = request.query_params - if 'fields' in query_params: - fields = query_params['fields'].split(',') + if "fields" in query_params: + fields = query_params["fields"].split(",") else: fields = self.get_serializer_class().Meta.default_fields - if 'include_fields' in query_params: - include_fields = query_params['include_fields'].split(',') + if "include_fields" in query_params: + include_fields = query_params["include_fields"].split(",") fields.extend(include_fields) - if not(allow_sharing): - exclude_fields = ['sharing'] + exclude_fields + if not (allow_sharing): + exclude_fields = ["sharing"] + exclude_fields if request.user.is_anonymous(): - exclude_fields = ['is_owner'] + exclude_fields + exclude_fields = ["is_owner"] + exclude_fields fields = [field for field in fields if field not in exclude_fields] return fields - - -class IsAuthorOrReadOnlyMixin(object): - """ - Apply this mixin to any view or viewset. Allows read for - all and modification only by author - """ - def get_permissions(self): - if self.request.method == 'GET': - self.permission_classes = [permissions.AllowAny] - else: - self.permission_classes = [permissions.IsAuthenticated, IsAuthor] - return super(IsAuthorOrReadOnlyMixin, self).get_permissions() - - -class IsAdminOrReadOnlyMixin(object): - """ - Apply this mixin to any view or viewset. Allows read for - all and modification only by admin - """ - def get_permissions(self): - if self.request.method == 'GET': - self.permission_classes = [permissions.AllowAny] - else: - self.permission_classes = [permissions.IsAuthenticated, permissions.IsAdminUser] - return super(IsAdminOrReadOnlyMixin, self).get_permissions() diff --git a/beat/web/common/permissions.py b/beat/web/common/permissions.py index 05081b8e8558a11255e2575040a5f623c1ee17f0..9ca392fe28632acf8bd6c7ee2f9ed9af662e125c 100644 --- a/beat/web/common/permissions.py +++ b/beat/web/common/permissions.py @@ -27,21 +27,64 @@ from rest_framework import permissions + class IsSuperuser(permissions.BasePermission): """ Global permission check for super user """ + def has_permission(self, request, view): return request.user.is_superuser -class IsAuthor(permissions.BasePermission): +class IsAuthor(permissions.IsAuthenticated): """ Global permission check that verify if the user is also the onwer of the asked data """ def has_permission(self, request, view): - kwargs = request.parser_context.get('kwargs') - author_name = kwargs.get('author_name') - return request.user.username == author_name + allowed = super().has_permission(request, view) + if allowed: + kwargs = request.parser_context.get("kwargs") + author_name = kwargs.get("author_name") + allowed = request.user.username == author_name + return allowed + + +class IsAuthorOrReadOnly(IsAuthor): + """ + Either allow access if using a read method or + check that the user is also the author. + """ + + def has_permission(self, request, view): + if request.method in permissions.SAFE_METHODS: + return True + else: + return super().has_permission(request, view) + + +class IsAdminOrReadOnly(permissions.IsAdminUser): + """ + Either allow access if using a read method or + check that the user is an admin. + """ + + def has_permission(self, request, view): + if request.method in permissions.SAFE_METHODS: + return True + else: + return super().has_permission(request, view) + + +class IsModifiableOrRead(permissions.BasePermission): + """ + Check for modifiable flag if there's a modification that is tried + """ + + def has_object_permission(self, request, view, obj): + if request.method in permissions.SAFE_METHODS: + return True + else: + return obj.modifiable() diff --git a/beat/web/common/serializers.py b/beat/web/common/serializers.py index 78b547c47b1baa70568e28a6bea02dc8d4f7e5ab..571ae87ba4874ea7db8770526a90944aa719cd82 100644 --- a/beat/web/common/serializers.py +++ b/beat/web/common/serializers.py @@ -25,24 +25,28 @@ # # ############################################################################### +import copy +import simplejson as json +import difflib + +from django.conf import settings from django.contrib.auth.models import User from django.utils import six -from django.db.models import CharField, Value as V -from django.db.models.functions import Concat - from rest_framework import serializers +from rest_framework import exceptions as drf_exceptions from ..team.models import Team from ..common.utils import ensure_html +from ..common.utils import annotate_full_name +from ..common.utils import validate_restructuredtext +from ..common import fields as beat_fields from .models import Shareable, Versionable, Contribution from .exceptions import ContributionCreationError -from .fields import JSONSerializerField, StringListField -import simplejson as json -import difflib +from . import fields as serializer_fields # ---------------------------------------------------------- @@ -83,8 +87,8 @@ class CheckNameSerializer(serializers.Serializer): class SharingSerializer(serializers.Serializer): - users = StringListField(required=False) - teams = StringListField(required=False) + users = serializer_fields.StringListField(required=False) + teams = serializer_fields.StringListField(required=False) def validate_users(self, users): user_accounts = User.objects.filter(username__in=users).values_list( @@ -297,24 +301,80 @@ class ContributionSerializer(VersionableSerializer): # ---------------------------------------------------------- -class ContributionCreationSerializer(serializers.ModelSerializer): - declaration = JSONSerializerField(required=False) +class ContributionModSerializer(serializers.ModelSerializer): + declaration = beat_fields.JSONField(required=False) description = serializers.CharField(required=False, allow_blank=True) - fork_of = serializers.JSONField(required=False) + + class Meta: + fields = ["short_description", "description", "declaration"] + beat_core_class = None + + def validate_description(self, description): + if description.find("\\") >= 0: # was escaped, unescape + description = description.decode("string_escape") + validate_restructuredtext(description) + return description + + def validate_declaration(self, declaration): + decl = copy.deepcopy(declaration) + obj = self.Meta.beat_core_class(prefix=settings.PREFIX, data=decl) + if not obj.valid: + raise drf_exceptions.ValidationError(obj.errors) + return declaration + + def update(self, instance, validated_data): + declaration = validated_data.get("declaration") + + if declaration is not None and not instance.modifiable(): + raise drf_exceptions.PermissionDenied( + "The {} isn't modifiable anymore (either shared with someone else, or needed by an attestation)".format( + self.Meta.model.__name__.lower() + ) + ) + + return super().update(instance, validated_data) + + def filter_representation(self, representation): + """Filter out fields if given in query parameters""" + + request = self.context["request"] + fields = request.query_params.get("fields", None) + if fields is not None: + fields = fields.split(",") + to_remove = [key for key in representation.keys() if key not in fields] + for key in to_remove: + representation.pop(key) + + # Retrieve the description in HTML format + if "html_description" in fields: + description = self.instance.description + if len(description) > 0: + representation["html_description"] = ensure_html(description) + else: + representation["html_description"] = "" + + return representation + + def to_representation(self, instance): + representation = super().to_representation(instance) + return self.filter_representation(representation) + + +# ---------------------------------------------------------- + + +class ContributionCreationSerializer(ContributionModSerializer): + fork_of = serializers.CharField(required=False) previous_version = serializers.CharField(required=False) version = serializers.IntegerField(min_value=1) - class Meta: - fields = [ + class Meta(ContributionModSerializer.Meta): + fields = ContributionModSerializer.Meta.fields + [ "name", - "short_description", - "description", - "declaration", "previous_version", "fork_of", "version", ] - beat_core_class = None def validate_fork_of(self, fork_of): if "previous_version" in self.initial_data: @@ -330,25 +390,23 @@ class ContributionCreationSerializer(serializers.ModelSerializer): ) return previous_version - def validate_description(self, description): - if description.find("\\") >= 0: # was escaped, unescape - description = description.decode("string_escape") - return description - - def validate(self, data): - user = self.context.get("user") - name = self.Meta.model.sanitize_name(data["name"]) - data["name"] = name - version = data.get("version") - + def validate_version(self, version): # If version is not one then it's necessarily a new version # forks start at one - if version > 1 and "previous_version" not in data: + if version > 1 and "previous_version" not in self.initial_data: + name = self.initial_data["name"] raise serializers.ValidationError( "{} {} version {} incomplete history data posted".format( self.Meta.model.__name__.lower(), name, version ) ) + return version + + def validate(self, data): + user = self.context.get("user") + name = self.Meta.model.sanitize_name(data["name"]) + data["name"] = name + version = data.get("version") if self.Meta.model.objects.filter( author=user, name=name, version=version @@ -364,16 +422,9 @@ class ContributionCreationSerializer(serializers.ModelSerializer): if previous_version is not None: try: - previous_object = self.Meta.model.objects.annotate( - fullname=Concat( - "author__username", - V("/"), - "name", - V("/"), - "version", - output_field=CharField(), - ) - ).get(fullname=previous_version) + previous_object = annotate_full_name(self.Meta.model.objects).get( + full_name=previous_version + ) except self.Meta.model.DoesNotExist: raise serializers.ValidationError( "{} '{}' not found".format( @@ -398,16 +449,9 @@ class ContributionCreationSerializer(serializers.ModelSerializer): raise serializers.ValidationError("A fork starts at 1") try: - forked_of_object = self.Meta.model.objects.annotate( - fullname=Concat( - "author__username", - V("/"), - "name", - V("/"), - "version", - output_field=CharField(), - ) - ).get(fullname=fork_of) + forked_of_object = annotate_full_name(self.Meta.model.objects).get( + full_name=fork_of + ) except self.Meta.model.DoesNotExist: raise serializers.ValidationError( "{} '{}' fork origin not found".format( diff --git a/beat/web/common/utils.py b/beat/web/common/utils.py index a2c5886e3123c01d0eb73e0d3d7e8e8531e85245..785a93f4699263e77727418e6573badc1752f5cb 100644 --- a/beat/web/common/utils.py +++ b/beat/web/common/utils.py @@ -31,22 +31,22 @@ Reusable help functions from django.core.exceptions import ValidationError from django.utils.encoding import force_text from django.utils import six +from django.db.models import CharField, Value as V +from django.db.models.functions import Concat + +from docutils import utils +from docutils.nodes import Element +from docutils.core import Publisher from ..ui.templatetags.markup import restructuredtext + def validate_restructuredtext(value): """Validates a piece of restructuredtext for strict conformance""" - try: - from docutils import utils - from docutils.nodes import Element - from docutils.core import Publisher - except ImportError: - raise forms.ValidationError("Error in 'reStructuredText' validator: The Python docutils package isn't installed.") - # Generate a new parser (copying `rst2html.py` flow) pub = Publisher(None, None, None, settings=None) - pub.set_components('standalone', 'restructuredtext', 'pseudoxml') + pub.set_components("standalone", "restructuredtext", "pseudoxml") # Configure publisher settings = pub.get_settings(halt_level=5) @@ -63,12 +63,13 @@ def validate_restructuredtext(value): # Collect errors via an observer errors = [] + def error_collector(data): # Mutate the data since it was just generated - data.line = data['line'] - data.source = data['source'] - data.level = data['level'] - data.type = data['type'] + data.line = data["line"] + data.source = data["source"] + data.level = data["level"] + data.type = data["type"] data.message = Element.astext(data.children[0]) data.full_message = Element.astext(data) @@ -82,8 +83,8 @@ def validate_restructuredtext(value): # Apply transforms (and more collect errors) document.transformer.populate_from_components( - (pub.source, pub.reader, pub.reader.parser, pub.writer, - pub.destination)) + (pub.source, pub.reader, pub.reader.parser, pub.writer, pub.destination) + ) transformer = document.transformer while transformer.transforms: if not transformer.sorted: @@ -99,25 +100,33 @@ def validate_restructuredtext(value): # errors should be ready if errors: validation_list = [] - msg = 'Line %(line)d (severity %(level)d): %(message)s' + msg = "Line %(line)d (severity %(level)d): %(message)s" for error in errors: - #from docutils.parsers.rst module: - #debug(0), info(1), warning(2), error(3), severe(4) + # from docutils.parsers.rst module: + # debug(0), info(1), warning(2), error(3), severe(4) if error.level > 1: - validation_list.append(ValidationError( - msg, code=error.type.lower(), - params=dict(line=error.line, level=error.level, message=error.message))) + validation_list.append( + ValidationError( + msg, + code=error.type.lower(), + params=dict( + line=error.line, level=error.level, message=error.message + ), + ) + ) if validation_list: raise ValidationError(validation_list) + def ensure_html(text): try: validate_restructuredtext(text) - except ValidationError as e: - return '<pre>{}</pre>'.format(text) + except ValidationError: + return "<pre>{}</pre>".format(text) else: return restructuredtext(text) + def ensure_string(data): """ Ensure that we have a str object from data which can be either a str in @@ -125,6 +134,34 @@ def ensure_string(data): """ if data is not None: if isinstance(data, six.binary_type): - return data.decode('utf-8') + return data.decode("utf-8") return data - return '' \ No newline at end of file + return "" + + +def py3_cmp(a, b): + """ + cmp is not available anymore in Python 3. This method is implemetend + as recommended in the documentation for this kind of use case. + Based on: + https://docs.python.org/3.0/whatsnew/3.0.html#ordering-comparisons + """ + return (a > b) - (a < b) + + +def annotate_full_name(query): + """ + Annotate a query with the asset full name so that it can be more easily + filtered. + """ + + return query.annotate( + full_name=Concat( + "author__username", + V("/"), + "name", + V("/"), + "version", + output_field=CharField(), + ) + ) diff --git a/beat/web/databases/api.py b/beat/web/databases/api.py index 209747d8d32c6b3592da8f8992efe5c836b31a2a..b076cb2ec4d35daf45c3dd1722907f2c4ff810f3 100755 --- a/beat/web/databases/api.py +++ b/beat/web/databases/api.py @@ -27,129 +27,140 @@ import os import json - -from django.http import HttpResponse -from django.core.urlresolvers import reverse +import logging from rest_framework.response import Response -from rest_framework import permissions +from rest_framework import permissions as drf_permissions +from rest_framework import exceptions as drf_exceptions from rest_framework import views -from rest_framework import status -from rest_framework import generics - from .models import Database from .models import DatabaseSetTemplate from .serializers import DatabaseSerializer, DatabaseCreationSerializer -from .exceptions import DatabaseCreationError from ..common import is_true -from ..common.mixins import IsAdminOrReadOnlyMixin +from ..common import permissions as beat_permissions from ..common.api import ListCreateBaseView -from ..common.responses import BadRequestResponse from ..common.utils import ensure_html + from ..dataformats.serializers import ReferencedDataFormatSerializer -import logging -import traceback logger = logging.getLogger(__name__) -#---------------------------------------------------------- +# ---------------------------------------------------------- + -def database_to_json(database, request_user, fields_to_return, - last_version=None): +def database_to_json(database, request_user, fields_to_return): # Prepare the response result = {} - if 'name' in fields_to_return: - result['name'] = database.fullname() + if "name" in fields_to_return: + result["name"] = database.fullname() + + if "version" in fields_to_return: + result["version"] = database.version - if 'version' in fields_to_return: - result['version'] = database.version + if "last_version" in fields_to_return: + latest = ( + Database.objects.for_user(request_user, True) + .filter(name=database.name) + .order_by("-version")[:1] + .first() + ) - if 'last_version' in fields_to_return: - result['last_version'] = last_version + result["last_version"] = database.version == latest.version - if 'short_description' in fields_to_return: - result['short_description'] = database.short_description + if "short_description" in fields_to_return: + result["short_description"] = database.short_description - if 'description' in fields_to_return: - result['description'] = database.description + if "description" in fields_to_return: + result["description"] = database.description - if 'previous_version' in fields_to_return: - result['previous_version'] = \ - (database.previous_version.fullname() if \ - database.previous_version is not None else None) + if "previous_version" in fields_to_return: + result["previous_version"] = ( + database.previous_version.fullname() + if database.previous_version is not None + else None + ) - if 'creation_date' in fields_to_return: - result['creation_date'] = database.creation_date.isoformat(' ') + if "creation_date" in fields_to_return: + result["creation_date"] = database.creation_date.isoformat(" ") - if 'hash' in fields_to_return: - result['hash'] = database.hash + if "hash" in fields_to_return: + result["hash"] = database.hash - if 'accessibility' in fields_to_return: - result['accessibility'] = database.accessibility_for(request_user) + if "accessibility" in fields_to_return: + result["accessibility"] = database.accessibility_for(request_user) return result def clean_paths(declaration): - pseudo_path = '/path_to_db_folder' - root_folder = declaration['root_folder'] + pseudo_path = "/path_to_db_folder" + root_folder = declaration["root_folder"] cleaned_folder = os.path.basename(os.path.normpath(root_folder)) - declaration['root_folder'] = os.path.join(pseudo_path, cleaned_folder) - for protocol in declaration['protocols']: - for set_ in protocol['sets']: - if 'parameters' in set_ and 'annotations' in set_['parameters']: - annotations_folder = set_['parameters']['annotations'] - cleaned_folder = annotations_folder.split('/')[-2:] - set_['parameters']['annotations'] = os.path.join(pseudo_path, *cleaned_folder) + declaration["root_folder"] = os.path.join(pseudo_path, cleaned_folder) + for protocol in declaration["protocols"]: + for set_ in protocol["sets"]: + if "parameters" in set_ and "annotations" in set_["parameters"]: + annotations_folder = set_["parameters"]["annotations"] + cleaned_folder = annotations_folder.split("/")[-2:] + set_["parameters"]["annotations"] = os.path.join( + pseudo_path, *cleaned_folder + ) return declaration -#---------------------------------------------------------- +# ---------------------------------------------------------- -class ListCreateDatabasesView(IsAdminOrReadOnlyMixin, ListCreateBaseView): + +class ListCreateDatabasesView(ListCreateBaseView): """ Read/Write end point that list the database available to a user and allows the creation of new databases only to platform administrator """ + model = Database + permission_classes = [beat_permissions.IsAdminOrReadOnly] serializer_class = DatabaseSerializer writing_serializer_class = DatabaseCreationSerializer - namespace = 'api_databases' + namespace = "api_databases" def get_queryset(self): user = self.request.user return self.model.objects.for_user(user, True) - def get(self, request, *args, **kwargs): fields_to_return = self.get_serializer_fields(request) - limit_to_latest_versions = is_true(request.query_params.get('latest_versions', False)) + limit_to_latest_versions = is_true( + request.query_params.get("latest_versions", False) + ) - all_databases = self.get_queryset().order_by('name') + all_databases = self.get_queryset().order_by("name") if limit_to_latest_versions: all_databases = self.model.filter_latest_versions(all_databases) all_databases.sort(key=lambda x: x.fullname()) - serializer = self.get_serializer(all_databases, many=True, fields=fields_to_return) + serializer = self.get_serializer( + all_databases, many=True, fields=fields_to_return + ) return Response(serializer.data) -#---------------------------------------------------------- +# ---------------------------------------------------------- class ListTemplatesView(views.APIView): """ List all templates available """ - permission_classes = [permissions.AllowAny] + + permission_classes = [drf_permissions.AllowAny] def get(self, request): result = {} @@ -158,39 +169,39 @@ class ListTemplatesView(views.APIView): databases = Database.objects.for_user(request.user, True) databases = Database.filter_latest_versions(databases) - for set_template in DatabaseSetTemplate.objects.filter(sets__protocol__database__in=databases).distinct().order_by('name'): - (db_template, dataset) = set_template.name.split('__') + for set_template in ( + DatabaseSetTemplate.objects.filter(sets__protocol__database__in=databases) + .distinct() + .order_by("name") + ): + (db_template, dataset) = set_template.name.split("__") if db_template not in result: - result[db_template] = { - 'templates': {}, - 'sets': [], - } + result[db_template] = {"templates": {}, "sets": []} - result[db_template]['templates'][dataset] = map(lambda x: x.name, set_template.outputs.order_by('name')) + result[db_template]["templates"][dataset] = map( + lambda x: x.name, set_template.outputs.order_by("name") + ) known_sets = [] for db_set in set_template.sets.iterator(): if db_set.name not in known_sets: - result[db_template]['sets'].append({ - 'name': db_set.name, - 'template': dataset, - 'id': db_set.id, - }) + result[db_template]["sets"].append( + {"name": db_set.name, "template": dataset, "id": db_set.id} + ) known_sets.append(db_set.name) for name, entry in result.items(): - entry['sets'].sort(key=lambda x: x['id']) - result[name]['sets'] = map(lambda x: { - 'name': x['name'], - 'template': x['template'], - }, entry['sets']) + entry["sets"].sort(key=lambda x: x["id"]) + result[name]["sets"] = map( + lambda x: {"name": x["name"], "template": x["template"]}, entry["sets"] + ) return Response(result) -#---------------------------------------------------------- +# ---------------------------------------------------------- class RetrieveDatabaseView(views.APIView): @@ -198,102 +209,92 @@ class RetrieveDatabaseView(views.APIView): Returns the given database details """ - permission_classes = [permissions.AllowAny] + model = Database + permission_classes = [drf_permissions.AllowAny] - def get(self, request, database_name, version=None): - # Retrieve the database + def get_object(self): + version = self.kwargs["version"] + database_name = self.kwargs["database_name"] + user = self.request.user try: - if version is not None: - version = int(version) - - databases = Database.objects.for_user(request.user, True).filter(name__iexact=database_name, - version__gte=version).order_by('version') - - database = databases[0] - - if database.version != version: - return HttpResponse(status=status.HTTP_404_NOT_FOUND) - - last_version = (len(databases) == 1) - else: - database = Database.objects.for_user(request.user, True).filter( - name__iexact=database_name).order_by('-version')[0] - last_version = True - except: - return HttpResponse(status=status.HTTP_404_NOT_FOUND) - + obj = self.model.objects.for_user(user, True).get( + name__iexact=database_name, version=version + ) + except self.model.DoesNotExist: + raise drf_exceptions.NotFound() + return obj + + def get(self, request, database_name, version): + # Retrieve the database + database = self.get_object() + self.check_object_permissions(request, database) # Process the query string - if 'fields' in request.GET: - fields_to_return = request.GET['fields'].split(',') + if "fields" in request.GET: + fields_to_return = request.GET["fields"].split(",") else: - fields_to_return = [ 'name', 'version', 'last_version', - 'short_description', 'description', 'fork_of', - 'previous_version', 'is_owner', 'accessibility', 'sharing', - 'opensource', 'hash', 'creation_date', - 'declaration', 'code' ] - - - try: - # Prepare the response - result = database_to_json(database, request.user, fields_to_return, - last_version=last_version) - - # Retrieve the code - if 'declaration' in fields_to_return: - try: - declaration = database.declaration - except: - logger.error(traceback.format_exc()) - return HttpResponse(status=500) - cleaned_declaration = clean_paths(declaration) - result['declaration'] = json.dumps(cleaned_declaration) - - # Retrieve the source code - if 'code' in fields_to_return: - try: - result['code'] = database.source_code - except: - logger.error(traceback.format_exc()) - return HttpResponse(status=500) - - - # Retrieve the description in HTML format - if 'html_description' in fields_to_return: - description = database.description - if len(description) > 0: - result['html_description'] = ensure_html(description) - else: - result['html_description'] = '' - - - # Retrieve the referenced data formats - if 'referenced_dataformats' in fields_to_return: - dataformats = database.all_referenced_dataformats() - - referenced_dataformats = [] - for dataformat in dataformats: - (has_access, accessibility) = dataformat.accessibility_for(request.user) - if has_access: - referenced_dataformats.append(dataformat) - serializer = ReferencedDataFormatSerializer(referenced_dataformats) - result['referenced_dataformats'] = serializer.data - - - # Retrieve the needed data formats - if 'needed_dataformats' in fields_to_return: - dataformats = database.all_needed_dataformats() - - needed_dataformats = [] - for dataformat in dataformats: - (has_access, accessibility) = dataformat.accessibility_for(request.user) - if has_access: - needed_dataformats.append(dataformat) - serializer = ReferencedDataFormatSerializer(needed_dataformats) - result['needed_dataformats'] = serializer.data - - # Return the result - return Response(result) - except: - logger.error(traceback.format_exc()) - return HttpResponse(status=500) + fields_to_return = [ + "name", + "version", + "last_version", + "short_description", + "description", + "fork_of", + "previous_version", + "is_owner", + "accessibility", + "sharing", + "opensource", + "hash", + "creation_date", + "declaration", + "code", + ] + + # Prepare the response + result = database_to_json(database, request.user, fields_to_return) + + # Retrieve the code + if "declaration" in fields_to_return: + declaration = database.declaration + cleaned_declaration = clean_paths(declaration) + result["declaration"] = json.dumps(cleaned_declaration) + + # Retrieve the source code + if "code" in fields_to_return: + result["code"] = database.source_code + + # Retrieve the description in HTML format + if "html_description" in fields_to_return: + description = database.description + if len(description) > 0: + result["html_description"] = ensure_html(description) + else: + result["html_description"] = "" + + # Retrieve the referenced data formats + if "referenced_dataformats" in fields_to_return: + dataformats = database.all_referenced_dataformats() + + referenced_dataformats = [] + for dataformat in dataformats: + (has_access, accessibility) = dataformat.accessibility_for(request.user) + if has_access: + referenced_dataformats.append(dataformat) + serializer = ReferencedDataFormatSerializer(referenced_dataformats) + result["referenced_dataformats"] = serializer.data + + # Retrieve the needed data formats + if "needed_dataformats" in fields_to_return: + dataformats = database.all_needed_dataformats() + + needed_dataformats = [] + for dataformat in dataformats: + (has_access, accessibility) = dataformat.accessibility_for(request.user) + if has_access: + needed_dataformats.append(dataformat) + serializer = ReferencedDataFormatSerializer(needed_dataformats) + result["needed_dataformats"] = serializer.data + + # Return the result + return Response(result) diff --git a/beat/web/databases/api_urls.py b/beat/web/databases/api_urls.py index ef0737b5df3bfff455b2a49d7b93f152ab875b3a..b84db394a1ed73c278e519034148e9be8f8259d8 100644 --- a/beat/web/databases/api_urls.py +++ b/beat/web/databases/api_urls.py @@ -29,9 +29,13 @@ from django.conf.urls import url from . import api + urlpatterns = [ - url(r'^$', api.ListCreateDatabasesView.as_view(), name='all'), - url(r'^templates/$', api.ListTemplatesView.as_view(), name='templates'), - url(r'^(?P<database_name>[-\w]+)/(?P<version>\d+)/$', api.RetrieveDatabaseView.as_view(), name='object'), - url(r'^(?P<database_name>[-\w]+)/$', api.RetrieveDatabaseView.as_view(), name='object'), + url(r"^$", api.ListCreateDatabasesView.as_view(), name="all"), + url(r"^templates/$", api.ListTemplatesView.as_view(), name="templates"), + url( + r"^(?P<database_name>[-\w]+)/(?P<version>\d+)/$", + api.RetrieveDatabaseView.as_view(), + name="object", + ), ] diff --git a/beat/web/databases/serializers.py b/beat/web/databases/serializers.py index fe0d2aacd7b8d55e5d5f313c830ac0d58e1721f5..56d35a078887d1bb3633ce1830f33aaafa062f7a 100644 --- a/beat/web/databases/serializers.py +++ b/beat/web/databases/serializers.py @@ -30,55 +30,57 @@ from django.conf import settings from rest_framework import serializers from ..common.serializers import VersionableSerializer -from ..common.fields import JSONSerializerField +from ..common import fields as beat_fields from .models import Database from .exceptions import DatabaseCreationError import beat.core.database + class DatabaseSerializer(VersionableSerializer): protocols = serializers.SerializerMethodField() description = serializers.SerializerMethodField() - declaration = JSONSerializerField() + declaration = beat_fields.JSONField() class Meta(VersionableSerializer.Meta): model = Database - exclude = ['description_file', 'declaration_file', 'source_code_file'] + exclude = ["description_file", "declaration_file", "source_code_file"] def __init__(self, *args, **kwargs): super(DatabaseSerializer, self).__init__(*args, **kwargs) self.experiments_stats = {} def get_description(self, obj): - return obj.description.decode('utf8') + return obj.description.decode("utf8") def get_protocols(self, obj): - protocols_options = self.context['request'].query_params.get('protocols_options', []) - include_datasets = 'datasets' in protocols_options - include_outputs = 'outputs' in protocols_options + protocols_options = self.context["request"].query_params.get( + "protocols_options", [] + ) + include_datasets = "datasets" in protocols_options + include_outputs = "outputs" in protocols_options entries = [] for protocol in obj.protocols.all(): - protocol_entry = { - 'name': protocol.name, - } + protocol_entry = {"name": protocol.name} if include_datasets: - protocol_entry['datasets'] = [] + protocol_entry["datasets"] = [] - for dataset in protocol.sets.order_by('name'): - dataset_entry = { - 'name': dataset.name, - } + for dataset in protocol.sets.order_by("name"): + dataset_entry = {"name": dataset.name} if include_outputs: - dataset_entry['outputs'] = map(lambda x: { - 'name': x.name, - 'dataformat': x.dataformat.fullname(), - }, dataset.template.outputs.order_by('name')) + dataset_entry["outputs"] = map( + lambda x: { + "name": x.name, + "dataformat": x.dataformat.fullname(), + }, + dataset.template.outputs.order_by("name"), + ) - protocol_entry['datasets'].append(dataset_entry) + protocol_entry["datasets"].append(dataset_entry) entries.append(protocol_entry) return entries @@ -86,58 +88,76 @@ class DatabaseSerializer(VersionableSerializer): class DatabaseCreationSerializer(serializers.ModelSerializer): code = serializers.CharField(required=False) - declaration = JSONSerializerField(required=False) + declaration = beat_fields.JSONField(required=False) description = serializers.CharField(required=False, allow_blank=True) previous_version = serializers.CharField(required=False) + version = serializers.IntegerField(min_value=1) class Meta: model = Database - fields = ['name', 'short_description', 'description', - 'declaration', 'code', 'previous_version'] - beat_core_class = beat.core.database + fields = [ + "name", + "short_description", + "description", + "declaration", + "code", + "previous_version", + "version", + ] + beat_core_class = beat.core.database.Database + + def validate_declaration(self, declaration): + obj = self.Meta.beat_core_class(prefix=settings.PREFIX, data=declaration) + if not obj.valid: + raise serializers.ValidationError(obj.errors) + return declaration def validate(self, data): - user = self.context.get('user') - name = self.Meta.model.sanitize_name(data['name']) - data['name'] = name - - if 'previous_version' in data: - previous_version_id = self.Meta.beat_core_class.Storage(settings.PREFIX, - data['previous_version']) - else: - previous_version_id = None - - # Retrieve the previous version (if applicable) - if previous_version_id is not None: - try: - previous_version = self.Meta.model.objects.get( - name=previous_version_id.name, - version=previous_version_id.version) - except: - raise serializers.ValidationError("Database '%s' not found" % \ - previous_version_id.fullname) - - is_accessible = previous_version.accessibility_for(user) - if not is_accessible[0]: - raise serializers.ValidationError('No access allowed') - data['previous_version'] = previous_version + user = self.context.get("user") + name = self.Meta.model.sanitize_name(data["name"]) + data["name"] = name + version = data["version"] - # Determine the version number - last_version = None + if self.Meta.model.objects.filter(name=name, version=version).exists(): + raise serializers.ValidationError( + "{} {} version {} already exists".format( + self.Meta.model.__name__.lower(), name, version + ) + ) - if previous_version_id is not None: - if (previous_version_id.name == name): - last_version = self.Meta.model.objects.filter(name=name).order_by('-version')[0] + previous_version = data.get("previous_version") - if last_version is None: - if self.Meta.model.objects.filter(name=name).count() > 0: - raise serializers.ValidationError('This {} name already exists'.format(self.Meta.model.__name__.lower())) + if previous_version is not None: + try: + previous_object = self.Meta.model.objects.get( + name=name, version=previous_version + ) + except self.Meta.model.DoesNotExist: + raise serializers.ValidationError( + { + "previous_version", + "Database '{}' version '{}' not found".format( + name, previous_version + ), + } + ) - data['version'] = (last_version.version + 1 if last_version is not None else 1) + is_accessible = previous_object.accessibility_for(user) + if not is_accessible[0]: + raise serializers.ValidationError("No access allowed") + data["previous_version"] = previous_object + + if version - previous_object.version != 1: + raise serializers.ValidationError( + "The requested version ({}) for this {} does not match" + "the standard increment with {}".format( + version, self.Meta.model.__name__, previous_object.version + ) + ) + data["previous_version"] = previous_object return data - def create(self, validated_data): (db_object, errors) = self.Meta.model.objects.create_database(**validated_data) if errors: diff --git a/beat/web/databases/tests.py b/beat/web/databases/tests.py index cd081b8ccd50221c519f97c301f292dc59a8d0d5..25f13a4929411b5a92b72cf48c0b83f0f36c603f 100644 --- a/beat/web/databases/tests.py +++ b/beat/web/databases/tests.py @@ -98,7 +98,9 @@ class DatabaseCreationAPI(DatabaseAPIBase): self.client.login(username=settings.SYSTEM_ACCOUNT, password=TEST_PWD) response = self.client.post( self.url, - json.dumps({"name": self.db_name, "declaration": self.DATABASE}), + json.dumps( + {"name": self.db_name, "version": 1, "declaration": self.DATABASE} + ), content_type="application/json", ) @@ -116,7 +118,9 @@ class DatabaseCreationAPI(DatabaseAPIBase): response = self.client.post( self.url, - json.dumps({"name": self.db_name, "declaration": self.DATABASE}), + json.dumps( + {"name": self.db_name, "version": 1, "declaration": self.DATABASE} + ), content_type="application/json", ) @@ -127,6 +131,7 @@ class DatabaseCreationAPI(DatabaseAPIBase): databases = Database.objects.all() self.assertEqual(databases.count(), 1) databases.delete() + dataformat.delete() class DatabaseRetrievalAPI(DatabaseAPIBase): diff --git a/beat/web/dataformats/api.py b/beat/web/dataformats/api.py index 3f84f4d4afc3548d5a0dca2833a76cf1e4f196f8..a1a8f8ed4ecde111dd6cfdf164ef37685dcf68b7 100644 --- a/beat/web/dataformats/api.py +++ b/beat/web/dataformats/api.py @@ -25,29 +25,24 @@ # # ############################################################################### -from django.shortcuts import get_object_or_404 -from django.utils import six -from django.core.exceptions import ValidationError - -from rest_framework.response import Response -from rest_framework.exceptions import ParseError -from rest_framework import serializers from .models import DataFormat from .serializers import DataFormatSerializer from .serializers import FullDataFormatSerializer from .serializers import DataFormatCreationSerializer +from .serializers import DataFormatModSerializer -from ..common.responses import BadRequestResponse, ForbiddenResponse -from ..common.api import (CheckContributionNameView, ShareView, ListContributionView, - ListCreateContributionView, DiffView, RetrieveUpdateDestroyContributionView) - -from ..common.utils import validate_restructuredtext, ensure_html +from ..common.api import ( + CheckContributionNameView, + ShareView, + ListContributionView, + ListCreateContributionView, + DiffView, + RetrieveUpdateDestroyContributionView, +) -import simplejson as json - -#---------------------------------------------------------- +# ---------------------------------------------------------- class CheckDataFormatNameView(CheckContributionNameView): @@ -55,10 +50,11 @@ class CheckDataFormatNameView(CheckContributionNameView): This view sanitizes a data format name and checks whether it is already used. """ + model = DataFormat -#---------------------------------------------------------- +# ---------------------------------------------------------- class ShareDataFormatView(ShareView): @@ -66,166 +62,56 @@ class ShareDataFormatView(ShareView): This view allows to share a data format with other users and/or teams """ + model = DataFormat -#---------------------------------------------------------- +# ---------------------------------------------------------- class ListDataFormatView(ListContributionView): """ List all available data formats """ + model = DataFormat serializer_class = DataFormatSerializer -#---------------------------------------------------------- +# ---------------------------------------------------------- + class ListCreateDataFormatsView(ListCreateContributionView): """ Read/Write end point that list the data formats available from a given author and allows the creation of new data formats """ + model = DataFormat serializer_class = DataFormatSerializer writing_serializer_class = DataFormatCreationSerializer - namespace = 'api_dataformats' + namespace = "api_dataformats" + -#---------------------------------------------------------- +# ---------------------------------------------------------- class RetrieveUpdateDestroyDataFormatsView(RetrieveUpdateDestroyContributionView): """ Read/Write/Delete endpoint for a given data format """ + model = DataFormat serializer_class = FullDataFormatSerializer + writing_serializer_class = DataFormatModSerializer - def put(self, request, author_name, object_name, version=None): - if version is None: - raise serializers.ValidationError({'version': 'A version number must be provided'}) - - try: - data = request.data - except ParseError as e: - raise serializers.ValidationError({'data': str(e)}) - else: - if not data: - raise serializers.ValidationError({'data': 'Empty'}) - - if 'short_description' in data: - if not(isinstance(data['short_description'], six.string_types)): - raise serializers.ValidationError({'short_description': 'Invalid short_description data'}) - short_description = data['short_description'] - else: - short_description = None - - if 'description' in data: - if not(isinstance(data['description'], six.string_types)): - raise serializers.ValidationError({'description': 'Invalid description data'}) - description = data['description'] - try: - validate_restructuredtext(description) - except ValidationError as errors: - raise serializers.ValidationError({'description': [error for error in errors]}) - else: - description = None - - if 'declaration' in data: - if isinstance(data['declaration'], dict): - json_declaration = data['declaration'] - declaration = json.dumps(json_declaration, indent=4) - elif isinstance(data['declaration'], six.string_types): - declaration = data['declaration'] - try: - json_declaration = json.loads(declaration) - except: - raise serializers.ValidationError({'declaration': 'Invalid declaration data'}) - else: - raise serializers.ValidationError({'declaration': 'Invalid declaration data'}) - - if '#description' in json_declaration: - if short_description is not None: - raise serializers.ValidationError({'short_description': 'A short description is already provided in the data format declaration'}) - - short_description = json_declaration['#description'] - elif short_description is not None: - json_declaration['#description'] = short_description - declaration = json.dumps(json_declaration, indent=4) - else: - declaration = None - json_declaration = None - - if (short_description is not None) and (len(short_description) > DataFormat._meta.get_field('short_description').max_length): - raise serializers.ValidationError({'short_description': 'Short description too long'}) - - - # Process the query string - if 'fields' in request.GET: - fields_to_return = request.GET['fields'].split(',') - else: - # Available fields (not returned by default): - # - html_description - fields_to_return = [] - - - # Retrieve the data format - dataformat = get_object_or_404(DataFormat, - author__username__iexact=author_name, - name__iexact=object_name, - version=version) - - # Check that the data format can still be modified (if applicable, the documentation - # can always be modified) - if (declaration is not None) and not(dataformat.modifiable()): - return ForbiddenResponse("The data format isn't modifiable anymore (either shared with someone else, or needed by an attestation)") - - - # Modification of the documentation - if (short_description is not None) and (declaration is None): - tmp_declaration = dataformat.declaration - tmp_declaration['#description'] = short_description - dataformat.declaration = tmp_declaration - - if description is not None: - dataformat.description = description - - - # Modification of the declaration - if declaration is not None: - dataformat.declaration = declaration - - - # Save the data format model - try: - dataformat.save() - except Exception as e: - return BadRequestResponse(str(e)) - - - # Nothing to return? - if len(fields_to_return) == 0: - return Response(status=204) - - result = {} - - # Retrieve the description in HTML format - if 'html_description' in fields_to_return: - description = dataformat.description - if len(description) > 0: - result['html_description'] = ensure_html(description) - else: - result['html_description'] = '' - return Response(result) - - -#---------------------------------------------------------- +# ---------------------------------------------------------- class DiffDataFormatView(DiffView): """ This view shows the differences between two data formats """ + model = DataFormat diff --git a/beat/web/dataformats/api_urls.py b/beat/web/dataformats/api_urls.py index 23d1700232718bce8e332faf6a6ffc7c43de0966..6673f46aae6d5d4fb191c3a624d4b239cca83409 100644 --- a/beat/web/dataformats/api_urls.py +++ b/beat/web/dataformats/api_urls.py @@ -29,12 +29,28 @@ from django.conf.urls import url from . import api + urlpatterns = [ - url(r'^$', api.ListDataFormatView.as_view(), name='all'), - url(r'^check_name/$', api.CheckDataFormatNameView.as_view(), name='check_name'), - url(r'^diff/(?P<author1>\w+)/(?P<name1>[-\w]+)/(?P<version1>\d+)/(?P<author2>\w+)/(?P<name2>[-\w]+)/(?P<version2>\d+)/$', api.DiffDataFormatView.as_view(), name='diff'), - url(r'^(?P<author_name>\w+)/(?P<object_name>[a-zA-Z0-9_\-]+)/(?P<version>\d+)/share/$', api.ShareDataFormatView.as_view(), name='share'), - url(r'^(?P<author_name>\w+)/$', api.ListCreateDataFormatsView.as_view(), name='list_create'), - url(r'^(?P<author_name>\w+)/(?P<object_name>[a-zA-Z0-9_\-]+)/(?P<version>\d+)/$', api.RetrieveUpdateDestroyDataFormatsView.as_view(), name='object'), - url(r'^(?P<author_name>\w+)/(?P<object_name>[a-zA-Z0-9_\-]+)/$', api.RetrieveUpdateDestroyDataFormatsView.as_view(), name='object'), + url(r"^$", api.ListDataFormatView.as_view(), name="all"), + url(r"^check_name/$", api.CheckDataFormatNameView.as_view(), name="check_name"), + url( + r"^diff/(?P<author1>\w+)/(?P<name1>[-\w]+)/(?P<version1>\d+)/(?P<author2>\w+)/(?P<name2>[-\w]+)/(?P<version2>\d+)/$", + api.DiffDataFormatView.as_view(), + name="diff", + ), + url( + r"^(?P<author_name>\w+)/(?P<object_name>[a-zA-Z0-9_\-]+)/(?P<version>\d+)/share/$", + api.ShareDataFormatView.as_view(), + name="share", + ), + url( + r"^(?P<author_name>\w+)/$", + api.ListCreateDataFormatsView.as_view(), + name="list_create", + ), + url( + r"^(?P<author_name>\w+)/(?P<object_name>[a-zA-Z0-9_\-]+)/(?P<version>\d+)/$", + api.RetrieveUpdateDestroyDataFormatsView.as_view(), + name="object", + ), ] diff --git a/beat/web/dataformats/serializers.py b/beat/web/dataformats/serializers.py index 68fee96f423c4a9288da3ab288e3bf5adf046f9e..2062e9c49660d795702e2d7b6cf8b019a74b012a 100644 --- a/beat/web/dataformats/serializers.py +++ b/beat/web/dataformats/serializers.py @@ -27,7 +27,13 @@ from rest_framework import serializers -from ..common.serializers import ContributionSerializer, VersionableSerializer, DynamicFieldsSerializer, ContributionCreationSerializer +from ..common.serializers import ( + ContributionSerializer, + VersionableSerializer, + DynamicFieldsSerializer, + ContributionCreationSerializer, + ContributionModSerializer, +) from ..algorithms.models import Algorithm from ..databases.models import Database, DatabaseSet @@ -36,16 +42,22 @@ from .models import DataFormat import beat.core.dataformat -#---------------------------------------------------------- +# ---------------------------------------------------------- class DataFormatCreationSerializer(ContributionCreationSerializer): class Meta(ContributionCreationSerializer.Meta): model = DataFormat - beat_core_class = beat.core.dataformat + beat_core_class = beat.core.dataformat.DataFormat -#---------------------------------------------------------- +class DataFormatModSerializer(ContributionModSerializer): + class Meta(ContributionModSerializer.Meta): + model = DataFormat + beat_core_class = beat.core.dataformat.DataFormat + + +# ---------------------------------------------------------- class BaseSerializer(DynamicFieldsSerializer): @@ -53,37 +65,37 @@ class BaseSerializer(DynamicFieldsSerializer): accessibility = serializers.SerializerMethodField() class Meta(DynamicFieldsSerializer.Meta): - fields = '__all__' - default_fields = ['name', 'short_description', 'accessibility'] + fields = "__all__" + default_fields = ["name", "short_description", "accessibility"] def get_accessibility(self, obj): return obj.get_sharing_display().lower() -#---------------------------------------------------------- +# ---------------------------------------------------------- class ReferencedDataFormatSerializer(BaseSerializer): class Meta(BaseSerializer.Meta): model = DataFormat - default_fields = BaseSerializer.Meta.default_fields + ['creation_date'] + default_fields = BaseSerializer.Meta.default_fields + ["creation_date"] -#---------------------------------------------------------- +# ---------------------------------------------------------- class ReferencingSerializer(BaseSerializer): is_owner = serializers.SerializerMethodField() class Meta(BaseSerializer.Meta): - default_fields = BaseSerializer.Meta.default_fields + ['is_owner', 'version',] + default_fields = BaseSerializer.Meta.default_fields + ["is_owner", "version"] def get_is_owner(self, obj): - if hasattr(obj, 'author'): - return obj.author == self.context.get('user') + if hasattr(obj, "author"): + return obj.author == self.context.get("user") -#---------------------------------------------------------- +# ---------------------------------------------------------- class ReferencingDataFormatSerializer(ReferencingSerializer): @@ -91,7 +103,7 @@ class ReferencingDataFormatSerializer(ReferencingSerializer): model = DataFormat -#---------------------------------------------------------- +# ---------------------------------------------------------- class ReferencingAlgorithmSerializer(ReferencingSerializer): @@ -99,59 +111,78 @@ class ReferencingAlgorithmSerializer(ReferencingSerializer): model = Algorithm -#---------------------------------------------------------- +# ---------------------------------------------------------- class ReferencingDatabaseSerializer(VersionableSerializer): class Meta: model = Database - default_fields = ['name', 'short_description'] + default_fields = ["name", "short_description"] -#---------------------------------------------------------- +# ---------------------------------------------------------- class DataFormatSerializer(ContributionSerializer): extend = serializers.SerializerMethodField() - referenced_dataformats = ReferencedDataFormatSerializer(many=True, source="referenced_formats") - needed_dataformats = ReferencedDataFormatSerializer(many=True, source='all_needed_dataformats') + referenced_dataformats = ReferencedDataFormatSerializer( + many=True, source="referenced_formats" + ) + needed_dataformats = ReferencedDataFormatSerializer( + many=True, source="all_needed_dataformats" + ) referencing_dataformats = serializers.SerializerMethodField() referencing_algorithms = serializers.SerializerMethodField() referencing_databases = serializers.SerializerMethodField() class Meta(ContributionSerializer.Meta): model = DataFormat - default_fields = ContributionSerializer.Meta.default_fields + ['extend'] + default_fields = ContributionSerializer.Meta.default_fields + ["extend"] def get_extend(self, obj): return obj.extend.fullname() if obj.extend else None - def get_referencing_dataformats(self, obj): - user = self.context.get('user') + user = self.context.get("user") dataformats = obj.referencing.for_user(user, True) - serializer = ReferencingDataFormatSerializer(dataformats, many=True, context=self.context) + serializer = ReferencingDataFormatSerializer( + dataformats, many=True, context=self.context + ) return serializer.data - def get_referencing_algorithms(self, obj): - user = self.context.get('user') - algorithms = Algorithm.objects.for_user(user, True).filter(endpoints=obj.algorithm_endpoints.all()).distinct() - serializer = ReferencingAlgorithmSerializer(algorithms, many=True, context=self.context) + user = self.context.get("user") + algorithms = ( + Algorithm.objects.for_user(user, True) + .filter(endpoints=obj.algorithm_endpoints.all()) + .distinct() + ) + serializer = ReferencingAlgorithmSerializer( + algorithms, many=True, context=self.context + ) return serializer.data - def get_referencing_databases(self, obj): - databaseset = DatabaseSet.objects.filter(template__outputs=obj.database_outputs.distinct()) - databases = Database.objects.filter(protocols__sets=databaseset).distinct().order_by('name') - serializer = ReferencingDatabaseSerializer(databases, many=True, context=self.context) + databaseset = DatabaseSet.objects.filter( + template__outputs=obj.database_outputs.distinct() + ) + databases = ( + Database.objects.filter(protocols__sets=databaseset) + .distinct() + .order_by("name") + ) + serializer = ReferencingDatabaseSerializer( + databases, many=True, context=self.context + ) return serializer.data -#---------------------------------------------------------- +# ---------------------------------------------------------- class FullDataFormatSerializer(DataFormatSerializer): - class Meta(DataFormatSerializer.Meta): - default_fields = DataFormatSerializer.Meta.default_fields + DataFormatSerializer.Meta.extra_fields + default_fields = ( + DataFormatSerializer.Meta.default_fields + + DataFormatSerializer.Meta.extra_fields + ) diff --git a/beat/web/dataformats/tests/tests_api.py b/beat/web/dataformats/tests/tests_api.py index 79f93fd7bd9657e7397e0a752c1e013927e1b34f..2fd7908346c74b44ddad7a948598f2fd78be1609 100644 --- a/beat/web/dataformats/tests/tests_api.py +++ b/beat/web/dataformats/tests/tests_api.py @@ -519,9 +519,10 @@ class DataFormatUpdate(DataFormatsAPIBase): def test_no_update_without_content(self): self.login_jackdoe() - url = reverse("api_dataformats:object", args=["jackdoe", "format2"]) + url = reverse("api_dataformats:object", args=["jackdoe", "format2", 1]) response = self.client.put(url) - self.checkResponse(response, 400) + + self.checkResponse(response, 200, content_type="application/json") def test_fail_to_update_a_shared_dataformat(self): self.login_johndoe() @@ -536,7 +537,7 @@ class DataFormatUpdate(DataFormatsAPIBase): self.checkResponse(response, 403) - def test_successfull_update(self): + def test_successful_update(self): self.login_johndoe() response = self.client.put( @@ -545,7 +546,7 @@ class DataFormatUpdate(DataFormatsAPIBase): content_type="application/json", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") dataformat = DataFormat.objects.get( author__username="johndoe", name="format_private" @@ -558,7 +559,38 @@ class DataFormatUpdate(DataFormatsAPIBase): declaration = json.load(f) self.assertEqual(declaration["value"], "float64") - def test_successfull_update_of_description_only(self): + def test_successful_update_with_specific_return_field(self): + self.login_johndoe() + + response = self.client.put( + self.url, + json.dumps({"description": "blah", "declaration": {"value": "float64"}}), + content_type="application/json", + QUERY_STRING="fields=description", + ) + + self.checkResponse(response, 200, content_type="application/json") + answer = response.json() + self.assertEqual(len(answer), 1) + self.assertTrue("description" in answer) + + def test_successful_update_with_specific_return_several_fields(self): + self.login_johndoe() + + response = self.client.put( + self.url, + json.dumps({"description": "blah", "declaration": {"value": "float64"}}), + content_type="application/json", + QUERY_STRING="fields=declaration,description", + ) + + self.checkResponse(response, 200, content_type="application/json") + answer = response.json() + self.assertEqual(len(answer), 2) + self.assertTrue("declaration" in answer) + self.assertTrue("description" in answer) + + def test_successful_update_of_description_only(self): self.login_johndoe() response = self.client.put( @@ -567,12 +599,12 @@ class DataFormatUpdate(DataFormatsAPIBase): content_type="application/json", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") dataformat = DataFormat.objects.get(author__username="johndoe", name="format1") self.assertEqual(dataformat.description, b"blah") - def test_successfull_update_of_declaration_only(self): + def test_successful_update_of_declaration_only(self): self.login_johndoe() response = self.client.put( @@ -581,7 +613,7 @@ class DataFormatUpdate(DataFormatsAPIBase): content_type="application/json", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") dataformat = DataFormat.objects.get( author__username="johndoe", name="format_private" @@ -593,7 +625,7 @@ class DataFormatUpdate(DataFormatsAPIBase): declaration = json.load(f) self.assertEqual(declaration["value"], "float64") - def test_successfull_update_of_declaration__extension_addition(self): + def test_successful_update_of_declaration__extension_addition(self): self.login_johndoe() (dataformat, errors) = DataFormat.objects.create_dataformat( @@ -612,7 +644,7 @@ class DataFormatUpdate(DataFormatsAPIBase): content_type="application/json", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") dataformat = DataFormat.objects.get( author__username="johndoe", name="format_private", version=1 @@ -629,7 +661,7 @@ class DataFormatUpdate(DataFormatsAPIBase): self.assertEqual(declaration["value"], "float64") self.assertEqual(declaration["#extends"], "johndoe/format4/1") - def test_successfull_update_of_declaration__extension_removal(self): + def test_successful_update_of_declaration__extension_removal(self): self.login_johndoe() (dataformat, errors) = DataFormat.objects.create_dataformat( @@ -646,7 +678,7 @@ class DataFormatUpdate(DataFormatsAPIBase): content_type="application/json", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") dataformat = DataFormat.objects.get( author__username="johndoe", name="format4", version=1 @@ -655,7 +687,7 @@ class DataFormatUpdate(DataFormatsAPIBase): self.assertTrue(dataformat.extend is None) self.assertEqual(dataformat.referenced_formats.count(), 0) - def test_successfull_update_of_declaration__extension_keeping(self): + def test_successful_update_of_declaration__extension_keeping(self): self.login_johndoe() (dataformat, errors) = DataFormat.objects.create_dataformat( @@ -674,7 +706,7 @@ class DataFormatUpdate(DataFormatsAPIBase): content_type="application/json", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") dataformat = DataFormat.objects.get( author__username="johndoe", name="format4", version=1 @@ -684,7 +716,7 @@ class DataFormatUpdate(DataFormatsAPIBase): self.assertEqual(dataformat.extend.name, "format1") self.assertEqual(dataformat.referenced_formats.count(), 0) - def test_successfull_update_of_declaration__composition_addition(self): + def test_successful_update_of_declaration__composition_addition(self): self.login_johndoe() (dataformat, errors) = DataFormat.objects.create_dataformat( @@ -703,7 +735,7 @@ class DataFormatUpdate(DataFormatsAPIBase): content_type="application/json", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") dataformat = DataFormat.objects.get( author__username="johndoe", name="format_private", version=1 @@ -723,7 +755,7 @@ class DataFormatUpdate(DataFormatsAPIBase): self.assertEqual(declaration["value"], "float64") self.assertEqual(declaration["field1"], "johndoe/format4/1") - def test_successfull_update_of_declaration__composition_removal(self): + def test_successful_update_of_declaration__composition_removal(self): self.login_johndoe() (dataformat, errors) = DataFormat.objects.create_dataformat( @@ -740,7 +772,7 @@ class DataFormatUpdate(DataFormatsAPIBase): content_type="application/json", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") dataformat = DataFormat.objects.get( author__username="johndoe", name="format4", version=1 @@ -749,7 +781,7 @@ class DataFormatUpdate(DataFormatsAPIBase): self.assertTrue(dataformat.extend is None) self.assertEqual(dataformat.referenced_formats.count(), 0) - def test_successfull_update_of_declaration__composition_keeping(self): + def test_successful_update_of_declaration__composition_keeping(self): self.login_johndoe() (dataformat, errors) = DataFormat.objects.create_dataformat( @@ -768,7 +800,7 @@ class DataFormatUpdate(DataFormatsAPIBase): content_type="application/json", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") dataformat = DataFormat.objects.get( author__username="johndoe", name="format4", version=1 @@ -960,7 +992,7 @@ class DataFormatDeletion(DataFormatsAPIBase): self.login_johndoe() url = reverse("api_dataformats:object", args=["johndoe", "format1", 1]) response = self.client.delete(url) - self.checkResponse(response, 403) + self.checkResponse(response, 405) def test_no_deletion_of_not_owned_format(self): self.login_jackdoe() @@ -1665,7 +1697,7 @@ class NotSharedDataFormat_ExtensionOfPublicForOneUserFormat_SharingAPI( self.url, json.dumps({}), content_type="application/json" ) - data = self.checkResponse(response, 400, content_type="application/json") + data = self.checkResponse(response, 403, content_type="application/json") self.assertTrue(isinstance(data, list)) self.assertTrue(len(data) > 0) @@ -1699,7 +1731,7 @@ class NotSharedDataFormat_ExtensionOfPublicForOneUserFormat_SharingAPI( content_type="application/json", ) - data = self.checkResponse(response, 400, content_type="application/json") + data = self.checkResponse(response, 403, content_type="application/json") self.assertTrue(isinstance(data, list)) self.assertTrue(len(data) > 0) @@ -1713,7 +1745,7 @@ class NotSharedDataFormat_ExtensionOfPublicForOneUserFormat_SharingAPI( content_type="application/json", ) - data = self.checkResponse(response, 400, content_type="application/json") + data = self.checkResponse(response, 403, content_type="application/json") self.assertTrue(isinstance(data, list)) self.assertTrue(len(data) > 0) @@ -1727,7 +1759,7 @@ class NotSharedDataFormat_ExtensionOfPublicForOneUserFormat_SharingAPI( content_type="application/json", ) - data = self.checkResponse(response, 400, content_type="application/json") + data = self.checkResponse(response, 403, content_type="application/json") self.assertTrue(isinstance(data, list)) self.assertTrue(len(data) > 0) @@ -1741,7 +1773,7 @@ class NotSharedDataFormat_ExtensionOfPublicForOneUserFormat_SharingAPI( content_type="application/json", ) - data = self.checkResponse(response, 400, content_type="application/json") + data = self.checkResponse(response, 403, content_type="application/json") self.assertTrue(isinstance(data, list)) self.assertTrue(len(data) > 0) diff --git a/beat/web/experiments/api.py b/beat/web/experiments/api.py index cdd76b830432d6dc25cbe03a377bc0f49e8f5adb..e82b7c1f77c83ef2edaf82a72ae51208d2360450 100755 --- a/beat/web/experiments/api.py +++ b/beat/web/experiments/api.py @@ -25,9 +25,6 @@ # # ############################################################################### -import re -import uuid - import simplejson import functools @@ -49,7 +46,11 @@ import beat.core.algorithm import beat.core.toolchain from .models import Experiment -from .serializers import ExperimentSerializer, ExperimentResultsSerializer +from .serializers import ( + ExperimentSerializer, + ExperimentResultsSerializer, + ExperimentCreationSerializer, +) from .permissions import IsDatabaseAccessible from ..common.responses import BadRequestResponse, ForbiddenResponse @@ -62,7 +63,7 @@ from ..common.api import ( from ..common.mixins import CommonContextMixin from ..common.exceptions import ShareError from ..common.serializers import SharingSerializer -from ..common.utils import validate_restructuredtext, ensure_html +from ..common.utils import validate_restructuredtext, ensure_html, py3_cmp from ..toolchains.models import Toolchain @@ -102,24 +103,13 @@ class ListCreateExperimentsView(ListCreateContributionView): model = Experiment serializer_class = ExperimentSerializer - - # TODO: no serializer is used by the POST method, but the rest-framework isn't - # happy if we don't declare one - writing_serializer_class = ExperimentSerializer + writing_serializer_class = ExperimentCreationSerializer + namespace = "api_experiments" def get(self, request, author_name): def _getStatusLabel(status): return [s for s in Experiment.STATUS if s[0] == status][0][1] - def _cmp(a, b): - """ - cmp is not available anymore in Python 3. This method is implemetend - as recommended in the documentation for this kind of use case. - Based on: - https://docs.python.org/3.0/whatsnew/3.0.html#ordering-comparisons - """ - return (a > b) - (a < b) - def _custom_compare(exp1, exp2): PENDING = _getStatusLabel(Experiment.PENDING) @@ -129,7 +119,7 @@ class ListCreateExperimentsView(ListCreateContributionView): if exp2["status"] != PENDING: return -1 elif exp1["creation_date"] != exp2["creation_date"]: - return _cmp(exp1["creation_date"], exp2["creation_date"]) + return py3_cmp(exp1["creation_date"], exp2["creation_date"]) else: return 1 elif exp2["status"] == PENDING: @@ -142,7 +132,7 @@ class ListCreateExperimentsView(ListCreateContributionView): ] if (exp1["status"] in tier3_states) and (exp2["status"] in tier3_states): - return _cmp(exp2["end_date"], exp1["end_date"]) + return py3_cmp(exp2["end_date"], exp1["end_date"]) elif (exp1["status"] in tier3_states) and ( exp2["status"] not in tier3_states ): @@ -152,7 +142,7 @@ class ListCreateExperimentsView(ListCreateContributionView): ): return 1 - return _cmp(exp1["start_date"], exp2["start_date"]) + return py3_cmp(exp1["start_date"], exp2["start_date"]) # Retrieve the experiments fields_to_return_original = self.get_serializer_fields(request) @@ -183,134 +173,6 @@ class ListCreateExperimentsView(ListCreateContributionView): # Returns the results return Response(processed_result) - def post(self, request, author_name): - data = request.data - - if "name" in data: - if not (isinstance(data["name"], six.string_types)): - return BadRequestResponse("Invalid name") - - max_length = self.model._meta.get_field("name").max_length - current_length = len(data["name"]) - if current_length > max_length: - return BadRequestResponse( - "name too long, max is {}, current is {}".format( - max_length, current_length - ) - ) - - name = re.sub(r"[\W]", "-", data["name"]) - else: - name = None - - if "short_description" in data: - if not (isinstance(data["short_description"], six.string_types)): - return BadRequestResponse("Invalid short_description") - - max_length = self.model._meta.get_field("short_description").max_length - current_length = len(data["short_description"]) - if current_length > max_length: - return BadRequestResponse( - "Short description too long, max is {}, current is {}".format( - max_length, current_length - ) - ) - - short_description = data["short_description"] - else: - short_description = "" - - if "description" in data: - if not (isinstance(data["description"], six.string_types)): - raise serializers.ValidationError( - {"description": "Invalid description data"} - ) - description = data["description"] - try: - validate_restructuredtext(description) - except ValidationError as errors: - raise serializers.ValidationError( - {"description": [error for error in errors]} - ) - else: - description = None - - if "toolchain" not in data: - return BadRequestResponse("Must indicate a toolchain name") - - if not (isinstance(data["toolchain"], six.string_types)): - return BadRequestResponse("Invalid toolchain name") - - if "declaration" not in data: - return BadRequestResponse("Must indicate a declaration") - - if not (isinstance(data["declaration"], dict)) and not ( - isinstance(data["declaration"], six.string_types) - ): - return BadRequestResponse("Invalid declaration") - - if isinstance(data["declaration"], dict): - declaration = data["declaration"] - else: - declaration_string = data["declaration"] - try: - declaration = simplejson.loads(declaration_string) - except simplejson.errors.JSONDecodeError: - return BadRequestResponse("Invalid declaration data") - - # Retrieve the toolchain - core_toolchain = beat.core.toolchain.Storage(settings.PREFIX, data["toolchain"]) - - if core_toolchain.version == "unknown": - return BadRequestResponse("Unknown toolchain version") - - if core_toolchain.username is None: - core_toolchain.username = request.user.username - - try: - db_toolchain = Toolchain.objects.for_user(request.user, True).get( - author__username=core_toolchain.username, - name=core_toolchain.name, - version=core_toolchain.version, - ) - except Toolchain.DoesNotExist: - return Response("Toolchain %s not found" % data["toolchain"], status=404) - - # Create the experiment object in the database - existing_names = map( - lambda x: x.name, - Experiment.objects.filter(author=request.user, toolchain=db_toolchain), - ) - if name is None: - name = str(uuid.uuid4()) - while name in existing_names: - name = str(uuid.uuid4()) - - elif name in existing_names: - return BadRequestResponse("The name '" + name + "' is already used") - - experiment, toolchain, errors = Experiment.objects.create_experiment( - author=request.user, - toolchain=db_toolchain, - name=name, - short_description=short_description, - description=description, - declaration=declaration, - ) - - if not experiment: - return BadRequestResponse(errors) - - # Send the result - result = { - "name": experiment.fullname(), - "url": reverse("api_experiments:all") + experiment.fullname() + "/", - } - - response = Response(result, status=201) - response["Location"] = result["url"] - return response - # ---------------------------------------------------------- diff --git a/beat/web/experiments/api_urls.py b/beat/web/experiments/api_urls.py index c0f08a197cfc9fbca71f0dcf4303f057f57bbc3f..fcd876f0d58d969b05cfb7028b0eb133240d266b 100644 --- a/beat/web/experiments/api_urls.py +++ b/beat/web/experiments/api_urls.py @@ -26,6 +26,7 @@ ############################################################################### from django.conf.urls import url + from . import api @@ -37,48 +38,24 @@ urlpatterns = [ api.ShareExperimentView.as_view(), name="share", ), - url( - r"^(?P<author_name>\w+)/(?P<toolchain_name>[-\w]+)/(?P<version>\d+)/(?P<name>[-\w]+)/share/$", - api.ShareExperimentView.as_view(), - {"toolchain_author_name": None}, - name="share", - ), # Start url( r"^(?P<author_name>\w+)/(?P<toolchain_author_name>\w+)/(?P<toolchain_name>[-\w]+)/(?P<version>\d+)/(?P<name>[-\w]+)/start/$", api.StartExperimentView.as_view(), name="start", ), - url( - r"^(?P<author_name>\w+)/(?P<toolchain_name>[-\w]+)/(?P<version>\d+)/(?P<name>[-\w]+)/start/$", - api.StartExperimentView.as_view(), - {"toolchain_author_name": None}, - name="start", - ), # Cancelling url( r"^(?P<author_name>\w+)/(?P<toolchain_author_name>\w+)/(?P<toolchain_name>[-\w]+)/(?P<version>\d+)/(?P<name>[-\w]+)/cancel/$", api.CancelExperimentView.as_view(), name="cancel", ), - url( - r"^(?P<author_name>\w+)/(?P<toolchain_name>[-\w]+)/(?P<version>\d+)/(?P<name>[-\w]+)/cancel/$", - api.CancelExperimentView.as_view(), - {"toolchain_author_name": None}, - name="cancel", - ), # Reseting url( r"^(?P<author_name>\w+)/(?P<toolchain_author_name>\w+)/(?P<toolchain_name>[-\w]+)/(?P<version>\d+)/(?P<name>[-\w]+)/reset/$", api.ResetExperimentView.as_view(), name="reset", ), - url( - r"^(?P<author_name>\w+)/(?P<toolchain_name>[-\w]+)/(?P<version>\d+)/(?P<name>[-\w]+)/reset/$", - api.ResetExperimentView.as_view(), - {"toolchain_author_name": None}, - name="reset", - ), # Attestations url( r"^(?P<attestation_number>\d+)/", @@ -96,10 +73,4 @@ urlpatterns = [ api.RetrieveUpdateDestroyExperimentView.as_view(), name="object", ), - url( - r"^(?P<author_name>\w+)/(?P<toolchain_name>[-\w]+)/(?P<version>\d+)/(?P<name>[-\w]+)/$", - api.RetrieveUpdateDestroyExperimentView.as_view(), - {"toolchain_author_name": None}, - name="object", - ), ] diff --git a/beat/web/experiments/serializers.py b/beat/web/experiments/serializers.py index fb6e31ef238339d88b454e15c2178a38ba25aae9..fef9b9fbe1705bc3494ca23835a73e5443d62907 100755 --- a/beat/web/experiments/serializers.py +++ b/beat/web/experiments/serializers.py @@ -25,21 +25,107 @@ # # ############################################################################### +import simplejson as json + +import beat.core + +from datetime import datetime + from rest_framework import serializers + from django.contrib.humanize.templatetags.humanize import naturaltime from ..common.serializers import ShareableSerializer -from ..common.fields import JSONSerializerField +from ..common.utils import validate_restructuredtext +from ..common.utils import annotate_full_name +from ..common.exceptions import ContributionCreationError +from ..common import fields as beat_fields + from ..ui.templatetags.markup import restructuredtext +from ..toolchains.models import Toolchain + from .models import Experiment, Block -from datetime import datetime -import simplejson as json +# ---------------------------------------------------------- -#---------------------------------------------------------- +class ExperimentCreationSerializer(serializers.ModelSerializer): + declaration = beat_fields.JSONField() + description = serializers.CharField(required=False, allow_blank=True) + fork_of = serializers.CharField(required=False) + toolchain = serializers.CharField() + + class Meta: + model = Experiment + beat_core_class = beat.core.experiment.Experiment + fields = [ + "name", + "short_description", + "description", + "declaration", + "toolchain", + "fork_of", + ] + + def toolchain_queryset(self): + user = self.context.get("user") + return annotate_full_name(Toolchain.objects.for_user(user, True)) + + def validate_name(self, name): + # sanitize_name is a static method of Versionable models but Experiment + # is currently not such a model + return Toolchain.sanitize_name(name) + + def validate_description(self, description): + if description.find("\\") >= 0: # was escaped, unescape + description = description.decode("string_escape") + validate_restructuredtext(description) + return description + + def validate_toolchain(self, toolchain): + if not self.toolchain_queryset().filter(full_name=toolchain).exists(): + raise serializers.ValidationError("Invalid toolchain: {}".format(toolchain)) + + return toolchain + + def validate(self, data): + user = self.context.get("user") + name = data["name"] + toolchain = data["toolchain"] + + if self.Meta.model.objects.filter( + author=user, + name=name, + toolchain=self.toolchain_queryset().get(full_name=toolchain), + ).exists(): + raise serializers.ValidationError( + "{} {} with toolchain {} already exists on this account".format( + self.Meta.model.__name__.lower(), name, toolchain + ) + ) + + return data + + def create(self, validated_data): + toolchain = self.toolchain_queryset().get(full_name=validated_data["toolchain"]) + + experiment, _, errors = Experiment.objects.create_experiment( + author=self.context.get("user"), + toolchain=toolchain, + name=validated_data["name"], + short_description=validated_data.get("short_description", ""), + description=validated_data.get("description", ""), + declaration=validated_data["declaration"], + ) + + if errors: + raise ContributionCreationError(errors) + return experiment + + +# ---------------------------------------------------------- class ExperimentSerializer(ShareableSerializer): @@ -55,37 +141,41 @@ class ExperimentSerializer(ShareableSerializer): data_read = serializers.SerializerMethodField() data_written = serializers.SerializerMethodField() description = serializers.CharField() - declaration = JSONSerializerField() + declaration = beat_fields.JSONField() class Meta: model = Experiment - fields = '__all__' - default_fields = ShareableSerializer.Meta.default_fields + ['name', 'toolchain', - 'datasets', 'short_description', - 'creation_date', 'start_date', 'end_date', - 'duration', 'status', - 'accessibility', - 'attestation_number', 'attestation_locked', - 'cpu_time', 'data_read', 'data_written', - 'is_owner'] - + fields = "__all__" + default_fields = ShareableSerializer.Meta.default_fields + [ + "name", + "toolchain", + "datasets", + "short_description", + "creation_date", + "start_date", + "end_date", + "duration", + "status", + "accessibility", + "attestation_number", + "attestation_locked", + "cpu_time", + "data_read", + "data_written", + "is_owner", + ] def __init__(self, *args, **kwargs): super(ExperimentSerializer, self).__init__(*args, **kwargs) self.statistics = {} def __get_statistics(self, obj, key): - if not obj in self.statistics: - self.statistics[obj] = { - 'cpu_time': 0, - 'data_read': 0, - 'data_written': 0 - } + if obj not in self.statistics: + self.statistics[obj] = {"cpu_time": 0, "data_read": 0, "data_written": 0} for block in obj.blocks.all(): - self.statistics[obj]['cpu_time'] += block.cpu_time() - self.statistics[obj]['data_read'] += block.data_read_size() - self.statistics[obj]['data_written'] += \ - block.data_written_size() + self.statistics[obj]["cpu_time"] += block.cpu_time() + self.statistics[obj]["data_read"] += block.data_read_size() + self.statistics[obj]["data_written"] += block.data_written_size() return self.statistics[obj][key] def get_datasets(self, obj): @@ -112,11 +202,11 @@ class ExperimentSerializer(ShareableSerializer): end_date = self.get_end_date(obj) if (start_date is None) or (end_date is None): - return '-' + return "-" duration = end_date - start_date - duration_seconds = max(duration.seconds, 1) # At least one second + duration_seconds = max(duration.seconds, 1) # At least one second seconds = duration_seconds % 60 minutes = ((duration_seconds - seconds) % 3600) / 60 @@ -129,19 +219,19 @@ class ExperimentSerializer(ShareableSerializer): minutes = 0 hours += 1 - return '%dh%02d' % (hours, minutes) + return "%dh%02d" % (hours, minutes) def get_cpu_time(self, obj): - return self.__get_statistics(obj, 'cpu_time') + return self.__get_statistics(obj, "cpu_time") def get_data_read(self, obj): - return self.__get_statistics(obj, 'data_read') + return self.__get_statistics(obj, "data_read") def get_data_written(self, obj): - return self.__get_statistics(obj, 'data_written') + return self.__get_statistics(obj, "data_written") -#---------------------------------------------------------- +# ---------------------------------------------------------- class AnalyzerSerializer(serializers.ModelSerializer): @@ -149,7 +239,7 @@ class AnalyzerSerializer(serializers.ModelSerializer): class Meta: model = Block - fields = '__all__' + fields = "__all__" def get_results(self, obj): results = {} @@ -158,48 +248,48 @@ class AnalyzerSerializer(serializers.ModelSerializer): if db_results.count() > 0: for result_entry in db_results: results[result_entry.name] = { - 'type': result_entry.type, - 'primary': result_entry.primary, - 'value': result_entry.value(), + "type": result_entry.type, + "primary": result_entry.primary, + "value": result_entry.value(), } else: dataformat_declaration = json.loads(obj.algorithm.result_dataformat) for field, type in dataformat_declaration.items(): - primary = (field[0] == '+') + primary = field[0] == "+" results[field[1:] if primary else field] = { - 'type': type, - 'primary': primary, - 'value': None, + "type": type, + "primary": primary, + "value": None, } return results -#---------------------------------------------------------- +# ---------------------------------------------------------- class BlockErrorSerializer(serializers.ModelSerializer): - block = serializers.CharField(source='name') - algorithm = serializers.CharField(source='algorithm.fullname') + block = serializers.CharField(source="name") + algorithm = serializers.CharField(source="algorithm.fullname") stdout = serializers.CharField() stderr = serializers.CharField() - details = serializers.CharField(source='error_report') + details = serializers.CharField(source="error_report") class Meta: model = Block - fields = '__all__' + fields = "__all__" -#---------------------------------------------------------- +# ---------------------------------------------------------- class ExperimentResultsSerializer(ShareableSerializer): started = serializers.SerializerMethodField() done = serializers.SerializerMethodField() failed = serializers.SerializerMethodField() - attestation = serializers.IntegerField(source='attestation.number') + attestation = serializers.IntegerField(source="attestation.number") blocks_status = serializers.SerializerMethodField() execution_info = serializers.SerializerMethodField() execution_order = serializers.SerializerMethodField() @@ -208,20 +298,35 @@ class ExperimentResultsSerializer(ShareableSerializer): errors = serializers.SerializerMethodField() html_description = serializers.SerializerMethodField() description = serializers.SerializerMethodField() - declaration = JSONSerializerField() + declaration = beat_fields.JSONField() display_start_date = serializers.SerializerMethodField() display_end_date = serializers.SerializerMethodField() class Meta(ShareableSerializer.Meta): model = Experiment - fields = '__all__' - default_fields = ['started', 'done', 'failed', 'status', 'blocks_status', - 'results', 'attestation', 'declaration', - 'errors', 'sharing', 'accessibility', - 'execution_info', 'execution_order'] + fields = "__all__" + default_fields = [ + "started", + "done", + "failed", + "status", + "blocks_status", + "results", + "attestation", + "declaration", + "errors", + "sharing", + "accessibility", + "execution_info", + "execution_order", + ] def get_started(self, obj): - return obj.status not in [Experiment.PENDING, Experiment.SCHEDULED, Experiment.CANCELLING] + return obj.status not in [ + Experiment.PENDING, + Experiment.SCHEDULED, + Experiment.CANCELLING, + ] def get_done(self, obj): return obj.status in [Experiment.DONE, Experiment.FAILED] @@ -233,33 +338,33 @@ class ExperimentResultsSerializer(ShareableSerializer): results = {} for block in obj.blocks.iterator(): if block.status == Block.DONE: - results[block.name] = 'generated' + results[block.name] = "generated" elif block.status == Block.PENDING: - results[block.name] = 'pending' + results[block.name] = "pending" elif block.status == Block.FAILED: - results[block.name] = 'failed' + results[block.name] = "failed" elif block.status == Block.CANCELLED: - results[block.name] = 'cancelled' + results[block.name] = "cancelled" elif obj.status == Experiment.CANCELLING: - results[block.name] = 'cancelling' + results[block.name] = "cancelling" else: - results[block.name] = 'processing' + results[block.name] = "processing" return results def get_execution_info(self, obj): results = {} for block in obj.blocks.iterator(): results[block.name] = { - 'linear_execution_time': block.linear_execution_time(), - 'speed_up_real': block.speed_up_real(), - 'speed_up_maximal': block.speed_up_maximal(), - 'queuing_time': block.queuing_time(), + "linear_execution_time": block.linear_execution_time(), + "speed_up_real": block.speed_up_real(), + "speed_up_maximal": block.speed_up_maximal(), + "queuing_time": block.queuing_time(), } return results def get_execution_order(self, obj): results = [] - for block in obj.blocks.order_by('execution_order').iterator(): + for block in obj.blocks.order_by("execution_order").iterator(): results.append(block.name) return results @@ -267,18 +372,20 @@ class ExperimentResultsSerializer(ShareableSerializer): results = {} for k in obj.blocks.filter(analyzer=True).all(): serializer = AnalyzerSerializer(k) - results[k.name] = serializer.data['results'] + results[k.name] = serializer.data["results"] return results def get_errors(self, obj): - serializer = BlockErrorSerializer(obj.blocks.filter(outputs__error_report__isnull=False), many=True) + serializer = BlockErrorSerializer( + obj.blocks.filter(outputs__error_report__isnull=False), many=True + ) return serializer.data def get_html_description(self, obj): d = obj.description if len(d) > 0: return restructuredtext(d) - return '' + return "" def get_description(self, obj): return obj.description @@ -288,8 +395,8 @@ class ExperimentResultsSerializer(ShareableSerializer): return None return { - 'date': obj.start_date.strftime("%b %d, %Y, %-H:%M"), - 'natural': naturaltime(obj.start_date), + "date": obj.start_date.strftime("%b %d, %Y, %-H:%M"), + "natural": naturaltime(obj.start_date), } def get_display_end_date(self, obj): @@ -297,6 +404,6 @@ class ExperimentResultsSerializer(ShareableSerializer): return None return { - 'date': obj.end_date.strftime("%b %d, %Y, %-H:%M"), - 'natural': naturaltime(obj.end_date), + "date": obj.end_date.strftime("%b %d, %Y, %-H:%M"), + "natural": naturaltime(obj.end_date), } diff --git a/beat/web/experiments/static/experiments/js/panels.js b/beat/web/experiments/static/experiments/js/panels.js index b2281934a5a14341d8367f9a0229c52845c1284f..45359c3200455ad083a6a03ae792d4718e707b80 100644 --- a/beat/web/experiments/static/experiments/js/panels.js +++ b/beat/web/experiments/static/experiments/js/panels.js @@ -108,7 +108,7 @@ beat.experiments.panels.Settings = function(panel_id, toolchain_name, d.done(function(data) { if (queue) { - window.location = url_prefix + '/experiments/' + data.name.split('/')[0] + '/'; + window.location = url_prefix + '/experiments/' + username + '/'; return; } diff --git a/beat/web/libraries/api.py b/beat/web/libraries/api.py index a46b212d55c0f2e02876760a28dbdbe063dd45c1..1e40b3c7b6a269996b04668e3066233b4670b4f7 100644 --- a/beat/web/libraries/api.py +++ b/beat/web/libraries/api.py @@ -29,16 +29,20 @@ from .models import Library from .serializers import LibrarySerializer from .serializers import FullLibrarySerializer from .serializers import LibraryCreationSerializer +from .serializers import LibraryModSerializer from ..code.api import ShareCodeView, RetrieveUpdateDestroyCodeView from ..code.serializers import CodeDiffSerializer -from ..common.api import (CheckContributionNameView, ListContributionView, - ListCreateContributionView) +from ..common.api import ( + CheckContributionNameView, + ListContributionView, + ListCreateContributionView, +) from ..code.api import DiffView -#---------------------------------------------------------- +# ---------------------------------------------------------- class CheckLibraryNameView(CheckContributionNameView): @@ -46,10 +50,11 @@ class CheckLibraryNameView(CheckContributionNameView): This view sanitizes a library name and checks whether it is already used. """ + model = Library -#---------------------------------------------------------- +# ---------------------------------------------------------- class ShareLibraryView(ShareCodeView): @@ -57,21 +62,23 @@ class ShareLibraryView(ShareCodeView): This view allows to share a library with other users and/or teams """ + model = Library -#---------------------------------------------------------- +# ---------------------------------------------------------- class ListLibrariesView(ListContributionView): """ List all available libraries """ + model = Library serializer_class = LibrarySerializer -#---------------------------------------------------------- +# ---------------------------------------------------------- class ListCreateLibrariesView(ListCreateContributionView): @@ -79,44 +86,33 @@ class ListCreateLibrariesView(ListCreateContributionView): Read/Write end point that list the libraries available from a given author and allows the creation of new libraries """ + model = Library serializer_class = LibrarySerializer writing_serializer_class = LibraryCreationSerializer - namespace = 'api_libraries' + namespace = "api_libraries" -#---------------------------------------------------------- +# ---------------------------------------------------------- class RetrieveUpdateDestroyLibrariesView(RetrieveUpdateDestroyCodeView): """ Read/Write/Delete endpoint for a given library """ + model = Library serializer_class = FullLibrarySerializer + writing_serializer_class = LibraryModSerializer - def do_update(self, request, author_name, object_name, version=None): - modified, library = super(RetrieveUpdateDestroyLibrariesView, self).do_update(request, author_name, object_name, version) - - if modified: - # Delete existing experiments using the library (code changed) - experiments = [] - for item in library.referencing.all(): - for algorithm in item.used_by_algorithms.all(): - experiments.append(list(set(map(lambda x: x.experiment, - algorithm.blocks.iterator())))) - for experiment in set(experiments): experiment.delete() - - return modified, library - - -#---------------------------------------------------------- +# ---------------------------------------------------------- class DiffLibraryView(DiffView): """ This view shows the differences between two libraries """ + model = Library serializer_class = CodeDiffSerializer diff --git a/beat/web/libraries/api_urls.py b/beat/web/libraries/api_urls.py index a4f30298b87e1d727b58322e64e178961aba1ef2..8b1248fc1ce5c8ed7ecf648df9490ce5caa41794 100644 --- a/beat/web/libraries/api_urls.py +++ b/beat/web/libraries/api_urls.py @@ -29,41 +29,28 @@ from django.conf.urls import url from . import api -urlpatterns = [ - - url(r'^$', - api.ListLibrariesView.as_view(), - name='all', - ), - - url(r'^check_name/$', - api.CheckLibraryNameView.as_view(), - name='check_name', - ), - url(r'^diff/(?P<author1>\w+)/(?P<name1>[-\w]+)/(?P<version1>\d+)/(?P<author2>\w+)/(?P<name2>[-\w]+)/(?P<version2>\d+)/$', +urlpatterns = [ + url(r"^$", api.ListLibrariesView.as_view(), name="all"), + url(r"^check_name/$", api.CheckLibraryNameView.as_view(), name="check_name"), + url( + r"^diff/(?P<author1>\w+)/(?P<name1>[-\w]+)/(?P<version1>\d+)/(?P<author2>\w+)/(?P<name2>[-\w]+)/(?P<version2>\d+)/$", api.DiffLibraryView.as_view(), - name='diff', - ), - - url(r'^(?P<author_name>\w+)/(?P<object_name>[-\w]+)/(?P<version>\d+)/share/$', + name="diff", + ), + url( + r"^(?P<author_name>\w+)/(?P<object_name>[-\w]+)/(?P<version>\d+)/share/$", api.ShareLibraryView.as_view(), - name='share', - ), - - url(r'^(?P<author_name>\w+)/$', + name="share", + ), + url( + r"^(?P<author_name>\w+)/$", api.ListCreateLibrariesView.as_view(), - name='list_create', - ), - - url(r'^(?P<author_name>\w+)/(?P<object_name>[-\w]+)/(?P<version>\d+)/$', + name="list_create", + ), + url( + r"^(?P<author_name>\w+)/(?P<object_name>[-\w]+)/(?P<version>\d+)/$", api.RetrieveUpdateDestroyLibrariesView.as_view(), - name='object', - ), - - url(r'^(?P<author_name>\w+)/(?P<object_name>[-\w]+)/$', - api.RetrieveUpdateDestroyLibrariesView.as_view(), - name='object', - ), - + name="object", + ), ] diff --git a/beat/web/libraries/serializers.py b/beat/web/libraries/serializers.py index 4fac5d1c88da2d6abc049b3ddbde4ddd3401a4d9..e5a93c04498c9ecec2e3083bfd42677ab1d2e454 100644 --- a/beat/web/libraries/serializers.py +++ b/beat/web/libraries/serializers.py @@ -29,7 +29,9 @@ import beat.core.library from rest_framework import serializers -from ..code.serializers import CodeSerializer, CodeCreationSerializer +from ..code.serializers import CodeSerializer +from ..code.serializers import CodeCreationSerializer +from ..code.serializers import CodeModSerializer from ..algorithms.models import Algorithm from ..backend.serializers import EnvironmentInfoSerializer @@ -43,7 +45,16 @@ from .models import Library class LibraryCreationSerializer(CodeCreationSerializer): class Meta(CodeCreationSerializer.Meta): model = Library - beat_core_class = beat.core.library + beat_core_class = beat.core.library.Library + + +# ---------------------------------------------------------- + + +class LibraryModSerializer(CodeModSerializer): + class Meta(CodeModSerializer.Meta): + model = Library + beat_core_class = beat.core.library.Library # ---------------------------------------------------------- diff --git a/beat/web/libraries/tests/tests_api.py b/beat/web/libraries/tests/tests_api.py index 7b6e62b4de4ee501e50285169c9505d9e51ac6df..1a26fa1eeed7acf0667b12cb45ef336ce42d0696 100644 --- a/beat/web/libraries/tests/tests_api.py +++ b/beat/web/libraries/tests/tests_api.py @@ -833,18 +833,25 @@ class LibraryUpdate(LibrariesAPIBase): def test_no_update_without_content(self): self.login_jackdoe() response = self.client.put(self.url) - self.checkResponse(response, 400) + self.checkResponse(response, 200, content_type="application/json") - def test_successfull_update(self): + def test_successful_update(self): self.login_jackdoe() + code = b"import numpy as np" response = self.client.put( self.url, - json.dumps({"description": "blah", "declaration": LibrariesAPIBase.UPDATE}), + json.dumps( + { + "description": "blah", + "declaration": LibrariesAPIBase.UPDATE, + "code": code, + } + ), content_type="application/json", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") library = Library.objects.get( author__username="jackdoe", name="personal", version=1 @@ -854,9 +861,60 @@ class LibraryUpdate(LibrariesAPIBase): storage = beat.core.library.Storage(settings.PREFIX, library.fullname()) storage.language = "python" self.assertTrue(storage.exists()) - self.assertEqual(storage.json.load(), LibrariesAPIBase.UPDATE) + self.assertEqual(library.description, b"blah") + self.assertEqual( + json.loads(storage.json.load()), json.loads(LibrariesAPIBase.UPDATE) + ) + self.assertEqual(storage.code.load(), code) + + def test_successful_update_with_specific_return_field(self): + self.login_jackdoe() + + code = b"""import numpy as np""" + + response = self.client.put( + self.url, + json.dumps( + { + "description": "blah", + "declaration": LibrariesAPIBase.UPDATE, + "code": code, + } + ), + content_type="application/json", + QUERY_STRING="fields=code", + ) - def test_successfull_update_description_only(self): + self.checkResponse(response, 200, content_type="application/json") + answer = response.json() + self.assertEqual(len(answer), 1) + self.assertTrue("code" in answer) + + def test_successful_update_with_specific_return_several_fields(self): + self.login_jackdoe() + + code = b"""import numpy as np""" + + response = self.client.put( + self.url, + json.dumps( + { + "description": "blah", + "declaration": LibrariesAPIBase.UPDATE, + "code": code, + } + ), + content_type="application/json", + QUERY_STRING="fields=code,description", + ) + + self.checkResponse(response, 200, content_type="application/json") + answer = response.json() + self.assertEqual(len(answer), 2) + self.assertTrue("code" in answer) + self.assertTrue("description" in answer) + + def test_successful_update_description_only(self): self.login_jackdoe() response = self.client.put( @@ -865,14 +923,14 @@ class LibraryUpdate(LibrariesAPIBase): content_type="application/json", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") library = Library.objects.get( author__username="jackdoe", name="personal", version=1 ) self.assertEqual(library.description, b"blah") - def test_successfull_update_code_only(self): + def test_successful_update_declaration_only(self): self.login_jackdoe() response = self.client.put( @@ -881,7 +939,7 @@ class LibraryUpdate(LibrariesAPIBase): content_type="application/json", ) - self.checkResponse(response, 204) + self.checkResponse(response, 200, content_type="application/json") library = Library.objects.get( author__username="jackdoe", name="personal", version=1 @@ -890,37 +948,59 @@ class LibraryUpdate(LibrariesAPIBase): storage = beat.core.library.Storage(settings.PREFIX, library.fullname()) storage.language = "python" self.assertTrue(storage.exists()) - self.assertEqual(storage.json.load(), LibrariesAPIBase.UPDATE) + self.assertEqual( + json.loads(storage.json.load()), json.loads(LibrariesAPIBase.UPDATE) + ) + + def test_successful_update_code_only(self): + self.login_jackdoe() + + code = b"import numpy as np" + + response = self.client.put( + self.url, json.dumps({"code": code}), content_type="application/json" + ) + + self.checkResponse(response, 200, content_type="application/json") + + library = Library.objects.get( + author__username="jackdoe", name="personal", version=1 + ) + + storage = beat.core.library.Storage(settings.PREFIX, library.fullname()) + storage.language = "python" + self.assertTrue(storage.exists()) + self.assertEqual(storage.code.load(), code) class LibraryRetrieval(LibrariesAPIBase): def test_no_retrieval_of_confidential_library_for_anonymous_user(self): - url = reverse("api_libraries:object", args=["johndoe", "forked_algo"]) + url = reverse("api_libraries:object", args=["johndoe", "forked_algo", 1]) response = self.client.get(url) self.checkResponse(response, 404) def test_fail_to_retrieve_with_invalid_username(self): self.login_johndoe() - url = reverse("api_libraries:object", args=["unknown", "forked_algo"]) + url = reverse("api_libraries:object", args=["unknown", "forked_algo", 1]) response = self.client.get(url) self.checkResponse(response, 404) def test_fail_to_retrieve_with_invalid_library_name(self): self.login_johndoe() - url = reverse("api_libraries:object", args=["johndoe", "unknown"]) + url = reverse("api_libraries:object", args=["johndoe", "unknown", 1]) response = self.client.get(url) self.checkResponse(response, 404) def test_no_retrieval_of_confidential_library(self): self.login_jackdoe() - url = reverse("api_libraries:object", args=["johndoe", "forked_algo"]) + url = reverse("api_libraries:object", args=["johndoe", "forked_algo", 1]) response = self.client.get(url) self.checkResponse(response, 404) def test_successful_retrieval_of_public_library_for_anonymous_user(self): - url = reverse("api_libraries:object", args=["jackdoe", "public_for_all"]) + url = reverse("api_libraries:object", args=["jackdoe", "public_for_all", 1]) response = self.client.get(url) data = self.checkResponse(response, 200, content_type="application/json") @@ -933,7 +1013,7 @@ class LibraryRetrieval(LibrariesAPIBase): self.assertEqual(data["code"].encode("utf-8"), LibrariesAPIBase.CODE) def test_successful_retrieval_of_usable_library_for_anonymous_user(self): - url = reverse("api_libraries:object", args=["jackdoe", "usable_by_all"]) + url = reverse("api_libraries:object", args=["jackdoe", "usable_by_all", 1]) response = self.client.get(url) data = self.checkResponse(response, 200, content_type="application/json") @@ -946,7 +1026,7 @@ class LibraryRetrieval(LibrariesAPIBase): def test_successful_retrieval_of_public_library(self): self.login_johndoe() - url = reverse("api_libraries:object", args=["jackdoe", "public_for_all"]) + url = reverse("api_libraries:object", args=["jackdoe", "public_for_all", 1]) response = self.client.get(url) data = self.checkResponse(response, 200, content_type="application/json") @@ -961,7 +1041,7 @@ class LibraryRetrieval(LibrariesAPIBase): def test_successful_retrieval_of_usable_library(self): self.login_johndoe() - url = reverse("api_libraries:object", args=["jackdoe", "usable_by_one_user"]) + url = reverse("api_libraries:object", args=["jackdoe", "usable_by_one_user", 1]) response = self.client.get(url) data = self.checkResponse(response, 200, content_type="application/json") @@ -974,7 +1054,7 @@ class LibraryRetrieval(LibrariesAPIBase): def test_successful_retrieval_of_publicly_usable_library(self): self.login_johndoe() - url = reverse("api_libraries:object", args=["jackdoe", "usable_by_all"]) + url = reverse("api_libraries:object", args=["jackdoe", "usable_by_all", 1]) response = self.client.get(url) data = self.checkResponse(response, 200, content_type="application/json") @@ -987,7 +1067,9 @@ class LibraryRetrieval(LibrariesAPIBase): def test_successful_retrieval_of_confidential_library(self): self.login_johndoe() - url = reverse("api_libraries:object", args=["jackdoe", "public_for_one_user"]) + url = reverse( + "api_libraries:object", args=["jackdoe", "public_for_one_user", 1] + ) response = self.client.get(url) data = self.checkResponse(response, 200, content_type="application/json") @@ -1002,7 +1084,7 @@ class LibraryRetrieval(LibrariesAPIBase): def test_successful_retrieval_of_own_public_library(self): self.login_jackdoe() - url = reverse("api_libraries:object", args=["jackdoe", "public_for_all"]) + url = reverse("api_libraries:object", args=["jackdoe", "public_for_all", 1]) response = self.client.get(url) data = self.checkResponse(response, 200, content_type="application/json") @@ -1022,7 +1104,7 @@ class LibraryRetrieval(LibrariesAPIBase): def test_successful_retrieval_of_own_confidential_library(self): self.login_jackdoe() - url = reverse("api_libraries:object", args=["jackdoe", "usable_by_one_user"]) + url = reverse("api_libraries:object", args=["jackdoe", "usable_by_one_user", 1]) response = self.client.get(url) data = self.checkResponse(response, 200, content_type="application/json") @@ -1042,7 +1124,7 @@ class LibraryRetrieval(LibrariesAPIBase): def test_successful_retrieval_of_own_usable_library(self): self.login_jackdoe() - url = reverse("api_libraries:object", args=["jackdoe", "usable_by_all"]) + url = reverse("api_libraries:object", args=["jackdoe", "usable_by_all", 1]) response = self.client.get(url) data = self.checkResponse(response, 200, content_type="application/json") @@ -1062,7 +1144,9 @@ class LibraryRetrieval(LibrariesAPIBase): def test_successful_retrieval_of_own_shared_library(self): self.login_jackdoe() - url = reverse("api_libraries:object", args=["jackdoe", "public_for_one_user"]) + url = reverse( + "api_libraries:object", args=["jackdoe", "public_for_one_user", 1] + ) response = self.client.get(url) data = self.checkResponse(response, 200, content_type="application/json") diff --git a/beat/web/plotters/api_urls.py b/beat/web/plotters/api_urls.py index 4f954893948159b685b3270b1d3b813b294cc9a1..8a2c447aaa1ca5ddcd4854c42f4ba371708f4afc 100644 --- a/beat/web/plotters/api_urls.py +++ b/beat/web/plotters/api_urls.py @@ -29,34 +29,53 @@ from django.conf.urls import url from . import api -urlpatterns = [ - url(r'^$', api.ListPlotterView.as_view(), name='all'), - url(r'^format/(?P<author_name>\w+)/(?P<dataformat_name>[a-zA-Z0-9_\-]+)/(?P<version>\d+)/$', api.ListFormatPlotterView.as_view(), name='object'), - url(r'^plotterparameters/(?P<author_name>\w+)/(?P<object_name>[a-zA-Z0-9_\-]+)/(?P<version>\d+)/share/$', +urlpatterns = [ + url(r"^$", api.ListPlotterView.as_view(), name="all"), + url( + r"^format/(?P<author_name>\w+)/(?P<dataformat_name>[a-zA-Z0-9_\-]+)/(?P<version>\d+)/$", + api.ListFormatPlotterView.as_view(), + name="object", + ), + url( + r"^plotterparameters/(?P<author_name>\w+)/(?P<object_name>[a-zA-Z0-9_\-]+)/(?P<version>\d+)/share/$", api.SharePlotterParameterView.as_view(), - name='share', - ), - - url(r'^plotterparameters/(?P<author_name>\w+)/(?P<object_name>[a-zA-Z0-9_\-]+)/(?P<version>\d+)/$', api.RetrieveUpdateDestroyPlotterParametersView.as_view(), name='view'), - url(r'^plotterparameters/(?P<author_name>\w+)/$', api.ListPlotterParametersView.as_view(), name='view'), - url(r'^plotterparameters/$', api.ListPlotterParameterView.as_view(), name='all_plotterparameter'), - url(r'^plotterparameter/(?P<author_name>\w+)/(?P<dataformat_name>[a-zA-Z0-9_\-]+)/(?P<version>\d+)/$', api.ListPlotterParameterView.as_view(), name='plotterparameter'), - url(r'^defaultplotters/$', api.ListDefaultPlotterView.as_view(), name='all_defaultplotters'), - - url(r'^check_name/$', - api.CheckPlotterNameView.as_view(), - name='check_name', - ), - - - url(r'^(?P<author_name>\w+)/$', + name="share", + ), + url( + r"^plotterparameters/(?P<author_name>\w+)/(?P<object_name>[a-zA-Z0-9_\-]+)/(?P<version>\d+)/$", + api.RetrieveUpdateDestroyPlotterParametersView.as_view(), + name="view", + ), + url( + r"^plotterparameters/(?P<author_name>\w+)/$", + api.ListPlotterParametersView.as_view(), + name="view", + ), + url( + r"^plotterparameters/$", + api.ListPlotterParameterView.as_view(), + name="all_plotterparameter", + ), + url( + r"^plotterparameter/(?P<author_name>\w+)/(?P<dataformat_name>[a-zA-Z0-9_\-]+)/(?P<version>\d+)/$", + api.ListPlotterParameterView.as_view(), + name="plotterparameter", + ), + url( + r"^defaultplotters/$", + api.ListDefaultPlotterView.as_view(), + name="all_defaultplotters", + ), + url(r"^check_name/$", api.CheckPlotterNameView.as_view(), name="check_name"), + url( + r"^(?P<author_name>\w+)/$", api.ListCreatePlottersView.as_view(), - name='list_create', - ), - - url(r'^(?P<author_name>\w+)/(?P<object_name>[-\w]+)/(?P<version>\d+)/$', + name="list_create", + ), + url( + r"^(?P<author_name>\w+)/(?P<object_name>[-\w]+)/(?P<version>\d+)/$", api.RetrieveUpdateDestroyPlottersView.as_view(), - name='object', - ), + name="object", + ), ] diff --git a/beat/web/plotters/serializers.py b/beat/web/plotters/serializers.py index 6a8d5cb260f16ad0e16a29c6febd1b2482d31dc8..cd8431f834befa9f26069f971ea50971378231b1 100644 --- a/beat/web/plotters/serializers.py +++ b/beat/web/plotters/serializers.py @@ -25,19 +25,27 @@ # # ############################################################################### -from ..common.serializers import DynamicFieldsSerializer, ContributionSerializer, ContributionCreationSerializer -from .models import Plotter, PlotterParameter, DefaultPlotter + +import simplejson as json + +import beat.core.plotter +import beat.core.plotterparameter + +from django.utils.encoding import smart_text from rest_framework import serializers +from ..common import fields as beat_fields +from ..common.serializers import ( + DynamicFieldsSerializer, + ContributionSerializer, + ContributionCreationSerializer, +) + from ..code.serializers import CodeSerializer, CodeCreationSerializer from ..libraries.serializers import LibraryReferenceSerializer -from ..dataformats.serializers import ReferencedDataFormatSerializer -from django.utils.encoding import smart_str -from django.utils.encoding import smart_text +from .models import Plotter, PlotterParameter, DefaultPlotter -import beat.core.plotter -import simplejson as json class PlotterSerializer(ContributionSerializer): @@ -46,141 +54,175 @@ class PlotterSerializer(ContributionSerializer): class Meta(ContributionSerializer.Meta): model = Plotter default_fields = [ - #'name', 'dataformat', - 'id', 'accessibility', 'modifiable', 'deletable', 'is_owner', 'name', 'dataformat', 'fork_of', 'last_version', 'previous_version', 'short_description', 'description', 'version', 'creation_date', 'data', 'sample_data', 'declaration', + # 'name', 'dataformat', + "id", + "accessibility", + "modifiable", + "deletable", + "is_owner", + "name", + "dataformat", + "fork_of", + "last_version", + "previous_version", + "short_description", + "description", + "version", + "creation_date", + "data", + "sample_data", + "declaration", ] -class PlotterParameterSerializer(ContributionSerializer): +class PlotterParameterSerializer(ContributionSerializer): class Meta(ContributionSerializer.Meta): model = PlotterParameter exclude = [] - default_fields = [ - 'name', 'plotter', - ] + default_fields = ["name", "plotter"] + class DefaultPlotterSerializer(DynamicFieldsSerializer): dataformat = serializers.CharField(source="dataformat.fullname") - plotter = serializers.CharField(source="plotter.fullname") - parameter = serializers.CharField(source="parameter.fullname") + plotter = serializers.CharField(source="plotter.fullname") + parameter = serializers.CharField(source="parameter.fullname") class Meta(DynamicFieldsSerializer.Meta): model = DefaultPlotter exclude = [] - default_fields = [ - 'dataformat', 'plotter', 'parameter', - ] + default_fields = ["dataformat", "plotter", "parameter"] -#---------------------------------------------------------- +# ---------------------------------------------------------- class PlotterCreationSerializer(CodeCreationSerializer): class Meta(CodeCreationSerializer.Meta): model = Plotter - beat_core_class = beat.core.plotter + beat_core_class = beat.core.plotter.Plotter -#---------------------------------------------------------- +# ---------------------------------------------------------- class PlotterAllSerializer(CodeSerializer): - dataformat = serializers.SerializerMethodField() - declaration_file = serializers.SerializerMethodField() - description_file = serializers.SerializerMethodField() - source_code_file = serializers.SerializerMethodField() - referenced_libraries = LibraryReferenceSerializer(many=True) + dataformat = serializers.SerializerMethodField() + declaration_file = serializers.SerializerMethodField() + description_file = serializers.SerializerMethodField() + source_code_file = serializers.SerializerMethodField() + referenced_libraries = LibraryReferenceSerializer(many=True) class Meta(CodeSerializer.Meta): model = Plotter -#---------------------------------------------------------- +# ---------------------------------------------------------- class FullPlotterSerializer(PlotterAllSerializer): - class Meta(PlotterAllSerializer.Meta): - default_fields = PlotterAllSerializer.Meta.default_fields + PlotterAllSerializer.Meta.extra_fields + default_fields = ( + PlotterAllSerializer.Meta.default_fields + + PlotterAllSerializer.Meta.extra_fields + ) exclude = [] -#---------------------------------------------------------- +# ---------------------------------------------------------- + class PlotterParameterCreationFailedException(Exception): pass + class PlotterParameterCreationSerializer(ContributionCreationSerializer): - data = serializers.JSONField(required=False) + data = beat_fields.JSONField(required=False) class Meta(ContributionCreationSerializer.Meta): model = PlotterParameter - fields = ['name', 'plotter', 'data', 'version', 'previous_version', 'short_description', 'description', 'fork_of'] - #beat_core_class = beat.core.PlotterParameter + fields = [ + "name", + "plotter", + "data", + "version", + "previous_version", + "short_description", + "description", + "fork_of", + ] + beat_core_class = beat.core.plotterparameter.Plotterparameter def create(self, validated_data): plotterparameter = None if "name" not in validated_data: - raise serializers.ValidationError('No name provided') + raise serializers.ValidationError("No name provided") - try: - plotterparameter = PlotterParameter.objects.get(author=self.context['request'].user, name=validated_data['name'], version=validated_data['version']) - except: - pass - - if plotterparameter is not None: - raise serializers.ValidationError('A plotterparameter with this name already exists') + if PlotterParameter.objects.filter( + author=self.context["request"].user, + name=validated_data["name"], + version=validated_data["version"], + ).exists(): + raise serializers.ValidationError( + "A plotterparameter with this name already exists" + ) if "plotter" not in self.data: - raise serializers.ValidationError('No plotter provided') + raise serializers.ValidationError("No plotter provided") - plotter = None try: - plotter = Plotter.objects.get(id=self.data['plotter']) - except: - pass - - if plotter is None: - raise serializers.ValidationError('Required plotter does not exist') + Plotter.objects.get(id=self.data["plotter"]) + except Exception: + raise serializers.ValidationError("Required plotter does not exist") if "data" not in validated_data: - validated_data['data'] = {} + validated_data["data"] = {} - #Only create new version for latest version + # Only create new version for latest version if "previous_version" in validated_data: - if validated_data['previous_version'].version < validated_data['version'] - 1: - raise serializers.ValidationError('A new version for this plotterparameter version already exist') - - #add description/short_description to new version - validated_data['short_description'] = validated_data['previous_version'].short_description - validated_data['description'] = validated_data['previous_version'].description - - #Create fork + if ( + validated_data["previous_version"].version + < validated_data["version"] - 1 + ): + raise serializers.ValidationError( + "A new version for this plotterparameter version already exist" + ) + + # add description/short_description to new version + validated_data["short_description"] = validated_data[ + "previous_version" + ].short_description + validated_data["description"] = validated_data[ + "previous_version" + ].description + + # Create fork if "fork_of" in validated_data: - #add description/short_description to new version - validated_data['short_description'] = validated_data['fork_of'].short_description - validated_data['description'] = validated_data['fork_of'].description - + # add description/short_description to new version + validated_data["short_description"] = validated_data[ + "fork_of" + ].short_description + validated_data["description"] = validated_data["fork_of"].description plotterparameter = PlotterParameter.objects.create(**validated_data) if plotterparameter is None: raise PlotterParameterCreationFailedException() return plotterparameter -#---------------------------------------------------------- + +# ---------------------------------------------------------- class PlotterParameterAllSerializer(ContributionSerializer): - data = serializers.SerializerMethodField() - description = serializers.SerializerMethodField() - plotter = serializers.SerializerMethodField() + data = serializers.SerializerMethodField() + description = serializers.SerializerMethodField() + plotter = serializers.SerializerMethodField() class Meta(ContributionSerializer.Meta): model = PlotterParameter - #def get_referencing_experiments(self, obj): + # def get_referencing_experiments(self, obj): # user = self.context.get('user') # experiments = obj.experiments.for_user(user, True).order_by('-creation_date') @@ -193,25 +235,48 @@ class PlotterParameterAllSerializer(ContributionSerializer): # return ordered_result - #def get_new_experiment_url(self, obj): + # def get_new_experiment_url(self, obj): # return obj.get_new_experiment_url() -#---------------------------------------------------------- + +# ---------------------------------------------------------- + class FullPlotterParameterSerializer(PlotterParameterAllSerializer): plotters = serializers.SerializerMethodField() class Meta(PlotterParameterAllSerializer.Meta): - #exclude = ['declaration'] + # exclude = ['declaration'] exclude = [] - #default_fields = PlotterParameterAllSerializer.Meta.default_fields + PlotterParameterAllSerializer.Meta.extra_fields - default_fields = ['id', 'accessibility', 'modifiable', 'deletable', 'is_owner', 'name', 'fork_of', 'last_version', 'previous_version', 'short_description', 'description', 'version', 'creation_date', 'data', 'plotter', 'plotters'] + # default_fields = PlotterParameterAllSerializer.Meta.default_fields + PlotterParameterAllSerializer.Meta.extra_fields + default_fields = [ + "id", + "accessibility", + "modifiable", + "deletable", + "is_owner", + "name", + "fork_of", + "last_version", + "previous_version", + "short_description", + "description", + "version", + "creation_date", + "data", + "plotter", + "plotters", + ] def get_description(self, obj): - return smart_text(obj.description, encoding='utf-8', strings_only=False, errors='strict') + return smart_text( + obj.description, encoding="utf-8", strings_only=False, errors="strict" + ) def get_short_description(self, obj): - return smart_text(obj.short_description, encoding='utf-8', strings_only=False, errors='strict') + return smart_text( + obj.short_description, encoding="utf-8", strings_only=False, errors="strict" + ) def get_data(self, obj): return json.loads(obj.data) @@ -226,24 +291,44 @@ class FullPlotterParameterSerializer(PlotterParameterAllSerializer): all_plotters = Plotter.objects.all() results = {} for plotter in all_plotters.iterator(): - serializer = FullPlotterSerializer(plotter, context=self.context, fields=['id', 'accessibility', 'modifiable', 'deletable', 'is_owner', 'name', 'fork_of', 'last_version', 'previous_version', 'short_description', 'description', 'version', 'creation_date', 'data', 'sample_data', 'declaration']) + serializer = FullPlotterSerializer( + plotter, + context=self.context, + fields=[ + "id", + "accessibility", + "modifiable", + "deletable", + "is_owner", + "name", + "fork_of", + "last_version", + "previous_version", + "short_description", + "description", + "version", + "creation_date", + "data", + "sample_data", + "declaration", + ], + ) results[plotter.fullname()] = serializer.data return results - - #def get_plotter(self, obj): + # def get_plotter(self, obj): # return obj.author.username - #"accessibility": "public", - #"modifiable": true, - #"deletable": true, - #"is_owner": false, - #"name": "plot/bar/1", - #"fork_of": null, - #"last_version": true, - #"previous_version": null, - #"short_description": "Default parameters for bar plots", - #"description": "Raw content", - #"version": 1, - #"creation_date": "2015-09-03T16:55:47.620000", + # "accessibility": "public", + # "modifiable": true, + # "deletable": true, + # "is_owner": false, + # "name": "plot/bar/1", + # "fork_of": null, + # "last_version": true, + # "previous_version": null, + # "short_description": "Default parameters for bar plots", + # "description": "Raw content", + # "version": 1, + # "creation_date": "2015-09-03T16:55:47.620000", diff --git a/beat/web/reports/api_urls.py b/beat/web/reports/api_urls.py index 1d0440b7a6e7d04424699196c42d5bca94ff4ae6..41b1d78a2058bc03ca747ccfee9ce5da9b2418c8 100644 --- a/beat/web/reports/api_urls.py +++ b/beat/web/reports/api_urls.py @@ -25,86 +25,61 @@ # # ############################################################################### -from django.conf.urls import * +from django.conf.urls import url + from . import api urlpatterns = [ url( - r'^(?P<owner_name>\w+)/(?P<report_name>[\w\W]+)/rst/$', + r"^(?P<owner_name>\w+)/(?P<report_name>[\w\W]+)/rst/$", api.ReportRSTCompileView.as_view(), - name='rst_compiler' + name="rst_compiler", ), - url( - r'^(?P<number>\d+)/rst/$', + r"^(?P<number>\d+)/rst/$", api.ReportRSTCompileAnonView.as_view(), - name='rst_compiler' - ), - - url( - r'^$', - api.ReportListView.as_view(), - name='all' - ), - - url( - r'^(?P<number>\d+)/$', - api.ReportDetailView.as_view(), - name='object_report' - ), - - url( - r'^(?P<number>\d+)/results/$', - api.ReportResultsView.as_view(), - name='results' + name="rst_compiler", ), - + url(r"^$", api.ReportListView.as_view(), name="all"), + url(r"^(?P<number>\d+)/$", api.ReportDetailView.as_view(), name="object_report"), + url(r"^(?P<number>\d+)/results/$", api.ReportResultsView.as_view(), name="results"), url( - r'^(?P<owner_name>\w+)/(?P<report_name>[\w\W]+)/results_author/$', + r"^(?P<owner_name>\w+)/(?P<report_name>[\w\W]+)/results_author/$", api.ReportResultsAllExperimentsView.as_view(), - name='results' + name="results", ), - url( - r'^(?P<owner_name>\w+)/$', - api.UserReportListView.as_view(), - name='list_create' + r"^(?P<owner_name>\w+)/$", api.UserReportListView.as_view(), name="list_create" ), - url( - r'^(?P<owner_name>\w+)/(?P<report_name>[\w\W]+)/add/$', + r"^(?P<owner_name>\w+)/(?P<report_name>[\w\W]+)/add/$", api.ReportAddExperimentsView.as_view(), - name='add_experiments' + name="add_experiments", ), - url( - r'^(?P<owner_name>\w+)/(?P<report_name>[\w\W]+)/remove/$', + r"^(?P<owner_name>\w+)/(?P<report_name>[\w\W]+)/remove/$", api.ReportRemoveExperimentsView.as_view(), - name='remove_experiments' + name="remove_experiments", ), - url( - r'^(?P<owner_name>\w+)/(?P<report_name>[\w\W]+)/lock/$', + r"^(?P<owner_name>\w+)/(?P<report_name>[\w\W]+)/lock/$", api.LockReportView.as_view(), - name='lock' + name="lock", ), - url( - r'^(?P<owner_name>\w+)/(?P<report_name>[\w\W]+)/publish/$', + r"^(?P<owner_name>\w+)/(?P<report_name>[\w\W]+)/publish/$", api.PublishReportView.as_view(), - name='publish' + name="publish", ), - url( - r'^(?P<owner_name>\w+)/(?P<report_name>[\w\W]+)/algorithms/$', + r"^(?P<owner_name>\w+)/(?P<report_name>[\w\W]+)/algorithms/$", api.ReportAlgorithmsView.as_view(), - name='algorithms' + name="algorithms", ), - url( - r'^(?P<owner_name>\w+)/(?P<report_name>[\w\W]+)/$', + r"^(?P<owner_name>\w+)/(?P<report_name>[\w\W]+)/$", api.ReportDetailView.as_view(), - name='object' + name="object", ), ] diff --git a/beat/web/reports/serializers.py b/beat/web/reports/serializers.py index 583667fcfdd6f0b476744aa615bc11fd42b197e9..e954479daecf4ad621b5fe97311ccb6a32bee5fb 100644 --- a/beat/web/reports/serializers.py +++ b/beat/web/reports/serializers.py @@ -25,21 +25,19 @@ # # ############################################################################### -from django.contrib.auth.models import User, AnonymousUser from rest_framework import serializers from .models import Report from ..common.models import Contribution -from ..common.fields import JSONSerializerField -from ..experiments.models import Experiment -from ..ui.templatetags.markup import restructuredtext from ..common.utils import validate_restructuredtext +from ..common import fields as beat_fields +from ..ui.templatetags.markup import restructuredtext import simplejson as json -#---------------------------------------------------------- +# ---------------------------------------------------------- class BasicReportSerializer(serializers.ModelSerializer): @@ -58,10 +56,10 @@ class BasicReportSerializer(serializers.ModelSerializer): class Meta: model = Report - fields = ['short_description', 'status', 'is_owner', 'accessibility'] + fields = ["short_description", "status", "is_owner", "accessibility"] def get_name(self, obj): - return '{}/{}'.format(obj.author.username, obj.name) + return "{}/{}".format(obj.author.username, obj.name) def get_author(self, obj): return obj.author.username @@ -77,25 +75,26 @@ class BasicReportSerializer(serializers.ModelSerializer): def get_status(self, obj): if obj.status == Report.EDITABLE: - return 'editable' + return "editable" elif obj.status == Report.LOCKED: - return 'locked' + return "locked" elif obj.status == Report.PUBLISHED: - return 'published' + return "published" else: - return 'editable' + return "editable" def get_experiments(self, obj): return map(lambda x: x.fullname(), obj.experiments.iterator()) def get_experiment_access_map(self, obj): - user = self.context['request'].user - access_map = list(map(lambda x: x.accessibility_for(user)[0], - obj.experiments.iterator())) + user = self.context["request"].user + access_map = list( + map(lambda x: x.accessibility_for(user)[0], obj.experiments.iterator()) + ) return access_map def get_analyzers_access_map(self, obj): - user = self.context['request'].user + user = self.context["request"].user access_map = list() for exp in obj.experiments.iterator(): # find analyzer @@ -114,99 +113,136 @@ class BasicReportSerializer(serializers.ModelSerializer): d = obj.description if len(d) > 0: return restructuredtext(d) - return '' + return "" def get_add_url(self, obj): return obj.get_api_add_url() -#---------------------------------------------------------- +# ---------------------------------------------------------- class SimpleReportSerializer(BasicReportSerializer): - class Meta(BasicReportSerializer.Meta): - fields = ['name', 'number', 'short_description', 'is_owner', 'author','status', 'description', 'creation_date', 'html_description', 'add_url', 'content'] + fields = [ + "name", + "number", + "short_description", + "is_owner", + "author", + "status", + "description", + "creation_date", + "html_description", + "add_url", + "content", + ] -#---------------------------------------------------------- +# ---------------------------------------------------------- class FullReportSerializer(BasicReportSerializer): - - class Meta(BasicReportSerializer.Meta): - fields = ['name', 'number', 'short_description', 'description', 'is_owner', 'author','status', 'creation_date', 'publication_date', 'experiments', 'analyzer', 'content', 'html_description', 'experiment_access_map', 'analyzers_access_map'] - - -#---------------------------------------------------------- + fields = [ + "name", + "number", + "short_description", + "description", + "is_owner", + "author", + "status", + "creation_date", + "publication_date", + "experiments", + "analyzer", + "content", + "html_description", + "experiment_access_map", + "analyzers_access_map", + ] + + +# ---------------------------------------------------------- class CreatedReportSerializer(BasicReportSerializer): - class Meta(BasicReportSerializer.Meta): - fields = ['name', 'short_description', 'experiments'] + fields = ["name", "short_description", "experiments"] -#---------------------------------------------------------- +# ---------------------------------------------------------- class UpdatedReportSerializer(BasicReportSerializer): - class Meta(BasicReportSerializer.Meta): - fields = ['short_description', 'description', 'experiments', 'content'] + fields = ["short_description", "description", "experiments", "content"] -#---------------------------------------------------------- +# ---------------------------------------------------------- class ReportCreationFailedException(Exception): pass -#---------------------------------------------------------- +# ---------------------------------------------------------- class ReportCreationSerializer(serializers.ModelSerializer): - content = JSONSerializerField() + content = beat_fields.JSONField() experiments = serializers.ListField(child=serializers.CharField()) class Meta: model = Report - fields = ['name', 'short_description', 'description', 'content', 'experiments'] + fields = ["name", "short_description", "description", "content", "experiments"] def create(self, validated_data): report = None - if 'name' not in validated_data: - raise serializers.ValidationError('No name provided') + if "name" not in validated_data: + raise serializers.ValidationError("No name provided") try: - report = Report.objects.get(author=self.context['request'].user, name=validated_data['name']) - except: + report = Report.objects.get( + author=self.context["request"].user, name=validated_data["name"] + ) + except Report.DoesNotExist: pass if report is not None: - raise serializers.ValidationError('A report with this name already exists') + raise serializers.ValidationError("A report with this name already exists") - (report, self.details) = self.Meta.model.objects.create_object(author=self.context['request'].user, **validated_data) + (report, self.details) = self.Meta.model.objects.create_object( + author=self.context["request"].user, **validated_data + ) if report is None: raise ReportCreationFailedException() return report def update(self, instance, validated_data): - instance.short_description = validated_data.get('short_description', instance.short_description) - instance.description = validated_data.get('description', instance.description) - instance.content = validated_data.get('content', instance.content) - instance.publication_date = validated_data.get('publication_date', instance.publication_date) + instance.short_description = validated_data.get( + "short_description", instance.short_description + ) + instance.description = validated_data.get("description", instance.description) + instance.content = validated_data.get("content", instance.content) + instance.publication_date = validated_data.get( + "publication_date", instance.publication_date + ) instance.save() - if 'experiments' in validated_data: - experiments = validated_data.get('experiments', []) - current_experiments = map(lambda x: x.fullname(), instance.experiments.iterator()) + if "experiments" in validated_data: + experiments = validated_data.get("experiments", []) + current_experiments = map( + lambda x: x.fullname(), instance.experiments.iterator() + ) - experiments_to_remove = filter(lambda x: x not in experiments, current_experiments) - experiments_to_add = filter(lambda x: x not in current_experiments, experiments) + experiments_to_remove = filter( + lambda x: x not in experiments, current_experiments + ) + experiments_to_add = filter( + lambda x: x not in current_experiments, experiments + ) instance.remove_experiments(experiments_to_remove) @@ -234,13 +270,12 @@ class ReportCreationSerializer(serializers.ModelSerializer): return serializer.data -#---------------------------------------------------------- +# ---------------------------------------------------------- class ReportUpdateSerializer(ReportCreationSerializer): - class Meta(ReportCreationSerializer.Meta): - fields = ['short_description', 'description', 'content', 'experiments'] + fields = ["short_description", "description", "content", "experiments"] def to_representation(self, obj): serializer = UpdatedReportSerializer(obj) diff --git a/beat/web/search/api.py b/beat/web/search/api.py index 3f1c6f5039ea654863e82d43840d80fb85201437..5953f2b2caf2d1c45029a67922934b95ac6428bf 100644 --- a/beat/web/search/api.py +++ b/beat/web/search/api.py @@ -25,6 +25,11 @@ # # ############################################################################### + +import simplejson as json + +from functools import reduce + from django.conf import settings from django.contrib.auth.models import User from django.db.models import Q @@ -33,37 +38,36 @@ from django.utils import six from rest_framework.response import Response from rest_framework.views import APIView -from rest_framework import permissions +from rest_framework import permissions as drf_permissions from rest_framework import generics from rest_framework import status -from .utils import apply_filter -from .utils import FilterGenerator -from .utils import OR - from ..algorithms.models import Algorithm from ..databases.models import Database from ..dataformats.models import DataFormat from ..experiments.models import Experiment from ..toolchains.models import Toolchain + from ..common.models import Shareable -from ..common.mixins import IsAuthorOrReadOnlyMixin from ..common.api import ShareView from ..common.utils import ensure_html - -from .models import Search - from ..common.responses import BadRequestResponse from ..common.mixins import CommonContextMixin, SerializerFieldsMixin +from ..common.utils import py3_cmp +from ..common import permissions as beat_permissions from ..ui.templatetags.gravatar import gravatar_hash -from .serializers import SearchResultSerializer, SearchSerializer, SearchWriteSerializer +from .utils import apply_filter +from .utils import FilterGenerator +from .utils import OR -import simplejson as json +from .models import Search + +from .serializers import SearchResultSerializer, SearchSerializer, SearchWriteSerializer -#------------------------------------------------ +# ------------------------------------------------ class SearchView(APIView): @@ -81,19 +85,24 @@ class SearchView(APIView): 'order-by' """ - permission_classes = [permissions.AllowAny] + permission_classes = [drf_permissions.AllowAny] - FILTER_IEXACT = 0 - FILTER_ICONTAINS = 1 + FILTER_IEXACT = 0 + FILTER_ICONTAINS = 1 FILTER_ISTARTSWITH = 2 - FILTER_IENDSWITH = 3 - + FILTER_IENDSWITH = 3 @staticmethod def build_name_and_description_query(keywords): - return reduce(lambda a, b: a & b, map(lambda keyword: Q(name__icontains=keyword) | - Q(short_description__icontains=keyword), keywords)) + return reduce( + lambda a, b: a & b, + map( + lambda keyword: Q(name__icontains=keyword) + | Q(short_description__icontains=keyword), + keywords, + ), + ) def post(self, request): data = request.data @@ -102,225 +111,289 @@ class SearchView(APIView): filters = None display_settings = None - if 'query' in data: - if not(isinstance(data['query'], six.string_types)) or \ - (len(data['query']) == 0): - return BadRequestResponse('Invalid query data') + if "query" in data: + if not (isinstance(data["query"], six.string_types)) or ( + len(data["query"]) == 0 + ): + return BadRequestResponse("Invalid query data") - query = data['query'] + query = data["query"] else: - if not(isinstance(data['filters'], list)) or (len(data['filters']) == 0): - return BadRequestResponse('Invalid filter data') - - filters = data['filters'] + if not (isinstance(data["filters"], list)) or (len(data["filters"]) == 0): + return BadRequestResponse("Invalid filter data") - if 'settings' in data: - display_settings = data['settings'] + filters = data["filters"] + if "settings" in data: + display_settings = data["settings"] # Process the query - scope_database = None - scope_type = None + scope_database = None + scope_type = None scope_toolchain = None scope_algorithm = None - scope_analyzer = None - keywords = [] + scope_analyzer = None + keywords = [] if filters is None: - for keyword in map(lambda x: x.strip(), query.split(' ')): - offset = keyword.find(':') + for keyword in map(lambda x: x.strip(), query.split(" ")): + offset = keyword.find(":") if offset != -1: command = keyword[:offset] - keyword = keyword[offset+1:] - - if command in ['db', 'database']: - scope_database = keyword.split(',') - elif command in ['tc', 'toolchain']: - scope_toolchain = keyword.split(',') - elif command in ['algo', 'algorithm']: - scope_algorithm = keyword.split(',') - elif command == 'analyzer': - scope_analyzer = keyword.split(',') - elif command == 'type': - if keyword in ['results', 'toolchains', 'algorithms', 'analyzers', - 'dataformats', 'databases', 'users']: + keyword = keyword[offset + 1 :] + + if command in ["db", "database"]: + scope_database = keyword.split(",") + elif command in ["tc", "toolchain"]: + scope_toolchain = keyword.split(",") + elif command in ["algo", "algorithm"]: + scope_algorithm = keyword.split(",") + elif command == "analyzer": + scope_analyzer = keyword.split(",") + elif command == "type": + if keyword in [ + "results", + "toolchains", + "algorithms", + "analyzers", + "dataformats", + "databases", + "users", + ]: scope_type = keyword else: keywords.append(keyword) - if (scope_type is None) or (scope_type == 'results'): + if (scope_type is None) or (scope_type == "results"): filters = [] if scope_toolchain is not None: if len(scope_toolchain) > 1: - filters.append({ - 'context': 'toolchain', - 'name': None, - 'operator': 'contains-any-of', - 'value': scope_toolchain, - }) + filters.append( + { + "context": "toolchain", + "name": None, + "operator": "contains-any-of", + "value": scope_toolchain, + } + ) elif len(scope_toolchain) == 1: - filters.append({ - 'context': 'toolchain', - 'name': None, - 'operator': 'contains', - 'value': scope_toolchain[0], - }) + filters.append( + { + "context": "toolchain", + "name": None, + "operator": "contains", + "value": scope_toolchain[0], + } + ) if scope_algorithm is not None: if len(scope_algorithm) > 1: - filters.append({ - 'context': 'algorithm', - 'name': None, - 'operator': 'contains-any-of', - 'value': scope_algorithm, - }) + filters.append( + { + "context": "algorithm", + "name": None, + "operator": "contains-any-of", + "value": scope_algorithm, + } + ) elif len(scope_algorithm) == 1: - filters.append({ - 'context': 'algorithm', - 'name': None, - 'operator': 'contains', - 'value': scope_algorithm[0], - }) + filters.append( + { + "context": "algorithm", + "name": None, + "operator": "contains", + "value": scope_algorithm[0], + } + ) if scope_analyzer is not None: if len(scope_analyzer) > 1: - filters.append({ - 'context': 'analyzer', - 'name': None, - 'operator': 'contains-any-of', - 'value': scope_analyzer, - }) + filters.append( + { + "context": "analyzer", + "name": None, + "operator": "contains-any-of", + "value": scope_analyzer, + } + ) elif len(scope_analyzer) == 1: - filters.append({ - 'context': 'analyzer', - 'name': None, - 'operator': 'contains', - 'value': scope_analyzer[0], - }) + filters.append( + { + "context": "analyzer", + "name": None, + "operator": "contains", + "value": scope_analyzer[0], + } + ) if scope_database is not None: if len(scope_database) > 1: - filters.append({ - 'context': 'database-name', - 'name': None, - 'operator': 'contains-any-of', - 'value': scope_database, - }) + filters.append( + { + "context": "database-name", + "name": None, + "operator": "contains-any-of", + "value": scope_database, + } + ) elif len(scope_database) == 1: - filters.append({ - 'context': 'database-name', - 'name': None, - 'operator': 'contains', - 'value': scope_database[0], - }) + filters.append( + { + "context": "database-name", + "name": None, + "operator": "contains", + "value": scope_database[0], + } + ) if len(keywords) > 0: - filters.append({ - 'context': 'any-field', - 'name': None, - 'operator': 'contains-any-of', - 'value': keywords, - }) + filters.append( + { + "context": "any-field", + "name": None, + "operator": "contains-any-of", + "value": keywords, + } + ) else: - scope_type = 'results' - + scope_type = "results" result = { - 'users': [], - 'toolchains': [], - 'algorithms': [], - 'analyzers': [], - 'dataformats': [], - 'databases': [], - 'results': [], - 'filters': filters, - 'settings': display_settings, - 'query': { - 'type': scope_type, - }, + "users": [], + "toolchains": [], + "algorithms": [], + "analyzers": [], + "dataformats": [], + "databases": [], + "results": [], + "filters": filters, + "settings": display_settings, + "query": {"type": scope_type}, } - # Search for users matching the query - if (scope_database is None) and (scope_toolchain is None) and \ - (scope_algorithm is None) and (scope_analyzer is None) and \ - ((scope_type is None) or (scope_type == 'users')): - result['users'] = [] + if ( + (scope_database is None) + and (scope_toolchain is None) + and (scope_algorithm is None) + and (scope_analyzer is None) + and ((scope_type is None) or (scope_type == "users")) + ): + result["users"] = [] if len(keywords) > 0: - q = reduce(lambda a, b: a & b, map(lambda keyword: Q(username__icontains=keyword), keywords)) - users = User.objects.filter(q).exclude(username__in=settings.ACCOUNTS_TO_EXCLUDE_FROM_SEARCH).order_by('username') - - result['users'] = map(lambda u: { 'username': u.username, - 'gravatar_hash': gravatar_hash(u.email), - 'join_date': u.date_joined.strftime('%b %d, %Y') - }, users) + q = reduce( + lambda a, b: a & b, + map(lambda keyword: Q(username__icontains=keyword), keywords), + ) + users = ( + User.objects.filter(q) + .exclude(username__in=settings.ACCOUNTS_TO_EXCLUDE_FROM_SEARCH) + .order_by("username") + ) + + result["users"] = map( + lambda u: { + "username": u.username, + "gravatar_hash": gravatar_hash(u.email), + "join_date": u.date_joined.strftime("%b %d, %Y"), + }, + users, + ) query = None if len(keywords) > 0: query = self.build_name_and_description_query(keywords) # Search for toolchains matching the query - if (scope_database is None) and (scope_algorithm is None) and \ - (scope_analyzer is None) and ((scope_type is None) or (scope_type == 'toolchains')): - result['toolchains'] = self._retrieve_contributions( - Toolchain.objects.for_user(request.user, True), - scope_toolchain, query + if ( + (scope_database is None) + and (scope_algorithm is None) + and (scope_analyzer is None) + and ((scope_type is None) or (scope_type == "toolchains")) + ): + result["toolchains"] = self._retrieve_contributions( + Toolchain.objects.for_user(request.user, True), scope_toolchain, query ) # Search for algorithms matching the query - if (scope_database is None) and (scope_toolchain is None) and \ - (scope_analyzer is None) and ((scope_type is None) or (scope_type == 'algorithms')): - result['algorithms'] = self._retrieve_contributions( - Algorithm.objects.for_user(request.user, True).filter(result_dataformat__isnull=True), - scope_algorithm, query + if ( + (scope_database is None) + and (scope_toolchain is None) + and (scope_analyzer is None) + and ((scope_type is None) or (scope_type == "algorithms")) + ): + result["algorithms"] = self._retrieve_contributions( + Algorithm.objects.for_user(request.user, True).filter( + result_dataformat__isnull=True + ), + scope_algorithm, + query, ) # Search for analyzers matching the query - if (scope_database is None) and (scope_toolchain is None) and \ - (scope_algorithm is None) and ((scope_type is None) or (scope_type == 'analyzers')): - result['analyzers'] = self._retrieve_contributions( - Algorithm.objects.for_user(request.user, True).filter(result_dataformat__isnull=False), - scope_analyzer, query + if ( + (scope_database is None) + and (scope_toolchain is None) + and (scope_algorithm is None) + and ((scope_type is None) or (scope_type == "analyzers")) + ): + result["analyzers"] = self._retrieve_contributions( + Algorithm.objects.for_user(request.user, True).filter( + result_dataformat__isnull=False + ), + scope_analyzer, + query, ) # Search for data formats matching the query - if (scope_database is None) and (scope_toolchain is None) and \ - (scope_algorithm is None) and (scope_analyzer is None) and \ - ((scope_type is None) or (scope_type == 'dataformats')): + if ( + (scope_database is None) + and (scope_toolchain is None) + and (scope_algorithm is None) + and (scope_analyzer is None) + and ((scope_type is None) or (scope_type == "dataformats")) + ): dataformats = DataFormat.objects.for_user(request.user, True) if query: dataformats = dataformats.filter(query) serializer = SearchResultSerializer(dataformats, many=True) - result['dataformats'] = serializer.data + result["dataformats"] = serializer.data # Search for databases matching the query - if (scope_toolchain is None) and (scope_algorithm is None) and \ - (scope_analyzer is None) and ((scope_type is None) or (scope_type == 'databases')): - result['databases'] = self._retrieve_databases(Database.objects.for_user(request.user, True), scope_database, query) + if ( + (scope_toolchain is None) + and (scope_algorithm is None) + and (scope_analyzer is None) + and ((scope_type is None) or (scope_type == "databases")) + ): + result["databases"] = self._retrieve_databases( + Database.objects.for_user(request.user, True), scope_database, query + ) # Search for experiments matching the query - if ((scope_type is None) or (scope_type == 'results')): - result['results'] = self._retrieve_experiment_results(request.user, filters) + if (scope_type is None) or (scope_type == "results"): + result["results"] = self._retrieve_experiment_results(request.user, filters) # Sort the results - result['toolchains'].sort(lambda x, y: cmp(x['name'], y['name'])) - result['algorithms'].sort(lambda x, y: cmp(x['name'], y['name'])) - result['analyzers'].sort(lambda x, y: cmp(x['name'], y['name'])) - result['dataformats'].sort(lambda x, y: cmp(x['name'], y['name'])) - result['databases'].sort(lambda x, y: cmp(x['name'], y['name'])) + result["toolchains"].sort(lambda x, y: py3_cmp(x["name"], y["name"])) + result["algorithms"].sort(lambda x, y: py3_cmp(x["name"], y["name"])) + result["analyzers"].sort(lambda x, y: py3_cmp(x["name"], y["name"])) + result["dataformats"].sort(lambda x, y: py3_cmp(x["name"], y["name"])) + result["databases"].sort(lambda x, y: py3_cmp(x["name"], y["name"])) return Response(result) - def _retrieve_contributions(self, queryset, scope, query): generator = FilterGenerator() scope_filters = [] if scope is not None: for contribution_name in scope: - scope_filters.append(generator.process_contribution_name(contribution_name)) + scope_filters.append( + generator.process_contribution_name(contribution_name) + ) if len(scope_filters): queryset = queryset.filter(OR(scope_filters)) @@ -331,7 +404,6 @@ class SearchView(APIView): serializer = SearchResultSerializer(queryset, many=True) return serializer.data - def _retrieve_databases(self, queryset, scope, query): generator = FilterGenerator() @@ -346,26 +418,26 @@ class SearchView(APIView): if query: queryset = queryset.filter(query) - queryset= queryset.distinct() + queryset = queryset.distinct() - serializer = SearchResultSerializer(queryset, many=True, name_field='name') + serializer = SearchResultSerializer(queryset, many=True, name_field="name") return serializer.data - def _retrieve_experiment_results(self, user, filters): results = { - 'experiments': [], - 'dataformats': {}, - 'common_analyzers': [], - 'common_protocols': [], + "experiments": [], + "dataformats": {}, + "common_analyzers": [], + "common_protocols": [], } if len(filters) == 0: return results - # Use the experiment filters - experiments = Experiment.objects.for_user(user, True).filter(status=Experiment.DONE) + experiments = Experiment.objects.for_user(user, True).filter( + status=Experiment.DONE + ) for filter_entry in filters: experiments = apply_filter(experiments, filter_entry) @@ -375,7 +447,6 @@ class SearchView(APIView): if experiments.count() == 0: return results - # Retrieve informations about each experiment and determine if there is at least # one common analyzer common_protocols = None @@ -383,77 +454,95 @@ class SearchView(APIView): for experiment in experiments.iterator(): experiment_entry = { - 'name': experiment.fullname(), - 'toolchain': experiment.toolchain.fullname(), - 'description': experiment.short_description, - 'public': (experiment.sharing == Shareable.PUBLIC), - 'attestation_number': None, - 'attestation_locked': False, - 'end_date': experiment.end_date, - 'protocols': list(set(map(lambda x: x.protocol.fullname(), experiment.referenced_datasets.iterator()))), - 'analyzers': [], + "name": experiment.fullname(), + "toolchain": experiment.toolchain.fullname(), + "description": experiment.short_description, + "public": (experiment.sharing == Shareable.PUBLIC), + "attestation_number": None, + "attestation_locked": False, + "end_date": experiment.end_date, + "protocols": list( + set( + map( + lambda x: x.protocol.fullname(), + experiment.referenced_datasets.iterator(), + ) + ) + ), + "analyzers": [], } if experiment.has_attestation(): - experiment_entry['attestation_number'] = experiment.attestation.number - experiment_entry['attestation_locked'] = experiment.attestation.locked + experiment_entry["attestation_number"] = experiment.attestation.number + experiment_entry["attestation_locked"] = experiment.attestation.locked experiment_analyzers = [] for analyzer_block in experiment.blocks.filter(analyzer=True).iterator(): analyzer_entry = { - 'name': analyzer_block.algorithm.fullname(), - 'block': analyzer_block.name, - 'results': {}, + "name": analyzer_block.algorithm.fullname(), + "block": analyzer_block.name, + "results": {}, } - experiment_entry['analyzers'].append(analyzer_entry) - experiment_analyzers.append(analyzer_entry['name']) + experiment_entry["analyzers"].append(analyzer_entry) + experiment_analyzers.append(analyzer_entry["name"]) - if analyzer_entry['name'] not in results['dataformats']: - results['dataformats'][analyzer_entry['name']] = json.loads(analyzer_block.algorithm.result_dataformat) + if analyzer_entry["name"] not in results["dataformats"]: + results["dataformats"][analyzer_entry["name"]] = json.loads( + analyzer_block.algorithm.result_dataformat + ) if common_analyzers is None: common_analyzers = experiment_analyzers elif len(common_analyzers) > 0: - common_analyzers = filter(lambda x: x in experiment_analyzers, common_analyzers) + common_analyzers = filter( + lambda x: x in experiment_analyzers, common_analyzers + ) if common_protocols is None: - common_protocols = experiment_entry['protocols'] + common_protocols = experiment_entry["protocols"] elif len(common_protocols) > 0: - common_protocols = filter(lambda x: x in experiment_entry['protocols'], common_protocols) - - results['experiments'].append(experiment_entry) + common_protocols = filter( + lambda x: x in experiment_entry["protocols"], common_protocols + ) - results['common_analyzers'] = common_analyzers - results['common_protocols'] = common_protocols + results["experiments"].append(experiment_entry) + results["common_analyzers"] = common_analyzers + results["common_protocols"] = common_protocols # No common analyzer found, don't retrieve any result if len(common_analyzers) == 0: - results['dataformats'] = {} + results["dataformats"] = {} return results - # Retrieve the results of each experiment for index, experiment in enumerate(experiments.iterator()): for analyzer_block in experiment.blocks.filter(analyzer=True).iterator(): - analyzer_entry = filter(lambda x: x['block'] == analyzer_block.name, - results['experiments'][index]['analyzers'])[0] + analyzer_entry = filter( + lambda x: x["block"] == analyzer_block.name, + results["experiments"][index]["analyzers"], + )[0] for analyzer_result in analyzer_block.results.iterator(): - analyzer_entry['results'][analyzer_result.name] = { - 'type': analyzer_result.type, - 'primary': analyzer_result.primary, - 'value': analyzer_result.value() + analyzer_entry["results"][analyzer_result.name] = { + "type": analyzer_result.type, + "primary": analyzer_result.primary, + "value": analyzer_result.value(), } return results -#------------------------------------------------ +# ------------------------------------------------ -class SearchSaveView(CommonContextMixin, SerializerFieldsMixin, generics.CreateAPIView, generics.UpdateAPIView): +class SearchSaveView( + CommonContextMixin, + SerializerFieldsMixin, + generics.CreateAPIView, + generics.UpdateAPIView, +): """ This endpoint allows to save and update a search query @@ -466,7 +555,7 @@ class SearchSaveView(CommonContextMixin, SerializerFieldsMixin, generics.CreateA """ model = Search - permission_classes = [permissions.IsAuthenticated] + permission_classes = [drf_permissions.IsAuthenticated] serializer_class = SearchWriteSerializer def build_results(self, request, search): @@ -474,12 +563,12 @@ class SearchSaveView(CommonContextMixin, SerializerFieldsMixin, generics.CreateA fields_to_return = self.get_serializer_fields(request) # Retrieve the description in HTML format - if 'html_description' in fields_to_return: + if "html_description" in fields_to_return: description = search.description if len(description) > 0: - result['html_description'] = ensure_html(description) + result["html_description"] = ensure_html(description) else: - result['html_description'] = '' + result["html_description"] = "" return result def post(self, request): @@ -487,73 +576,91 @@ class SearchSaveView(CommonContextMixin, SerializerFieldsMixin, generics.CreateA serializer.is_valid(raise_exception=True) search = serializer.save() result = self.build_results(request, search) - result['fullname'] = search.fullname() - result['url'] = search.get_absolute_url() + result["fullname"] = search.fullname() + result["url"] = search.get_absolute_url() return Response(result, status=status.HTTP_201_CREATED) def put(self, request, author_name, name): search = get_object_or_404(Search, author__username=author_name, name=name) - serializer = self.get_serializer(instance=search, data=request.data, partial=True) + serializer = self.get_serializer( + instance=search, data=request.data, partial=True + ) serializer.is_valid(raise_exception=True) serializer.save() result = self.build_results(request, search) return Response(result) -#------------------------------------------------ +# ------------------------------------------------ class ListSearchView(CommonContextMixin, generics.ListAPIView): """ Lists all available search from a user """ - permission_classes = [permissions.AllowAny] + + permission_classes = [drf_permissions.AllowAny] serializer_class = SearchSerializer def get_queryset(self): - author_name = self.kwargs['author_name'] - return Search.objects.for_user(self.request.user, True).select_related().filter(author__username=author_name) + author_name = self.kwargs["author_name"] + return ( + Search.objects.for_user(self.request.user, True) + .select_related() + .filter(author__username=author_name) + ) -#---------------------------------------------------------- +# ---------------------------------------------------------- -class RetrieveDestroySearchAPIView(CommonContextMixin, SerializerFieldsMixin, IsAuthorOrReadOnlyMixin, generics.RetrieveDestroyAPIView): +class RetrieveDestroySearchAPIView( + CommonContextMixin, SerializerFieldsMixin, generics.RetrieveDestroyAPIView +): """ Delete the given search """ + model = Search serializer_class = SearchSerializer - + permission_classes = [beat_permissions.IsAuthorOrReadOnly] def get_object(self): - author_name = self.kwargs.get('author_name') - name = self.kwargs.get('object_name') + author_name = self.kwargs.get("author_name") + name = self.kwargs.get("object_name") user = self.request.user - return get_object_or_404(self.model.objects.for_user(user, True), - author__username=author_name, - name=name) + return get_object_or_404( + self.model.objects.for_user(user, True), + author__username=author_name, + name=name, + ) def get(self, request, *args, **kwargs): search = self.get_object() + self.check_object_permissions(request, search) + # Process the query string allow_sharing = request.user == search.author - fields_to_return = self.get_serializer_fields(request, allow_sharing=allow_sharing) + fields_to_return = self.get_serializer_fields( + request, allow_sharing=allow_sharing + ) serializer = self.get_serializer(search, fields=fields_to_return) return Response(serializer.data) -#------------------------------------------------ + +# ------------------------------------------------ class ShareSearchView(ShareView): """ Share the given search with other users/teams """ + model = Search - permission_classes = [permissions.AllowAny] + permission_classes = [drf_permissions.AllowAny] def get_queryset(self): - self.kwargs['version'] = 1 + self.kwargs["version"] = 1 return super(ShareSearchView, self).get_queryset() diff --git a/beat/web/search/api_urls.py b/beat/web/search/api_urls.py index c12777868b9dac1cdcda287bd8be01caac0c9b3e..0d604080aac1c1c1487ebc1d5e6ee44d8fa827ee 100644 --- a/beat/web/search/api_urls.py +++ b/beat/web/search/api_urls.py @@ -26,43 +26,31 @@ ############################################################################### from django.conf.urls import url + from . import api urlpatterns = [ + url(r"^$", api.SearchView.as_view(), name="all"), url( - r'^$', - api.SearchView.as_view(), - name='all' - ), - - url( - r'^save/(?P<author_name>\w+)/(?P<name>[\w\-]+)/$', - api.SearchSaveView.as_view(), - name='save' - ), - - url( - r'^save/$', + r"^save/(?P<author_name>\w+)/(?P<name>[\w\-]+)/$", api.SearchSaveView.as_view(), - name='save' + name="save", ), - + url(r"^save/$", api.SearchSaveView.as_view(), name="save"), url( - r'^list/(?P<author_name>\w+)/$', + r"^list/(?P<author_name>\w+)/$", api.ListSearchView.as_view(), - name='list_for_author' + name="list_for_author", ), - url( - r'^(?P<author_name>\w+)/(?P<object_name>[\w\-]+)/$', + r"^(?P<author_name>\w+)/(?P<object_name>[\w\-]+)/$", api.RetrieveDestroySearchAPIView.as_view(), - name='object' + name="object", ), - url( - r'^share/(?P<author_name>\w+)/(?P<object_name>[\w\-]+)/$', + r"^share/(?P<author_name>\w+)/(?P<object_name>[\w\-]+)/$", api.ShareSearchView.as_view(), - name='share' - ) + name="share", + ), ] diff --git a/beat/web/team/api.py b/beat/web/team/api.py index a4f4a02d89fe9748b88acf35411c8f75f96b5d58..6570ec26f45c6265b8953be8ca278deb86292a61 100644 --- a/beat/web/team/api.py +++ b/beat/web/team/api.py @@ -32,8 +32,8 @@ from django.db.models import Q from rest_framework import generics from rest_framework import permissions from rest_framework.response import Response -from rest_framework import status from rest_framework.reverse import reverse +from rest_framework import exceptions as drf_exceptions from .serializers import FullTeamSerializer from .serializers import SimpleTeamSerializer @@ -42,26 +42,26 @@ from .serializers import TeamUpdateSerializer from .models import Team from .permissions import IsOwner, HasPrivacyLevel -from ..common.responses import BadRequestResponse, ForbiddenResponse from ..common.mixins import CommonContextMixin -#---------------------------------------------------------- +# ---------------------------------------------------------- class UserTeamListView(CommonContextMixin, generics.ListCreateAPIView): """ Lists the team from a user and create new teams """ + model = Team serializer_class = SimpleTeamSerializer writing_serializer_class = TeamCreationSerializer permission_classes = [permissions.IsAuthenticated] def get_serializer(self, *args, **kwargs): - if self.request.method == 'POST': + if self.request.method == "POST": self.serializer_class = self.writing_serializer_class - return super(UserTeamListView, self).get_serializer(*args, **kwargs) + return super().get_serializer(*args, **kwargs) def list(self, request, owner_name): owner = get_object_or_404(User, username=owner_name) @@ -69,13 +69,19 @@ class UserTeamListView(CommonContextMixin, generics.ListCreateAPIView): if request.user == owner: queryset = Team.objects.filter(owner=owner) else: - queryset = Team.objects.filter(Q(owner=owner), Q(privacy_level=Team.PUBLIC)|Q(privacy_level=Team.MEMBERS, members=request.user)).distinct() - - serializer = self.get_serializer(queryset, many=True, context={'user': request.user}) + queryset = Team.objects.filter( + Q(owner=owner), + Q(privacy_level=Team.PUBLIC) + | Q(privacy_level=Team.MEMBERS, members=request.user), + ).distinct() + + serializer = self.get_serializer( + queryset, many=True, context={"user": request.user} + ) return Response(serializer.data) -#---------------------------------------------------------- +# ---------------------------------------------------------- class TeamDetailView(CommonContextMixin, generics.RetrieveUpdateDestroyAPIView): @@ -89,15 +95,15 @@ class TeamDetailView(CommonContextMixin, generics.RetrieveUpdateDestroyAPIView): def get_permissions(self): self.permission_classes = [permissions.IsAuthenticatedOrReadOnly] - if self.request.method == 'GET': + if self.request.method == "GET": self.permission_classes.append(HasPrivacyLevel) else: self.permission_classes.append(IsOwner) - return super(TeamDetailView, self).get_permissions() + return super().get_permissions() - def get_queryset(self): - owner_name = self.kwargs.get('owner_name') - team_name = self.kwargs.get('team_name') + def get_object(self): + owner_name = self.kwargs.get("owner_name") + team_name = self.kwargs.get("team_name") team = get_object_or_404(Team, owner__username=owner_name, name=team_name) @@ -106,52 +112,54 @@ class TeamDetailView(CommonContextMixin, generics.RetrieveUpdateDestroyAPIView): return team def get_serializer(self, *args, **kwargs): - if self.request.method in ['PUT', 'PATCH']: + if self.request.method in ["PUT", "PATCH"]: self.serializer_class = self.writing_serializer_class - return super(TeamDetailView, self).get_serializer(*args, **kwargs) + return super().get_serializer(*args, **kwargs) - def get(self, request, owner_name, team_name): - team = self.get_queryset() - serializer = self.serializer_class(team, context={'user': request.user}) - return Response(serializer.data) + def get_serializer_context(self): + context = super().get_serializer_context() + context["user"] = self.request.user + return context - def delete(self, request, owner_name, team_name): - team = self.get_queryset() - - # Check that the team can still be deleted - if not(team.deletable()): - return ForbiddenResponse("The team isn't deletable (it has been used to share %d objects with its members)" % team.total_shares()) - - team.delete() - return Response(status=status.HTTP_204_NO_CONTENT) + def perform_destroy(self, instance): + if not instance.deletable(): + raise drf_exceptions.PermissionDenied( + "The team isn't deletable (it has been used to share %d objects with its members)" + % instance.total_shares() + ) + return super().perform_destroy(instance) def update(self, request, owner_name, team_name): - team = self.get_queryset() + team = self.get_object() self.check_object_permissions(request, team) - serializer = self.writing_serializer_class(team, data=request.data, partial=True) - if not(serializer.is_valid()): - return BadRequestResponse(serializer.errors) + serializer = self.writing_serializer_class( + team, data=request.data, partial=True + ) + serializer.is_valid(raise_exception=True) db_object = serializer.save() result = { - 'name': db_object.name, - 'url': reverse('api_teams:team_info', args=[db_object.owner.username, db_object.name]) + "name": db_object.name, + "url": reverse( + "api_teams:team_info", args=[db_object.owner.username, db_object.name] + ), } response = Response(result, status=200) - response['Location'] = result['url'] + response["Location"] = result["url"] return response -#---------------------------------------------------------- +# ---------------------------------------------------------- class TeamListView(CommonContextMixin, generics.ListAPIView): """ List all teams of the caller """ + model = Team serializer_class = SimpleTeamSerializer permission_classes = [permissions.IsAuthenticated] diff --git a/beat/web/team/api_urls.py b/beat/web/team/api_urls.py index 047e5b356efafddc552569504c33699acb140e11..7bd5ef8f0d301972f826cedfe49f905b017f7beb 100644 --- a/beat/web/team/api_urls.py +++ b/beat/web/team/api_urls.py @@ -26,25 +26,18 @@ ############################################################################### from django.conf.urls import url + from . import api urlpatterns = [ + url(r"^$", api.TeamListView.as_view(), name="teamlist"), url( - r'^$', - api.TeamListView.as_view(), - name='teamlist' + r"^(?P<owner_name>\w+)/$", api.UserTeamListView.as_view(), name="user_teamlist" ), - - url( - r'^(?P<owner_name>\w+)/$', - api.UserTeamListView.as_view(), - name='user_teamlist' - ), - url( - r'^(?P<owner_name>\w+)/(?P<team_name>[\w\W]+)/$', + r"^(?P<owner_name>\w+)/(?P<team_name>[\w\W]+)/$", api.TeamDetailView.as_view(), - name='team_info' + name="team_info", ), ] diff --git a/beat/web/toolchains/api.py b/beat/web/toolchains/api.py index 693af428ea884c94216bfadff6cc296d1e9b7e93..1d861e1b985cb1f2649a73c4cf8c2ac1925dc2fe 100644 --- a/beat/web/toolchains/api.py +++ b/beat/web/toolchains/api.py @@ -25,34 +25,24 @@ # # ############################################################################### -import simplejson as json -from django.conf import settings -from django.shortcuts import get_object_or_404 -from django.utils import six -from django.core.exceptions import ValidationError - -from rest_framework.response import Response -from rest_framework import serializers -from rest_framework.exceptions import PermissionDenied, ParseError - -from ..common.responses import BadRequestResponse -from ..common.api import (CheckContributionNameView, ShareView, - ListCreateContributionView, RetrieveUpdateDestroyContributionView) +from ..common.api import ( + CheckContributionNameView, + ShareView, + ListCreateContributionView, + RetrieveUpdateDestroyContributionView, +) from .models import Toolchain from .serializers import ToolchainSerializer from .serializers import FullToolchainSerializer from .serializers import ToolchainCreationSerializer +from .serializers import ToolchainModSerializer from ..common.api import ListContributionView -from ..common.utils import validate_restructuredtext, ensure_html -from ..experiments.models import Experiment - -import beat.core.toolchain -#---------------------------------------------------------- +# ---------------------------------------------------------- class CheckToolchainNameView(CheckContributionNameView): @@ -60,10 +50,11 @@ class CheckToolchainNameView(CheckContributionNameView): This view sanitizes a toolchain name and checks whether it is already used. """ + model = Toolchain -#---------------------------------------------------------- +# ---------------------------------------------------------- class ShareToolchainView(ShareView): @@ -71,21 +62,23 @@ class ShareToolchainView(ShareView): This view allows to share a toolchain with other users and/or teams """ + model = Toolchain -#---------------------------------------------------------- +# ---------------------------------------------------------- class ListToolchainView(ListContributionView): """ List all available toolchains """ + model = Toolchain serializer_class = ToolchainSerializer -#---------------------------------------------------------- +# ---------------------------------------------------------- class ListCreateToolchainsView(ListCreateContributionView): @@ -93,166 +86,21 @@ class ListCreateToolchainsView(ListCreateContributionView): Read/Write end point that list the toolchains available from a given author and allows the creation of new toolchains """ + model = Toolchain serializer_class = ToolchainSerializer writing_serializer_class = ToolchainCreationSerializer - namespace = 'api_toolchains' + namespace = "api_toolchains" -#---------------------------------------------------------- +# ---------------------------------------------------------- class RetrieveUpdateDestroyToolchainsView(RetrieveUpdateDestroyContributionView): """ Read/Write/Delete endpoint for a given toolchain """ + model = Toolchain serializer_class = FullToolchainSerializer - - - def put(self, request, author_name, object_name, version=None): - if version is None: - return BadRequestResponse('A version number must be provided') - - try: - data = request.data - except ParseError as e: - raise serializers.ValidationError({'data': str(e)}) - else: - if not data: - raise serializers.ValidationError({'data': 'Empty'}) - - - if 'short_description' in data: - if not(isinstance(data['short_description'], six.string_types)): - raise serializers.ValidationError({'short_description', 'Invalid short_description data'}) - short_description = data['short_description'] - else: - short_description = None - - if 'description' in data: - if not(isinstance(data['description'], six.string_types)): - raise serializers.ValidationError({'description': 'Invalid description data'}) - description = data['description'] - try: - validate_restructuredtext(description) - except ValidationError as errors: - raise serializers.ValidationError({'description': [error for error in errors]}) - else: - description = None - - if 'strict' in data: - strict = data['strict'] - else: - strict = True - - if 'declaration' in data: - if isinstance(data['declaration'], dict): - json_declaration = data['declaration'] - declaration = json.dumps(json_declaration, indent=4) - elif isinstance(data['declaration'], six.string_types): - declaration = data['declaration'] - try: - json_declaration = json.loads(declaration) - except: - raise serializers.ValidationError({'declaration': 'Invalid declaration data'}) - else: - raise serializers.ValidationError({'declaration': 'Invalid declaration data'}) - - if 'description' in json_declaration: - if short_description is not None: - raise serializers.ValidationError({'short_description': 'A short description is already provided in the toolchain declaration'}) - - short_description = json_declaration['description'] - elif short_description is not None: - json_declaration['description'] = short_description - declaration = json.dumps(json_declaration, indent=4) - - toolchain_declaration = beat.core.toolchain.Toolchain( - settings.PREFIX, json_declaration) - - if not toolchain_declaration.valid: - if strict: - raise serializers.ValidationError({'declaration': toolchain_declaration.errors}) - - else: - declaration = None - - if (short_description is not None) and (len(short_description) > self.model._meta.get_field('short_description').max_length): - raise serializers.ValidationError({'short_description': 'Short description too long'}) - - - # Process the query string - if 'fields' in request.GET: - fields_to_return = request.GET['fields'].split(',') - else: - # Available fields (not returned by default): - # - html_description - fields_to_return = ['errors'] - - - # Retrieve the toolchain - dbtoolchain = get_object_or_404(Toolchain, - author__username__iexact=author_name, - name__iexact=object_name, - version=version) - - # Check that the object can still be modified (if applicable, the - # documentation can always be modified) - if declaration is not None and not dbtoolchain.modifiable(): - raise PermissionDenied("The {} isn't modifiable anymore (either shared with someone else, or needed by an attestation)".format(dbtoolchain.model_name())) - - errors = None - - - # Modification of the short_description - if (short_description is not None) and (declaration is None): - tmp_declaration = dbtoolchain.declaration - tmp_declaration['description'] = short_description - dbtoolchain.declaration = tmp_declaration - - # Modification of the description - if description is not None: - dbtoolchain.description = description - - # Modification of the declaration - if declaration is not None: - errors = '' - if not toolchain_declaration.valid: - errors = ' * %s' % '\n * '.join(toolchain_declaration.errors) - - dbtoolchain.errors = errors - dbtoolchain.declaration = declaration - - experiments = Experiment.objects.filter(toolchain=dbtoolchain) - experiments.delete() - - # Save the toolchain model - try: - dbtoolchain.save() - except Exception as e: - return BadRequestResponse(str(e)) - - # Nothing to return? - if len(fields_to_return) == 0: - return Response(status=204) - - result = {} - - # Retrieve the errors (if necessary) - if 'errors' in fields_to_return: - if errors: - result['errors'] = errors - else: - result['errors'] = '' - - - # Retrieve the description in HTML format (if necessary) - if 'html_description' in fields_to_return: - description = dbtoolchain.description - if len(description) > 0: - result['html_description'] = ensure_html(description) - else: - result['html_description'] = '' - - return Response(result) + writing_serializer_class = ToolchainModSerializer diff --git a/beat/web/toolchains/api_urls.py b/beat/web/toolchains/api_urls.py index 20a5e11955408e3bdd4f16ef2d176d797655dc8d..1182d977c1a630e4dd51164283810d5a1afbbb1f 100644 --- a/beat/web/toolchains/api_urls.py +++ b/beat/web/toolchains/api_urls.py @@ -26,43 +26,26 @@ ############################################################################### from django.conf.urls import url + from . import api urlpatterns = [ + url(r"^$", api.ListToolchainView.as_view(), name="all"), + url(r"^check_name/$", api.CheckToolchainNameView.as_view(), name="check_name"), url( - r'^$', - api.ListToolchainView.as_view(), - name='all' - ), - - url( - r'^check_name/$', - api.CheckToolchainNameView.as_view(), - name='check_name' - ), - - url( - r'^(?P<author_name>\w+)/(?P<object_name>[-\w]+)/(?P<version>\d+)/share/$', + r"^(?P<author_name>\w+)/(?P<object_name>[-\w]+)/(?P<version>\d+)/share/$", api.ShareToolchainView.as_view(), - name='share' + name="share", ), - url( - r'^(?P<author_name>\w+)/$', + r"^(?P<author_name>\w+)/$", api.ListCreateToolchainsView.as_view(), - name='list_create' - ), - - url( - r'^(?P<author_name>\w+)/(?P<object_name>[-\w]+)/(?P<version>\d+)/$', - api.RetrieveUpdateDestroyToolchainsView.as_view(), - name='object' + name="list_create", ), - url( - r'^(?P<author_name>\w+)/(?P<object_name>[-\w]+)/$', + r"^(?P<author_name>\w+)/(?P<object_name>[-\w]+)/(?P<version>\d+)/$", api.RetrieveUpdateDestroyToolchainsView.as_view(), - name='object' + name="object", ), ] diff --git a/beat/web/toolchains/serializers.py b/beat/web/toolchains/serializers.py index 1feaadca3df483fcc64cf55af60520b6f720452f..247cf1399c24672be7e1e9508ece03a6c1d76531 100644 --- a/beat/web/toolchains/serializers.py +++ b/beat/web/toolchains/serializers.py @@ -27,7 +27,11 @@ from rest_framework import serializers -from ..common.serializers import ContributionSerializer, ContributionCreationSerializer +from ..common.serializers import ( + ContributionSerializer, + ContributionCreationSerializer, + ContributionModSerializer, +) from ..attestations.serializers import AttestationSerializer from ..experiments.serializers import ExperimentSerializer @@ -36,16 +40,25 @@ from .models import Toolchain import beat.core.toolchain -#---------------------------------------------------------- +# ---------------------------------------------------------- class ToolchainCreationSerializer(ContributionCreationSerializer): class Meta(ContributionCreationSerializer.Meta): model = Toolchain - beat_core_class = beat.core.toolchain + beat_core_class = beat.core.toolchain.Toolchain -#---------------------------------------------------------- +# ---------------------------------------------------------- + + +class ToolchainModSerializer(ContributionModSerializer): + class Meta(ContributionModSerializer.Meta): + model = Toolchain + beat_core_class = beat.core.toolchain.Toolchain + + +# ---------------------------------------------------------- class ToolchainSerializer(ContributionSerializer): @@ -58,25 +71,32 @@ class ToolchainSerializer(ContributionSerializer): model = Toolchain def get_referencing_experiments(self, obj): - user = self.context.get('user') + user = self.context.get("user") - experiments = obj.experiments.for_user(user, True).order_by('-creation_date') + experiments = obj.experiments.for_user(user, True).order_by("-creation_date") serializer = ExperimentSerializer(experiments, many=True) referencing_experiments = serializer.data # Put the pending experiments first - ordered_result = filter(lambda x: x['creation_date'] is None, referencing_experiments) - ordered_result += filter(lambda x: x['creation_date'] is not None, referencing_experiments) + ordered_result = filter( + lambda x: x["creation_date"] is None, referencing_experiments + ) + ordered_result += filter( + lambda x: x["creation_date"] is not None, referencing_experiments + ) return ordered_result def get_new_experiment_url(self, obj): return obj.get_new_experiment_url() -#---------------------------------------------------------- +# ---------------------------------------------------------- -class FullToolchainSerializer(ToolchainSerializer): +class FullToolchainSerializer(ToolchainSerializer): class Meta(ToolchainSerializer.Meta): - default_fields = ToolchainSerializer.Meta.default_fields + ToolchainSerializer.Meta.extra_fields + default_fields = ( + ToolchainSerializer.Meta.default_fields + + ToolchainSerializer.Meta.extra_fields + ) diff --git a/beat/web/toolchains/tests.py b/beat/web/toolchains/tests.py index 3de38a63412b3d69a93de77ea8db61edfdd7f218..904c94fe70f2c753273f87469389716589c2aac3 100644 --- a/beat/web/toolchains/tests.py +++ b/beat/web/toolchains/tests.py @@ -853,9 +853,7 @@ class ToolchainCreation(ToolchainsAPIBase): ) content = self.checkResponse(response, 400, content_type="application/json") - self.assertNotEqual( - content.find("The toolchain declaration is **invalid**"), -1 - ) + self.assertTrue("declaration" in content) def test_no_forking_for_anonymous_user(self): response = self.client.post( @@ -1195,25 +1193,15 @@ class ToolchainUpdate(ToolchainsAPIBase): self.checkResponse(response, 404) - def test_fail_to_update_without_version_number(self): - self.login_jackdoe() - - url = reverse("api_toolchains:object", args=["jackdoe", "personal"]) - response = self.client.put( - url, json.dumps({"description": "blah"}), content_type="application/json" - ) - - self.checkResponse(response, 400) - def test_fail_to_update_without_content_not_json_request(self): self.login_jackdoe() response = self.client.put(self.url) - self.checkResponse(response, 400) + self.checkResponse(response, 200, content_type="application/json") - def test_fail_to_update_without_content(self): + def test_successful_update_without_content(self): self.login_jackdoe() response = self.client.put(self.url, "{}", content_type="application/json") - self.checkResponse(response, 400, content_type="application/json") + self.checkResponse(response, 200, content_type="application/json") def test_successful_update_description_only(self): self.login_jackdoe() @@ -1325,14 +1313,14 @@ class ToolchainRetrieval(ToolchainsAPIBase): def test_fail_to_retrieve_with_invalid_username(self): self.login_jackdoe() - url = reverse("api_toolchains:object", args=["unknown", "personal"]) + url = reverse("api_toolchains:object", args=["unknown", "personal", 1]) response = self.client.get(url) self.checkResponse(response, 404) def test_fail_to_retrieve_with_invalid_toolchain_name(self): self.login_jackdoe() - url = reverse("api_toolchains:object", args=["jackdoe", "unknown"]) + url = reverse("api_toolchains:object", args=["jackdoe", "unknown", 1]) response = self.client.get(url) self.checkResponse(response, 404) @@ -1431,13 +1419,6 @@ class ToolchainDeletion(ToolchainsAPIBase): response = self.client.delete(url) self.checkResponse(response, 404) - def test_fail_to_delete_without_version_number(self): - self.login_jackdoe() - - url = reverse("api_toolchains:object", args=["jackdoe", "personal"]) - response = self.client.delete(url) - self.checkResponse(response, 400) - def test_no_deletion_of_not_owned_toolchain(self): self.login_johndoe() response = self.client.delete(self.url) diff --git a/beat/web/utils/drf.py b/beat/web/utils/drf.py new file mode 100644 index 0000000000000000000000000000000000000000..4a6f63d2dd275d91772c0b3801efeec3c271a109 --- /dev/null +++ b/beat/web/utils/drf.py @@ -0,0 +1,59 @@ +# vim: set fileencoding=utf-8 : + +############################################################################### +# # +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ # +# Contact: beat.support@idiap.ch # +# # +# This file is part of the beat.web module of the BEAT platform. # +# # +# Commercial License Usage # +# Licensees holding valid commercial BEAT licenses may use this file in # +# accordance with the terms contained in a written agreement between you # +# and Idiap. For further information contact tto@idiap.ch # +# # +# Alternatively, this file may be used under the terms of the GNU Affero # +# Public License version 3 as published by the Free Software and appearing # +# in the file LICENSE.AGPL included in the packaging of this file. # +# The BEAT platform is distributed in the hope that it will be useful, but # +# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY # +# or FITNESS FOR A PARTICULAR PURPOSE. # +# # +# You should have received a copy of the GNU Affero Public License along # +# with the BEAT platform. If not, see http://www.gnu.org/licenses/. # +# # +############################################################################### + +""" +Django REST framework helpers +""" + +import logging + +from rest_framework.views import exception_handler +from rest_framework.exceptions import APIException +from rest_framework.status import is_client_error + +logger = logging.getLogger("beat.drf_exceptions") + + +def custom_exception_handler(exc, context): + # Call REST framework's default exception handler first, + # to get the standard error response. + response = exception_handler(exc, context) + + # check that a ValidationError exception is raised + if isinstance(exc, APIException): + # Log all client errors + if is_client_error(exc.status_code): + view = context["view"] + request = context["request"] + detail = { + "user": request.user.username, + "view": view.__class__.__name__, + "kwargs": view.kwargs, + "error": exc.detail, + } + logger.warning("Error occured: {}".format(detail)) + + return response