diff --git a/.flake8 b/.flake8 index 5fabfeed91611982e96a84c038dbfe24b0708055..994815d8870e9822617c4578efdce0e121988c60 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,4 @@ [flake8] -max-line-length = 80 +max-line-length = 88 select = B,C,E,F,W,T4,B9,B950 -ignore = E501, W503 +ignore = E501, W503, E203 diff --git a/beat/web/databases/migrations/0007_add_accessibility_date.py b/beat/web/databases/migrations/0007_add_accessibility_date.py new file mode 100644 index 0000000000000000000000000000000000000000..802aa1bcd0b7e903e6d51b88638784a6f78aff4d --- /dev/null +++ b/beat/web/databases/migrations/0007_add_accessibility_date.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.25 on 2020-02-12 11:22 +from __future__ import unicode_literals + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [("databases", "0006_databaseset_hash_unique")] + + operations = [ + migrations.AddField( + model_name="database", + name="accessibility_start_date", + field=models.DateTimeField(blank=True, null=True), + ), + migrations.AddField( + model_name="database", + name="accessibility_end_date", + field=models.DateTimeField(blank=True, null=True), + ), + ] diff --git a/beat/web/databases/models.py b/beat/web/databases/models.py index 6179e22843e317d83b4d15934f878ea10bde0017..3be769434a72aa3de052c7bb39101e3d13f93755 100755 --- a/beat/web/databases/models.py +++ b/beat/web/databases/models.py @@ -29,16 +29,18 @@ import os import simplejson +from datetime import datetime + from django.db import models from django.conf import settings from django.core.urlresolvers import reverse +from django.db.models import Q import beat.core.database from beat.backend.python.hash import hashDataset from ..dataformats.models import DataFormat -from ..common.models import Shareable from ..common.models import Versionable from ..common.models import VersionableManager from ..common.storage import OverwriteStorage @@ -57,7 +59,7 @@ from ..code.models import get_source_code from ..code.models import set_source_code -#---------------------------------------------------------- +# ---------------------------------------------------------- def validate_database(declaration): @@ -66,8 +68,10 @@ def validate_database(declaration): database = beat.core.database.Database(settings.PREFIX, declaration) if not database.valid: - errors = 'The database declaration is **invalid**. Errors:\n * ' + \ - '\n * '.join(database.errors) + errors = ( + "The database declaration is **invalid**. Errors:\n * " + + "\n * ".join(database.errors) + ) raise SyntaxError(errors) # TODO: for each protocol, if a template already exists, that the inserted @@ -77,37 +81,57 @@ def validate_database(declaration): return database -#---------------------------------------------------------- +# ---------------------------------------------------------- class DatabaseStorage(OverwriteStorage): - def __init__(self, *args, **kwargs): - super(DatabaseStorage, self).__init__(*args, location=settings.DATABASES_ROOT, **kwargs) - + super(DatabaseStorage, self).__init__( + *args, location=settings.DATABASES_ROOT, **kwargs + ) -#---------------------------------------------------------- +# ---------------------------------------------------------- +def _filter_accessibility_date(queryset): + now = datetime.now() -class DatabaseManager(VersionableManager): + return queryset.filter( + (Q(accessibility_start_date=None) | Q(accessibility_start_date__lte=now)) + & (Q(accessibility_end_date=None) | Q(accessibility_end_date__gte=now)) + ) +class DatabaseManager(VersionableManager): def get_by_natural_key(self, name, version): return self.get(name=name, version=version) - - def create_database(self, name, short_description='', description='', - declaration=None, code=None, version=1, - previous_version=None): + def for_user(self, user, add_public=False): + original_query = super(DatabaseManager, self).for_user(user, add_public) + return _filter_accessibility_date(original_query) + + def public(self): + original_query = super(DatabaseManager, self).public() + return _filter_accessibility_date(original_query) + + def create_database( + self, + name, + short_description="", + description="", + declaration=None, + code=None, + version=1, + previous_version=None, + ): """Convenience function to create a new database from its parts""" # Create the database representation database = self.model( - name = name, - version = version, - sharing = self.model.PRIVATE, - previous_version = previous_version, + name=name, + version=version, + sharing=self.model.PRIVATE, + previous_version=previous_version, ) # Makes sure we get a declaration in string format @@ -117,11 +141,11 @@ class DatabaseManager(VersionableManager): else: default = beat.core.database.Database(settings.PREFIX, data=None) declaration = default.data - elif not(isinstance(declaration, dict)): + elif not (isinstance(declaration, dict)): declaration = simplejson.loads(declaration) if len(short_description) > 0: - declaration['description'] = short_description + declaration["description"] = short_description database.declaration = declaration @@ -130,7 +154,7 @@ class DatabaseManager(VersionableManager): if previous_version is not None: code = previous_version.source_code else: - code = '' + code = "" database.source_code = code @@ -139,7 +163,7 @@ class DatabaseManager(VersionableManager): if previous_version is not None: description = previous_version.description else: - description = '' + description = "" database.description = description @@ -148,68 +172,111 @@ class DatabaseManager(VersionableManager): database.save() except Exception: import traceback + return (None, traceback.format_exc()) return (database, None) -#---------------------------------------------------------- +# ---------------------------------------------------------- class Database(Versionable): - #_____ Fields __________ + # _____ Fields __________ + + accessibility_start_date = models.DateTimeField(blank=True, null=True) + accessibility_end_date = models.DateTimeField(blank=True, null=True) declaration_file = models.FileField( storage=DatabaseStorage(), upload_to=get_contribution_declaration_filename, - blank=True, null=True, + blank=True, + null=True, max_length=200, - db_column='declaration' + db_column="declaration", ) description_file = models.FileField( storage=DatabaseStorage(), upload_to=get_contribution_description_filename, - blank=True, null=True, + blank=True, + null=True, max_length=200, - db_column='description' + db_column="description", ) source_code_file = models.FileField( storage=DatabaseStorage(), upload_to=get_contribution_source_code_filename, - blank=True, null=True, + blank=True, + null=True, max_length=200, - db_column='source_code' + db_column="source_code", ) - objects = DatabaseManager() - - #_____ Meta parameters __________ + # _____ Meta parameters __________ class Meta(Versionable.Meta): - unique_together = ('name', 'version') - + unique_together = ("name", "version") - #_____ Utilities __________ + # _____ Utilities __________ def get_absolute_url(self): - return reverse( - 'databases:view', - args=(self.name, self.version,), - ) - + return reverse("databases:view", args=(self.name, self.version)) def natural_key(self): return (self.name, self.version) - - #_____ Overrides __________ + # _____ Overrides __________ + + def accessibility_for(self, user_or_team, without_usable=False): + now = datetime.now() + accessible = True + + if self.accessibility_start_date and self.accessibility_end_date: + if now < self.accessibility_start_date or now > self.accessibility_end_date: + accessible = False + elif self.accessibility_start_date: + if now < self.accessibility_start_date: + accessible = False + elif self.accessibility_end_date: + if now > self.accessibility_end_date: + accessible = False + + if not accessible: + return (False, False, None) + + return super(Database, self).accessibility_for(user_or_team, without_usable) + + def is_accessible(self, users=None, teams=None): + errors = [] + now = datetime.now() + + if self.accessibility_start_date and self.accessibility_end_date: + if now < self.accessibility_start_date or now > self.accessibility_end_date: + errors.append( + "The database {} is currently not accessible".format(self.fullname) + ) + elif self.accessibility_start_date: + if now < self.accessibility_start_date: + errors.append( + "The database {} is not yet accessible".format(self.fullname) + ) + elif self.accessibility_end_date: + if now > self.accessibility_end_date: + errors.append( + "The database {} is not accessible anymore".format(self.fullname) + ) + + if errors: + return errors + + return super(Database, self).is_accessible(users, teams) def save(self, *args, **kwargs): @@ -220,7 +287,11 @@ class Database(Versionable): wrapper = validate_database(declaration) # reset the description - self.short_description = wrapper.description if (wrapper is not None) and (wrapper.description is not None) else '' + self.short_description = ( + wrapper.description + if (wrapper is not None) and (wrapper.description is not None) + else "" + ) # Save the changed files (if necessary) storage.save_files(self) @@ -229,27 +300,26 @@ class Database(Versionable): # if the filename has changed, move the declaration if self.declaration_filename() != self.declaration_file.name: - storage.rename_file(self, 'declaration_file', self.declaration_filename()) - storage.rename_file(self, 'description_file', self.description_filename()) - storage.rename_file(self, 'source_code_file', self.source_code_filename()) - + storage.rename_file(self, "declaration_file", self.declaration_filename()) + storage.rename_file(self, "description_file", self.description_filename()) + storage.rename_file(self, "source_code_file", self.source_code_filename()) - #_____ Methods __________ + # _____ Methods __________ def fullname(self): - return '%s/%d' % (self.name, self.version) + return "%s/%d" % (self.name, self.version) def fullpath(self, extension): return os.path.join(self.name, str(self.version) + extension) def declaration_filename(self): - return self.fullpath('.json') + return self.fullpath(".json") def description_filename(self): - return self.fullpath('.rst') + return self.fullpath(".rst") def source_code_filename(self): - return self.fullpath('.py') + return self.fullpath(".py") def all_referenced_dataformats(self): result = [] @@ -263,7 +333,7 @@ class Database(Versionable): result.extend(database_protocol.all_needed_dataformats()) return list(set(result)) - #_____ Properties __________ + # _____ Properties __________ description = property(get_description, set_description) declaration = property(get_declaration, set_declaration) @@ -271,16 +341,13 @@ class Database(Versionable): source_code = property(get_source_code, set_source_code) -#---------------------------------------------------------- +# ---------------------------------------------------------- class DatabaseProtocolManager(models.Manager): - def get_by_natural_key(self, database_name, database_version, name): return self.get( - database__name=database_name, - database__version=database_version, - name=name, + database__name=database_name, database__version=database_version, name=name ) @@ -288,24 +355,26 @@ class DatabaseProtocol(models.Model): objects = DatabaseProtocolManager() - database = models.ForeignKey(Database, related_name='protocols', - on_delete=models.CASCADE) + database = models.ForeignKey( + Database, related_name="protocols", on_delete=models.CASCADE + ) name = models.CharField(max_length=200, blank=True) class Meta: - unique_together = ('database', 'name') - ordering = ['name'] + unique_together = ("database", "name") + ordering = ["name"] def __str__(self): return self.fullname() def natural_key(self): return self.database.natural_key() + (self.name,) - natural_key.dependencies = ['databases.database'] + + natural_key.dependencies = ["databases.database"] def fullname(self): - if self.name != '': - return self.database.fullname() + '@' + self.name + if self.name != "": + return self.database.fullname() + "@" + self.name else: return self.database.fullname() @@ -322,16 +391,16 @@ class DatabaseProtocol(models.Model): return list(set(result)) def set_template_basename(self): - if not self.sets.count(): return "unknown" + if not self.sets.count(): + return "unknown" dbset = self.sets.all()[0] - return dbset.template.name.rsplit('__')[0] + return dbset.template.name.rsplit("__")[0] -#---------------------------------------------------------- +# ---------------------------------------------------------- class DatabaseSetTemplateManager(models.Manager): - def get_by_natural_key(self, name): return self.get(name=name) @@ -349,24 +418,24 @@ class DatabaseSetTemplate(models.Model): return (self.name,) -#---------------------------------------------------------- +# ---------------------------------------------------------- class DatabaseSetManager(models.Manager): - def create(self, protocol, template, name): dataset = DatabaseSet( - name = name, - template = template, - protocol = protocol, - hash = hashDataset(protocol.database.fullname(), protocol.name, name) + name=name, + template=template, + protocol=protocol, + hash=hashDataset(protocol.database.fullname(), protocol.name, name), ) dataset.save() return dataset - - def get_by_natural_key(self, database_name, database_version, protocol_name, name, template_name): + def get_by_natural_key( + self, database_name, database_version, protocol_name, name, template_name + ): return self.get( protocol__database__name=database_name, protocol__database__version=database_version, @@ -378,29 +447,34 @@ class DatabaseSetManager(models.Manager): class DatabaseSet(models.Model): - objects = DatabaseSetManager() + objects = DatabaseSetManager() - protocol = models.ForeignKey(DatabaseProtocol, related_name='sets', - on_delete=models.CASCADE) - name = models.CharField(max_length=200, blank=True) - template = models.ForeignKey(DatabaseSetTemplate, related_name='sets', - on_delete=models.CASCADE) + protocol = models.ForeignKey( + DatabaseProtocol, related_name="sets", on_delete=models.CASCADE + ) + name = models.CharField(max_length=200, blank=True) + template = models.ForeignKey( + DatabaseSetTemplate, related_name="sets", on_delete=models.CASCADE + ) hash = models.CharField(max_length=64, unique=True) class Meta: - unique_together = ('protocol', 'name', 'template') + unique_together = ("protocol", "name", "template") def __str__(self): return self.fullname() def natural_key(self): return self.protocol.natural_key() + (self.name,) + self.template.natural_key() - natural_key.dependencies = ['databases.databaseprotocol', - 'databases.databasesettemplate'] + + natural_key.dependencies = [ + "databases.databaseprotocol", + "databases.databasesettemplate", + ] def fullname(self): - if self.name != '': - return self.protocol.fullname() + '.' + self.name + if self.name != "": + return self.protocol.fullname() + "." + self.name else: return self.protocol.fullname() @@ -416,43 +490,44 @@ class DatabaseSet(models.Model): return list(set(result)) -#---------------------------------------------------------- +# ---------------------------------------------------------- class DatabaseSetTemplateOutput(models.Model): - template = models.ForeignKey(DatabaseSetTemplate, - related_name='outputs', on_delete=models.CASCADE) - name = models.CharField(max_length=200) - dataformat = models.ForeignKey(DataFormat, - related_name='database_outputs', on_delete=models.CASCADE) + template = models.ForeignKey( + DatabaseSetTemplate, related_name="outputs", on_delete=models.CASCADE + ) + name = models.CharField(max_length=200) + dataformat = models.ForeignKey( + DataFormat, related_name="database_outputs", on_delete=models.CASCADE + ) class Meta: - unique_together = ('template', 'name') + unique_together = ("template", "name") def __str__(self): return self.fullname() def fullname(self): - return self.template.name + '.' + self.name + return self.template.name + "." + self.name -#---------------------------------------------------------- +# ---------------------------------------------------------- class DatabaseSetOutputManager(models.Manager): def get_by_natural_key(self, set_natural_key, output_name): set_ = DatabaseSet.objects.get_by_natural_key(*set_natural_key) - return self.get( - set=set_, - template__name=output_name, - ) + return self.get(set=set_, template__name=output_name) class DatabaseSetOutput(models.Model): - template = models.ForeignKey(DatabaseSetTemplateOutput, - related_name='instances', on_delete=models.CASCADE) - set = models.ForeignKey(DatabaseSet, related_name='outputs', - on_delete=models.CASCADE) + template = models.ForeignKey( + DatabaseSetTemplateOutput, related_name="instances", on_delete=models.CASCADE + ) + set = models.ForeignKey( + DatabaseSet, related_name="outputs", on_delete=models.CASCADE + ) objects = DatabaseSetOutputManager() @@ -460,7 +535,7 @@ class DatabaseSetOutput(models.Model): return self.fullname() def fullname(self): - return '%s.%s.%s.%s' % ( + return "%s.%s.%s.%s" % ( self.set.protocol.database.fullname(), self.set.protocol.name, self.set.name, @@ -475,5 +550,8 @@ class DatabaseSetOutput(models.Model): def natural_key(self): return (self.set.natural_key(), self.template.name) - natural_key.dependencies = ['databases.databasesettemplateoutput', - 'databases.databaseset'] + + natural_key.dependencies = [ + "databases.databasesettemplateoutput", + "databases.databaseset", + ] diff --git a/beat/web/databases/tests.py b/beat/web/databases/tests.py index 92b35f1272a263cf6e7b22c3ee755ea3a53cc501..cd081b8ccd50221c519f97c301f292dc59a8d0d5 100644 --- a/beat/web/databases/tests.py +++ b/beat/web/databases/tests.py @@ -27,6 +27,8 @@ import json +from datetime import datetime, timedelta + from django.contrib.auth.models import User from django.conf import settings from django.core.urlresolvers import reverse @@ -39,7 +41,7 @@ from ..common.testutils import BaseTestCase from ..common.testutils import tearDownModule # noqa test runner will call it -TEST_PWD = "1234" +TEST_PWD = "1234" # nosec class DatabaseAPIBase(BaseTestCase): @@ -128,18 +130,57 @@ class DatabaseCreationAPI(DatabaseAPIBase): class DatabaseRetrievalAPI(DatabaseAPIBase): - def test_retrieve_database(self): + def setUp(self): + super(DatabaseRetrievalAPI, self).setUp() + (dataformat, errors) = DataFormat.objects.create_dataformat( self.system_user, "float", "" ) self.assertIsNotNone(dataformat, errors) dataformat.share() - (database, errors) = Database.objects.create_database( + (self.database, errors) = Database.objects.create_database( self.db_name, declaration=self.DATABASE ) - self.assertIsNotNone(database, errors) - database.share() + self.assertIsNotNone(self.database, errors) + + def tearDown(self): + self.database.delete() + + def __accessibility_test(self): + self.database.accessibility_start_date = datetime.now() + self.database.accessibility_end_date = ( + self.database.accessibility_start_date + timedelta(days=1) + ) + self.database.save() + + object_url = reverse( + "api_databases:object", kwargs={"database_name": self.db_name, "version": 1} + ) + all_url = reverse("api_databases:all") + + response = self.client.get(object_url, format="json") + self.checkResponse(response, 200, content_type="application/json") + + response = self.client.get(all_url, format="json") + self.checkResponse(response, 200, content_type="application/json") + self.assertEqual(len(response.json()), 1) + + self.database.accessibility_start_date = datetime.now() - timedelta(days=2) + self.database.accessibility_end_date = ( + self.database.accessibility_start_date + timedelta(days=1) + ) + self.database.save() + + response = self.client.get(object_url, format="json") + self.checkResponse(response, 404) + + response = self.client.get(all_url, format="json") + self.checkResponse(response, 200, content_type="application/json") + self.assertEqual(len(response.json()), 0) + + def test_retrieve_database(self): + self.database.share() self.client.login(username=settings.SYSTEM_ACCOUNT, password=TEST_PWD) @@ -153,4 +194,13 @@ class DatabaseRetrievalAPI(DatabaseAPIBase): declaration = json.loads(data["declaration"]) self.assertTrue(declaration["root_folder"].startswith("/path_to_db_folder")) - database.delete() + def test_dated_database_for_user(self): + self.database.share(users=[settings.SYSTEM_ACCOUNT]) + self.database.save() + self.client.login(username=settings.SYSTEM_ACCOUNT, password=TEST_PWD) + self.__accessibility_test() + + def test_dated_database_for_public(self): + self.database.share() + self.database.save() + self.__accessibility_test() diff --git a/beat/web/experiments/permissions.py b/beat/web/experiments/permissions.py index bbc04aa6549ac0f773b411991108512f115b89af..dbbe65ae2ef0432a026347ff1c38a5a34f2c6c05 100644 --- a/beat/web/experiments/permissions.py +++ b/beat/web/experiments/permissions.py @@ -29,14 +29,22 @@ from rest_framework import permissions from ..databases.models import Database + class IsDatabaseAccessible(permissions.BasePermission): """ The logged in user must have access to the database used by the experiment """ + message = "Database is not accessible" + def has_object_permission(self, request, view, obj): accessible_databases = Database.objects.for_user(request.user, True) - experiment_databases = Database.objects.filter(protocols__sets__in=obj.referenced_datasets.all()).distinct() + experiment_databases = Database.objects.filter( + protocols__sets__in=obj.referenced_datasets.all() + ).distinct() - return all(experiment_db in accessible_databases for experiment_db in experiment_databases) + return all( + experiment_db in accessible_databases + for experiment_db in experiment_databases + ) diff --git a/beat/web/experiments/tests/tests_api.py b/beat/web/experiments/tests/tests_api.py index d6b6311e9307be6e634cf88a6ff48126e88e0480..e294c96af581b82a50122220295fc06cbd0e8455 100755 --- a/beat/web/experiments/tests/tests_api.py +++ b/beat/web/experiments/tests/tests_api.py @@ -29,7 +29,9 @@ import os import simplejson as json import shutil import copy + from datetime import datetime +from datetime import timedelta from django.conf import settings from django.contrib.auth.models import User @@ -1091,6 +1093,45 @@ class ExperimentStartingAPI(ExperimentTestBase): self.checkResponse(response, 403, content_type="application/json") + def test_start_experiment_with_database_outside_accessibility(self): + database = Database.objects.get(name="integers") + + self.login_johndoe() + + now = datetime.now() + delta = timedelta(days=1) + + bad_cases = { + "before start": (now + delta, None), + "after end": (None, now - delta), + "out of range": (now + delta, now + timedelta(days=2)), + } + + for case, range_ in bad_cases.items(): + start, end = range_ + print("Testing", case, "...") + database.accessibility_start_date = start + database.accessibility_end_date = end + database.save() + + response = self.client.post(self.url) + self.checkResponse(response, 403, content_type="application/json") + + def test_start_experiment_with_database_within_accessibility_range(self): + database = Database.objects.get(name="integers") + + self.login_johndoe() + + now = datetime.now() + delta = timedelta(days=1) + + database.accessibility_start_date = now - delta + database.accessibility_end_date = now + delta + database.save() + + response = self.client.post(self.url) + self.checkResponse(response, 200, content_type="application/json") + # ----------------------------------------------------------