Objective
You are given two DataFrames related to venture capital investments. The companies DataFrame lists startups and their associated industries, and the investments DataFrame logs individual funding rounds.
Task
Write a PySpark function that calculates the total investment amount injected into each industry sector. Your final result must be sorted by total_investment in descending order so the highest-funded industries appear at the top.
Save your result as result_df.
File Path
- Companies Dataset:
/home/interview/companies.csv
- Investments Dataset:
/home/interview/investments.csv
- Starter script:
/home/interview/vc_analysis.py
Schema
companies.csv
| Column Name |
Data Type |
| company_id |
integer |
| company_name |
string |
| industry |
string |
investments.csv
| Column Name |
Data Type |
| investment_id |
integer |
| company_id |
integer |
| amount |
double |
Expected Output Schema
| Column Name |
Data Type |
| industry |
string |
| total_investment |
double |
Example
Given this sample input:
companies
| company_id |
company_name |
industry |
| 1 |
AlphaTech |
Technology |
| 2 |
BetaHealth |
Healthcare |
| 3 |
GammaEntertainment |
Entertainment |
| 4 |
DeltaGreen |
Renewable Energy |
| 5 |
EpsilonFinance |
Finance |
investments
| investment_id |
company_id |
amount |
| 1 |
1 |
5000000.0 |
| 2 |
2 |
3000000.0 |
| 3 |
3 |
1000000.0 |
| 4 |
4 |
4000000.0 |
| 5 |
5 |
2000000.0 |
The output would be:
| industry |
total_investment |
| Technology |
5000000.0 |
| Renewable Energy |
4000000.0 |
| Healthcare |
3000000.0 |
| Finance |
2000000.0 |
| Entertainment |
1000000.0 |
(Note: The output is sorted by total_investment in descending order).
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
spark = SparkSession.builder.appName("PrepareshSpark").getOrCreate()
spark.sparkContext.setLogLevel("ERROR")
companies = spark.read.csv("/home/interview/companies.csv", header=True, inferSchema=True)
investments = spark.read.csv("/home/interview/investments.csv", header=True, inferSchema=True)
# Join the tables to link investments with their target industries
joined_df = companies.join(investments, on="company_id", how="inner")
# Group by industry, sum the amounts, and order descending
result_df = joined_df.groupBy("industry") \
.agg(F.sum("amount").alias("total_investment")) \
.orderBy(F.col("total_investment").desc())
# --- Do not edit below this line ---
result_df.coalesce(1).write.csv("/home/interview/output", header=True, mode="overwrite")
spark.stop()
Explanation
Step 1: Joining the DataFrames
joined_df = companies.join(investments, on="company_id", how="inner")
To calculate the total investment per industry, we first need to associate every funding round (amount) with the industry of the company receiving it. Because company_id exists in both DataFrames, it acts as our foreign key. An inner join is perfect here because we only care about companies that actually received investments, and investments that map to a known company.
Step 2: Grouping the Data
joined_df.groupBy("industry")
Once the data is merged, we use .groupBy("industry"). This tells PySpark to collect all rows sharing the exact same industry string into distinct buckets. At this stage, the data is grouped but no calculations have been performed yet.
Step 3: Aggregating the Investments
.agg(F.sum("amount").alias("total_investment"))
After grouping, we must define the aggregation metric. We use .agg() to apply the F.sum() function to the amount column, which adds up all the individual funding rounds within each industry bucket. We chain .alias("total_investment") immediately after the sum to rename the resulting column, ensuring it matches the Expected Output Schema.
Step 4: Sorting the Results
.orderBy(F.col("total_investment").desc())
Finally, the prompt requires the output to be sorted from highest to lowest total investment. The .orderBy() method (which is an alias for .sort()) handles this. By wrapping the column reference with .desc(), we instruct PySpark to sort the data in descending order.