diff --git a/app/models.py b/app/models.py index f0b6f7d..26508d7 100644 --- a/app/models.py +++ b/app/models.py @@ -71,6 +71,15 @@ class Role(db.Model): return '' % self.name +class Follow(db.Model): + __tablename__ = 'follows' + follower_id = db.Column(db.Integer, db.ForeignKey('users.id'), + primary_key=True) + followed_id = db.Column(db.Integer, db.ForeignKey('users.id'), + primary_key=True) + timestamp = db.Column(db.DateTime, default=datetime.utcnow) + + class User(UserMixin, db.Model): __tablename__ = 'users' id = db.Column(db.Integer, primary_key=True) @@ -86,6 +95,16 @@ class User(UserMixin, db.Model): last_seen = db.Column(db.DateTime(), default=datetime.utcnow) avatar_hash = db.Column(db.String(32)) posts = db.relationship('Post', backref='author', lazy='dynamic') + followed = db.relationship('Follow', + foreign_keys=[Follow.follower_id], + backref=db.backref('follower', lazy='joined'), + lazy='dynamic', + cascade='all, delete-orphan') + followers = db.relationship('Follow', + foreign_keys=[Follow.followed_id], + backref=db.backref('followed', lazy='joined'), + lazy='dynamic', + cascade='all, delete-orphan') def __init__(self, **kwargs): super(User, self).__init__(**kwargs) @@ -185,6 +204,28 @@ class User(UserMixin, db.Model): return '{url}/{hash}?s={size}&d={default}&r={rating}'.format( url=url, hash=hash, size=size, default=default, rating=rating) + def follow(self, user): + if not self.is_following(user): + f = Follow(follower=self, followed=user) + db.session.add(f) + + def unfollow(self, user): + f = self.followed.filter_by(followed_id=user.id).first() + if f: + db.session.delete(f) + + def is_following(self, user): + if user.id is None: + return False + return self.followed.filter_by( + followed_id=user.id).first() is not None + + def is_followed_by(self, user): + if user.id is None: + return False + return self.followers.filter_by( + follower_id=user.id).first() is not None + def __repr__(self): return '' % self.username diff --git a/flasky.py b/flasky.py index ab40988..a379daf 100644 --- a/flasky.py +++ b/flasky.py @@ -1,7 +1,7 @@ import os from flask_migrate import Migrate from app import create_app, db -from app.models import User, Role, Permission, Post +from app.models import User, Role, Permission, Post, Follow app = create_app(os.getenv('FLASK_CONFIG') or 'default') @@ -10,7 +10,8 @@ migrate = Migrate(app, db) @app.shell_context_processor def make_shell_context(): - return dict(db=db, User=User, Role=Role, Permission=Permission, Post=Post) + return dict(db=db, User=User, Follow=Follow, Role=Role, + Permission=Permission, Post=Post) @app.cli.command() diff --git a/migrations/versions/d1ab608c102a_.py b/migrations/versions/d1ab608c102a_.py new file mode 100644 index 0000000..6f06b22 --- /dev/null +++ b/migrations/versions/d1ab608c102a_.py @@ -0,0 +1,35 @@ +"""empty message + +Revision ID: d1ab608c102a +Revises: 53a175284ebe +Create Date: 2018-11-23 23:40:00.232851 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'd1ab608c102a' +down_revision = '53a175284ebe' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('follows', + sa.Column('follower_id', sa.Integer(), nullable=False), + sa.Column('followed_id', sa.Integer(), nullable=False), + sa.Column('timestamp', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['followed_id'], ['users.id'], ), + sa.ForeignKeyConstraint(['follower_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('follower_id', 'followed_id') + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('follows') + # ### end Alembic commands ### diff --git a/tests/test_user_model.py b/tests/test_user_model.py index 0ee4fe5..b62786b 100644 --- a/tests/test_user_model.py +++ b/tests/test_user_model.py @@ -2,7 +2,7 @@ import unittest import time from datetime import datetime from app import create_app, db -from app.models import User, AnonymousUser, Role, Permission +from app.models import User, AnonymousUser, Role, Permission, Follow class UserModelTestCase(unittest.TestCase): @@ -169,3 +169,40 @@ class UserModelTestCase(unittest.TestCase): self.assertTrue('s=256' in gravatar_256) self.assertTrue('r=pg'in gravatar_pg) self.assertTrue('d=retro' in gravatar_retro) + + def test_follows(self): + u1 = User(email='john@example.com', password='cat') + u2 = User(email='susan@example.org', password='dog') + db.session.add(u1) + db.session.add(u2) + db.session.commit() + self.assertFalse(u1.is_following(u2)) + self.assertFalse(u1.is_followed_by(u2)) + timestamp_before = datetime.utcnow() + u1.follow(u2) + db.session.add(u1) + db.session.commit() + timestamp_after = datetime.utcnow() + self.assertTrue(u1.is_following(u2)) + self.assertFalse(u1.is_followed_by(u2)) + self.assertTrue(u2.is_followed_by(u1)) + self.assertTrue(u1.followed.count() == 1) + self.assertTrue(u2.followers.count() == 1) + f = u1.followed.all()[-1] + self.assertTrue(f.followed == u2) + self.assertTrue(timestamp_before <= f.timestamp <= timestamp_after) + f = u2.followers.all()[-1] + self.assertTrue(f.follower == u1) + u1.unfollow(u2) + db.session.add(u1) + db.session.commit() + self.assertTrue(u1.followed.count() == 0) + self.assertTrue(u2.followers.count() == 0) + self.assertTrue(Follow.query.count() == 0) + u2.follow(u1) + db.session.add(u1) + db.session.add(u2) + db.session.commit() + db.session.delete(u2) + db.session.commit() + self.assertTrue(Follow.query.count() == 0)