Skip to content
Snippets Groups Projects
Commit 88828988 authored by Flavio TARSETTI's avatar Flavio TARSETTI
Browse files

Merge branch '541_add_usage_time_to_database' into 'master'

Add usage time to database

See merge request !319
parents e0796e0f bfce16b6
No related branches found
No related tags found
1 merge request!319Add usage time to database
Pipeline #36953 passed
[flake8]
max-line-length = 80
max-line-length = 88
select = B,C,E,F,W,T4,B9,B950
ignore = E501, W503
ignore = E501, W503, E203
# -*- coding: utf-8 -*-
# Generated by Django 1.11.25 on 2020-02-12 11:22
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [("databases", "0006_databaseset_hash_unique")]
operations = [
migrations.AddField(
model_name="database",
name="accessibility_start_date",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="database",
name="accessibility_end_date",
field=models.DateTimeField(blank=True, null=True),
),
]
......@@ -29,16 +29,18 @@ import os
import simplejson
from datetime import datetime
from django.db import models
from django.conf import settings
from django.core.urlresolvers import reverse
from django.db.models import Q
import beat.core.database
from beat.backend.python.hash import hashDataset
from ..dataformats.models import DataFormat
from ..common.models import Shareable
from ..common.models import Versionable
from ..common.models import VersionableManager
from ..common.storage import OverwriteStorage
......@@ -57,7 +59,7 @@ from ..code.models import get_source_code
from ..code.models import set_source_code
#----------------------------------------------------------
# ----------------------------------------------------------
def validate_database(declaration):
......@@ -66,8 +68,10 @@ def validate_database(declaration):
database = beat.core.database.Database(settings.PREFIX, declaration)
if not database.valid:
errors = 'The database declaration is **invalid**. Errors:\n * ' + \
'\n * '.join(database.errors)
errors = (
"The database declaration is **invalid**. Errors:\n * "
+ "\n * ".join(database.errors)
)
raise SyntaxError(errors)
# TODO: for each protocol, if a template already exists, that the inserted
......@@ -77,37 +81,57 @@ def validate_database(declaration):
return database
#----------------------------------------------------------
# ----------------------------------------------------------
class DatabaseStorage(OverwriteStorage):
def __init__(self, *args, **kwargs):
super(DatabaseStorage, self).__init__(*args, location=settings.DATABASES_ROOT, **kwargs)
super(DatabaseStorage, self).__init__(
*args, location=settings.DATABASES_ROOT, **kwargs
)
#----------------------------------------------------------
# ----------------------------------------------------------
def _filter_accessibility_date(queryset):
now = datetime.now()
class DatabaseManager(VersionableManager):
return queryset.filter(
(Q(accessibility_start_date=None) | Q(accessibility_start_date__lte=now))
& (Q(accessibility_end_date=None) | Q(accessibility_end_date__gte=now))
)
class DatabaseManager(VersionableManager):
def get_by_natural_key(self, name, version):
return self.get(name=name, version=version)
def create_database(self, name, short_description='', description='',
declaration=None, code=None, version=1,
previous_version=None):
def for_user(self, user, add_public=False):
original_query = super(DatabaseManager, self).for_user(user, add_public)
return _filter_accessibility_date(original_query)
def public(self):
original_query = super(DatabaseManager, self).public()
return _filter_accessibility_date(original_query)
def create_database(
self,
name,
short_description="",
description="",
declaration=None,
code=None,
version=1,
previous_version=None,
):
"""Convenience function to create a new database from its parts"""
# Create the database representation
database = self.model(
name = name,
version = version,
sharing = self.model.PRIVATE,
previous_version = previous_version,
name=name,
version=version,
sharing=self.model.PRIVATE,
previous_version=previous_version,
)
# Makes sure we get a declaration in string format
......@@ -117,11 +141,11 @@ class DatabaseManager(VersionableManager):
else:
default = beat.core.database.Database(settings.PREFIX, data=None)
declaration = default.data
elif not(isinstance(declaration, dict)):
elif not (isinstance(declaration, dict)):
declaration = simplejson.loads(declaration)
if len(short_description) > 0:
declaration['description'] = short_description
declaration["description"] = short_description
database.declaration = declaration
......@@ -130,7 +154,7 @@ class DatabaseManager(VersionableManager):
if previous_version is not None:
code = previous_version.source_code
else:
code = ''
code = ""
database.source_code = code
......@@ -139,7 +163,7 @@ class DatabaseManager(VersionableManager):
if previous_version is not None:
description = previous_version.description
else:
description = ''
description = ""
database.description = description
......@@ -148,68 +172,111 @@ class DatabaseManager(VersionableManager):
database.save()
except Exception:
import traceback
return (None, traceback.format_exc())
return (database, None)
#----------------------------------------------------------
# ----------------------------------------------------------
class Database(Versionable):
#_____ Fields __________
# _____ Fields __________
accessibility_start_date = models.DateTimeField(blank=True, null=True)
accessibility_end_date = models.DateTimeField(blank=True, null=True)
declaration_file = models.FileField(
storage=DatabaseStorage(),
upload_to=get_contribution_declaration_filename,
blank=True, null=True,
blank=True,
null=True,
max_length=200,
db_column='declaration'
db_column="declaration",
)
description_file = models.FileField(
storage=DatabaseStorage(),
upload_to=get_contribution_description_filename,
blank=True, null=True,
blank=True,
null=True,
max_length=200,
db_column='description'
db_column="description",
)
source_code_file = models.FileField(
storage=DatabaseStorage(),
upload_to=get_contribution_source_code_filename,
blank=True, null=True,
blank=True,
null=True,
max_length=200,
db_column='source_code'
db_column="source_code",
)
objects = DatabaseManager()
#_____ Meta parameters __________
# _____ Meta parameters __________
class Meta(Versionable.Meta):
unique_together = ('name', 'version')
unique_together = ("name", "version")
#_____ Utilities __________
# _____ Utilities __________
def get_absolute_url(self):
return reverse(
'databases:view',
args=(self.name, self.version,),
)
return reverse("databases:view", args=(self.name, self.version))
def natural_key(self):
return (self.name, self.version)
#_____ Overrides __________
# _____ Overrides __________
def accessibility_for(self, user_or_team, without_usable=False):
now = datetime.now()
accessible = True
if self.accessibility_start_date and self.accessibility_end_date:
if now < self.accessibility_start_date or now > self.accessibility_end_date:
accessible = False
elif self.accessibility_start_date:
if now < self.accessibility_start_date:
accessible = False
elif self.accessibility_end_date:
if now > self.accessibility_end_date:
accessible = False
if not accessible:
return (False, False, None)
return super(Database, self).accessibility_for(user_or_team, without_usable)
def is_accessible(self, users=None, teams=None):
errors = []
now = datetime.now()
if self.accessibility_start_date and self.accessibility_end_date:
if now < self.accessibility_start_date or now > self.accessibility_end_date:
errors.append(
"The database {} is currently not accessible".format(self.fullname)
)
elif self.accessibility_start_date:
if now < self.accessibility_start_date:
errors.append(
"The database {} is not yet accessible".format(self.fullname)
)
elif self.accessibility_end_date:
if now > self.accessibility_end_date:
errors.append(
"The database {} is not accessible anymore".format(self.fullname)
)
if errors:
return errors
return super(Database, self).is_accessible(users, teams)
def save(self, *args, **kwargs):
......@@ -220,7 +287,11 @@ class Database(Versionable):
wrapper = validate_database(declaration)
# reset the description
self.short_description = wrapper.description if (wrapper is not None) and (wrapper.description is not None) else ''
self.short_description = (
wrapper.description
if (wrapper is not None) and (wrapper.description is not None)
else ""
)
# Save the changed files (if necessary)
storage.save_files(self)
......@@ -229,27 +300,26 @@ class Database(Versionable):
# if the filename has changed, move the declaration
if self.declaration_filename() != self.declaration_file.name:
storage.rename_file(self, 'declaration_file', self.declaration_filename())
storage.rename_file(self, 'description_file', self.description_filename())
storage.rename_file(self, 'source_code_file', self.source_code_filename())
storage.rename_file(self, "declaration_file", self.declaration_filename())
storage.rename_file(self, "description_file", self.description_filename())
storage.rename_file(self, "source_code_file", self.source_code_filename())
#_____ Methods __________
# _____ Methods __________
def fullname(self):
return '%s/%d' % (self.name, self.version)
return "%s/%d" % (self.name, self.version)
def fullpath(self, extension):
return os.path.join(self.name, str(self.version) + extension)
def declaration_filename(self):
return self.fullpath('.json')
return self.fullpath(".json")
def description_filename(self):
return self.fullpath('.rst')
return self.fullpath(".rst")
def source_code_filename(self):
return self.fullpath('.py')
return self.fullpath(".py")
def all_referenced_dataformats(self):
result = []
......@@ -263,7 +333,7 @@ class Database(Versionable):
result.extend(database_protocol.all_needed_dataformats())
return list(set(result))
#_____ Properties __________
# _____ Properties __________
description = property(get_description, set_description)
declaration = property(get_declaration, set_declaration)
......@@ -271,16 +341,13 @@ class Database(Versionable):
source_code = property(get_source_code, set_source_code)
#----------------------------------------------------------
# ----------------------------------------------------------
class DatabaseProtocolManager(models.Manager):
def get_by_natural_key(self, database_name, database_version, name):
return self.get(
database__name=database_name,
database__version=database_version,
name=name,
database__name=database_name, database__version=database_version, name=name
)
......@@ -288,24 +355,26 @@ class DatabaseProtocol(models.Model):
objects = DatabaseProtocolManager()
database = models.ForeignKey(Database, related_name='protocols',
on_delete=models.CASCADE)
database = models.ForeignKey(
Database, related_name="protocols", on_delete=models.CASCADE
)
name = models.CharField(max_length=200, blank=True)
class Meta:
unique_together = ('database', 'name')
ordering = ['name']
unique_together = ("database", "name")
ordering = ["name"]
def __str__(self):
return self.fullname()
def natural_key(self):
return self.database.natural_key() + (self.name,)
natural_key.dependencies = ['databases.database']
natural_key.dependencies = ["databases.database"]
def fullname(self):
if self.name != '':
return self.database.fullname() + '@' + self.name
if self.name != "":
return self.database.fullname() + "@" + self.name
else:
return self.database.fullname()
......@@ -322,16 +391,16 @@ class DatabaseProtocol(models.Model):
return list(set(result))
def set_template_basename(self):
if not self.sets.count(): return "unknown"
if not self.sets.count():
return "unknown"
dbset = self.sets.all()[0]
return dbset.template.name.rsplit('__')[0]
return dbset.template.name.rsplit("__")[0]
#----------------------------------------------------------
# ----------------------------------------------------------
class DatabaseSetTemplateManager(models.Manager):
def get_by_natural_key(self, name):
return self.get(name=name)
......@@ -349,24 +418,24 @@ class DatabaseSetTemplate(models.Model):
return (self.name,)
#----------------------------------------------------------
# ----------------------------------------------------------
class DatabaseSetManager(models.Manager):
def create(self, protocol, template, name):
dataset = DatabaseSet(
name = name,
template = template,
protocol = protocol,
hash = hashDataset(protocol.database.fullname(), protocol.name, name)
name=name,
template=template,
protocol=protocol,
hash=hashDataset(protocol.database.fullname(), protocol.name, name),
)
dataset.save()
return dataset
def get_by_natural_key(self, database_name, database_version, protocol_name, name, template_name):
def get_by_natural_key(
self, database_name, database_version, protocol_name, name, template_name
):
return self.get(
protocol__database__name=database_name,
protocol__database__version=database_version,
......@@ -378,29 +447,34 @@ class DatabaseSetManager(models.Manager):
class DatabaseSet(models.Model):
objects = DatabaseSetManager()
objects = DatabaseSetManager()
protocol = models.ForeignKey(DatabaseProtocol, related_name='sets',
on_delete=models.CASCADE)
name = models.CharField(max_length=200, blank=True)
template = models.ForeignKey(DatabaseSetTemplate, related_name='sets',
on_delete=models.CASCADE)
protocol = models.ForeignKey(
DatabaseProtocol, related_name="sets", on_delete=models.CASCADE
)
name = models.CharField(max_length=200, blank=True)
template = models.ForeignKey(
DatabaseSetTemplate, related_name="sets", on_delete=models.CASCADE
)
hash = models.CharField(max_length=64, unique=True)
class Meta:
unique_together = ('protocol', 'name', 'template')
unique_together = ("protocol", "name", "template")
def __str__(self):
return self.fullname()
def natural_key(self):
return self.protocol.natural_key() + (self.name,) + self.template.natural_key()
natural_key.dependencies = ['databases.databaseprotocol',
'databases.databasesettemplate']
natural_key.dependencies = [
"databases.databaseprotocol",
"databases.databasesettemplate",
]
def fullname(self):
if self.name != '':
return self.protocol.fullname() + '.' + self.name
if self.name != "":
return self.protocol.fullname() + "." + self.name
else:
return self.protocol.fullname()
......@@ -416,43 +490,44 @@ class DatabaseSet(models.Model):
return list(set(result))
#----------------------------------------------------------
# ----------------------------------------------------------
class DatabaseSetTemplateOutput(models.Model):
template = models.ForeignKey(DatabaseSetTemplate,
related_name='outputs', on_delete=models.CASCADE)
name = models.CharField(max_length=200)
dataformat = models.ForeignKey(DataFormat,
related_name='database_outputs', on_delete=models.CASCADE)
template = models.ForeignKey(
DatabaseSetTemplate, related_name="outputs", on_delete=models.CASCADE
)
name = models.CharField(max_length=200)
dataformat = models.ForeignKey(
DataFormat, related_name="database_outputs", on_delete=models.CASCADE
)
class Meta:
unique_together = ('template', 'name')
unique_together = ("template", "name")
def __str__(self):
return self.fullname()
def fullname(self):
return self.template.name + '.' + self.name
return self.template.name + "." + self.name
#----------------------------------------------------------
# ----------------------------------------------------------
class DatabaseSetOutputManager(models.Manager):
def get_by_natural_key(self, set_natural_key, output_name):
set_ = DatabaseSet.objects.get_by_natural_key(*set_natural_key)
return self.get(
set=set_,
template__name=output_name,
)
return self.get(set=set_, template__name=output_name)
class DatabaseSetOutput(models.Model):
template = models.ForeignKey(DatabaseSetTemplateOutput,
related_name='instances', on_delete=models.CASCADE)
set = models.ForeignKey(DatabaseSet, related_name='outputs',
on_delete=models.CASCADE)
template = models.ForeignKey(
DatabaseSetTemplateOutput, related_name="instances", on_delete=models.CASCADE
)
set = models.ForeignKey(
DatabaseSet, related_name="outputs", on_delete=models.CASCADE
)
objects = DatabaseSetOutputManager()
......@@ -460,7 +535,7 @@ class DatabaseSetOutput(models.Model):
return self.fullname()
def fullname(self):
return '%s.%s.%s.%s' % (
return "%s.%s.%s.%s" % (
self.set.protocol.database.fullname(),
self.set.protocol.name,
self.set.name,
......@@ -475,5 +550,8 @@ class DatabaseSetOutput(models.Model):
def natural_key(self):
return (self.set.natural_key(), self.template.name)
natural_key.dependencies = ['databases.databasesettemplateoutput',
'databases.databaseset']
natural_key.dependencies = [
"databases.databasesettemplateoutput",
"databases.databaseset",
]
......@@ -27,6 +27,8 @@
import json
from datetime import datetime, timedelta
from django.contrib.auth.models import User
from django.conf import settings
from django.core.urlresolvers import reverse
......@@ -39,7 +41,7 @@ from ..common.testutils import BaseTestCase
from ..common.testutils import tearDownModule # noqa test runner will call it
TEST_PWD = "1234"
TEST_PWD = "1234" # nosec
class DatabaseAPIBase(BaseTestCase):
......@@ -128,18 +130,57 @@ class DatabaseCreationAPI(DatabaseAPIBase):
class DatabaseRetrievalAPI(DatabaseAPIBase):
def test_retrieve_database(self):
def setUp(self):
super(DatabaseRetrievalAPI, self).setUp()
(dataformat, errors) = DataFormat.objects.create_dataformat(
self.system_user, "float", ""
)
self.assertIsNotNone(dataformat, errors)
dataformat.share()
(database, errors) = Database.objects.create_database(
(self.database, errors) = Database.objects.create_database(
self.db_name, declaration=self.DATABASE
)
self.assertIsNotNone(database, errors)
database.share()
self.assertIsNotNone(self.database, errors)
def tearDown(self):
self.database.delete()
def __accessibility_test(self):
self.database.accessibility_start_date = datetime.now()
self.database.accessibility_end_date = (
self.database.accessibility_start_date + timedelta(days=1)
)
self.database.save()
object_url = reverse(
"api_databases:object", kwargs={"database_name": self.db_name, "version": 1}
)
all_url = reverse("api_databases:all")
response = self.client.get(object_url, format="json")
self.checkResponse(response, 200, content_type="application/json")
response = self.client.get(all_url, format="json")
self.checkResponse(response, 200, content_type="application/json")
self.assertEqual(len(response.json()), 1)
self.database.accessibility_start_date = datetime.now() - timedelta(days=2)
self.database.accessibility_end_date = (
self.database.accessibility_start_date + timedelta(days=1)
)
self.database.save()
response = self.client.get(object_url, format="json")
self.checkResponse(response, 404)
response = self.client.get(all_url, format="json")
self.checkResponse(response, 200, content_type="application/json")
self.assertEqual(len(response.json()), 0)
def test_retrieve_database(self):
self.database.share()
self.client.login(username=settings.SYSTEM_ACCOUNT, password=TEST_PWD)
......@@ -153,4 +194,13 @@ class DatabaseRetrievalAPI(DatabaseAPIBase):
declaration = json.loads(data["declaration"])
self.assertTrue(declaration["root_folder"].startswith("/path_to_db_folder"))
database.delete()
def test_dated_database_for_user(self):
self.database.share(users=[settings.SYSTEM_ACCOUNT])
self.database.save()
self.client.login(username=settings.SYSTEM_ACCOUNT, password=TEST_PWD)
self.__accessibility_test()
def test_dated_database_for_public(self):
self.database.share()
self.database.save()
self.__accessibility_test()
......@@ -29,14 +29,22 @@ from rest_framework import permissions
from ..databases.models import Database
class IsDatabaseAccessible(permissions.BasePermission):
"""
The logged in user must have access to the database
used by the experiment
"""
message = "Database is not accessible"
def has_object_permission(self, request, view, obj):
accessible_databases = Database.objects.for_user(request.user, True)
experiment_databases = Database.objects.filter(protocols__sets__in=obj.referenced_datasets.all()).distinct()
experiment_databases = Database.objects.filter(
protocols__sets__in=obj.referenced_datasets.all()
).distinct()
return all(experiment_db in accessible_databases for experiment_db in experiment_databases)
return all(
experiment_db in accessible_databases
for experiment_db in experiment_databases
)
......@@ -29,7 +29,9 @@ import os
import simplejson as json
import shutil
import copy
from datetime import datetime
from datetime import timedelta
from django.conf import settings
from django.contrib.auth.models import User
......@@ -1091,6 +1093,45 @@ class ExperimentStartingAPI(ExperimentTestBase):
self.checkResponse(response, 403, content_type="application/json")
def test_start_experiment_with_database_outside_accessibility(self):
database = Database.objects.get(name="integers")
self.login_johndoe()
now = datetime.now()
delta = timedelta(days=1)
bad_cases = {
"before start": (now + delta, None),
"after end": (None, now - delta),
"out of range": (now + delta, now + timedelta(days=2)),
}
for case, range_ in bad_cases.items():
start, end = range_
print("Testing", case, "...")
database.accessibility_start_date = start
database.accessibility_end_date = end
database.save()
response = self.client.post(self.url)
self.checkResponse(response, 403, content_type="application/json")
def test_start_experiment_with_database_within_accessibility_range(self):
database = Database.objects.get(name="integers")
self.login_johndoe()
now = datetime.now()
delta = timedelta(days=1)
database.accessibility_start_date = now - delta
database.accessibility_end_date = now + delta
database.save()
response = self.client.post(self.url)
self.checkResponse(response, 200, content_type="application/json")
# ----------------------------------------------------------
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment