diff --git a/bob/db/cuhk_cufsf/query.py b/bob/db/cuhk_cufsf/query.py
index efaffb4510e443c45db25d22c243b652aebac915..023f60f5685ace7485057268f7bfa8c2ec3c110b 100755
--- a/bob/db/cuhk_cufsf/query.py
+++ b/bob/db/cuhk_cufsf/query.py
@@ -119,7 +119,7 @@ class Database(bob.db.base.SQLiteDatabase):
     return file.annotations(annotation_type=annotation_type)
 
 
-  def objects(self, groups = None, protocol = None, purposes = None, model_ids = None, **kwargs):
+  def objects(self, groups = None, protocol = None, purposes = None, model_ids = None, modality=None, **kwargs):
     """
       This function returns lists of File objects, which fulfill the given restrictions.
 
@@ -129,6 +129,9 @@ class Database(bob.db.base.SQLiteDatabase):
     groups    = self.check_parameters_for_validity(groups, "group", GROUPS)
     protocols = self.check_parameters_for_validity(protocol, "protocol", PROTOCOLS) 
     purposes  = self.check_parameters_for_validity(purposes, "purpose", PURPOSES)
+    modality = self.check_parameters_for_validity(
+        modality, "modality", self.modalities)
+    
 
     #You need to select only one protocol
     if (len(protocols) > 1):
@@ -141,6 +144,9 @@ class Database(bob.db.base.SQLiteDatabase):
     query = query.filter(bob.db.cuhk_cufsf.Protocol_File_Association.group.in_(groups))
     query = query.filter(bob.db.cuhk_cufsf.Protocol_File_Association.protocol.in_(protocols))
     query = query.filter(bob.db.cuhk_cufsf.Protocol_File_Association.purpose.in_(purposes))
+    query = query.filter(
+        bob.db.cuhk_cufsf.File.modality.in_(modality))
+    
 
     if model_ids is not None and not 'probe' in purposes:
       if type(model_ids) is not list and type(model_ids) is not tuple:
diff --git a/bob/db/cuhk_cufsf/test.py b/bob/db/cuhk_cufsf/test.py
index 793af61260faadc762c929ce33737d19fd501470..0cd79fee3067c6b550c99a1db07d282e75aba12d 100644
--- a/bob/db/cuhk_cufsf/test.py
+++ b/bob/db/cuhk_cufsf/test.py
@@ -76,6 +76,10 @@ def test02_search_files_protocols():
     
       assert len(bob.db.cuhk_cufsf.Database().objects(protocol=p, groups="eval")) == 0
 
+      # Checking the modalities
+      assert len(bob.db.cuhk_cufsf.Database().objects(protocol=p, groups="world", modality=["photo"])) == world//2
+      assert len(bob.db.cuhk_cufsf.Database().objects(protocol=p, groups="world", modality=["sketch"])) == world//2
+      
 
 def test03_verification_files_protocols():