OAuth refresh
This commit is contained in:
parent
446c8c2a52
commit
faca12783e
2 changed files with 59 additions and 19 deletions
|
@ -1,5 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from enum import IntEnum
|
||||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
@ -69,7 +70,7 @@ class Session(LibrespotSession):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
session_builder: LibrespotSession.Builder,
|
session_builder: LibrespotSession.Builder,
|
||||||
token: TokenProvider.StoredToken,
|
oauth: OAuth,
|
||||||
language: str = "en",
|
language: str = "en",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -89,7 +90,7 @@ class Session(LibrespotSession):
|
||||||
),
|
),
|
||||||
ApResolver.get_random_accesspoint(),
|
ApResolver.get_random_accesspoint(),
|
||||||
)
|
)
|
||||||
self.__token = token
|
self.__oauth = oauth
|
||||||
self.__language = language
|
self.__language = language
|
||||||
self.connect()
|
self.connect()
|
||||||
self.authenticate(session_builder.login_credentials)
|
self.authenticate(session_builder.login_credentials)
|
||||||
|
@ -112,8 +113,7 @@ class Session(LibrespotSession):
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
session = LibrespotSession.Builder(conf).stored_file(str(cred_file))
|
session = LibrespotSession.Builder(conf).stored_file(str(cred_file))
|
||||||
token = session.login_credentials.auth_data # TODO: this is wrong
|
return Session(session, OAuth(), language) # TODO
|
||||||
return Session(session, token, language)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_oauth(
|
def from_oauth(
|
||||||
|
@ -148,7 +148,7 @@ class Session(LibrespotSession):
|
||||||
typ=Authentication.AuthenticationType.values()[3],
|
typ=Authentication.AuthenticationType.values()[3],
|
||||||
auth_data=token.access_token.encode(),
|
auth_data=token.access_token.encode(),
|
||||||
)
|
)
|
||||||
return Session(session, token, language)
|
return Session(session, auth, language)
|
||||||
|
|
||||||
def __get_playable(
|
def __get_playable(
|
||||||
self, playable_id: PlayableId, quality: Quality
|
self, playable_id: PlayableId, quality: Quality
|
||||||
|
@ -188,9 +188,9 @@ class Session(LibrespotSession):
|
||||||
self.api(),
|
self.api(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def token(self) -> TokenProvider.StoredToken:
|
def oauth(self) -> OAuth:
|
||||||
"""Returns API token"""
|
"""Returns OAuth service"""
|
||||||
return self.__token
|
return self.__oauth
|
||||||
|
|
||||||
def language(self) -> str:
|
def language(self) -> str:
|
||||||
"""Returns session language"""
|
"""Returns session language"""
|
||||||
|
@ -288,7 +288,7 @@ class TokenProvider(LibrespotTokenProvider):
|
||||||
self._session = session
|
self._session = session
|
||||||
|
|
||||||
def get_token(self, *scopes) -> TokenProvider.StoredToken:
|
def get_token(self, *scopes) -> TokenProvider.StoredToken:
|
||||||
return self._session.token()
|
return self._session.oauth().get_token()
|
||||||
|
|
||||||
class StoredToken(LibrespotTokenProvider.StoredToken):
|
class StoredToken(LibrespotTokenProvider.StoredToken):
|
||||||
def __init__(self, obj):
|
def __init__(self, obj):
|
||||||
|
@ -309,6 +309,11 @@ class OAuth:
|
||||||
self.__server_thread.start()
|
self.__server_thread.start()
|
||||||
|
|
||||||
def get_authorization_url(self) -> str:
|
def get_authorization_url(self) -> str:
|
||||||
|
"""
|
||||||
|
Generate OAuth URL
|
||||||
|
Returns:
|
||||||
|
OAuth URL
|
||||||
|
"""
|
||||||
self.__code_verifier = generate_code_verifier()
|
self.__code_verifier = generate_code_verifier()
|
||||||
code_challenge = get_code_challenge(self.__code_verifier)
|
code_challenge = get_code_challenge(self.__code_verifier)
|
||||||
params = {
|
params = {
|
||||||
|
@ -322,19 +327,48 @@ class OAuth:
|
||||||
return f"{AUTH_URL}authorize?{urlencode(params)}"
|
return f"{AUTH_URL}authorize?{urlencode(params)}"
|
||||||
|
|
||||||
def await_token(self) -> TokenProvider.StoredToken:
|
def await_token(self) -> TokenProvider.StoredToken:
|
||||||
|
"""
|
||||||
|
Blocks until server thread gets token
|
||||||
|
Returns:
|
||||||
|
StoredToken
|
||||||
|
"""
|
||||||
self.__server_thread.join()
|
self.__server_thread.join()
|
||||||
return self.__token
|
return self.__token
|
||||||
|
|
||||||
def set_token(self, code: str) -> None:
|
def get_token(self) -> TokenProvider.StoredToken:
|
||||||
|
"""
|
||||||
|
Gets a valid token
|
||||||
|
Returns:
|
||||||
|
StoredToken
|
||||||
|
"""
|
||||||
|
if self.__token is None:
|
||||||
|
raise RuntimeError("Session isn't authenticated!")
|
||||||
|
elif self.__token.expired():
|
||||||
|
self.set_token(self.__token.refresh_token, OAuth.RequestType.REFRESH)
|
||||||
|
return self.__token
|
||||||
|
|
||||||
|
def set_token(self, code: str, request_type: RequestType) -> None:
|
||||||
|
"""
|
||||||
|
Fetches and sets stored token
|
||||||
|
Returns:
|
||||||
|
StoredToken
|
||||||
|
"""
|
||||||
token_url = f"{AUTH_URL}api/token"
|
token_url = f"{AUTH_URL}api/token"
|
||||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||||
body = {
|
if request_type == OAuth.RequestType.LOGIN:
|
||||||
"grant_type": "authorization_code",
|
body = {
|
||||||
"code": code,
|
"grant_type": "authorization_code",
|
||||||
"redirect_uri": REDIRECT_URI,
|
"code": code,
|
||||||
"client_id": CLIENT_ID,
|
"redirect_uri": REDIRECT_URI,
|
||||||
"code_verifier": self.__code_verifier,
|
"client_id": CLIENT_ID,
|
||||||
}
|
"code_verifier": self.__code_verifier,
|
||||||
|
}
|
||||||
|
elif request_type == OAuth.RequestType.REFRESH:
|
||||||
|
body = {
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
"refresh_token": code,
|
||||||
|
"client_id": CLIENT_ID,
|
||||||
|
}
|
||||||
response = post(token_url, headers=headers, data=body)
|
response = post(token_url, headers=headers, data=body)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise IOError(
|
raise IOError(
|
||||||
|
@ -348,6 +382,10 @@ class OAuth:
|
||||||
httpd.authenticator = self
|
httpd.authenticator = self
|
||||||
httpd.serve_forever()
|
httpd.serve_forever()
|
||||||
|
|
||||||
|
class RequestType(IntEnum):
|
||||||
|
LOGIN = 0
|
||||||
|
REFRESH = 1
|
||||||
|
|
||||||
class OAuthHTTPServer(HTTPServer):
|
class OAuthHTTPServer(HTTPServer):
|
||||||
authenticator: OAuth
|
authenticator: OAuth
|
||||||
|
|
||||||
|
@ -371,7 +409,9 @@ class OAuth:
|
||||||
|
|
||||||
if code:
|
if code:
|
||||||
if isinstance(self.server, OAuth.OAuthHTTPServer):
|
if isinstance(self.server, OAuth.OAuthHTTPServer):
|
||||||
self.server.authenticator.set_token(code[0])
|
self.server.authenticator.set_token(
|
||||||
|
code[0], OAuth.RequestType.LOGIN
|
||||||
|
)
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_header("Content-type", "text/html")
|
self.send_header("Content-type", "text/html")
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
|
|
|
@ -106,7 +106,7 @@ class Selection:
|
||||||
|
|
||||||
def __print(self, count: int, items: list[dict[str, Any]], *args: str) -> None:
|
def __print(self, count: int, items: list[dict[str, Any]], *args: str) -> None:
|
||||||
arg_range = range(len(args))
|
arg_range = range(len(args))
|
||||||
category_str = " # " + " ".join("{:<38}" for _ in arg_range)
|
category_str = "# " + " ".join("{:<38}" for _ in arg_range)
|
||||||
print(category_str.format(*[s.upper() for s in list(args)]))
|
print(category_str.format(*[s.upper() for s in list(args)]))
|
||||||
for item in items:
|
for item in items:
|
||||||
count += 1
|
count += 1
|
||||||
|
|
Loading…
Add table
Reference in a new issue