initial commit
This commit is contained in:
169
main.py
Normal file
169
main.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# ============================================
|
||||
# 1. Imports
|
||||
# ============================================
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import yfinance as yf
|
||||
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.metrics import accuracy_score
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# ============================================
|
||||
# 2. Parameters
|
||||
# ============================================
|
||||
TICKERS = ["AAPL", "MSFT", "GOOGL", "AMZN", "META"]
|
||||
START_DATE = "2015-01-01"
|
||||
END_DATE = "2024-01-01"
|
||||
TRAIN_END = "2020-12-31"
|
||||
|
||||
TRANSACTION_COST = 0.001 # 0.1%
|
||||
|
||||
# ============================================
|
||||
# 3. Download Data
|
||||
# ============================================
|
||||
def download_data(tickers):
|
||||
data = yf.download(tickers, start=START_DATE, end=END_DATE, group_by="ticker")
|
||||
|
||||
dfs = []
|
||||
for ticker in tickers:
|
||||
df = data[ticker].copy()
|
||||
df["ticker"] = ticker
|
||||
dfs.append(df)
|
||||
|
||||
df = pd.concat(dfs)
|
||||
df.index.name = "date"
|
||||
return df.reset_index()
|
||||
|
||||
df = download_data(TICKERS)
|
||||
|
||||
# ============================================
|
||||
# 4. Sort (IMPORTANT)
|
||||
# ============================================
|
||||
df = df.sort_values(["ticker", "date"])
|
||||
|
||||
# ============================================
|
||||
# 5. Feature Engineering (NO APPLY)
|
||||
# ============================================
|
||||
df["return_1d"] = df.groupby("ticker")["Close"].pct_change()
|
||||
df["return_5d"] = df.groupby("ticker")["Close"].pct_change(5)
|
||||
|
||||
df["ma_5"] = df.groupby("ticker")["Close"].transform(lambda x: x.rolling(5).mean())
|
||||
df["ma_10"] = df.groupby("ticker")["Close"].transform(lambda x: x.rolling(10).mean())
|
||||
|
||||
df["volatility_5d"] = (
|
||||
df.groupby("ticker")["return_1d"]
|
||||
.transform(lambda x: x.rolling(5).std())
|
||||
)
|
||||
|
||||
df["volume_change"] = df.groupby("ticker")["Volume"].pct_change()
|
||||
|
||||
df["price_ma5_ratio"] = df["Close"] / df["ma_5"]
|
||||
|
||||
# ============================================
|
||||
# 6. Labels (SAFE)
|
||||
# ============================================
|
||||
df["future_return"] = df.groupby("ticker")["Close"].pct_change().shift(-1)
|
||||
df["target"] = (df["future_return"] > 0).astype(int)
|
||||
|
||||
# ============================================
|
||||
# 7. Clean Data
|
||||
# ============================================
|
||||
df = df.dropna().reset_index(drop=True)
|
||||
|
||||
# ============================================
|
||||
# 8. Train/Test Split
|
||||
# ============================================
|
||||
train = df[df["date"] <= TRAIN_END]
|
||||
test = df[df["date"] > TRAIN_END]
|
||||
|
||||
FEATURES = [
|
||||
"return_1d",
|
||||
"return_5d",
|
||||
"ma_5",
|
||||
"ma_10",
|
||||
"volatility_5d",
|
||||
"volume_change",
|
||||
"price_ma5_ratio"
|
||||
]
|
||||
|
||||
X_train = train[FEATURES]
|
||||
y_train = train["target"]
|
||||
|
||||
X_test = test[FEATURES]
|
||||
y_test = test["target"]
|
||||
|
||||
# ============================================
|
||||
# 9. Scaling
|
||||
# ============================================
|
||||
scaler = StandardScaler()
|
||||
X_train = scaler.fit_transform(X_train)
|
||||
X_test = scaler.transform(X_test)
|
||||
|
||||
# ============================================
|
||||
# 10. Train Model
|
||||
# ============================================
|
||||
model = RandomForestClassifier(
|
||||
n_estimators=100,
|
||||
max_depth=5,
|
||||
random_state=42
|
||||
)
|
||||
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
# ============================================
|
||||
# 11. Predictions
|
||||
# ============================================
|
||||
preds = model.predict(X_test)
|
||||
|
||||
accuracy = accuracy_score(y_test, preds)
|
||||
print(f"Test Accuracy: {accuracy:.4f}")
|
||||
|
||||
# ============================================
|
||||
# 12. Backtest
|
||||
# ============================================
|
||||
test = test.copy()
|
||||
test["prediction"] = preds
|
||||
|
||||
# 🚨 Avoid lookahead bias
|
||||
test["prediction"] = test.groupby("ticker")["prediction"].shift(1)
|
||||
|
||||
# Strategy returns
|
||||
test["strategy_return"] = test["future_return"] * test["prediction"]
|
||||
|
||||
# Transaction costs
|
||||
test["position_change"] = (
|
||||
test.groupby("ticker")["prediction"].diff().abs()
|
||||
)
|
||||
test["transaction_cost"] = test["position_change"] * TRANSACTION_COST
|
||||
|
||||
test["strategy_return"] -= test["transaction_cost"]
|
||||
|
||||
# Drop NaNs from shifting
|
||||
test = test.dropna()
|
||||
|
||||
# ============================================
|
||||
# 13. Performance
|
||||
# ============================================
|
||||
test["cum_market"] = (1 + test["future_return"]).cumprod()
|
||||
test["cum_strategy"] = (1 + test["strategy_return"]).cumprod()
|
||||
|
||||
sharpe = np.sqrt(252) * test["strategy_return"].mean() / test["strategy_return"].std()
|
||||
print(f"Sharpe Ratio: {sharpe:.2f}")
|
||||
|
||||
# ============================================
|
||||
# 14. Plot
|
||||
# ============================================
|
||||
plt.figure(figsize=(10,6))
|
||||
plt.plot(test["date"], test["cum_market"], label="Market")
|
||||
plt.plot(test["date"], test["cum_strategy"], label="Strategy")
|
||||
plt.legend()
|
||||
plt.title("Strategy vs Market")
|
||||
plt.xlabel("Date")
|
||||
plt.ylabel("Cumulative Return")
|
||||
plt.grid()
|
||||
#plt.show()
|
||||
plt.savefig("strategy.png", dpi=150)
|
||||
print("Plot saved as strategy.png")
|
||||
36
requirements.txt
Normal file
36
requirements.txt
Normal file
@@ -0,0 +1,36 @@
|
||||
beautifulsoup4==4.14.3
|
||||
certifi==2026.2.25
|
||||
cffi==2.0.0
|
||||
charset-normalizer==3.4.6
|
||||
contourpy==1.3.3
|
||||
curl_cffi==0.13.0
|
||||
cycler==0.12.1
|
||||
fonttools==4.62.1
|
||||
frozendict==2.4.7
|
||||
idna==3.11
|
||||
joblib==1.5.3
|
||||
kiwisolver==1.5.0
|
||||
lxml==6.0.2
|
||||
matplotlib==3.10.8
|
||||
multitasking==0.0.12
|
||||
numpy==2.4.3
|
||||
packaging==26.0
|
||||
pandas==3.0.1
|
||||
peewee==4.0.2
|
||||
pillow==12.1.1
|
||||
platformdirs==4.9.4
|
||||
protobuf==7.34.1
|
||||
pycparser==3.0
|
||||
pyparsing==3.3.2
|
||||
python-dateutil==2.9.0.post0
|
||||
pytz==2026.1.post1
|
||||
requests==2.32.5
|
||||
scikit-learn==1.8.0
|
||||
scipy==1.17.1
|
||||
six==1.17.0
|
||||
soupsieve==2.8.3
|
||||
threadpoolctl==3.6.0
|
||||
typing_extensions==4.15.0
|
||||
urllib3==2.6.3
|
||||
websockets==16.0
|
||||
yfinance==1.2.0
|
||||
Reference in New Issue
Block a user